33from spotPython .light .crossvalidationdatamodule import CrossValidationDataModule
44from spotPython .utils .eda import generate_config_id
55from pytorch_lightning .loggers import TensorBoardLogger
6+ from lightning .pytorch .callbacks import ModelCheckpoint
67from lightning .pytorch .callbacks .early_stopping import EarlyStopping
78from spotPython .torch .initialization import kaiming_init , xavier_init
9+ import os
810
911
1012def train_model (config , fun_control ):
13+ _L_in = fun_control ["_L_in" ]
14+ _L_out = fun_control ["_L_out" ]
15+ print (f"_L_in: { _L_in } " )
16+ print (f"_L_out: { _L_out } " )
1117 if fun_control ["enable_progress_bar" ] is None :
1218 enable_progress_bar = False
1319 else :
1420 enable_progress_bar = fun_control ["enable_progress_bar" ]
1521 config_id = generate_config_id (config )
16- # Init DataModule
17- dm = CSVDataModule (
18- batch_size = config ["batch_size" ], num_workers = fun_control ["num_workers" ], data_dir = fun_control ["data_dir" ]
19- )
20- # Init model from datamodule's attributes
21- # model = LitModel(*dm.dims, dm.num_classes)
22- model = fun_control ["core_model" ](** config , _L_in = 64 , _L_out = 11 )
22+ model = fun_control ["core_model" ](** config , _L_in = _L_in , _L_out = _L_out )
2323 initialization = config ["initialization" ]
2424 if initialization == "Xavier" :
2525 xavier_init (model )
@@ -28,12 +28,22 @@ def train_model(config, fun_control):
2828 else :
2929 pass
3030 print (f"model: { model } " )
31+
32+ # Init DataModule
33+ dm = CSVDataModule (
34+ batch_size = config ["batch_size" ],
35+ num_workers = fun_control ["num_workers" ],
36+ DATASET_PATH = fun_control ["DATASET_PATH" ],
37+ )
38+
3139 # Init trainer
3240 trainer = L .Trainer (
33- max_epochs = model .epochs ,
41+ # Where to save models
42+ default_root_dir = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id ),
43+ max_epochs = model .hparams .epochs ,
3444 accelerator = "auto" ,
3545 devices = 1 ,
36- logger = TensorBoardLogger (save_dir = fun_control ["tensorboard_path " ], version = config_id , default_hp_metric = True ),
46+ logger = TensorBoardLogger (save_dir = fun_control ["TENSORBOARD_PATH " ], version = config_id , default_hp_metric = True ),
3747 callbacks = [
3848 EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )
3949 ],
@@ -51,18 +61,22 @@ def train_model(config, fun_control):
5161
5262
5363def test_model (config , fun_control ):
64+ _L_in = fun_control ["_L_in" ]
65+ _L_out = fun_control ["_L_out" ]
5466 if fun_control ["enable_progress_bar" ] is None :
5567 enable_progress_bar = False
5668 else :
5769 enable_progress_bar = fun_control ["enable_progress_bar" ]
58- config_id = generate_config_id (config )
70+ # Add "TEST" postfix to config_id
71+ config_id = generate_config_id (config ) + "_TEST"
5972 # Init DataModule
6073 dm = CSVDataModule (
61- batch_size = config ["batch_size" ], num_workers = fun_control ["num_workers" ], data_dir = fun_control ["data_dir" ]
74+ batch_size = config ["batch_size" ],
75+ num_workers = fun_control ["num_workers" ],
76+ DATASET_PATH = fun_control ["DATASET_PATH" ],
6277 )
6378 # Init model from datamodule's attributes
64- # model = LitModel(*dm.dims, dm.num_classes)
65- model = fun_control ["core_model" ](** config , _L_in = 64 , _L_out = 11 )
79+ model = fun_control ["core_model" ](** config , _L_in = _L_in , _L_out = _L_out )
6680 initialization = config ["initialization" ]
6781 if initialization == "Xavier" :
6882 xavier_init (model )
@@ -73,12 +87,15 @@ def test_model(config, fun_control):
7387 print (f"model: { model } " )
7488 # Init trainer
7589 trainer = L .Trainer (
76- max_epochs = model .epochs ,
90+ # Where to save models
91+ default_root_dir = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id ),
92+ max_epochs = model .hparams .epochs ,
7793 accelerator = "auto" ,
7894 devices = 1 ,
79- logger = TensorBoardLogger (save_dir = fun_control ["tensorboard_path " ], version = config_id , default_hp_metric = True ),
95+ logger = TensorBoardLogger (save_dir = fun_control ["TENSORBOARD_PATH " ], version = config_id , default_hp_metric = True ),
8096 callbacks = [
81- EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )
97+ EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False ),
98+ ModelCheckpoint (save_last = True ), # Save the last checkpoint
8299 ],
83100 enable_progress_bar = enable_progress_bar ,
84101 )
@@ -91,15 +108,18 @@ def test_model(config, fun_control):
91108
92109
93110def cv_model (config , fun_control ):
94- config_id = generate_config_id (config )
111+ _L_in = fun_control ["_L_in" ]
112+ _L_out = fun_control ["_L_out" ]
95113 if fun_control ["enable_progress_bar" ] is None :
96114 enable_progress_bar = False
97115 else :
98116 enable_progress_bar = fun_control ["enable_progress_bar" ]
117+ # Add "CV" postfix to config_id
118+ config_id = generate_config_id (config ) + "_CV"
99119 results = []
100120 num_folds = 10
101121 split_seed = 12345
102- model = fun_control ["core_model" ](** config , _L_in = 64 , _L_out = 11 )
122+ model = fun_control ["core_model" ](** config , _L_in = _L_in , _L_out = _L_out )
103123 initialization = config ["initialization" ]
104124 if initialization == "Xavier" :
105125 xavier_init (model )
@@ -116,7 +136,7 @@ def cv_model(config, fun_control):
116136 num_splits = num_folds ,
117137 split_seed = split_seed ,
118138 batch_size = config ["batch_size" ],
119- data_dir = fun_control ["data_dir " ],
139+ DATASET_PATH = fun_control ["DATASET_PATH " ],
120140 )
121141 dm .prepare_data ()
122142 dm .setup ()
@@ -125,11 +145,13 @@ def cv_model(config, fun_control):
125145 print (f"model: { model } " )
126146 # Init trainer
127147 trainer = L .Trainer (
128- max_epochs = model .epochs ,
148+ # Where to save models
149+ default_root_dir = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id ),
150+ max_epochs = model .hparams .epochs ,
129151 accelerator = "auto" ,
130152 devices = 1 ,
131153 logger = TensorBoardLogger (
132- save_dir = fun_control ["tensorboard_path " ], version = config_id , default_hp_metric = True
154+ save_dir = fun_control ["TENSORBOARD_PATH " ], version = config_id , default_hp_metric = True
133155 ),
134156 callbacks = [
135157 EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )
@@ -150,3 +172,16 @@ def cv_model(config, fun_control):
150172 mapk_score = sum (results ) / num_folds
151173 print (f"cv_model mapk result: { mapk_score } " )
152174 return mapk_score
175+
176+
177+ def load_light_from_checkpoint (config , fun_control , postfix = "_TEST" ):
178+ config_id = generate_config_id (config ) + postfix
179+ default_root_dir = fun_control ["TENSORBOARD_PATH" ] + "lightning_logs/" + config_id + "/checkpoints/last.ckpt"
180+ # default_root_dir = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)
181+ print (f"Loading model from { default_root_dir } " )
182+ model = fun_control ["core_model" ].load_from_checkpoint (
183+ default_root_dir , _L_in = fun_control ["_L_in" ], _L_out = fun_control ["_L_out" ]
184+ )
185+ # disable randomness, dropout, etc...
186+ model .eval ()
187+ return model
0 commit comments