@@ -344,6 +344,21 @@ def _get_theta10_from_logtheta(self) -> np.ndarray:
344344 theta10 = theta10 * np .ones (self .k )
345345 return theta10
346346
347+ def _reshape_X (self , X : np .ndarray ) -> np .ndarray :
348+ # Ensure X has shape (n_samples, n_features=self.k)
349+ X = np .asarray (X )
350+ if X .ndim == 1 :
351+ X = X .reshape (- 1 , self .k )
352+ else :
353+ if X .shape [1 ] != self .k :
354+ if X .shape [0 ] == self .k : # common case: row/col swap for 1D
355+ X = X .T
356+ elif self .k == 1 :
357+ X = X .reshape (- 1 , 1 )
358+ else :
359+ raise ValueError (f"X has shape { X .shape } , expected (*, { self .k } )." )
360+ return X
361+
347362 def fit (self , X : np .ndarray , y : np .ndarray , bounds : Optional [List [Tuple [float , float ]]] = None ) -> "Kriging" :
348363 """
349364 Fits the Kriging model to training data X and y. This method is compatible
@@ -461,19 +476,7 @@ def predict(self, X: np.ndarray, return_std=False, return_val: str = "y") -> np.
461476 >>> print("Predictions:", y_pred)
462477 """
463478 self .return_std = return_std
464- X = np .asarray (X )
465-
466- # Ensure X has shape (n_samples, n_features=self.k)
467- if X .ndim == 1 :
468- X = X .reshape (- 1 , self .k )
469- else :
470- if X .shape [1 ] != self .k :
471- if X .shape [0 ] == self .k : # common case: row/col swap for 1D
472- X = X .T
473- elif self .k == 1 :
474- X = X .reshape (- 1 , 1 )
475- else :
476- raise ValueError (f"X has shape { X .shape } , expected (*, { self .k } )." )
479+ X = self ._reshape_X (X )
477480
478481 if return_std :
479482 # Return predictions and standard deviations
0 commit comments