Skip to content

Commit 84d3c06

Browse files
v0.6.23
show option for some plots
1 parent c7e9ea9 commit 84d3c06

2 files changed

Lines changed: 41 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.22"
10+
version = "0.6.23"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/plot/validation.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pandas as pd
1010

1111

12-
def plot_cv_predictions(model: Any, fun_control: Dict) -> None:
12+
def plot_cv_predictions(model: Any, fun_control: Dict, show=True) -> None:
1313
"""
1414
Plots cross-validated predictions for regression.
1515
@@ -23,6 +23,8 @@ def plot_cv_predictions(model: Any, fun_control: Dict) -> None:
2323
Sklearn model. The model to be used for cross-validation.
2424
fun_control (Dict):
2525
Dictionary containing the data and the target column.
26+
show (bool, optional):
27+
If True, the plot is shown. Defaults to True.
2628
2729
Returns:
2830
(NoneType): None
@@ -59,14 +61,16 @@ def plot_cv_predictions(model: Any, fun_control: Dict) -> None:
5961
axs[1].set_title("Residuals vs. Predicted Values")
6062
fig.suptitle("Plotting cross-validated predictions")
6163
plt.tight_layout()
62-
plt.show()
64+
if show:
65+
plt.show()
6366

6467

6568
def plot_roc(
6669
model_list: List[BaseEstimator],
6770
fun_control: Dict[str, Union[str, pd.DataFrame]],
6871
alpha: float = 0.8,
6972
model_names: List[str] = None,
73+
show=True,
7074
) -> None:
7175
"""
7276
Plots ROC curves for a list of models using the Visualization API from scikit-learn.
@@ -80,6 +84,8 @@ def plot_roc(
8084
The alpha value for the ROC curve. Defaults to 0.8.
8185
model_names (List[str], optional):
8286
A list of names for the models. Defaults to None.
87+
show (bool, optional):
88+
If True, the plot is shown. Defaults to True.
8389
8490
Returns:
8591
(NoneType): None
@@ -113,23 +119,31 @@ def plot_roc(
113119
model_name = None
114120
y_pred = model.predict(X_test)
115121
RocCurveDisplay.from_predictions(y_test, y_pred, ax=ax, alpha=alpha, name=model_name)
116-
plt.show()
122+
if show:
123+
plt.show()
117124

118125

119126
def plot_roc_from_dataframes(
120127
df_list: List[pd.DataFrame],
121128
alpha: float = 0.8,
122129
model_names: List[str] = None,
123130
target_column: str = None,
131+
show=True,
124132
) -> None:
125133
"""
126134
Plot ROC curve for a list of dataframes from model evaluations.
127135
128136
Args:
129-
df_list: List of dataframes with results from models.
130-
alpha: Transparency of the plotted lines.
131-
model_names: List of model names.
132-
target_column: Name of the target column.
137+
df_list:
138+
List of dataframes with results from models.
139+
alpha:
140+
Transparency of the plotted lines.
141+
model_names:
142+
List of model names.
143+
target_column:
144+
Name of the target column.
145+
show:
146+
If True, the plot is shown.
133147
134148
Returns:
135149
None
@@ -157,7 +171,7 @@ def plot_roc_from_dataframes(
157171

158172

159173
def plot_confusion_matrix(
160-
model=None, fun_control=None, df=None, title=None, target_names=None, y_true_name=None, y_pred_name=None
174+
model=None, fun_control=None, df=None, title=None, target_names=None, y_true_name=None, y_pred_name=None, show=False
161175
):
162176
"""
163177
Plotting a confusion matrix. If a model and the fun_control dictionary are passed,
@@ -180,6 +194,8 @@ def plot_confusion_matrix(
180194
Name of the column with the true values if a dataframe is specified. Defaults to None.
181195
y_pred_name (str, optional):
182196
Name of the column with the predicted values if a dataframe is specified. Defaults to None.
197+
show (bool, optional):
198+
If True, the plot is shown. Defaults to False.
183199
184200
Returns:
185201
(NoneType): None
@@ -202,9 +218,22 @@ def plot_confusion_matrix(
202218
ax.yaxis.set_ticklabels(target_names)
203219
if title is not None:
204220
_ = ax.set_title(title)
221+
if show:
222+
plt.show()
223+
205224

225+
def plot_actual_vs_predicted(y_test, y_pred, title=None, show=True) -> None:
226+
"""Plot actual vs. predicted values.
227+
228+
Args:
229+
y_test (np.ndarray): True values.
230+
y_pred (np.ndarray): Predicted values.
231+
title (str, optional): Title of the plot. Defaults to None.
232+
show (bool, optional): If True, the plot is shown. Defaults to True.
206233
207-
def plot_actual_vs_predicted(y_test, y_pred, title=None):
234+
Returns:
235+
(NoneType): None
236+
"""
208237
fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
209238
PredictionErrorDisplay.from_predictions(
210239
y_test,
@@ -228,4 +257,5 @@ def plot_actual_vs_predicted(y_test, y_pred, title=None):
228257
if title is not None:
229258
fig.suptitle(title)
230259
plt.tight_layout()
231-
plt.show()
260+
if show:
261+
plt.show()

0 commit comments

Comments
 (0)