Skip to content

Commit 97c82f6

Browse files
0.6.19
plot_roc and plot_confusion updated
1 parent b3b7872 commit 97c82f6

3 files changed

Lines changed: 80 additions & 8 deletions

File tree

makeSpot.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
cd ~/workspace/spotPython
33
rm -f dist/spotPython*; python -m build; python -m pip install dist/spotPython*.tar.gz
44
python -m mkdocs build
5+
pytest

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.17"
10+
version = "0.6.19"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/plot/validation.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,87 @@ def plot_roc(
116116
plt.show()
117117

118118

119-
def plot_confusion_matrix(model, fun_control, target_names=None, title=None):
119+
def plot_roc_from_dataframes(
120+
df_list: List[pd.DataFrame],
121+
alpha: float = 0.8,
122+
model_names: List[str] = None,
123+
target_column: str = None,
124+
) -> None:
120125
"""
121-
Plotting a confusion matrix
126+
Plot ROC curve for a list of dataframes from model evaluations.
127+
128+
Args:
129+
df_list: List of dataframes with results from models.
130+
alpha: Transparency of the plotted lines.
131+
model_names: List of model names.
132+
target_column: Name of the target column.
133+
134+
Returns:
135+
None
136+
137+
Examples:
138+
>>> import pandas as pd
139+
from spotPython.plot.validation import plot_roc_from_dataframes
140+
df1 = pd.DataFrame({"y": [1, 0, 0, 1], "Prediction": [1,0,0,0]})
141+
df2 = pd.DataFrame({"y": [1, 0, 0, 1], "Prediction": [1,0,1,1]})
142+
df_list = [df1, df2]
143+
model_names = ["Model 1", "Model 2"]
144+
plot_roc_from_dataframes(df_list, model_names=model_names, target_column="y")
145+
122146
"""
123-
X_train, y_train = get_Xy_from_df(fun_control["train"], fun_control["target_column"])
124-
X_test, y_test = get_Xy_from_df(fun_control["test"], fun_control["target_column"])
125-
model.fit(X_train, y_train)
126-
pred = model.predict(X_test)
147+
ax = plt.gca()
148+
for i, df in enumerate(df_list):
149+
y_test = df[target_column]
150+
y_pred = df["Prediction"]
151+
if model_names is not None:
152+
model_name = model_names[i]
153+
else:
154+
model_name = None
155+
RocCurveDisplay.from_predictions(y_test, y_pred, ax=ax, alpha=alpha, name=model_name)
156+
plt.show()
157+
158+
159+
def plot_confusion_matrix(
160+
model=None, df=None, target_names=None, fun_control=None, title=None, y_true_name=None, y_pred_name=None
161+
):
162+
"""
163+
Plotting a confusion matrix. If a model and the fun_control dictionary are passed,
164+
the confusion matrix is computed. If a dataframe is passed, the confusion matrix is
165+
computed from the dataframe. In this case, the names of the columns with the true and
166+
the predicted values must be specified. Default the dataframe is None.
167+
168+
Args:
169+
model (Any, optional):
170+
Sklearn model. The model to be used for cross-validation. Defaults to None.
171+
df (pd.DataFrame, optional):
172+
Dataframe containing the predictions and the target column. Defaults to None.
173+
fun_control (Dict, optional):
174+
Dictionary containing the data and the target column. Defaults to None.
175+
target_names (List[str], optional):
176+
List of target names. Defaults to None.
177+
title (str, optional):
178+
Title of the plot. Defaults to None.
179+
y_true_name (str, optional):
180+
Name of the column with the true values if a dataframe is specified. Defaults to None.
181+
y_pred_name (str, optional):
182+
Name of the column with the predicted values if a dataframe is specified. Defaults to None.
183+
184+
Returns:
185+
(NoneType): None
186+
187+
"""
188+
if df is not None:
189+
# assign the column y_true_name from df to y_true
190+
y_true = df[y_true_name]
191+
# assign the column y_pred_name from df to y_pred
192+
y_pred = df[y_pred_name]
193+
else:
194+
X_train, y_train = get_Xy_from_df(fun_control["train"], fun_control["target_column"])
195+
X_test, y_true = get_Xy_from_df(fun_control["test"], fun_control["target_column"])
196+
model.fit(X_train, y_train)
197+
y_pred = model.predict(X_test)
127198
fig, ax = plt.subplots(figsize=(10, 5))
128-
ConfusionMatrixDisplay.from_predictions(y_test, pred, ax=ax)
199+
ConfusionMatrixDisplay.from_predictions(y_true=y_true, y_pred=y_pred, ax=ax)
129200
if target_names is not None:
130201
ax.xaxis.set_ticklabels(target_names)
131202
ax.yaxis.set_ticklabels(target_names)

0 commit comments

Comments
 (0)