@@ -849,7 +849,7 @@ def plot_contour(
849849 if show :
850850 pylab .show ()
851851
852- def plot_important_hyperparameter_contour (self , threshold = 0.025 , filename = None ) :
852+ def plot_important_hyperparameter_contour (self , threshold = 0.025 , filename = None , show = True ) -> None :
853853 impo = self .print_importance (threshold = threshold , print_screen = True )
854854 var_plots = [i for i , x in enumerate (impo ) if x [1 ] > threshold ]
855855 min_z = min (self .y )
@@ -861,7 +861,7 @@ def plot_important_hyperparameter_contour(self, threshold=0.025, filename=None):
861861 filename_full = filename + "_contour_" + str (i ) + "_" + str (j ) + ".png"
862862 else :
863863 filename_full = None
864- self .plot_contour (i = i , j = j , min_z = min_z , max_z = max_z , filename = filename_full )
864+ self .plot_contour (i = i , j = j , min_z = min_z , max_z = max_z , filename = filename_full , show = show )
865865
866866 def get_importance (self ) -> list :
867867 """Get importance of each variable and return the results as a list.
@@ -934,7 +934,17 @@ def plot_importance(self, threshold=0.1, filename=None, dpi=300) -> None:
934934 plt .savefig (filename , bbox_inches = "tight" , dpi = dpi )
935935 plt .show ()
936936
937- def parallel_plot (self ):
937+ def parallel_plot (self , show = True ) -> go .Figure :
938+ """
939+ Parallel plot.
940+
941+ Args:
942+ show (bool): show the plot
943+
944+ Returns:
945+ fig (plotly.graph_objects.Figure): figure object
946+
947+ """
938948 X = self .X
939949 y = self .y
940950 df = pd .DataFrame (np .concatenate ((X , y .reshape (- 1 , 1 )), axis = 1 ), columns = self .var_name + ["y" ])
@@ -950,4 +960,6 @@ def parallel_plot(self):
950960 ),
951961 )
952962 )
953- fig .show ()
963+ if show :
964+ fig .show ()
965+ return fig
0 commit comments