Skip to content

Commit bc23cea

Browse files
0.18.8
cleanup
1 parent cc4910c commit bc23cea

6 files changed

Lines changed: 21 additions & 13 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.7"
10+
version = "0.18.8"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/lightdatamodule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class LightDataModule(L.LightningDataModule):
2828
The number of workers. Defaults to 0.
2929
scaler (object, optional):
3030
The spot scaler object (e.g. TorchStandardScaler). Defaults to None.
31+
verbosity (int):
32+
The verbosity level. Defaults to 0.
3133
3234
Examples:
3335
>>> from spotpython.data.lightdatamodule import LightDataModule

src/spotpython/light/cvmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def cv_model(config: dict, fun_control: dict) -> float:
6767
batch_size=config["batch_size"],
6868
data_dir=fun_control["DATASET_PATH"],
6969
scaler=fun_control["scaler"],
70+
verbosity=fun_control["verbosity"],
7071
)
7172
dm.setup()
7273
dm.prepare_data()

src/spotpython/light/predictmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7373
test_size=fun_control["test_size"],
7474
test_seed=fun_control["test_seed"],
7575
scaler=fun_control["scaler"],
76+
verbosity=fun_control["verbosity"],
7677
)
7778
# TODO: Check if this is necessary:
7879
# dm.setup(stage="train")

src/spotpython/light/testmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7474
test_size=fun_control["test_size"],
7575
test_seed=fun_control["test_seed"],
7676
scaler=fun_control["scaler"],
77+
verbosity=fun_control["verbosity"],
7778
)
7879
# TODO: Check if this is necessary:
7980
# dm.setup()

src/spotpython/light/trainmodel.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,21 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
105105
pprint.pprint(config)
106106
y = train_model(config, fun_control)
107107
"""
108-
config_id = generate_config_id_with_timestamp(config=config, timestamp=timestamp)
108+
if fun_control["data_module"] is None:
109+
dm = LightDataModule(
110+
dataset=fun_control["data_set"],
111+
data_full_train=fun_control["data_full_train"],
112+
data_test=fun_control["data_test"],
113+
batch_size=config["batch_size"],
114+
num_workers=fun_control["num_workers"],
115+
test_size=fun_control["test_size"],
116+
test_seed=fun_control["test_seed"],
117+
scaler=fun_control["scaler"],
118+
verbosity=fun_control["verbosity"],
119+
)
120+
else:
121+
dm = fun_control["data_module"]
109122
model = build_model_instance(config, fun_control)
110-
dm = LightDataModule(
111-
dataset=fun_control["data_set"],
112-
data_full_train=fun_control["data_full_train"],
113-
data_test=fun_control["data_test"],
114-
batch_size=config["batch_size"],
115-
num_workers=fun_control["num_workers"],
116-
test_size=fun_control["test_size"],
117-
test_seed=fun_control["test_seed"],
118-
scaler=fun_control["scaler"],
119-
)
120123
# TODO: Check if this is necessary or if this is handled by the trainer
121124
# dm.setup()
122125
# print(f"train_model(): Test set size: {len(dm.data_test)}")
@@ -168,7 +171,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
168171
# Can be set to 'link' on a local filesystem to create a symbolic link.
169172
# This allows accessing the latest checkpoint in a deterministic manner.
170173
# Default: None.
171-
174+
config_id = generate_config_id_with_timestamp(config=config, timestamp=timestamp)
172175
callbacks = [EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)]
173176
if not timestamp:
174177
# add ModelCheckpoint only if timestamp is False

0 commit comments

Comments
 (0)