Skip to content

Commit e43807d

Browse files
sklearn compatibility
1 parent 26ea7f4 commit e43807d

2 files changed

Lines changed: 125 additions & 35 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.26.0"
10+
version = "0.26.1"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/gp/gp_sep.py

Lines changed: 124 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,31 @@ def garg(g, y: np.ndarray = None) -> dict:
241241

242242

243243
class GPsep:
244-
"""A class to represent a Gaussian Process with separable covariance."""
244+
"""A class to represent a Gaussian Process with separable covariance.
245+
246+
Attributes:
247+
m: Number of input dimensions.
248+
n: Number of observations.
249+
X: Input data matrix.
250+
Z: Output data vector.
251+
d: Length-scale parameters.
252+
g: Nugget parameter.
253+
K: Covariance matrix.
254+
Ki: Inverse of covariance matrix.
255+
KiZ: Product of Ki and Z.
256+
phi: Scalar value from Z^T Ki Z calculation.
257+
dK: Boolean flag for calculating derivatives.
258+
DK: Matrix of derivatives.
259+
ldetK: Log determinant of K.
260+
nlsep_method: Method for likelihood computation.
261+
gradnlsep_method: Method for gradient computation.
262+
n_restarts_optimizer: Number of restarts for optimization.
263+
samp_size: Sample size for distance calculations.
264+
maxit: Maximum number of optimization iterations.
265+
verbosity: Verbosity level.
266+
auto_optimize: Whether to automatically optimize hyperparameters.
267+
max_points: Maximum number of points for model building.
268+
"""
245269

