@@ -2448,6 +2448,7 @@ def plot_contour(
24482448 use_min = False ,
24492449 use_max = True ,
24502450 tkagg = False ,
2451+ cmap = "jet" ,
24512452 ) -> None :
24522453 """
24532454 Plot the contour and 3D surface for any pair of dimensions of the surrogate model.
@@ -2471,19 +2472,25 @@ def plot_contour(
24712472 use_min (bool, optional): If True, fix hidden dimensions to their minimum values. Default is False.
24722473 use_max (bool, optional): If True, fix hidden dimensions to their maximum values. Default is True.
24732474 tkagg (bool, optional): If True, use TkAgg backend for matplotlib. Default is False.
2475+ cmap (str, optional): Colormap to use for the contour plot. Default is "jet".
24742476
24752477 Returns:
24762478 None
24772479 """
2478- plot_data = self .prepare_plot (
2480+ X , Y , Z = self .prepare_plot (
24792481 i = i ,
24802482 j = j ,
24812483 n_grid = n_grid ,
24822484 use_min = use_min ,
24832485 use_max = use_max ,
24842486 )
24852487 plot_3d_contour (
2486- plot_data ,
2488+ X = X ,
2489+ Y = Y ,
2490+ Z = Z ,
2491+ vmin = min_z if min_z is not None else np .min (Z ),
2492+ vmax = max_z if max_z is not None else np .max (Z ),
2493+ var_name = self .var_name ,
24872494 i = i ,
24882495 j = j ,
24892496 show = show ,
@@ -2493,6 +2500,7 @@ def plot_contour(
24932500 title = title ,
24942501 figsize = figsize ,
24952502 tkagg = tkagg ,
2503+ cmap = cmap ,
24962504 )
24972505
24982506 def prepare_plot (
@@ -2523,7 +2531,7 @@ def prepare_plot(
25232531 def generate_mesh_grid (lower , upper , grid_points ):
25242532 x = np .linspace (lower [i ], upper [i ], num = grid_points )
25252533 y = np .linspace (lower [j ], upper [j ], num = grid_points )
2526- return np .meshgrid (x , y ), x , y
2534+ return np .meshgrid (x , y )
25272535
25282536 def validate_types (var_type , lower , upper ):
25292537 if var_type is not None :
@@ -2540,7 +2548,7 @@ def predict_contour_values(X, Y, z0):
25402548 Z = np .array (predictions ).reshape (X .shape )
25412549 return Z
25422550
2543- (X , Y ), x , y = generate_mesh_grid (self .lower , self .upper , n_grid )
2551+ (X , Y ) = generate_mesh_grid (self .lower , self .upper , n_grid )
25442552 validate_types (self .var_type , self .lower , self .upper )
25452553
25462554 z00 = np .array ([self .lower , self .upper ])
@@ -2567,16 +2575,7 @@ def predict_contour_values(X, Y, z0):
25672575 else :
25682576 raise ValueError ("No data to plot." )
25692577
2570- min_z = np .min (Z_combined )
2571- max_z = np .max (Z_combined )
2572-
2573- return {
2574- "X_combined" : X_combined ,
2575- "Y_combined" : Y_combined ,
2576- "Z_combined" : Z_combined ,
2577- "min_z" : min_z ,
2578- "max_z" : max_z ,
2579- }
2578+ return X_combined , Y_combined , Z_combined
25802579
25812580 def plot_important_hyperparameter_contour (
25822581 self ,
0 commit comments