Skip to content

Commit 256491d

Browse files
v0.2.48
early stopping
1 parent 44523da commit 256491d

2 files changed

Lines changed: 3 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.47"
10+
version = "0.2.48"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/traintest.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import lightning as L
2-
3-
# from spotPython.light.mnistdatamodule import MNISTDataModule
42
from spotPython.light.csvdatamodule import CSVDataModule
53
from spotPython.light.crossvalidationdatamodule import CrossValidationDataModule
64
from spotPython.utils.eda import generate_config_id
7-
8-
# from spotPython.light.litmodel import LitModel
9-
105
from pytorch_lightning.loggers import TensorBoardLogger
6+
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
117

128

139
def train_model(config, fun_control):
@@ -26,6 +22,7 @@ def train_model(config, fun_control):
2622
accelerator="auto",
2723
devices=1,
2824
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id),
25+
callbacks=[EarlyStopping(monitor="val_loss", patience=3, mode="min", strict=False, verbose=False)],
2926
)
3027
# Pass the datamodule as arg to trainer.fit to override model hooks :)
3128
trainer.fit(model=model, datamodule=dm)

0 commit comments

Comments
 (0)