Skip to content

Commit 76011e0

Browse files
0.10.64
plot_confusion_matrix layout improved fun_control_init accepts weights arguments new get_metric_sgn function
1 parent e92d684 commit 76011e0

5 files changed

Lines changed: 93 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.63"
10+
version = "0.10.64"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/plot/validation.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,15 @@ def plot_roc_from_dataframes(
172172

173173

174174
def plot_confusion_matrix(
175-
model=None, fun_control=None, df=None, title=None, target_names=None, y_true_name=None, y_pred_name=None, show=False
175+
model=None,
176+
fun_control=None,
177+
df=None,
178+
title=None,
179+
target_names=None,
180+
y_true_name=None,
181+
y_pred_name=None,
182+
show=False,
183+
ax=None,
176184
):
177185
"""
178186
Plotting a confusion matrix. If a model and the fun_control dictionary are passed,
@@ -197,6 +205,8 @@ def plot_confusion_matrix(
197205
Name of the column with the predicted values if a dataframe is specified. Defaults to None.
198206
show (bool, optional):
199207
If True, the plot is shown. Defaults to False.
208+
ax (matplotlib.axes._subplots.AxesSubplot, optional):
209+
Axes to plot the confusion matrix. Defaults to None.
200210
201211
Returns:
202212
(NoneType): None
@@ -212,8 +222,9 @@ def plot_confusion_matrix(
212222
X_test, y_true = get_Xy_from_df(fun_control["test"], fun_control["target_column"])
213223
model.fit(X_train, y_train)
214224
y_pred = model.predict(X_test)
215-
fig, ax = plt.subplots(figsize=(10, 5))
216-
ConfusionMatrixDisplay.from_predictions(y_true=y_true, y_pred=y_pred, ax=ax)
225+
if ax is None:
226+
fig, ax = plt.subplots(figsize=(10, 5))
227+
ConfusionMatrixDisplay.from_predictions(y_true=y_true, y_pred=y_pred, ax=ax, colorbar=False)
217228
if target_names is not None:
218229
ax.xaxis.set_ticklabels(target_names)
219230
ax.yaxis.set_ticklabels(target_names)

src/spotPython/utils/init.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def fun_control_init(
5555
var_name=None,
5656
var_type=["num"],
5757
verbosity=0,
58+
weights=1.0,
59+
weight_coeff=0.0,
5860
):
5961
"""Initialize fun_control dictionary.
6062
@@ -171,6 +173,12 @@ def fun_control_init(
171173
verbosity (int):
172174
The verbosity level. Determines print output to console. Higher values
173175
result in more output. Default is 0.
176+
weights (float):
177+
The weight coefficient of the objective function. Positive values mean minimization.
178+
If set to -1, scores that are better when maximized will be minimized, e.g, accuracy.
179+
Can be an array, so that different weights can be used for different (multiple) objectives.
180+
weight_coeff (float):
181+
Determines how to weight older measures. Default is 1.0. Used in the OML algorithm eval_oml.py.
174182
175183
Returns:
176184
fun_control (dict):
@@ -327,7 +335,8 @@ def fun_control_init(
327335
"var_name": var_name,
328336
"var_type": var_type,
329337
"verbosity": verbosity,
330-
"weights": 1.0,
338+
"weights": weights,
339+
"weight_coeff": weight_coeff,
331340
}
332341
# lower = X_reshape(lower)
333342
# fun_control.update({"lower": lower})

src/spotPython/utils/metrics.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,52 @@ def mapk_scorer(estimator, X, y):
128128
y_pred = estimator.predict_proba(X)
129129
score = mapk_score(y, y_pred, k=3)
130130
return score
131+
132+
133+
def get_metric_sign(metric_name):
134+
"""Returns the sign of a metric.
135+
136+
Args:
137+
metric_name (str):
138+
The name of the metric. Can be one of the following:
139+
- "accuracy_score"
140+
- "cohen_kappa_score"
141+
- "f1_score"
142+
- "hamming_loss"
143+
- "hinge_loss"
144+
-"jaccard_score"
145+
- "matthews_corrcoef"
146+
- "precision_score"
147+
- "recall_score"
148+
- "roc_auc_score"
149+
- "zero_one_loss"
150+
151+
Returns:
152+
sign (float): The sign of the metric. -1 for max, +1 for min.
153+
154+
Raises:
155+
ValueError: If the metric is not found.
156+
157+
Examples:
158+
>>> from spotPython.metrics import get_metric_sign
159+
>>> get_metric_sign("accuracy_score")
160+
-1
161+
>>> get_metric_sign("hamming_loss")
162+
+1
163+
164+
"""
165+
if metric_name in [
166+
"accuracy_score",
167+
"cohen_kappa_score",
168+
"f1_score",
169+
"jaccard_score",
170+
"matthews_corrcoef",
171+
"precision_score",
172+
"recall_score",
173+
"roc_auc_score",
174+
]:
175+
return -1
176+
elif metric_name in ["hamming_loss", "hinge_loss", "zero_one_loss"]:
177+
return +1
178+
else:
179+
raise ValueError(f"Metric '{metric_name}' not found.")

test/test_metrics_sign.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from spotPython.utils.metrics import get_metric_sign
2+
import pytest
3+
4+
def test_get_metric_sign():
5+
# Test for accuracy_score
6+
assert get_metric_sign("accuracy_score") == -1
7+
8+
# Test for hamming_loss
9+
assert get_metric_sign("hamming_loss") == +1
10+
11+
# Test for f1_score
12+
assert get_metric_sign("f1_score") == -1
13+
14+
# Test for roc_auc_score
15+
assert get_metric_sign("roc_auc_score") == -1
16+
17+
# Test for unknown metric
18+
with pytest.raises(ValueError):
19+
get_metric_sign("unknown_metric")

0 commit comments

Comments
 (0)