@@ -172,7 +172,15 @@ def plot_roc_from_dataframes(
172172
173173
174174def plot_confusion_matrix (
175- model = None , fun_control = None , df = None , title = None , target_names = None , y_true_name = None , y_pred_name = None , show = False
175+ model = None ,
176+ fun_control = None ,
177+ df = None ,
178+ title = None ,
179+ target_names = None ,
180+ y_true_name = None ,
181+ y_pred_name = None ,
182+ show = False ,
183+ ax = None ,
176184):
177185 """
178186 Plotting a confusion matrix. If a model and the fun_control dictionary are passed,
@@ -197,6 +205,8 @@ def plot_confusion_matrix(
197205 Name of the column with the predicted values if a dataframe is specified. Defaults to None.
198206 show (bool, optional):
199207 If True, the plot is shown. Defaults to False.
208+ ax (matplotlib.axes._subplots.AxesSubplot, optional):
209+ Axes to plot the confusion matrix. Defaults to None.
200210
201211 Returns:
202212 (NoneType): None
@@ -212,8 +222,9 @@ def plot_confusion_matrix(
212222 X_test , y_true = get_Xy_from_df (fun_control ["test" ], fun_control ["target_column" ])
213223 model .fit (X_train , y_train )
214224 y_pred = model .predict (X_test )
215- fig , ax = plt .subplots (figsize = (10 , 5 ))
216- ConfusionMatrixDisplay .from_predictions (y_true = y_true , y_pred = y_pred , ax = ax )
225+ if ax is None :
226+ fig , ax = plt .subplots (figsize = (10 , 5 ))
227+ ConfusionMatrixDisplay .from_predictions (y_true = y_true , y_pred = y_pred , ax = ax , colorbar = False )
217228 if target_names is not None :
218229 ax .xaxis .set_ticklabels (target_names )
219230 ax .yaxis .set_ticklabels (target_names )
0 commit comments