Skip to content

Commit d9f13db

Browse files
0.33.15
1 parent 2df8dee commit d9f13db

2 files changed

Lines changed: 17 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.33.14"
10+
version = "0.33.15"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/surrogate/kriging.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,21 @@ def _get_theta10_from_logtheta(self) -> np.ndarray:
344344
theta10 = theta10 * np.ones(self.k)
345345
return theta10
346346

347+
def _reshape_X(self, X: np.ndarray) -> np.ndarray:
348+
# Ensure X has shape (n_samples, n_features=self.k)
349+
X = np.asarray(X)
350+
if X.ndim == 1:
351+
X = X.reshape(-1, self.k)
352+
else:
353+
if X.shape[1] != self.k:
354+
if X.shape[0] == self.k: # common case: row/col swap for 1D
355+
X = X.T
356+
elif self.k == 1:
357+
X = X.reshape(-1, 1)
358+
else:
359+
raise ValueError(f"X has shape {X.shape}, expected (*, {self.k}).")
360+
return X
361+
347362
def fit(self, X: np.ndarray, y: np.ndarray, bounds: Optional[List[Tuple[float, float]]] = None) -> "Kriging":
348363
"""
349364
Fits the Kriging model to training data X and y. This method is compatible
@@ -461,19 +476,7 @@ def predict(self, X: np.ndarray, return_std=False, return_val: str = "y") -> np.
461476
>>> print("Predictions:", y_pred)
462477
"""
463478
self.return_std = return_std
464-
X = np.asarray(X)
465-
466-
# Ensure X has shape (n_samples, n_features=self.k)
467-
if X.ndim == 1:
468-
X = X.reshape(-1, self.k)
469-
else:
470-
if X.shape[1] != self.k:
471-
if X.shape[0] == self.k: # common case: row/col swap for 1D
472-
X = X.T
473-
elif self.k == 1:
474-
X = X.reshape(-1, 1)
475-
else:
476-
raise ValueError(f"X has shape {X.shape}, expected (*, {self.k}).")
479+
X = self._reshape_X(X)
477480

478481
if return_std:
479482
# Return predictions and standard deviations

0 commit comments

Comments
 (0)