@@ -241,7 +241,31 @@ def garg(g, y: np.ndarray = None) -> dict:
241241
242242
243243class GPsep :
244- """A class to represent a Gaussian Process with separable covariance."""
244+ """A class to represent a Gaussian Process with separable covariance.
245+
246+ Attributes:
247+ m: Number of input dimensions.
248+ n: Number of observations.
249+ X: Input data matrix.
250+ Z: Output data vector.
251+ d: Length-scale parameters.
252+ g: Nugget parameter.
253+ K: Covariance matrix.
254+ Ki: Inverse of covariance matrix.
255+ KiZ: Product of Ki and Z.
256+ phi: Scalar value from Z^T Ki Z calculation.
257+ dK: Boolean flag for calculating derivatives.
258+ DK: Matrix of derivatives.
259+ ldetK: Log determinant of K.
260+ nlsep_method: Method for likelihood computation.
261+ gradnlsep_method: Method for gradient computation.
262+ n_restarts_optimizer: Number of restarts for optimization.
263+ samp_size: Sample size for distance calculations.
264+ maxit: Maximum number of optimization iterations.
265+ verbosity: Verbosity level.
266+ auto_optimize: Whether to automatically optimize hyperparameters.
267+ max_points: Maximum number of points for model building.
268+ """
245269
246270 def __init__ (
247271 self ,
@@ -321,22 +345,91 @@ def __init__(
321345 self .auto_optimize = auto_optimize
322346 self .max_points = max_points
323347
324- def fit (self , X : np .ndarray , Z : np .ndarray , d = None , g = None , dK : bool = True , auto_optimize : bool = None , verbosity = 0 ) -> "GPsep" :
348+ # need to store the initial parameters for the fit method (sklearn compatibility)
349+ self .init_params = {
350+ "X" : X ,
351+ "Z" : Z ,
352+ "d" : d ,
353+ "g" : g ,
354+ "nlsep_method" : nlsep_method ,
355+ "gradnlsep_method" : gradnlsep_method ,
356+ "n_restarts_optimizer" : n_restarts_optimizer ,
357+ "samp_size" : samp_size ,
358+ "maxit" : maxit ,
359+ "verbosity" : verbosity ,
360+ "auto_optimize" : auto_optimize ,
361+ "max_points" : max_points ,
362+ }
363+
364+ # Add these two methods required by scikit-learn
365+ def get_params (self , deep = True ):
366+ """Get parameters for this estimator.
367+
368+ This method is required for scikit-learn compatibility.
369+
370+ Args:
371+ deep: If True, will return the parameters for this estimator and
372+ contained subobjects that are estimators. Defaults to True.
373+
374+ Returns:
375+ dict: Parameter names mapped to their values.
376+ """
377+ return {
378+ "X" : self .X ,
379+ "Z" : self .Z ,
380+ "d" : self .d ,
381+ "g" : self .g ,
382+ "nlsep_method" : self .nlsep_method ,
383+ "gradnlsep_method" : self .gradnlsep_method ,
384+ "n_restarts_optimizer" : self .n_restarts_optimizer ,
385+ "samp_size" : self .samp_size ,
386+ "maxit" : self .maxit ,
387+ "verbosity" : self .verbosity ,
388+ "auto_optimize" : self .auto_optimize ,
389+ "max_points" : self .max_points ,
390+ }
391+
392+ def set_params (self , ** parameters ):
393+ """Set the parameters of this estimator.
394+
395+ This method is required for scikit-learn compatibility.
396+
397+ Args:
398+ **parameters: Estimator parameters as keyword arguments.
399+
400+ Returns:
401+ self: Estimator instance.
325402 """
326- Fits the GP model with training data and optionally auto-optimizes hyperparameters.
403+ for parameter , value in parameters .items ():
404+ setattr (self , parameter , value )
405+
406+ # Update the stored parameters for potential re-initialization
407+ self .init_params .update (parameters )
408+
409+ return self
410+
411+ def fit (self , X : np .ndarray , Z : np .ndarray , d = None , g = None , dK : bool = True , auto_optimize : bool = None , verbosity = 0 ) -> "GPsep" :
412+ """Fit the GP model with training data and optionally auto-optimize hyperparameters.
327413
328414 Args:
329- X (np.ndarray): The input data matrix of shape (n, m).
330- Z (np.ndarray): The output data vector of length n.
331- d (Union[np.ndarray, float, None]): The length-scale parameters. If None, will be determined automatically.
332- g (Union[float, None]): The nugget parameter. If None, will be determined automatically.
333- dK (bool): Flag to indicate whether to calculate derivatives. Default is True.
334- auto_optimize (bool): Whether to automatically optimize hyperparameters using MLE.
335- verbosity (int): Verbosity level for optimization output.
336- auto_optimize (bool): Whether to automatically optimize hyperparameters using MLE. If None, uses the default value from the object, which is True.
415+ X: The input data matrix of shape (n, m).
416+ Z: The output data vector of length n.
417+ d: The length-scale parameters. If None, will be determined
418+ automatically. Defaults to None.
419+ g: The nugget parameter. If None, will be determined automatically.
420+ Defaults to None.
421+ dK: Flag to indicate whether to calculate derivatives.
422+ Defaults to True.
423+ auto_optimize: Whether to automatically optimize hyperparameters
424+ using MLE. If None, uses the default value from the object.
425+ Defaults to None.
426+ verbosity: Verbosity level for optimization output. Defaults to 0.
337427
338428 Returns:
339429 GPsep: The fitted GPsep object.
430+
431+ Raises:
432+ ValueError: If X has no rows or if X and Z dimensions mismatch.
340433 """
341434 # if X or Z are pandas dataframes or series, convert them to numpy arrays
342435 if hasattr (X , "to_numpy" ):
@@ -484,7 +577,6 @@ def gradient(par):
484577 d = result .x [:- 1 ]
485578 g = result .x [- 1 ]
486579
487-
488580 # set new parameters and build
489581 self .set_new_params (d , g )
490582 if self .verbosity > 0 :
@@ -547,34 +639,32 @@ def build(self) -> None:
547639 self .Ki = matrix_inversion_dispatcher (self .K , method = self .nlsep_method )
548640 self .ldetK = np .log (det (self .K ))
549641 self .calc_ZtKiZ ()
550- if self .dK :
551- # TODO: Check if this is necessary
552- # if self.dK is not None:
553- # raise RuntimeError("dK calculations have already been initialized.")
554- self .DK = diff_covar_sep_symm (self .m , self .X , self .n , self .d , self .K )
642+ # TODO: Check if this is necessary
643+ # if self.dK:
644+ # # TODO: Check if this is necessary
645+ # # if self.dK is not None:
646+ # # raise RuntimeError("dK calculations have already been initialized.")
647+ # self.DK = diff_covar_sep_symm(self.m, self.X, self.n, self.d, self.K)
555648
556649 def predict (self , XX : np .ndarray , lite : bool = False , nonug : bool = False , return_full = False , return_std = False ) -> float :
557- """
558- Predict the Gaussian Process output at new input points.
650+ """Predict the Gaussian Process output at new input points.
559651
560652 Args:
561- XX (np.ndarray):
562- The predictive locations.
563- lite (bool):
564- Flag to indicate whether to compute only the diagonal of Sigma.
565- nonug (bool):
566- Flag to indicate whether to use nugget.
567- return_full (bool): Flag to indicate whether to return the full dictionry, which
568- includes the mean, Sigma, df, and llik. Default is False.
569- return_std (bool):
570- Flag to indicate whether to return the standard deviation. Only applicable when
571- return_full is False. Default is False.
653+ XX: The predictive locations.
654+ lite: Flag to indicate whether to compute only the diagonal
655+ of Sigma. Defaults to False.
656+ nonug: Flag to indicate whether to exclude nugget.
657+ Defaults to False.
658+ return_full: Flag to indicate whether to return the full dictionary,
659+ which includes the mean, Sigma, df, and llik. Defaults to False.
660+ return_std: Flag to indicate whether to return the standard deviation.
661+ Only applicable when return_full is False. Defaults to False.
572662
573663 Returns:
574- float :
575- The predicted output at the new input points.
576- If return_full is True, returns a containing the mean, Sigma (or s2), df, and llik.
577- If return_std is True, returns a tuple containing the mean and standard deviation.
664+ Various formats based on arguments :
665+ - If return_full=True: Dictionary with 'mean', 'Sigma'/'s2', 'df', 'llik'
666+ - If return_std= True: Tuple ( mean, std_deviation)
667+ - Otherwise: Mean predictions
578668
579669 Examples:
580670 import numpy as np
0 commit comments