@@ -194,12 +194,18 @@ def fit(self, X: np.ndarray, y: np.ndarray, bounds: Optional[List[Tuple[float, f
194194 y (np.ndarray):
195195 Target values of shape (n_samples,) or (n_samples, 1).
196196 bounds (Optional[List[Tuple[float, float]]]):
197- Bounds for each dimension of log(theta). If None, defaults to [(-3, 2)] * n_features.
197+ Bounds for each dimension of log(theta). If None, defaults to
198+ [(-3, 2)] * n_features for interpolation, or
199+ [(-3, 2)] * n_features + [(-6, 0)] for regression/reinterpolation.
198200
199201 Returns:
200202 Kriging:
201203 The fitted Kriging model instance (self).
202204
205+ Raises:
206+ ValueError: If input data has invalid shape or contains invalid values.
207+ RuntimeError: If optimization fails or correlation matrix is singular.
208+
203209 Examples:
204210 >>> import numpy as np
205211 >>> from spotpython.surrogate.kriging import Kriging
@@ -211,42 +217,141 @@ def fit(self, X: np.ndarray, y: np.ndarray, bounds: Optional[List[Tuple[float, f
211217 >>> model.fit(X_train, y_train)
212218 >>> print("Fitted log(theta):", model.logtheta_lambda_)
213219 """
214- X = np .asarray (X )
215- y = np .asarray (y ).flatten ()
216- self .X_ = X
217- self .y_ = y
220+ # Input validation and preprocessing
221+ X = np .asarray (X , dtype = np .float64 )
222+ y = np .asarray (y , dtype = np .float64 ).flatten ()
223+
224+ # Validate input shapes
225+ if X .ndim != 2 :
226+ raise ValueError (f"X must be a 2D array, got { X .ndim } D array with shape { X .shape } " )
227+
228+ if y .ndim != 1 :
229+ raise ValueError (f"y must be a 1D array, got { y .ndim } D array with shape { y .shape } " )
230+
231+ if X .shape [0 ] != y .shape [0 ]:
232+ raise ValueError (f"Number of samples in X ({ X .shape [0 ]} ) must match number of samples in y ({ y .shape [0 ]} )" )
233+
234+ # Check for minimum number of samples
235+ if X .shape [0 ] < 2 :
236+ raise ValueError ("At least 2 samples are required for fitting" )
237+
238+ # Check for invalid values
239+ if not np .all (np .isfinite (X )):
240+ raise ValueError ("X contains non-finite values (NaN or inf)" )
241+
242+ if not np .all (np .isfinite (y )):
243+ raise ValueError ("y contains non-finite values (NaN or inf)" )
244+
245+ # Store training data FIRST before aggregation
246+ self .X_ = X .copy () # Create a copy to avoid external modifications
247+ self .y_ = y .copy ()
218248 self .n , self .k = X .shape
219- # Calculate and store min and max of X
249+
250+ # Calculate and store min and max of X for plotting and validation
220251 self .min_X = np .min (self .X_ , axis = 0 )
221252 self .max_X = np .max (self .X_ , axis = 0 )
222253
223- _ , aggregated_mean_y , _ = aggregate_mean_var (X = self .X_ , y = self .y_ )
224- self .aggregated_mean_y = np .copy (aggregated_mean_y )
254+ # Aggregate data for duplicates (if any) - NOW self.X_ and self.y_ are available
255+ try :
256+ _ , aggregated_mean_y , _ = aggregate_mean_var (X = self .X_ , y = self .y_ )
257+ self .aggregated_mean_y = np .copy (aggregated_mean_y )
258+ except Exception as e :
259+ raise RuntimeError (f"Failed to aggregate training data: { e } " )
260+
261+ # Check for duplicate rows (which can cause numerical issues)
262+ if X .shape [0 ] > 1 :
263+ unique_rows = np .unique (X , axis = 0 )
264+ if len (unique_rows ) != X .shape [0 ] and self .method == "interpolation" :
265+ logger .warning (f"Found { X .shape [0 ] - len (unique_rows )} duplicate rows in X. " "This may cause numerical issues with interpolation method." )
266+
267+ # Check for zero variance in any dimension
268+ if np .any (self .max_X - self .min_X == 0 ):
269+ zero_var_dims = np .where (self .max_X - self .min_X == 0 )[0 ]
270+ logger .warning (f"Zero variance detected in dimensions { zero_var_dims } . " "This may cause numerical issues." )
271+
272+ # Set optimization bounds
225273 if bounds is None :
226274 if self .method == "interpolation" :
227- bounds = [(- 3.0 , 2.0 )] * self .k
275+ bounds = [(self . min_theta , self . max_theta )] * self .k
228276 else :
229277 # regression and reinterpolation use lambda_ as well
230- bounds = [(- 3.0 , 2.0 )] * self .k + [(- 6.0 , 0.0 )]
278+ bounds = [(self .min_theta , self .max_theta )] * self .k + [(np .log10 (self .min_Lambda ), np .log10 (self .max_Lambda ))]
279+ else :
280+ # Validate user-provided bounds
281+ expected_length = self .k if self .method == "interpolation" else self .k + 1
282+ if len (bounds ) != expected_length :
283+ raise ValueError (f"bounds must have length { expected_length } for method '{ self .method } ', " f"got { len (bounds )} " )
284+
285+ # Validate individual bounds
286+ for i , (low , high ) in enumerate (bounds ):
287+ if not (isinstance (low , (int , float )) and isinstance (high , (int , float ))):
288+ raise ValueError (f"bounds[{ i } ] must contain numeric values" )
289+ if low >= high :
290+ raise ValueError (f"bounds[{ i } ]: lower bound ({ low } ) must be less than upper bound ({ high } )" )
291+
292+ # Optimize hyperparameters
293+ try :
294+ logger .info (f"Starting hyperparameter optimization with bounds: { bounds } " )
295+ self .logtheta_lambda_ , final_likelihood = self .max_likelihood (bounds )
296+ logger .info (f"Optimization completed. Final likelihood: { final_likelihood } " )
297+ except Exception as e :
298+ raise RuntimeError (f"Hyperparameter optimization failed: { e } " )
231299
232- self .logtheta_lambda_ , _ = self .max_likelihood (bounds )
300+ # Validate optimization results
301+ if not np .all (np .isfinite (self .logtheta_lambda_ )):
302+ raise RuntimeError ("Optimization resulted in non-finite hyperparameters" )
233303
234- # store theta and Lambda in log scale
304+ # Extract and store theta and Lambda parameters
235305 if (self .method == "regression" ) or (self .method == "reinterpolation" ):
236- # case noise is True
237306 self .theta = self .logtheta_lambda_ [:- 1 ]
238307 self .Lambda = self .logtheta_lambda_ [- 1 ]
239308 else :
240309 self .theta = self .logtheta_lambda_
241310 self .Lambda = None
242- # store p for future use
311+
312+ # Store p for future use (currently fixed at 2)
243313 self .p = 2
244314
245- # Once logtheta_lambda is found, compute the final correlation matrix
246- self .negLnLike , self .Psi_ , self .U_ = self .likelihood (self .logtheta_lambda_ )
315+ # Compute final correlation matrix and validate
316+ try :
317+ self .negLnLike , self .Psi_ , self .U_ = self .likelihood (self .logtheta_lambda_ )
318+
319+ # Check if correlation matrix is well-conditioned
320+ if self .U_ is None :
321+ raise RuntimeError ("Failed to compute Cholesky decomposition of correlation matrix" )
322+
323+ # Check condition number
324+ if hasattr (self , "Psi_" ) and self .Psi_ is not None :
325+ try :
326+ cond_num = np .linalg .cond (self .Psi_ )
327+ if cond_num > 1e12 :
328+ logger .warning (f"Correlation matrix is ill-conditioned (condition number: { cond_num :.2e} )" )
329+ except np .linalg .LinAlgError :
330+ logger .warning ("Could not compute condition number of correlation matrix" )
331+
332+ except Exception as e :
333+ raise RuntimeError (f"Failed to compute final correlation matrix: { e } " )
334+
335+ # Final validation
336+ if not np .isfinite (self .negLnLike ):
337+ raise RuntimeError ("Final likelihood is not finite" )
338+
339+ # Update logging information
340+ try :
341+ self ._update_log ()
342+ except Exception as e :
343+ logger .warning (f"Failed to update log: { e } " )
344+
345+ # Log fitting summary
346+ logger .info ("Kriging model fitted successfully:" )
347+ logger .info (f" - Method: { self .method } " )
348+ logger .info (f" - Training samples: { self .n } " )
349+ logger .info (f" - Features: { self .k } " )
350+ logger .info (f" - Final negative log-likelihood: { self .negLnLike :.6f} " )
351+ logger .info (f" - Theta parameters: { self .theta } " )
352+ if self .Lambda is not None :
353+ logger .info (f" - Lambda parameter: { self .Lambda :.6f} " )
247354
248- # Update log with the current values
249- self ._update_log ()
250355 return self
251356
252357 def predict (self , X : np .ndarray , return_std = False , return_val : str = "y" ) -> np .ndarray :
0 commit comments