@@ -152,7 +152,9 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
152152 logger = TensorBoardLogger (save_dir = fun_control ["TENSORBOARD_PATH" ], version = config_id , default_hp_metric = True ),
153153 callbacks = [
154154 EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False ),
155- ModelCheckpoint (save_last = True ), # Save the last checkpoint
155+ ModelCheckpoint (
156+ dirpath = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id ), save_last = True
157+ ), # Save the last checkpoint
156158 ],
157159 enable_progress_bar = enable_progress_bar ,
158160 )
@@ -287,7 +289,8 @@ def load_light_from_checkpoint(config: dict, fun_control: dict, postfix: str = "
287289 >>> model = load_light_from_checkpoint(config, fun_control)
288290 """
289291 config_id = generate_config_id (config ) + postfix
290- default_root_dir = fun_control ["TENSORBOARD_PATH" ] + "lightning_logs/" + config_id + "/checkpoints/last.ckpt"
292+ # default_root_dir = fun_control["TENSORBOARD_PATH"] + "lightning_logs/" + config_id + "/checkpoints/last.ckpt"
293+ default_root_dir = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id , "last.ckpt" )
291294 # default_root_dir = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)
292295 print (f"Loading model from { default_root_dir } " )
293296 model = fun_control ["core_model" ].load_from_checkpoint (
0 commit comments