Skip to content

Commit 15cb116

Browse files
033.7
trainmodel updated for 3.13
1 parent 852dd57 commit 15cb116

2 files changed

Lines changed: 12 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.33.6"
10+
version = "0.33.7"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
@@ -45,9 +45,7 @@ dependencies = [
4545
"plotly",
4646
"pytest",
4747
"pytest-mock",
48-
"PyQt6",
4948
"python-markdown-math",
50-
"pytorch-lightning>=1.4",
5149
"river>=0.22.0",
5250
"scikit-learn",
5351
"scipy",

src/spotpython/light/trainmodel.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
129129
)
130130
else:
131131
dm = fun_control["data_module"]
132+
dm.setup() # Manually call setup to prepare the datasets
132133

133134
model = build_model_instance(config, fun_control)
134135
# TODO: Check if this is necessary or if this is handled by the trainer
@@ -238,7 +239,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
238239
gradient_clip_algorithm="norm",
239240
)
240241

241-
trainer.fit(model=model, train_dataloaders=train_dl, ckpt_path=None)
242+
trainer.fit(model=model, train_dataloaders=train_dl, val_dataloaders=test_dl, ckpt_path=None)
242243
result = trainer.validate(model=model, dataloaders=test_dl, ckpt_path=None, verbose=verbose)
243244
result = result[0]
244245

@@ -350,10 +351,13 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
350351
# Could also be one of two special keywords "last" and "hpc".
351352
# If there is no checkpoint file at the path, an exception is raised.
352353
try:
353-
trainer.fit(model=model, datamodule=dm, ckpt_path=None)
354+
trainer.fit(model=model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader(), ckpt_path=None)
354355
except Exception as e:
355356
print(f"train_model(): trainer.fit failed with exception: {e}")
357+
return None
356358
# Test best model on validation and test set
359+
# The validate and test methods expect a datamodule or dataloaders.
360+
# Using the datamodule is cleaner.
357361
verbose = fun_control["verbosity"] > 0
358362

359363
# Validate the model
@@ -455,6 +459,7 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
455459
)
456460
else:
457461
dm = fun_control["data_module"]
462+
dm.setup() # Manually call setup to prepare the datasets
458463

459464
model = build_model_instance(config, fun_control)
460465
# TODO: Check if this is necessary or if this is handled by the trainer
@@ -619,10 +624,13 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
619624
# Could also be one of two special keywords "last" and "hpc".
620625
# If there is no checkpoint file at the path, an exception is raised.
621626
try:
622-
trainer.fit(model=model, datamodule=dm, ckpt_path=None)
627+
trainer.fit(model=model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader(), ckpt_path=None)
623628
except Exception as e:
624629
print(f"train_model(): trainer.fit failed with exception: {e}")
630+
return None
625631
# Test best model on validation and test set
632+
# The validate and test methods expect a datamodule or dataloaders.
633+
# Using the datamodule is cleaner.
626634
verbose = fun_control["verbosity"] > 0
627635

628636
# Validate the model

0 commit comments

Comments
 (0)