Skip to content

Commit 2f4c111

Browse files
0.30.5
EarlyStopping refined
1 parent 4ee660d commit 2f4c111

4 files changed

Lines changed: 54 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.30.4"
10+
version = "0.30.5"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/trainmodel.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,18 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
185185
# This allows accessing the latest checkpoint in a deterministic manner.
186186
# Default: None.
187187
config_id = generate_config_id_with_timestamp(config=config, timestamp=timestamp)
188-
callbacks = [EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)]
188+
callbacks = [
189+
EarlyStopping(
190+
monitor="val_loss",
191+
patience=config["patience"],
192+
divergence_threshold=fun_control["divergence_threshold"],
193+
check_finite=fun_control["check_finite"],
194+
stopping_threshold=fun_control["stopping_threshold"],
195+
mode="min",
196+
strict=False,
197+
verbose=False,
198+
)
199+
]
189200
if not timestamp:
190201
# add ModelCheckpoint only if timestamp is False
191202
dirpath = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)
@@ -500,7 +511,18 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
500511
# This allows accessing the latest checkpoint in a deterministic manner.
501512
# Default: None.
502513
config_id = generate_config_id_with_timestamp(config=config, timestamp=timestamp)
503-
callbacks = [EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)]
514+
callbacks = [
515+
EarlyStopping(
516+
monitor="val_loss",
517+
patience=config["patience"],
518+
divergence_threshold=fun_control["divergence_threshold"],
519+
check_finite=fun_control["check_finite"],
520+
stopping_threshold=fun_control["stopping_threshold"],
521+
mode="min",
522+
strict=False,
523+
verbose=False,
524+
)
525+
]
504526
if not timestamp:
505527
# add ModelCheckpoint only if timestamp is False
506528
dirpath = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)

src/spotpython/spot/spot.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1936,7 +1936,9 @@ def infill(self, x) -> float:
19361936
else:
19371937
return self.surrogate.predict(X)
19381938

1939-
def plot_progress(self, show=True, log_x=False, log_y=False, filename="plot.png", style=["ko", "k", "ro-"], dpi=300, tkagg=False) -> None:
1939+
def plot_progress(
1940+
self, show=True, log_x=False, log_y=False, filename="plot.png", style=["ko", "k", "ro-"], dpi=300, tkagg=False, title="Objective function value over iterations", y_label="y"
1941+
) -> None:
19401942
"""Plot the progress of the hyperparameter tuning (optimization).
19411943
19421944
Args:
@@ -1951,6 +1953,10 @@ def plot_progress(self, show=True, log_x=False, log_y=False, filename="plot.png"
19511953
style (list):
19521954
Style of the plot. Default: ['k', 'ro-'], i.e., the initial points are plotted as a black line
19531955
and the subsequent points as red dots connected by a line.
1956+
title (str):
1957+
Title of the plot. Default: "Objective function value over iterations".
1958+
y_label (str):
1959+
Label for the y-axis. Default: "y".
19541960
19551961
Returns:
19561962
None
@@ -2012,6 +2018,10 @@ def plot_progress(self, show=True, log_x=False, log_y=False, filename="plot.png"
20122018
ax.set_xscale("log")
20132019
if log_y:
20142020
ax.set_yscale("log")
2021+
# add a grid
2022+
ax.grid(True, which="both", linestyle="--", linewidth=0.5)
2023+
ax.set_ylabel(y_label)
2024+
ax.set_title(title)
20152025
if filename is not None:
20162026
pylab.savefig(filename, dpi=dpi, bbox_inches="tight")
20172027
if show:

src/spotpython/utils/init.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ def fun_control_init(
2727
PREFIX=None,
2828
TENSORBOARD_CLEAN=False,
2929
accelerator="auto",
30+
check_finite=True,
3031
collate_fn_name=None,
3132
converters=None,
3233
core_model=None,
3334
core_model_name=None,
3435
data=None,
3536
data_full_train=None,
37+
divergence_threshold=None,
3638
hacky=False, # !TODO: Documentation
3739
data_val=None,
3840
data_dir="./data",
@@ -90,6 +92,7 @@ def fun_control_init(
9092
shuffle_val=False,
9193
shuffle_test=False,
9294
sigma=0.0,
95+
stopping_threshold=None,
9396
strategy="auto",
9497
surrogate=None,
9598
target_column=None,
@@ -129,6 +132,9 @@ def fun_control_init(
129132
The accelerator to be used by the Lighting Trainer.
130133
It can be either "auto", "dp", "ddp", "ddp2", "ddp_spawn", "ddp_cpu", "gpu", "tpu".
131134
Default is "auto".
135+
check_finite (bool):
136+
When set True, stops training when the monitor becomes NaN or infinite.
137+
Default is True.
132138
collate_fn_name (str):
133139
The name of the collate function. Default is None.
134140
converters (dict):
@@ -164,6 +170,9 @@ def fun_control_init(
164170
Default is 1. Can be "auto" or an integer.
165171
design (object):
166172
The experimental design object. Default is None.
173+
divergence_threshold (float):
174+
Stop training as soon as the monitored quantity becomes worse than this threshold.
175+
Default is None.
167176
enable_progress_bar (bool):
168177
Whether to enable the progress bar or not.
169178
eval (str):
@@ -284,6 +293,9 @@ def fun_control_init(
284293
Whether the test data were shuffled or not. Default is False.
285294
surrogate (object):
286295
The surrogate model object. Default is None.
296+
stopping_threshold (float):
297+
Stop training immediately once the monitored quantity reaches this threshold.
298+
Default is None.
287299
strategy (str):
288300
The strategy to use. Default is "auto".
289301
target_column (str):
@@ -355,13 +367,15 @@ def fun_control_init(
355367
'_L_out': 11,
356368
'_L_cond': None,
357369
'accelerator': "auto",
370+
'check_finite': True,
358371
'core_model': None,
359372
'core_model_name': None,
360373
'data': None,
361374
'data_dir': './data',
362375
'db_dict_name': None,
363376
'device': None,
364377
'devices': "auto",
378+
'divergence_threshold': None,
365379
'enable_progress_bar': False,
366380
'eval': None,
367381
'horizon': 7,
@@ -391,6 +405,7 @@ def fun_control_init(
391405
'show_batch_interval': 1000000,
392406
'shuffle': None,
393407
'sigma': 0.0,
408+
'stopping_threshold': None,
394409
'target_column': None,
395410
'target_type': None,
396411
'train': None,
@@ -425,6 +440,7 @@ def fun_control_init(
425440
"_L_cond": _L_cond,
426441
"_torchmetric": _torchmetric,
427442
"accelerator": accelerator,
443+
"check_finite": check_finite,
428444
"collate_fn_name": collate_fn_name,
429445
"converters": converters,
430446
"core_model": core_model,
@@ -433,6 +449,7 @@ def fun_control_init(
433449
"data": data,
434450
"data_dir": data_dir,
435451
"data_full_train": data_full_train,
452+
"divergence_threshold": divergence_threshold,
436453
"hacky": hacky,
437454
"data_module": data_module,
438455
"data_set": data_set,
@@ -497,6 +514,7 @@ def fun_control_init(
497514
"shuffle_val": shuffle_val,
498515
"shuffle_test": shuffle_test,
499516
"sigma": sigma,
517+
"stopping_threshold": stopping_threshold,
500518
"spot_tensorboard_path": spot_tensorboard_path,
501519
"strategy": strategy,
502520
"target_column": target_column,

0 commit comments

Comments
 (0)