@@ -34,8 +34,8 @@ def plot1d(model, X: np.ndarray, y: np.ndarray, show: Optional[bool] = True) ->
3434 raise ValueError ("plot1d is only supported for 1D input data." )
3535
3636 _ = plt .figure (figsize = (9 , 6 ))
37- n_grid = 100
38- x = linspace (X [:, 0 ].min (), X [:, 0 ].max (), num = n_grid ).reshape (- 1 , 1 )
37+ num = 100
38+ x = linspace (X [:, 0 ].min (), X [:, 0 ].max (), num = num ).reshape (- 1 , 1 )
3939 y_pred , y_std = model .predict (x , return_std = True )
4040
4141 plt .plot (x , y_pred , "k" , label = "Prediction" )
@@ -55,30 +55,47 @@ def plot1d(model, X: np.ndarray, y: np.ndarray, show: Optional[bool] = True) ->
5555 plt .show ()
5656
5757
58- def generate_mesh_grid (X : np .ndarray , i : int , j : int , n_grid : int = 100 ):
58+ def generate_mesh_grid (
59+ X : Optional [np .ndarray ] = None ,
60+ i : int = 0 ,
61+ j : int = 1 ,
62+ num : int = 100 ,
63+ lower : Optional [np .ndarray ] = None ,
64+ upper : Optional [np .ndarray ] = None ,
65+ ):
5966 """
60- Generate a mesh grid for two selected dimensions of X, and fill the remaining dimensions with their mean values.
67+ Generate a mesh grid for two selected dimensions, filling remaining dimensions with their mean values
68+ (if X is given) or the mean of the lower and upper bound (if lower and upper are given).
6169
6270 Args:
63- X (np.ndarray): Input data of shape (n_samples, k).
71+ X (np.ndarray, optional ): Input data of shape (n_samples, k). Required if lower/upper are not given .
6472 i (int): Index of the first dimension to vary.
6573 j (int): Index of the second dimension to vary.
66- n_grid (int): Number of grid points per dimension.
74+ num (int): Number of grid points per dimension.
75+ lower (np.ndarray, optional): Lower bounds for each dimension (shape (k,)).
76+ upper (np.ndarray, optional): Upper bounds for each dimension (shape (k,)).
6777
6878 Returns:
6979 X_i (np.ndarray): Meshgrid for the i-th dimension.
7080 X_j (np.ndarray): Meshgrid for the j-th dimension.
71- grid_points (np.ndarray): Grid points of shape (n_grid*n_grid , k) for prediction.
81+ grid_points (np.ndarray): Grid points of shape (num*num , k) for prediction.
7282 """
73- k = X .shape [1 ]
74- mean_values = X .mean (axis = 0 )
83+ # Check that exactly one of (X) or (lower and upper) is provided
84+ if (X is not None and (lower is not None or upper is not None )) or (X is None and (lower is None or upper is None )):
85+ raise ValueError ("Provide either X or both lower and upper, but not both or neither." )
86+
87+ if X is not None :
88+ k = X .shape [1 ]
89+ mean_values = X .mean (axis = 0 )
90+ x_i = linspace (X [:, i ].min (), X [:, i ].max (), num = num )
91+ x_j = linspace (X [:, j ].min (), X [:, j ].max (), num = num )
92+ else :
93+ k = len (lower )
94+ mean_values = (np .array (lower ) + np .array (upper )) / 2.0
95+ x_i = linspace (lower [i ], upper [i ], num = num )
96+ x_j = linspace (lower [j ], upper [j ], num = num )
7597
76- # Create a grid for the two varied dimensions
77- x_i = linspace (X [:, i ].min (), X [:, i ].max (), num = n_grid )
78- x_j = linspace (X [:, j ].min (), X [:, j ].max (), num = n_grid )
7998 X_i , X_j = meshgrid (x_i , x_j )
80-
81- # Prepare the grid points for prediction
8299 grid_points = np .zeros ((X_i .size , k ))
83100 grid_points [:, i ] = X_i .ravel ()
84101 grid_points [:, j ] = X_j .ravel ()
@@ -318,7 +335,7 @@ def plotkd(
318335 max_error : float = 1e-3 ,
319336 var_names : Optional [List [str ]] = None ,
320337 cmap : str = "jet" ,
321- n_grid : int = 100 ,
338+ num : int = 100 ,
322339 vmin : Optional [float ] = None ,
323340 vmax : Optional [float ] = None ,
324341 add_points : bool = False ,
@@ -338,7 +355,7 @@ def plotkd(
338355 max_error (float): Maximum error for color scaling. Default is 1e-3.
339356 var_names (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
340357 cmap (str): Colormap for the surface and contour plots. Default is "jet".
341- n_grid (int): Number of grid points per dimension for the mesh grid. Default is 100.
358+ num (int): Number of grid points per dimension for the mesh grid. Default is 100.
342359 vmin (float, optional): Minimum value for the color scale. If None, determined from predictions.
343360 vmax (float, optional): Maximum value for the color scale. If None, determined from predictions.
344361 add_points (bool): If True, adds scatter points to the surface and contour plots. Default is False.
@@ -358,7 +375,7 @@ def plotkd(
358375 """
359376 k = X .shape [1 ]
360377 check_ij (i , j , k )
361- X_i , X_j , grid_points = generate_mesh_grid (X , i , j , n_grid )
378+ X_i , X_j , grid_points = generate_mesh_grid (X , i , j , num )
362379
363380 # Predict the values and standard deviations
364381 y_pred , y_std = model .predict (grid_points , return_std = True )
@@ -491,7 +508,7 @@ def plot_3d_contour(X, Y, Z, vmin, vmax, var_name=None, i=0, j=1, show=True, fil
491508 Examples:
492509 >>> # Example 1: Using output from Spot
493510 >>> # Assume S is a Spot object with a fitted surrogate
494- >>> plot_data = S.prepare_plot(i=0, j=1, n_grid =100)
511+ >>> plot_data = S.prepare_plot(i=0, j=1, num =100)
495512 >>> from spotpython.surrogate.plot import plot_3d_contour
496513 >>> plot_3d_contour(
497514 ... plot_data,
0 commit comments