1717from sklearn .base import BaseEstimator , RegressorMixin
1818
1919
20- def crude_reset (theta , tmin , tmax , m ):
20+ def crude_reset (theta , tmin , tmax , m ) -> dict :
2121 """
2222 Check whether any elements of the parameter vector ``theta`` lie below the
2323 corresponding elements of the lower bound ``tmin``. If so, reset ``theta``
@@ -31,12 +31,12 @@ def crude_reset(theta, tmin, tmax, m):
3131 m (int): The dimensionality or number of parameters (used to adjust negative ``tmax`` entries).
3232
3333 Returns:
34- dict or None: A dictionary containing:
34+ ( dict) or None: A dictionary containing:
3535 - "theta" (np.ndarray): The reset parameter values.
3636 - "its" (int): Number of iterations (0, indicating immediate reset).
3737 - "msg" (str): Reason for the reset.
3838 - "conv" (int): Reset code (102).
39- Returns None if no reset is needed.
39+ Returns None if no reset is needed.
4040 """
4141 if np .any (theta < tmin ):
4242 print ("resetting due to init on lower boundary" )
@@ -201,7 +201,7 @@ def garg(g, y: np.ndarray = None) -> dict:
201201 and priors for the nugget parameter.
202202
203203 Args:
204- g: Could be a dictionary, numeric, or None. If numeric, turn it into {"start": g}.
204+ g (dict} : Could be a dictionary, numeric, or None. If numeric, turn it into {"start": g}.
205205 y (np.ndarray): The response vector.
206206
207207 Returns:
@@ -393,13 +393,13 @@ def __init__(
393393 }
394394
395395 # Add these two methods required by scikit-learn
396- def get_params (self , deep = True ):
396+ def get_params (self , deep = True ) -> dict :
397397 """Get parameters for this estimator.
398398
399399 This method is required for scikit-learn compatibility.
400400
401401 Args:
402- deep: If True, will return the parameters for this estimator and
402+ deep (bool) : If True, will return the parameters for this estimator and
403403 contained subobjects that are estimators. Defaults to True.
404404
405405 Returns:
@@ -419,16 +419,16 @@ def get_params(self, deep=True):
419419 "seed" : self .seed ,
420420 }
421421
422- def set_params (self , ** parameters ) :
422+ def set_params (self , ** parameters : dict ) -> "GPsep" :
423423 """Set the parameters of this estimator.
424424
425425 This method is required for scikit-learn compatibility.
426426
427427 Args:
428- **parameters: Estimator parameters as keyword arguments.
428+ **parameters (dict) : Estimator parameters as keyword arguments.
429429
430430 Returns:
431- self: Estimator instance.
431+ self (GPsep) : Estimator instance.
432432 """
433433 for parameter , value in parameters .items ():
434434 setattr (self , parameter , value )
@@ -442,18 +442,24 @@ def fit(self, X: np.ndarray, y: np.ndarray, d=None, g=None, dK: bool = True, aut
442442 """Fit the GP model with training data and optionally auto-optimize hyperparameters.
443443
444444 Args:
445- X: array-like of shape (n_samples, n_features)
446- y: array-like of shape (n_samples,)
447- d: The length-scale parameters. If None, will be determined
445+ X (np.ndarray):
446+ Array-like of shape (n_samples, n_features).
447+ y (np.ndarray):
448+ Array-like of shape (n_samples,).
449+ d (Optional[Union[np.ndarray, float]]):
450+ The length-scale parameters. If None, will be determined
448451 automatically. Defaults to None.
449- g: The nugget parameter. If None, will be determined automatically.
450- Defaults to None.
451- dK: Flag to indicate whether to calculate derivatives.
452+ g (Optional[float]):
453+ The nugget parameter. If None, will be determined automatically. Defaults to None.
454+ dK (bool):
455+ Flag to indicate whether to calculate derivatives.
452456 Defaults to True.
453- auto_optimize: Whether to automatically optimize hyperparameters
457+ auto_optimize (Optional[bool]):
458+ Whether to automatically optimize hyperparameters
454459 using MLE. If None, uses the default value from the object.
455460 Defaults to None.
456- verbosity: Verbosity level for optimization output. Defaults to 0.
461+ verbosity (int):
462+ Verbosity level for optimization output. Defaults to 0.
457463
458464 Returns:
459465 GPsep: The fitted GPsep object.
@@ -685,22 +691,30 @@ def _build(self) -> None:
685691 # # raise RuntimeError("dK calculations have already been initialized.")
686692 # self.DK = diff_covar_sep_symm(self.m, self.X, self.n, self.d, self.K)
687693
688- def _check_is_fitted (self ):
694+ def _check_is_fitted (self ) -> None :
695+ """
696+ Check if the GPsep instance is fitted.
697+ """
689698 if not self ._is_fitted :
690699 raise ValueError ("This GPsep instance is not fitted yet. Call 'fit' with " "appropriate arguments before using 'predict'." )
691700
692701 def predict (self , X : np .ndarray , lite : bool = False , nonug : bool = False , return_full = False , return_std = False ) -> float :
693702 """Predict the Gaussian Process output at new input points.
694703
695704 Args:
696- X: The predictive locations.
697- lite: Flag to indicate whether to compute only the diagonal
705+ X (np.ndarray):
706+ The predictive locations.
707+ lite (bool):
708+ Flag to indicate whether to compute only the diagonal
698709 of Sigma. Defaults to False.
699- nonug: Flag to indicate whether to exclude nugget.
710+ nonug (bool):
711+ Flag to indicate whether to exclude nugget.
700712 Defaults to False.
701- return_full: Flag to indicate whether to return the full dictionary,
713+ return_full (bool):
714+ Flag to indicate whether to return the full dictionary,
702715 which includes the mean, Sigma, df, and llik. Defaults to False.
703- return_std: Flag to indicate whether to return the standard deviation.
716+ return_std (bool):
717+ Flag to indicate whether to return the standard deviation.
704718 Only applicable when return_full is False. Defaults to False.
705719
706720 Returns:
@@ -710,34 +724,34 @@ def predict(self, X: np.ndarray, lite: bool = False, nonug: bool = False, return
710724 - Otherwise: Mean predictions
711725
712726 Examples:
713- import numpy as np
714- from spotpython.gp.gp_sep import newGPsep
715- import matplotlib.pyplot as plt
716- # Simple sine data
717- X = np.linspace(0, 2 * np.pi, 7).reshape(-1, 1)
718- y = np.sin(X)
719- # New GP fit
720- gpsep = newGPsep(X, y, d=2, g=0.000001)
721- # Make predictions
722- XX = np.linspace(-1, 2 * np.pi + 1, 499).reshape(-1, 1)
723- p = gpsep.predict(XX, lite=False)
724- # Sample from the predictive distribution
725- N = 100
726- mean = p["mean"]
727- Sigma = p["Sigma"]
728- df = p["df"]
729- # Generate samples from the multivariate t-distribution
730- yy = np.random.multivariate_normal(mean, Sigma, N)
731- yy = yy.T
732- # Plot the results
733- plt.figure(figsize=(10, 6))
734- for i in range(N):
735- plt.plot(XX, yy[:, i], color="gray", linewidth=0.5)
736- plt.scatter(X, y, color="black", s=50, zorder=5)
737- plt.xlabel("x")
738- plt.ylabel("f-hat(x)")
739- plt.title("Predictive Distribution")
740- plt.show()
727+ import numpy as np
728+ from spotpython.gp.gp_sep import newGPsep
729+ import matplotlib.pyplot as plt
730+ # Simple sine data
731+ X = np.linspace(0, 2 * np.pi, 7).reshape(-1, 1)
732+ y = np.sin(X)
733+ # New GP fit
734+ gpsep = newGPsep(X, y, d=2, g=0.000001)
735+ # Make predictions
736+ XX = np.linspace(-1, 2 * np.pi + 1, 499).reshape(-1, 1)
737+ p = gpsep.predict(XX, lite=False)
738+ # Sample from the predictive distribution
739+ N = 100
740+ mean = p["mean"]
741+ Sigma = p["Sigma"]
742+ df = p["df"]
743+ # Generate samples from the multivariate t-distribution
744+ yy = np.random.multivariate_normal(mean, Sigma, N)
745+ yy = yy.T
746+ # Plot the results
747+ plt.figure(figsize=(10, 6))
748+ for i in range(N):
749+ plt.plot(XX, yy[:, i], color="gray", linewidth=0.5)
750+ plt.scatter(X, y, color="black", s=50, zorder=5)
751+ plt.xlabel("x")
752+ plt.ylabel("f-hat(x)")
753+ plt.title("Predictive Distribution")
754+ plt.show()
741755 """
742756 self ._check_is_fitted ()
743757 # if X is a pandas dataframe, convert it to a numpy array
0 commit comments