@@ -22,7 +22,32 @@ class Kriging(BaseEstimator, RegressorMixin):
2222 y_ (np.ndarray): The training target values (n,).
2323 """
2424
25- def __init__ (self , eps : float = None , penalty : float = 1e4 , method = "regression" ):
25+ def __init__ (
26+ self ,
27+ eps : float = None ,
28+ penalty : float = 1e4 ,
29+ method = "regression" ,
30+ noise : bool = False ,
31+ var_type : List [str ] = ["num" ],
32+ name : str = "Kriging" ,
33+ seed : int = 124 ,
34+ model_optimizer = None ,
35+ model_fun_evals : Optional [int ] = None ,
36+ min_theta : float = - 3.0 ,
37+ max_theta : float = 2.0 ,
38+ n_theta : int = 1 ,
39+ theta_init_zero : bool = False ,
40+ p_val : float = 2.0 ,
41+ n_p : int = 1 ,
42+ optim_p : bool = False ,
43+ min_Lambda : float = 1e-9 ,
44+ max_Lambda : float = 1.0 ,
45+ log_level : int = 50 ,
46+ spot_writer = None ,
47+ counter = None ,
48+ metric_factorial = "canberra" ,
49+ ** kwargs ,
50+ ):
2651 """
2752 Initializes the Kriging model.
2853
@@ -46,6 +71,31 @@ def __init__(self, eps: float = None, penalty: float = 1e4, method="regression")
4671 raise ValueError ("eps must be positive" )
4772 self .eps = eps
4873 self .penalty = penalty
74+
75+ self .noise = noise
76+ self .var_type = var_type
77+ self .name = name
78+ self .seed = seed
79+ self .log_level = log_level
80+ self .spot_writer = spot_writer
81+ self .counter = counter
82+ self .metric_factorial = metric_factorial
83+ self .min_theta = min_theta
84+ self .max_theta = max_theta
85+ self .min_Lambda = min_Lambda
86+ self .max_Lambda = max_Lambda
87+ self .n_theta = n_theta
88+ self .p_val = p_val
89+ self .n_p = n_p
90+ self .optim_p = optim_p
91+ self .theta_init_zero = theta_init_zero
92+ self .model_optimizer = model_optimizer
93+ if self .model_optimizer is None :
94+ self .model_optimizer = differential_evolution
95+ self .model_fun_evals = model_fun_evals
96+ if self .model_fun_evals is None :
97+ self .model_fun_evals = 100
98+
4999 self .logtheta_lambda_ = None
50100 self .U_ = None
51101 self .X_ = None
@@ -124,7 +174,7 @@ def fit(self, X: np.ndarray, y: np.ndarray, bounds: Optional[List[Tuple[float, f
124174 self .NegLnLike_ , self .Psi_ , self .U_ = self .likelihood (self .logtheta_lambda_ )
125175 return self
126176
127- def predict (self , X : np .ndarray , return_std = False , return_ei = False ) -> np .ndarray :
177+ def predict (self , X : np .ndarray , return_std = False , return_val : str = "y" ) -> np .ndarray :
128178 """
129179 Predicts the Kriging response at a set of points X. This method is compatible
130180 with scikit-learn and returns predictions for the input points.
@@ -135,10 +185,11 @@ def predict(self, X: np.ndarray, return_std=False, return_ei=False) -> np.ndarra
135185 to predict the Kriging response.
136186 return_std (bool, optional):
137187 If True, returns the standard deviation of the predictions as well.
188+ Implememented for compatibility with scikit-learn.
138189 Defaults to False.
139- return_ei (bool, optional ):
140- If True, returns the expected improvement at each point .
141- Defaults to False .
190+ return_val (str ):
191+ Specifies which prediction values to return .
192+ It can be "y", "s", "ei", or "all" .
142193
143194 Returns:
144195 np.ndarray:
@@ -160,26 +211,28 @@ def predict(self, X: np.ndarray, return_std=False, return_ei=False) -> np.ndarra
160211 >>> # Predict responses
161212 >>> y_pred, sd, ei = model.predict(X_test)
162213 >>> print("Predictions:", y_pred)
163- >>> print("Standard deviations:", sd)
164- >>> print("Expected improvement:", ei)
165214 """
166215 self .return_std = return_std
167- self .return_ei = return_ei
168216 X = np .atleast_2d (X )
169- if return_std and return_ei :
170- # Return predictions, standard deviations, and expected improvements
171- predictions , std_devs , eis = zip (* [self ._pred (x_i ) for x_i in X ])
172- return np .array (predictions ), np .array (std_devs ), np .array (eis )
173- elif return_std :
217+ if return_std :
174218 # Return predictions and standard deviations
219+ # Compatibility with scikit-learn
175220 predictions , std_devs = zip (* [self ._pred (x_i )[:2 ] for x_i in X ])
176221 return np .array (predictions ), np .array (std_devs )
177- elif return_ei :
178- # Return predictions and expected improvements
222+ if return_val == "s" :
223+ # Return only standard deviations
224+ predictions , std_devs = zip (* [self ._pred (x_i )[:2 ] for x_i in X ])
225+ return np .array (std_devs )
226+ elif return_val == "all" :
227+ # Return predictions, standard deviations, and expected improvements
228+ predictions , std_devs , eis = zip (* [self ._pred (x_i ) for x_i in X ])
229+ return np .array (predictions ), np .array (std_devs ), np .array (eis )
230+ elif return_val == "ei" :
231+ # Return only neg. expected improvements
179232 predictions , eis = zip (* [(self ._pred (x_i )[0 ], self ._pred (x_i )[2 ]) for x_i in X ])
180- return np . array ( predictions ), np .array (eis )
233+ return - 1.0 * np .array (eis )
181234 else :
182- # Return only predictions
235+ # Return only predictions (case "y")
183236 predictions = [self ._pred (x_i )[0 ] for x_i in X ]
184237 return np .array (predictions )
185238
0 commit comments