@@ -129,6 +129,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
129129 )
130130 else :
131131 dm = fun_control ["data_module" ]
132+ dm .setup () # Manually call setup to prepare the datasets
132133
133134 model = build_model_instance (config , fun_control )
134135 # TODO: Check if this is necessary or if this is handled by the trainer
@@ -238,7 +239,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
238239 gradient_clip_algorithm = "norm" ,
239240 )
240241
241- trainer .fit (model = model , train_dataloaders = train_dl , ckpt_path = None )
242+ trainer .fit (model = model , train_dataloaders = train_dl , val_dataloaders = test_dl , ckpt_path = None )
242243 result = trainer .validate (model = model , dataloaders = test_dl , ckpt_path = None , verbose = verbose )
243244 result = result [0 ]
244245
@@ -350,10 +351,13 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
350351 # Could also be one of two special keywords "last" and "hpc".
351352 # If there is no checkpoint file at the path, an exception is raised.
352353 try :
353- trainer .fit (model = model , datamodule = dm , ckpt_path = None )
354+ trainer .fit (model = model , train_dataloaders = dm . train_dataloader (), val_dataloaders = dm . val_dataloader () , ckpt_path = None )
354355 except Exception as e :
355356 print (f"train_model(): trainer.fit failed with exception: { e } " )
357+ return None
356358 # Test best model on validation and test set
359+ # The validate and test methods expect a datamodule or dataloaders.
360+ # Using the datamodule is cleaner.
357361 verbose = fun_control ["verbosity" ] > 0
358362
359363 # Validate the model
@@ -455,6 +459,7 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
455459 )
456460 else :
457461 dm = fun_control ["data_module" ]
462+ dm .setup () # Manually call setup to prepare the datasets
458463
459464 model = build_model_instance (config , fun_control )
460465 # TODO: Check if this is necessary or if this is handled by the trainer
@@ -619,10 +624,13 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
619624 # Could also be one of two special keywords "last" and "hpc".
620625 # If there is no checkpoint file at the path, an exception is raised.
621626 try :
622- trainer .fit (model = model , datamodule = dm , ckpt_path = None )
627+ trainer .fit (model = model , train_dataloaders = dm . train_dataloader (), val_dataloaders = dm . val_dataloader () , ckpt_path = None )
623628 except Exception as e :
624629 print (f"train_model(): trainer.fit failed with exception: { e } " )
630+ return None
625631 # Test best model on validation and test set
632+ # The validate and test methods expect a datamodule or dataloaders.
633+ # Using the datamodule is cleaner.
626634 verbose = fun_control ["verbosity" ] > 0
627635
628636 # Validate the model
0 commit comments