|
6 | 6 | from scipy.special import erf |
7 | 7 | import matplotlib.pyplot as plt |
8 | 8 | from numpy import linspace, meshgrid, array |
| 9 | +import pylab |
| 10 | +from numpy import ravel |
| 11 | +from spotpython.utils.aggregate import aggregate_mean_var |
9 | 12 |
|
10 | 13 |
|
11 | 14 | class Kriging(BaseEstimator, RegressorMixin): |
@@ -159,17 +162,33 @@ def fit(self, X: np.ndarray, y: np.ndarray, bounds: Optional[List[Tuple[float, f |
159 | 162 | y = np.asarray(y).flatten() |
160 | 163 | self.X_ = X |
161 | 164 | self.y_ = y |
| 165 | + self.n, self.k = X.shape |
| 166 | + # Calculate and store min and max of X |
| 167 | + self.min_X = np.min(self.X_, axis=0) |
| 168 | + self.max_X = np.max(self.X_, axis=0) |
162 | 169 |
|
163 | | - k = X.shape[1] |
| 170 | + _, aggregated_mean_y, _ = aggregate_mean_var(X=self.X_, y=self.y_) |
| 171 | + self.aggregated_mean_y = np.copy(aggregated_mean_y) |
164 | 172 | if bounds is None: |
165 | 173 | if self.method == "interpolation": |
166 | | - bounds = [(-3.0, 2.0)] * k |
| 174 | + bounds = [(-3.0, 2.0)] * self.k |
167 | 175 | else: |
168 | 176 | # regression and reinterpolation use lambda_ as well |
169 | | - bounds = [(-3.0, 2.0)] * k + [(-6.0, 0.0)] |
| 177 | + bounds = [(-3.0, 2.0)] * self.k + [(-6.0, 0.0)] |
170 | 178 |
|
171 | 179 | self.logtheta_lambda_, _ = self.max_likelihood(bounds) |
172 | 180 |
|
| 181 | + # store theta and Lambda in log scale |
| 182 | + if (self.method == "regression") or (self.method == "reinterpolation"): |
| 183 | + # case noise is True |
| 184 | + self.theta = self.logtheta_lambda_[:-1] |
| 185 | + self.Lambda = self.logtheta_lambda_[-1] |
| 186 | + else: |
| 187 | + self.theta = self.logtheta_lambda_ |
| 188 | + self.Lambda = None |
| 189 | + # store p for future use |
| 190 | + self.p = 2 |
| 191 | + |
173 | 192 | # Once logtheta_lambda is found, compute the final correlation matrix |
174 | 193 | self.NegLnLike_, self.Psi_, self.U_ = self.likelihood(self.logtheta_lambda_) |
175 | 194 | return self |
@@ -217,20 +236,25 @@ def predict(self, X: np.ndarray, return_std=False, return_val: str = "y") -> np. |
217 | 236 | if return_std: |
218 | 237 | # Return predictions and standard deviations |
219 | 238 | # Compatibility with scikit-learn |
| 239 | + self.return_std = True |
220 | 240 | predictions, std_devs = zip(*[self._pred(x_i)[:2] for x_i in X]) |
221 | 241 | return np.array(predictions), np.array(std_devs) |
222 | 242 | if return_val == "s": |
223 | 243 | # Return only standard deviations |
| 244 | + self.return_std = True |
224 | 245 | predictions, std_devs = zip(*[self._pred(x_i)[:2] for x_i in X]) |
225 | 246 | return np.array(std_devs) |
226 | 247 | elif return_val == "all": |
227 | 248 | # Return predictions, standard deviations, and expected improvements |
| 249 | + self.return_std = True |
| 250 | + self.return_ei = True |
228 | 251 | predictions, std_devs, eis = zip(*[self._pred(x_i) for x_i in X]) |
229 | 252 | return np.array(predictions), np.array(std_devs), np.array(eis) |
230 | 253 | elif return_val == "ei": |
231 | 254 | # Return only neg. expected improvements |
| 255 | + self.return_ei = True |
232 | 256 | predictions, eis = zip(*[(self._pred(x_i)[0], self._pred(x_i)[2]) for x_i in X]) |
233 | | - return -1.0 * np.array(eis) |
| 257 | + return np.array(eis) |
234 | 258 | else: |
235 | 259 | # Return only predictions (case "y") |
236 | 260 | predictions = [self._pred(x_i)[0] for x_i in X] |
@@ -259,6 +283,7 @@ def likelihood(self, x: np.ndarray) -> Tuple[float, np.ndarray, np.ndarray]: |
259 | 283 | y = self.y_.flatten() |
260 | 284 |
|
261 | 285 | if (self.method == "regression") or (self.method == "reinterpolation"): |
| 286 | + # case noise is True |
262 | 287 | theta = x[:-1] |
263 | 288 | # theta is in log scale, so transform it back: |
264 | 289 | theta = 10.0**theta |
@@ -413,6 +438,108 @@ def objective(logtheta_lambda): |
413 | 438 | result = differential_evolution(objective, bounds) |
414 | 439 | return result.x, result.fun |
415 | 440 |
|
| 441 | + def plot(self, show: Optional[bool] = True) -> None: |
| 442 | + """ |
| 443 | + This function plots 1D and 2D surrogates. |
| 444 | + Only for compatibility with the old Kriging implementation. |
| 445 | +
|
| 446 | + Args: |
| 447 | + self (object): |
| 448 | + The Kriging object. |
| 449 | + show (bool): |
| 450 | + If `True`, the plots are displayed. |
| 451 | + If `False`, `plt.show()` should be called outside this function. |
| 452 | +
|
| 453 | + Returns: |
| 454 | + None |
| 455 | +
|
| 456 | + Note: |
| 457 | + * This method provides only a basic plot. For more advanced plots, |
| 458 | + use the `plot_contour()` method of the `Spot` class. |
| 459 | +
|
| 460 | + Examples: |
| 461 | + >>> import numpy as np |
| 462 | + from spotpython.fun.objectivefunctions import Analytical |
| 463 | + from spotpython.spot import spot |
| 464 | + from spotpython.utils.init import fun_control_init, design_control_init |
| 465 | + # 1-dimensional example |
| 466 | + fun = analytical().fun_sphere |
| 467 | + fun_control=fun_control_init(lower = np.array([-1]), |
| 468 | + upper = np.array([1]), |
| 469 | + noise=False) |
| 470 | + design_control=design_control_init(init_size=10) |
| 471 | + S = spot.Spot(fun=fun, |
| 472 | + fun_control=fun_control, |
| 473 | + design_control=design_control) |
| 474 | + S.initialize_design() |
| 475 | + S.update_stats() |
| 476 | + S.fit_surrogate() |
| 477 | + S.surrogate.plot() |
| 478 | + # 2-dimensional example |
| 479 | + fun = analytical().fun_sphere |
| 480 | + fun_control=fun_control_init(lower = np.array([-1, -1]), |
| 481 | + upper = np.array([1, 1]), |
| 482 | + noise=False) |
| 483 | + design_control=design_control_init(init_size=10) |
| 484 | + S = spot.Spot(fun=fun, |
| 485 | + fun_control=fun_control, |
| 486 | + design_control=design_control) |
| 487 | + S.initialize_design() |
| 488 | + S.update_stats() |
| 489 | + S.fit_surrogate() |
| 490 | + S.surrogate.plot() |
| 491 | + """ |
| 492 | + if self.k == 1: |
| 493 | + # TODO: Improve plot (add conf. interval etc.) |
| 494 | + fig = pylab.figure(figsize=(9, 6)) |
| 495 | + n_grid = 100 |
| 496 | + x = linspace(self.min_X[0], self.max_X[0], num=n_grid) |
| 497 | + y = self.predict(x) |
| 498 | + plt.figure() |
| 499 | + plt.plot(x, y, "k") |
| 500 | + if show: |
| 501 | + plt.show() |
| 502 | + |
| 503 | + if self.k == 2: |
| 504 | + fig = pylab.figure(figsize=(9, 6)) |
| 505 | + n_grid = 100 |
| 506 | + x = linspace(self.min_X[0], self.max_X[0], num=n_grid) |
| 507 | + y = linspace(self.min_X[1], self.max_X[1], num=n_grid) |
| 508 | + X, Y = meshgrid(x, y) |
| 509 | + # Predict based on the optimized results |
| 510 | + zz = array([self.predict(array([x, y]), return_val="all") for x, y in zip(ravel(X), ravel(Y))]) |
| 511 | + zs = zz[:, 0, :] |
| 512 | + zse = zz[:, 1, :] |
| 513 | + Z = zs.reshape(X.shape) |
| 514 | + Ze = zse.reshape(X.shape) |
| 515 | + |
| 516 | + nat_point_X = self.X_[:, 0] |
| 517 | + nat_point_Y = self.X_[:, 1] |
| 518 | + contour_levels = 30 |
| 519 | + ax = fig.add_subplot(224) |
| 520 | + # plot predicted values: |
| 521 | + pylab.contourf(X, Y, Ze, contour_levels, cmap="jet") |
| 522 | + pylab.title("Error") |
| 523 | + pylab.colorbar() |
| 524 | + # plot observed points: |
| 525 | + pylab.plot(nat_point_X, nat_point_Y, "ow") |
| 526 | + # |
| 527 | + ax = fig.add_subplot(223) |
| 528 | + # plot predicted values: |
| 529 | + plt.contourf(X, Y, Z, contour_levels, zorder=1, cmap="jet") |
| 530 | + plt.title("Surrogate") |
| 531 | + # plot observed points: |
| 532 | + pylab.plot(nat_point_X, nat_point_Y, "ow", zorder=3) |
| 533 | + pylab.colorbar() |
| 534 | + # |
| 535 | + ax = fig.add_subplot(221, projection="3d") |
| 536 | + ax.plot_surface(X, Y, Z, rstride=3, cstride=3, alpha=0.9, cmap="jet") |
| 537 | + # |
| 538 | + ax = fig.add_subplot(222, projection="3d") |
| 539 | + ax.plot_surface(X, Y, Ze, rstride=3, cstride=3, alpha=0.9, cmap="jet") |
| 540 | + # |
| 541 | + pylab.show() |
| 542 | + |
416 | 543 |
|
417 | 544 | # Additional functions for plotting the Kriging surrogate model |
418 | 545 | # ------------------------------------------------------------ |
|
0 commit comments