44from scipy .optimize import differential_evolution
55from sklearn .base import BaseEstimator , RegressorMixin
66from scipy .special import erf
7+ import matplotlib .pyplot as plt
8+ from numpy import linspace , meshgrid , array
79
810
911class Kriging (BaseEstimator , RegressorMixin ):
@@ -53,6 +55,8 @@ def __init__(self, eps: float = None, penalty: float = 1e4, method="regression")
5355 if method not in ["interpolation" , "regression" , "reinterpolation" ]:
5456 raise ValueError ("method must be one of 'interpolation', 'regression', or 'reinterpolation']" )
5557 self .method = method
58+ self .return_ei = False
59+ self .return_std = False
5660
5761 def _get_eps (self ) -> float :
5862 """
@@ -159,6 +163,8 @@ def predict(self, X: np.ndarray, return_std=False, return_ei=False) -> np.ndarra
159163 >>> print("Standard deviations:", sd)
160164 >>> print("Expected improvement:", ei)
161165 """
166+ self .return_std = return_std
167+ self .return_ei = return_ei
162168 X = np .atleast_2d (X )
163169 if return_std and return_ei :
164170 # Return predictions, standard deviations, and expected improvements
@@ -325,12 +331,14 @@ def _pred(self, x: np.ndarray) -> float:
325331 f = mu + psi @ resid_tilde
326332
327333 # Compute ExpImp
328- yBest = np .min (y )
329- EITermOne = (yBest - f ) * (0.5 + 0.5 * erf ((1 / np .sqrt (2 )) * ((yBest - f ) / s )))
330- EITermTwo = s * (1 / np .sqrt (2 * np .pi )) * np .exp (- 0.5 * ((yBest - f ) ** 2 / SSqr ))
331- ExpImp = np .log10 (EITermOne + EITermTwo + self .eps )
332-
333- return float (f ), float (s ), float (- ExpImp )
334+ if self .return_ei :
335+ yBest = np .min (y )
336+ EITermOne = (yBest - f ) * (0.5 + 0.5 * erf ((1 / np .sqrt (2 )) * ((yBest - f ) / s )))
337+ EITermTwo = s * (1 / np .sqrt (2 * np .pi )) * np .exp (- 0.5 * ((yBest - f ) ** 2 / SSqr ))
338+ ExpImp = np .log10 (EITermOne + EITermTwo + self .eps )
339+ return float (f ), float (s ), float (- ExpImp )
340+ else :
341+ return float (f ), float (s )
334342
335343 def max_likelihood (self , bounds : List [Tuple [float , float ]]) -> Tuple [np .ndarray , float ]:
336344 """
@@ -351,3 +359,103 @@ def objective(logtheta_lambda):
351359
352360 result = differential_evolution (objective , bounds )
353361 return result .x , result .fun
362+
363+ def plot (self , show : Optional [bool ] = True , alpha = 0.8 ) -> None :
364+ """
365+ This function plots 1D and 2D surrogates.
366+
367+ Args:
368+ show (bool):
369+ If `True`, the plots are displayed.
370+ If `False`, `plt.show()` should be called outside this function.
371+
372+ Returns:
373+ None
374+
375+ Examples:
376+ >>> model = Kriging()
377+ >>> model.fit(X_train, y_train)
378+ >>> model.plot()
379+ """
380+ if self .X_ is None or self .y_ is None :
381+ raise ValueError ("The model must be fitted before calling the plot method." )
382+
383+ k = self .X_ .shape [1 ] # Number of dimensions
384+
385+ if k == 1 :
386+ # 1D Plot
387+ fig = plt .figure (figsize = (9 , 6 ))
388+ n_grid = 100
389+ x = linspace (self .X_ [:, 0 ].min (), self .X_ [:, 0 ].max (), num = n_grid ).reshape (- 1 , 1 )
390+ y_pred , y_std = self .predict (x , return_std = True )
391+
392+ plt .plot (x , y_pred , "k" , label = "Prediction" )
393+ plt .fill_between (
394+ x .ravel (),
395+ y_pred - 1.96 * y_std ,
396+ y_pred + 1.96 * y_std ,
397+ alpha = 0.2 ,
398+ label = "95% Confidence Interval" ,
399+ )
400+ plt .scatter (self .X_ , self .y_ , color = "red" , label = "Training Data" )
401+ plt .xlabel ("X" )
402+ plt .ylabel ("Prediction" )
403+ plt .title ("1D Kriging Surrogate" )
404+ plt .legend ()
405+ if show :
406+ plt .show ()
407+
408+ elif k == 2 :
409+ # 2D Plot
410+ fig = plt .figure (figsize = (12 , 10 ))
411+ n_grid = 100
412+ x1 = linspace (self .X_ [:, 0 ].min (), self .X_ [:, 0 ].max (), num = n_grid )
413+ x2 = linspace (self .X_ [:, 1 ].min (), self .X_ [:, 1 ].max (), num = n_grid )
414+ X1 , X2 = meshgrid (x1 , x2 )
415+ grid_points = array ([X1 .ravel (), X2 .ravel ()]).T
416+
417+ y_pred , y_std = self .predict (grid_points , return_std = True )
418+ Z_pred = y_pred .reshape (X1 .shape )
419+ Z_std = y_std .reshape (X1 .shape )
420+
421+ # Plot predicted values
422+ ax1 = fig .add_subplot (221 , projection = "3d" )
423+ ax1 .plot_surface (X1 , X2 , Z_pred , cmap = "viridis" , alpha = alpha )
424+ ax1 .set_title ("Prediction Surface" )
425+ ax1 .set_xlabel ("X1" )
426+ ax1 .set_ylabel ("X2" )
427+ ax1 .set_zlabel ("Prediction" )
428+
429+ # Plot prediction error
430+ ax2 = fig .add_subplot (222 , projection = "3d" )
431+ ax2 .plot_surface (X1 , X2 , Z_std , cmap = "viridis" , alpha = alpha )
432+ ax2 .set_title ("Prediction Error Surface" )
433+ ax2 .set_xlabel ("X1" )
434+ ax2 .set_ylabel ("X2" )
435+ ax2 .set_zlabel ("Error" )
436+
437+ # Contour plot of predicted values
438+ ax3 = fig .add_subplot (223 )
439+ contour = ax3 .contourf (X1 , X2 , Z_pred , cmap = "viridis" , levels = 30 )
440+ plt .colorbar (contour , ax = ax3 )
441+ ax3 .scatter (self .X_ [:, 0 ], self .X_ [:, 1 ], color = "red" , label = "Training Data" )
442+ ax3 .set_title ("Prediction Contour" )
443+ ax3 .set_xlabel ("X1" )
444+ ax3 .set_ylabel ("X2" )
445+ ax3 .legend ()
446+
447+ # Contour plot of prediction error
448+ ax4 = fig .add_subplot (224 )
449+ contour = ax4 .contourf (X1 , X2 , Z_std , cmap = "viridis" , levels = 30 )
450+ plt .colorbar (contour , ax = ax4 )
451+ ax4 .scatter (self .X_ [:, 0 ], self .X_ [:, 1 ], color = "red" , label = "Training Data" )
452+ ax4 .set_title ("Error Contour" )
453+ ax4 .set_xlabel ("X1" )
454+ ax4 .set_ylabel ("X2" )
455+ ax4 .legend ()
456+
457+ if show :
458+ plt .show ()
459+
460+ else :
461+ raise ValueError ("Plotting is only supported for 1D or 2D input data." )
0 commit comments