@@ -105,18 +105,21 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
105105 pprint.pprint(config)
106106 y = train_model(config, fun_control)
107107 """
108- config_id = generate_config_id_with_timestamp (config = config , timestamp = timestamp )
108+ if fun_control ["data_module" ] is None :
109+ dm = LightDataModule (
110+ dataset = fun_control ["data_set" ],
111+ data_full_train = fun_control ["data_full_train" ],
112+ data_test = fun_control ["data_test" ],
113+ batch_size = config ["batch_size" ],
114+ num_workers = fun_control ["num_workers" ],
115+ test_size = fun_control ["test_size" ],
116+ test_seed = fun_control ["test_seed" ],
117+ scaler = fun_control ["scaler" ],
118+ verbosity = fun_control ["verbosity" ],
119+ )
120+ else :
121+ dm = fun_control ["data_module" ]
109122 model = build_model_instance (config , fun_control )
110- dm = LightDataModule (
111- dataset = fun_control ["data_set" ],
112- data_full_train = fun_control ["data_full_train" ],
113- data_test = fun_control ["data_test" ],
114- batch_size = config ["batch_size" ],
115- num_workers = fun_control ["num_workers" ],
116- test_size = fun_control ["test_size" ],
117- test_seed = fun_control ["test_seed" ],
118- scaler = fun_control ["scaler" ],
119- )
120123 # TODO: Check if this is necessary or if this is handled by the trainer
121124 # dm.setup()
122125 # print(f"train_model(): Test set size: {len(dm.data_test)}")
@@ -168,7 +171,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
168171 # Can be set to 'link' on a local filesystem to create a symbolic link.
169172 # This allows accessing the latest checkpoint in a deterministic manner.
170173 # Default: None.
171-
174+ config_id = generate_config_id_with_timestamp ( config = config , timestamp = timestamp )
172175 callbacks = [EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )]
173176 if not timestamp :
174177 # add ModelCheckpoint only if timestamp is False
0 commit comments