246270
def __init__(
247271
self,
@@ -321,22 +345,91 @@ def __init__(
321345
self.auto_optimize = auto_optimize
322346
self.max_points = max_points
323347

324-
def fit(self, X: np.ndarray, Z: np.ndarray, d=None, g=None, dK: bool = True, auto_optimize: bool = None, verbosity=0) -> "GPsep":
348+
# need to store the initial parameters for the fit method (sklearn compatibility)
349+
self.init_params = {
350+
"X": X,
351+
"Z": Z,
352+
"d": d,
353+
"g": g,
354+
"nlsep_method": nlsep_method,
355+
"gradnlsep_method": gradnlsep_method,
356+
"n_restarts_optimizer": n_restarts_optimizer,
357+
"samp_size": samp_size,
358+
"maxit": maxit,
359+
"verbosity": verbosity,
360+
"auto_optimize": auto_optimize,
361+
"max_points": max_points,
362+
}
363+
364+
# Add these two methods required by scikit-learn
365+
def get_params(self, deep=True):
366+
"""Get parameters for this estimator.
367+
368+
This method is required for scikit-learn compatibility.
369+
370+
Args:
371+
deep: If True, will return the parameters for this estimator and
372+
contained subobjects that are estimators. Defaults to True.
373+
374+
Returns:
375+
dict: Parameter names mapped to their values.
376+
"""
377+
return {
378+
"X": self.X,
379+
"Z": self.Z,
380+
"d": self.d,
381+
"g": self.g,
382+
"nlsep_method": self.nlsep_method,
383+
"gradnlsep_method": self.gradnlsep_method,
384+
"n_restarts_optimizer": self.n_restarts_optimizer,
385+
"samp_size": self.samp_size,
386+
"maxit": self.maxit,
387+
"verbosity": self.verbosity,
388+
"auto_optimize": self.auto_optimize,
389+
"max_points": self.max_points,
390+
}
391+
392+
def set_params(self, **parameters):
393+
"""Set the parameters of this estimator.
394+
395+
This method is required for scikit-learn compatibility.
396+
397+
Args:
398+
**parameters: Estimator parameters as keyword arguments.
399+
400+
Returns:
401+
self: Estimator instance.
325402
"""
326-
Fits the GP model with training data and optionally auto-optimizes hyperparameters.
403+
for parameter, value in parameters.items():
404+
setattr(self, parameter, value)
405+
406+
# Update the stored parameters for potential re-initialization
407+
self.init_params.update(parameters)
408+
409+
return self
410+
411+
def fit(self, X: np.ndarray, Z: np.ndarray, d=None, g=None, dK: bool = True, auto_optimize: bool = None, verbosity=0) -> "GPsep":
412+
"""Fit the GP model with training data and optionally auto-optimize hyperparameters.
327413
328414
Args:
329-
X (np.ndarray): The input data matrix of shape (n, m).
330-
Z (np.ndarray): The output data vector of length n.
331-
d (Union[np.ndarray, float, None]): The length-scale parameters. If None, will be determined automatically.
332-
g (Union[float, None]): The nugget parameter. If None, will be determined automatically.
333-
dK (bool): Flag to indicate whether to calculate derivatives. Default is True.
334-
auto_optimize (bool): Whether to automatically optimize hyperparameters using MLE.
335-
verbosity (int): Verbosity level for optimization output.
336-
auto_optimize (bool): Whether to automatically optimize hyperparameters using MLE. If None, uses the default value from the object, which is True.
415+
X: The input data matrix of shape (n, m).
416+
Z: The output data vector of length n.
417+
d: The length-scale parameters. If None, will be determined
418+
automatically. Defaults to None.
419+
g: The nugget parameter. If None, will be determined automatically.
420+
Defaults to None.
421+
dK: Flag to indicate whether to calculate derivatives.
422+
Defaults to True.
423+
auto_optimize: Whether to automatically optimize hyperparameters
424+
using MLE. If None, uses the default value from the object.
425+
Defaults to None.
426+
verbosity: Verbosity level for optimization output. Defaults to 0.
337427
338428
Returns:
339429
GPsep: The fitted GPsep object.
430+
431+
Raises:
432+
ValueError: If X has no rows or if X and Z dimensions mismatch.
340433
"""
341434
# if X or Z are pandas dataframes or series, convert them to numpy arrays
342435
if hasattr(X, "to_numpy"):
@@ -484,7 +577,6 @@ def gradient(par):
484577
d = result.x[:-1]
485578
g = result.x[-1]
486579

487-
488580
# set new parameters and build
489581
self.set_new_params(d, g)
490582
if self.verbosity > 0:
@@ -547,34 +639,32 @@ def build(self) -> None:
547639
self.Ki = matrix_inversion_dispatcher(self.K, method=self.nlsep_method)
548640
self.ldetK = np.log(det(self.K))
549641
self.calc_ZtKiZ()
550-
if self.dK:
551-
# TODO: Check if this is necessary
552-
# if self.dK is not None:
553-
# raise RuntimeError("dK calculations have already been initialized.")
554-
self.DK = diff_covar_sep_symm(self.m, self.X, self.n, self.d, self.K)
642+
# TODO: Check if this is necessary
643+
# if self.dK:
644+
# # TODO: Check if this is necessary
645+
# # if self.dK is not None:
646+
# # raise RuntimeError("dK calculations have already been initialized.")
647+
# self.DK = diff_covar_sep_symm(self.m, self.X, self.n, self.d, self.K)
555648

556649
def predict(self, XX: np.ndarray, lite: bool = False, nonug: bool = False, return_full=False, return_std=False) -> float:
557-
"""
558-
Predict the Gaussian Process output at new input points.
650+
"""Predict the Gaussian Process output at new input points.
559651
560652
Args:
561-
XX (np.ndarray):
562-
The predictive locations.
563-
lite (bool):
564-
Flag to indicate whether to compute only the diagonal of Sigma.
565-
nonug (bool):
566-
Flag to indicate whether to use nugget.
567-
return_full (bool): Flag to indicate whether to return the full dictionry, which
568-
includes the mean, Sigma, df, and llik. Default is False.
569-
return_std (bool):
570-
Flag to indicate whether to return the standard deviation. Only applicable when
571-
return_full is False. Default is False.
653+
XX: The predictive locations.
654+
lite: Flag to indicate whether to compute only the diagonal
655+
of Sigma. Defaults to False.
656+
nonug: Flag to indicate whether to exclude nugget.
657+
Defaults to False.
658+
return_full: Flag to indicate whether to return the full dictionary,
659+
which includes the mean, Sigma, df, and llik. Defaults to False.
660+
return_std: Flag to indicate whether to return the standard deviation.
661+
Only applicable when return_full is False. Defaults to False.
572662
573663
Returns:
574-
float:
575-
The predicted output at the new input points.
576-
If return_full is True, returns a containing the mean, Sigma (or s2), df, and llik.
577-
If return_std is True, returns a tuple containing the mean and standard deviation.
664+
Various formats based on arguments:
665+
- If return_full=True: Dictionary with 'mean', 'Sigma'/'s2', 'df', 'llik'
666+
- If return_std=True: Tuple (mean, std_deviation)
667+
- Otherwise: Mean predictions
578668
579669
Examples:
580670
import numpy as np

0 commit comments

Comments
 (0)