Skip to content

Commit b3b7872

Browse files
Model Checkpoint adapted to new tensorflow handling
1 parent 247e1e1 commit b3b7872

3 files changed

Lines changed: 8 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.16"
10+
version = "0.6.17"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/traintest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
152152
logger=TensorBoardLogger(save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True),
153153
callbacks=[
154154
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False),
155-
ModelCheckpoint(save_last=True), # Save the last checkpoint
155+
ModelCheckpoint(
156+
dirpath=os.path.join(fun_control["CHECKPOINT_PATH"], config_id), save_last=True
157+
), # Save the last checkpoint
156158
],
157159
enable_progress_bar=enable_progress_bar,
158160
)
@@ -287,7 +289,8 @@ def load_light_from_checkpoint(config: dict, fun_control: dict, postfix: str = "
287289
>>> model = load_light_from_checkpoint(config, fun_control)
288290
"""
289291
config_id = generate_config_id(config) + postfix
290-
default_root_dir = fun_control["TENSORBOARD_PATH"] + "lightning_logs/" + config_id + "/checkpoints/last.ckpt"
292+
# default_root_dir = fun_control["TENSORBOARD_PATH"] + "lightning_logs/" + config_id + "/checkpoints/last.ckpt"
293+
default_root_dir = os.path.join(fun_control["CHECKPOINT_PATH"], config_id, "last.ckpt")
291294
# default_root_dir = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)
292295
print(f"Loading model from {default_root_dir}")
293296
model = fun_control["core_model"].load_from_checkpoint(

src/spotPython/utils/eda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def generate_config_id(config):
173173
config_id = ""
174174
for key in config:
175175
config_id += str(config[key]) + "_"
176+
# hash the config_id to make it shorter and unique
177+
config_id = str(hash(config_id)) + "_"
176178
return config_id[:-1]
177179

178180

0 commit comments

Comments
 (0)