@@ -2625,7 +2625,7 @@ def print_importance(self, threshold=0.1, print_screen=True) -> list:
26252625 print ("Importance requires more than one theta values (n_theta>1)." )
26262626 return output
26272627
2628- def plot_importance (self , threshold = 0.1 , filename = None , dpi = 300 , show = True , tkagg = False ) -> None :
2628+ def plot_importance (self , threshold = 0.1 , filename = None , dpi = 300 , show = True , tkagg = False , figsize = ( 9 , 6 ) ) -> None :
26292629 """Plot the importance of each variable.
26302630
26312631 Args:
@@ -2637,23 +2637,33 @@ def plot_importance(self, threshold=0.1, filename=None, dpi=300, show=True, tkag
26372637 The dpi of the plot.
26382638 show (bool):
26392639 Show the plot. Default is `True`.
2640+ tkagg (bool):
2641+ Use TkAgg backend. Default is `False`.
2642+ figsize (tuple):
2643+ Figure size (width, height) in inches. Default is (9, 6).
26402644
26412645 Returns:
26422646 None
26432647 """
26442648 if self .surrogate .n_theta > 1 :
26452649 if tkagg :
26462650 matplotlib .use ("TkAgg" )
2651+
2652+ # Create figure with specified size
2653+ plt .figure (figsize = figsize )
2654+
26472655 theta = np .power (10 , self .surrogate .theta )
26482656 imp = 100 * theta / np .max (theta )
26492657 idx = np .where (imp > threshold )[0 ]
2658+
26502659 if self .var_name is None :
26512660 plt .bar (range (len (imp [idx ])), imp [idx ])
26522661 plt .xticks (range (len (imp [idx ])), ["x" + str (i ) for i in idx ])
26532662 else :
26542663 var_name = [self .var_name [i ] for i in idx ]
26552664 plt .bar (range (len (imp [idx ])), imp [idx ])
26562665 plt .xticks (range (len (imp [idx ])), var_name )
2666+
26572667 if filename is not None :
26582668 plt .savefig (filename , bbox_inches = "tight" , dpi = dpi )
26592669 if show :
0 commit comments