Skip to content

Commit 11c2ab9

Browse files
0.18.12
predict model
1 parent c9b7e5c commit 11c2ab9

3 files changed

Lines changed: 10 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.11"
10+
version = "0.18.12"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/predictmodel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
3838
import spotpython.light.testmodel as tm
3939
fun_control = fun_control_init(
4040
_L_in=10,
41-
_L_out=1,)
41+
_L_out=1,
42+
_torchmetric="mean_squared_error")
4243
dataset = Diabetes()
4344
set_control_key_value(control_dict=fun_control,
4445
key="data_set",
@@ -107,9 +108,12 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
107108
# Pass the datamodule as arg to trainer.fit to override model hooks :)
108109
trainer.fit(model=model, datamodule=dm)
109110

110-
dm.setup(stage="predict")
111-
predictions = trainer.predict(model=model, datamodule=dm)
112-
# predictions = trainer.predict(datamodule=dm)
111+
# Changed in spotpython 0.18.12: commented out the following line
112+
# dm.setup(stage="predict")
113+
114+
# predictions = trainer.predict(model=model, datamodule=dm)
115+
# Changed in spotpython 0.18.12: use ckpt_path="last" to load the last checkpoint and not the model
116+
predictions = trainer.predict(datamodule=dm, ckpt_path="last")
113117

114118
# # Load the last checkpoint
115119
# test_result = trainer.test(datamodule=dm, ckpt_path="last")

src/spotpython/light/testmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
107107
)
108108
# Pass the datamodule as arg to trainer.fit to override model hooks :)
109109
trainer.fit(model=model, datamodule=dm)
110+
110111
# Load the last checkpoint
111112
test_result = trainer.test(datamodule=dm, ckpt_path="last")
112113
test_result = test_result[0]

0 commit comments

Comments
 (0)