@@ -204,6 +204,8 @@ def plot_3d_surface(
204204 cmap : str = "jet" ,
205205 error_surface : bool = False ,
206206 add_points : bool = False ,
207+ vmin : float = None ,
208+ vmax : float = None ,
207209) -> None :
208210 """
209211 Plot a 3D surface and scatter input points, colored by prediction error.
@@ -229,7 +231,7 @@ def plot_3d_surface(
229231 Returns:
230232 None
231233 """
232- ax .plot_surface (* Z [:2 ], Z [2 ], cmap = cmap , alpha = alpha ) if isinstance (Z , tuple ) else ax .plot_surface (Z [0 ], Z [1 ], Z [2 ], cmap = cmap , alpha = alpha )
234+ ax .plot_surface (* Z [:2 ], Z [2 ], cmap = cmap , alpha = alpha , vmin = vmin , vmax = vmax ) if isinstance (Z , tuple ) else ax .plot_surface (Z [0 ], Z [1 ], Z [2 ], cmap = cmap , alpha = alpha , vmin = vmin , vmax = vmax )
233235 ax .set_title (surface_label )
234236 ax .set_xlabel (var_names [0 ] if var_names else f"Dimension { i } " )
235237 ax .set_ylabel (var_names [1 ] if var_names else f"Dimension { j } " )
@@ -238,7 +240,7 @@ def plot_3d_surface(
238240 plot_error_points (ax , X , y , model , i , j , eps , max_error , var_names , z_mode = "error" if error_surface else "actual" )
239241
240242
241- def plot_contour (
243+ def plot_contour_and_err (
242244 X_i : np .ndarray ,
243245 X_j : np .ndarray ,
244246 Z : np .ndarray ,
@@ -255,6 +257,8 @@ def plot_contour(
255257 levels : int = 30 ,
256258 title : str = "Prediction Contour" ,
257259 add_points : bool = False ,
260+ vmin : float = None ,
261+ vmax : float = None ,
258262) -> None :
259263 """
260264 Plot a filled contour plot with scatter points colored by prediction error.
@@ -276,11 +280,13 @@ def plot_contour(
276280 levels (int): Number of contour levels.
277281 title (str): Title for the plot.
278282 add_points (bool): If True, adds scatter points to the contour plot.
283+ vmin (float): Minimum value for color scaling.
284+ vmax (float): Maximum value for color scaling.
279285
280286 Returns:
281287 None
282288 """
283- contour = ax .contourf (X_i , X_j , Z , cmap = cmap , levels = levels )
289+ contour = ax .contourf (X_i , X_j , Z , cmap = cmap , levels = levels , vmin = vmin , vmax = vmax )
284290 plt .colorbar (contour , ax = ax )
285291 if add_points and X is not None and y is not None and model is not None :
286292 plot_error_points (ax , X , y , model , i , j , eps , max_error , var_names , title , z_mode = None )
@@ -313,6 +319,9 @@ def plotkd(
313319 var_names : Optional [List [str ]] = None ,
314320 cmap : str = "jet" ,
315321 n_grid : int = 100 ,
322+ vmin : Optional [float ] = None ,
323+ vmax : Optional [float ] = None ,
324+ add_points : bool = False ,
316325) -> None :
317326 """
318327 Plots the Kriging surrogate model for k-dimensional input data by varying two dimensions (i, j).
@@ -330,6 +339,9 @@ def plotkd(
330339 var_names (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
331340 cmap (str): Colormap for the surface and contour plots. Default is "jet".
332341 n_grid (int): Number of grid points per dimension for the mesh grid. Default is 100.
342+ vmin (float, optional): Minimum value for the color scale. If None, determined from predictions.
343+ vmax (float, optional): Maximum value for the color scale. If None, determined from predictions.
344+ add_points (bool): If True, adds scatter points to the surface and contour plots. Default is False.
333345
334346 Examples:
335347 >>> import numpy as np
@@ -373,6 +385,9 @@ def plotkd(
373385 max_error = max_error ,
374386 cmap = cmap ,
375387 error_surface = False ,
388+ vmin = vmin ,
389+ vmax = vmax ,
390+ add_points = add_points ,
376391 )
377392
378393 # Plot prediction error
@@ -393,11 +408,14 @@ def plotkd(
393408 max_error = max_error ,
394409 cmap = cmap ,
395410 error_surface = True ,
411+ vmin = vmin ,
412+ vmax = vmax ,
413+ add_points = add_points ,
396414 )
397415
398416 # Contour plot of predicted values
399417 ax3 = fig .add_subplot (223 )
400- plot_contour (
418+ plot_contour_and_err (
401419 X_i ,
402420 X_j ,
403421 Z_pred ,
@@ -413,11 +431,14 @@ def plotkd(
413431 cmap = cmap ,
414432 levels = 30 ,
415433 title = "Prediction Contour" ,
434+ vmin = vmin ,
435+ vmax = vmax ,
436+ add_points = add_points ,
416437 )
417438
418439 # Contour plot of prediction error
419440 ax4 = fig .add_subplot (224 )
420- plot_contour (
441+ plot_contour_and_err (
421442 X_i ,
422443 X_j ,
423444 Z_std ,
@@ -433,6 +454,9 @@ def plotkd(
433454 cmap = cmap ,
434455 levels = 30 ,
435456 title = "Error Contour" ,
457+ vmin = vmin ,
458+ vmax = vmax ,
459+ add_points = add_points ,
436460 )
437461
438462 if show :
0 commit comments