@@ -75,6 +75,20 @@ def generate_mesh_grid(
7575 X_i (np.ndarray): Meshgrid for the i-th dimension.
7676 X_j (np.ndarray): Meshgrid for the j-th dimension.
7777 grid_points (np.ndarray): Grid points of shape (num*num, k) for prediction.
78+
79+ Examples:
80+ >>> import numpy as np
81+ >>> from spotpython.surrogate.plot import generate_mesh_grid
82+ >>> # Example 1: Using input data
83+ >>> X = np.random.rand(4, 3) # 5 samples with 3 dimensions
84+ >>> print(f"X:\n {X}")
85+ >>> X_i, X_j, grid_points = generate_mesh_grid(X, i=0, j=1, num=5)
86+ >>> print(f"X_i:\n {X_i},\n X_j:\n {X_j},\n grid_points:\n {grid_points}")
87+ >>> # Example 2: Using lower and upper bounds
88+ >>> lower = np.array([-5, 0, 0])
89+ >>> upper = np.array([10, 15, 3])
90+ >>> X_i, X_j, grid_points = generate_mesh_grid(lower=lower, upper=upper, i=0, j=1, num=5)
91+ >>> print(f"X_i:\n {X_i},\n X_j:\n {X_j},\n grid_points:\n {grid_points}")
7892 """
7993 # Check that exactly one of (X) or (lower and upper) is provided
8094 if (X is not None and (lower is not None or upper is not None )) or (X is None and (lower is None or upper is None )):
@@ -91,17 +105,6 @@ def generate_mesh_grid(
91105 x_i = linspace (lower [i ], upper [i ], num = num )
92106 x_j = linspace (lower [j ], upper [j ], num = num )
93107
94- # Masked rounding (using floor) for integer or factor variables
95- if var_type is not None :
96- # For x_i
97- if hasattr (x_i , "__len__" ):
98- mask_i = np .array ([var_type [i ] != "num" ] * len (x_i ))
99- x_i = np .where (mask_i , np .floor (x_i ), x_i )
100- # For x_j
101- if hasattr (x_j , "__len__" ):
102- mask_j = np .array ([var_type [j ] != "num" ] * len (x_j ))
103- x_j = np .where (mask_j , np .floor (x_j ), x_j )
104-
105108 X_i , X_j = meshgrid (x_i , x_j )
106109 grid_points = np .zeros ((X_i .size , k ))
107110 grid_points [:, i ] = X_i .ravel ()
@@ -112,6 +115,10 @@ def generate_mesh_grid(
112115 if dim != i and dim != j :
113116 grid_points [:, dim ] = mean_values [dim ]
114117
118+ # Apply floor to mean_values for non-"num" columns if var_type is provided
119+ if var_type is not None :
120+ grid_points = np .where (np .array ([vt != "num" for vt in var_type ]), np .floor (grid_points + 0.5 ), grid_points )
121+
115122 return X_i , X_j , grid_points
116123
117124
@@ -167,7 +174,7 @@ def plot_error_points(
167174 j : int ,
168175 eps : float = 1e-4 ,
169176 max_error : float = 1e-3 ,
170- var_names : Optional [List [str ]] = None ,
177+ var_name : Optional [List [str ]] = None ,
171178 title : Optional [str ] = None ,
172179 z_mode : str = "actual" , # "actual", "error", or None
173180) -> None :
@@ -183,7 +190,7 @@ def plot_error_points(
183190 j (int): Index of second varied dimension.
184191 eps (float): Tolerance for coloring points based on prediction error.
185192 max_error (float): Maximum error for color scaling.
186- var_names (list of str or None): List of axis labels or None.
193+ var_name (list of str or None): List of axis labels or None.
187194 title (str or None): Title for the plot.
188195 z_mode (str): "actual" for z_actual (for 3D), "error" for abs error (for 3D error surface), or None (for 2D).
189196 """
@@ -203,9 +210,9 @@ def plot_error_points(
203210 ax .scatter (x_point , y_point_ , color = color , s = 50 , edgecolor = "black" )
204211 if title is not None :
205212 ax .set_title (title )
206- if var_names is not None :
207- ax .set_xlabel (var_names [0 ])
208- ax .set_ylabel (var_names [1 ])
213+ if var_name is not None :
214+ ax .set_xlabel (var_name [0 ])
215+ ax .set_ylabel (var_name [1 ])
209216 else :
210217 ax .set_xlabel (f"Dimension { i } " )
211218 ax .set_ylabel (f"Dimension { j } " )
@@ -220,8 +227,8 @@ def plot_3d_surface(
220227 y : np .ndarray = None ,
221228 model = None ,
222229 surface_label : str = "Prediction Surface" ,
223- zlabel : str = "Prediction " ,
224- var_names : Optional [List [str ]] = None ,
230+ zlabel : str = "y " ,
231+ var_name : Optional [List [str ]] = None ,
225232 alpha : float = 0.8 ,
226233 eps : float = 1e-4 ,
227234 max_error : float = 1e-3 ,
@@ -244,7 +251,7 @@ def plot_3d_surface(
244251 model (object): Fitted model with predict().
245252 surface_label (str): Title for the surface.
246253 zlabel (str): Label for the z-axis.
247- var_names (list of str or None): List of axis labels or None.
254+ var_name (list of str or None): List of axis labels or None.
248255 alpha (float): Surface transparency.
249256 eps (float): Tolerance for error coloring.
250257 max_error (float): Maximum error for color scaling.
@@ -257,11 +264,11 @@ def plot_3d_surface(
257264 """
258265 ax .plot_surface (* Z [:2 ], Z [2 ], cmap = cmap , alpha = alpha , vmin = vmin , vmax = vmax ) if isinstance (Z , tuple ) else ax .plot_surface (Z [0 ], Z [1 ], Z [2 ], cmap = cmap , alpha = alpha , vmin = vmin , vmax = vmax )
259266 ax .set_title (surface_label )
260- ax .set_xlabel (var_names [ 0 ] if var_names else f"Dimension { i } " )
261- ax .set_ylabel (var_names [ 1 ] if var_names else f"Dimension { j } " )
262- ax .set_zlabel (var_names [ 2 ] if var_names else zlabel )
267+ ax .set_xlabel (var_name [ i ] if var_name else f"Dimension { i } " )
268+ ax .set_ylabel (var_name [ j ] if var_name else f"Dimension { j } " )
269+ ax .set_zlabel (zlabel )
263270 if add_points and X is not None and y is not None and model is not None :
264- plot_error_points (ax , X , y , model , i , j , eps , max_error , var_names , z_mode = "error" if error_surface else "actual" )
271+ plot_error_points (ax , X , y , model , i , j , eps , max_error , var_name , z_mode = "error" if error_surface else "actual" )
265272
266273
267274def plot_contour_and_err (
@@ -276,7 +283,7 @@ def plot_contour_and_err(
276283 model = None ,
277284 eps : float = 1e-4 ,
278285 max_error : float = 1e-3 ,
279- var_names : Optional [List [str ]] = None ,
286+ var_name : Optional [List [str ]] = None ,
280287 cmap : str = "jet" ,
281288 levels : int = 30 ,
282289 title : str = "Prediction Contour" ,
@@ -299,7 +306,7 @@ def plot_contour_and_err(
299306 model (object): Fitted model with predict().
300307 eps (float): Tolerance for coloring points based on prediction error.
301308 max_error (float): Maximum error for color scaling.
302- var_names (list of str or None): List of axis labels or None.
309+ var_name (list of str or None): List of axis labels or None.
303310 cmap (str): Colormap for the contour plot.
304311 levels (int): Number of contour levels.
305312 title (str): Title for the plot.
@@ -311,9 +318,11 @@ def plot_contour_and_err(
311318 None
312319 """
313320 contour = ax .contourf (X_i , X_j , Z , cmap = cmap , levels = levels , vmin = vmin , vmax = vmax )
321+ ax .set_xlabel (var_name [i ] if var_name else f"Dimension { i } " )
322+ ax .set_ylabel (var_name [j ] if var_name else f"Dimension { j } " )
314323 plt .colorbar (contour , ax = ax )
315324 if add_points and X is not None and y is not None and model is not None :
316- plot_error_points (ax , X , y , model , i , j , eps , max_error , var_names , title , z_mode = None )
325+ plot_error_points (ax , X , y , model , i , j , eps , max_error , var_name , title , z_mode = None )
317326
318327
319328def check_ij (i : int , j : int , k : int ) -> None :
@@ -340,7 +349,7 @@ def plotkd(
340349 alpha : float = 0.8 ,
341350 eps : float = 1e-4 ,
342351 max_error : float = 1e-3 ,
343- var_names : Optional [List [str ]] = None ,
352+ var_name : Optional [List [str ]] = None ,
344353 var_type : Optional [List [str ]] = None ,
345354 cmap : str = "jet" ,
346355 num : int = 100 ,
@@ -361,7 +370,7 @@ def plotkd(
361370 alpha (float): Transparency of the surface plot. Default is 0.8.
362371 eps (float): Tolerance for coloring points based on prediction error. Default is 1e-4.
363372 max_error (float): Maximum error for color scaling. Default is 1e-3.
364- var_names (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
373+ var_name (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
365374 var_type (list of str, optional): List of variable types for each dimension. Can be either "num", "int", or "factor".
366375 cmap (str): Colormap for the surface and contour plots. Default is "jet".
367376 num (int): Number of grid points per dimension for the mesh grid. Default is 100.
@@ -405,7 +414,7 @@ def plotkd(
405414 model ,
406415 surface_label = "Prediction Surface" ,
407416 zlabel = "Prediction" ,
408- var_names = var_names ,
417+ var_name = var_name ,
409418 alpha = alpha ,
410419 eps = eps ,
411420 max_error = max_error ,
@@ -428,7 +437,7 @@ def plotkd(
428437 model ,
429438 surface_label = "Prediction Error Surface" ,
430439 zlabel = "Error" ,
431- var_names = var_names ,
440+ var_name = var_name ,
432441 alpha = alpha ,
433442 eps = eps ,
434443 max_error = max_error ,
@@ -453,7 +462,7 @@ def plotkd(
453462 model ,
454463 eps = eps ,
455464 max_error = max_error ,
456- var_names = var_names ,
465+ var_name = var_name ,
457466 cmap = cmap ,
458467 levels = 30 ,
459468 title = "Prediction Contour" ,
@@ -476,7 +485,7 @@ def plotkd(
476485 model ,
477486 eps = eps ,
478487 max_error = max_error ,
479- var_names = var_names ,
488+ var_name = var_name ,
480489 cmap = cmap ,
481490 levels = 30 ,
482491 title = "Error Contour" ,
0 commit comments