Skip to content

Commit d731523

Browse files
0.14.59
documentation traintest
1 parent 8a94b9c commit d731523

2 files changed

Lines changed: 52 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.58"
10+
version = "0.14.59"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/sklearn/traintest.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,28 @@
66
import pandas as pd
77

88

9-
def evaluate_model(model, fun_control):
9+
def evaluate_model(model, fun_control) -> np.ndarray:
10+
"""Evaluate a model using the test set.
11+
First, the model is trained on the training set. If a scaler
12+
is provided, the data is transformed using the scaler and `fit_transform(X_train)`.
13+
Then, the model is evaluated using the test set from `fun_control`,
14+
the scaler with `transform(X_test)`,
15+
the model.predict() method and the
16+
`metric_params` specified in `fun_control`.
17+
18+
Note:
19+
In contrast to `evaluate_hold_out()`, this function uses the test set.
20+
It can be selected by setting `fun_control["eval"] = "eval_test"`.
21+
22+
Args:
23+
model (sklearn model):
24+
sklearn model.
25+
fun_control (dict):
26+
dictionary containing control parameters for the function.
27+
28+
Returns:
29+
(np.ndarray): array containing evaluation results.
30+
"""
1031
try:
1132
X_train, y_train = get_Xy_from_df(fun_control["train"], fun_control["target_column"])
1233
X_test, y_test = get_Xy_from_df(fun_control["test"], fun_control["target_column"])
@@ -26,11 +47,32 @@ def evaluate_model(model, fun_control):
2647
return df_eval, df_preds
2748

2849

29-
def evaluate_hold_out(model, fun_control):
50+
def evaluate_hold_out(model, fun_control) -> np.ndarray:
51+
"""Evaluate a model using hold-out validation.
52+
A validation set is created from the training set.
53+
The test set is not used in this evaluation.
54+
55+
Note:
56+
In contrast to `evaluate_model()`, this function creates a validation set as
57+
a subset of the training set.
58+
It can be selected by setting `fun_control["eval"] = "evaluate_hold_out"`.
59+
60+
Args:
61+
model (sklearn model):
62+
sklearn model.
63+
fun_control (dict):
64+
dictionary containing control parameters for the function.
65+
66+
Returns:
67+
(np.ndarray): array containing evaluation results.
68+
69+
Raises:
70+
Exception: if call to train_test_split() or fit() or predict() fails.
71+
"""
3072
train_df = fun_control["train"]
3173
target_column = fun_control["target_column"]
3274
try:
33-
X_train, X_test, y_train, y_test = train_test_split(
75+
X_train, X_val, y_train, y_val = train_test_split(
3476
train_df.drop(target_column, axis=1),
3577
train_df[target_column],
3678
random_state=42,
@@ -51,14 +93,14 @@ def evaluate_hold_out(model, fun_control):
5193
print(f"Error in evaluate_hold_out(). Call to fit() failed. {err=}, {type(err)=}")
5294
try:
5395
if fun_control["scaler"] is not None:
54-
X_test = scaler.transform(X_test)
55-
X_test = pd.DataFrame(X_test, columns=train_df.drop(target_column, axis=1).columns) # Maintain column names
56-
y_test = np.array(y_test)
96+
X_val = scaler.transform(X_val)
97+
X_val = pd.DataFrame(X_val, columns=train_df.drop(target_column, axis=1).columns) # Maintain column names
98+
y_val = np.array(y_val)
5799
if fun_control["predict_proba"] or fun_control["task"] == "classification":
58-
df_preds = model.predict_proba(X_test)
100+
df_preds = model.predict_proba(X_val)
59101
else:
60-
df_preds = model.predict(X_test)
61-
df_eval = fun_control["metric_sklearn"](y_test, df_preds, **fun_control["metric_params"])
102+
df_preds = model.predict(X_val)
103+
df_eval = fun_control["metric_sklearn"](y_val, df_preds, **fun_control["metric_params"])
62104
except Exception as err:
63105
print(f"Error in evaluate_hold_out(). Call to predict() failed. {err=}, {type(err)=}")
64106
df_eval = np.nan

0 commit comments

Comments
 (0)