99import pandas as pd
1010
1111
12- def plot_cv_predictions (model : Any , fun_control : Dict ) -> None :
12+ def plot_cv_predictions (model : Any , fun_control : Dict , show = True ) -> None :
1313 """
1414 Plots cross-validated predictions for regression.
1515
@@ -23,6 +23,8 @@ def plot_cv_predictions(model: Any, fun_control: Dict) -> None:
2323 Sklearn model. The model to be used for cross-validation.
2424 fun_control (Dict):
2525 Dictionary containing the data and the target column.
26+ show (bool, optional):
27+ If True, the plot is shown. Defaults to True.
2628
2729 Returns:
2830 (NoneType): None
@@ -59,14 +61,16 @@ def plot_cv_predictions(model: Any, fun_control: Dict) -> None:
5961 axs [1 ].set_title ("Residuals vs. Predicted Values" )
6062 fig .suptitle ("Plotting cross-validated predictions" )
6163 plt .tight_layout ()
62- plt .show ()
64+ if show :
65+ plt .show ()
6366
6467
6568def plot_roc (
6669 model_list : List [BaseEstimator ],
6770 fun_control : Dict [str , Union [str , pd .DataFrame ]],
6871 alpha : float = 0.8 ,
6972 model_names : List [str ] = None ,
73+ show = True ,
7074) -> None :
7175 """
7276 Plots ROC curves for a list of models using the Visualization API from scikit-learn.
@@ -80,6 +84,8 @@ def plot_roc(
8084 The alpha value for the ROC curve. Defaults to 0.8.
8185 model_names (List[str], optional):
8286 A list of names for the models. Defaults to None.
87+ show (bool, optional):
88+ If True, the plot is shown. Defaults to True.
8389
8490 Returns:
8591 (NoneType): None
@@ -113,23 +119,31 @@ def plot_roc(
113119 model_name = None
114120 y_pred = model .predict (X_test )
115121 RocCurveDisplay .from_predictions (y_test , y_pred , ax = ax , alpha = alpha , name = model_name )
116- plt .show ()
122+ if show :
123+ plt .show ()
117124
118125
119126def plot_roc_from_dataframes (
120127 df_list : List [pd .DataFrame ],
121128 alpha : float = 0.8 ,
122129 model_names : List [str ] = None ,
123130 target_column : str = None ,
131+ show = True ,
124132) -> None :
125133 """
126134 Plot ROC curve for a list of dataframes from model evaluations.
127135
128136 Args:
129- df_list: List of dataframes with results from models.
130- alpha: Transparency of the plotted lines.
131- model_names: List of model names.
132- target_column: Name of the target column.
137+ df_list:
138+ List of dataframes with results from models.
139+ alpha:
140+ Transparency of the plotted lines.
141+ model_names:
142+ List of model names.
143+ target_column:
144+ Name of the target column.
145+ show:
146+ If True, the plot is shown.
133147
134148 Returns:
135149 None
@@ -157,7 +171,7 @@ def plot_roc_from_dataframes(
157171
158172
159173def plot_confusion_matrix (
160- model = None , fun_control = None , df = None , title = None , target_names = None , y_true_name = None , y_pred_name = None
174+ model = None , fun_control = None , df = None , title = None , target_names = None , y_true_name = None , y_pred_name = None , show = False
161175):
162176 """
163177 Plotting a confusion matrix. If a model and the fun_control dictionary are passed,
@@ -180,6 +194,8 @@ def plot_confusion_matrix(
180194 Name of the column with the true values if a dataframe is specified. Defaults to None.
181195 y_pred_name (str, optional):
182196 Name of the column with the predicted values if a dataframe is specified. Defaults to None.
197+ show (bool, optional):
198+ If True, the plot is shown. Defaults to False.
183199
184200 Returns:
185201 (NoneType): None
@@ -202,9 +218,22 @@ def plot_confusion_matrix(
202218 ax .yaxis .set_ticklabels (target_names )
203219 if title is not None :
204220 _ = ax .set_title (title )
221+ if show :
222+ plt .show ()
223+
205224
225+ def plot_actual_vs_predicted (y_test , y_pred , title = None , show = True ) -> None :
226+ """Plot actual vs. predicted values.
227+
228+ Args:
229+ y_test (np.ndarray): True values.
230+ y_pred (np.ndarray): Predicted values.
231+ title (str, optional): Title of the plot. Defaults to None.
232+ show (bool, optional): If True, the plot is shown. Defaults to True.
206233
207- def plot_actual_vs_predicted (y_test , y_pred , title = None ):
234+ Returns:
235+ (NoneType): None
236+ """
208237 fig , axs = plt .subplots (ncols = 2 , figsize = (8 , 4 ))
209238 PredictionErrorDisplay .from_predictions (
210239 y_test ,
@@ -228,4 +257,5 @@ def plot_actual_vs_predicted(y_test, y_pred, title=None):
228257 if title is not None :
229258 fig .suptitle (title )
230259 plt .tight_layout ()
231- plt .show ()
260+ if show :
261+ plt .show ()
0 commit comments