|
1 | 1 | import lightning as L |
2 | | -from spotpython.data.lightdatamodule import LightDataModule |
| 2 | +from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany |
3 | 3 | from spotpython.utils.eda import generate_config_id |
4 | 4 | from pytorch_lightning.loggers import TensorBoardLogger |
5 | 5 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
6 | 6 | from lightning.pytorch.callbacks import ModelCheckpoint |
| 7 | +from torch.utils.data import DataLoader |
| 8 | +import torch |
7 | 9 | import os |
8 | 10 |
|
| 11 | +import numpy as np |
| 12 | + |
9 | 13 |
|
10 | 14 | def generate_config_id_with_timestamp(config: dict, timestamp: bool) -> str: |
11 | 15 | """ |
@@ -124,6 +128,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa |
124 | 128 | ) |
125 | 129 | else: |
126 | 130 | dm = fun_control["data_module"] |
| 131 | + |
127 | 132 | model = build_model_instance(config, fun_control) |
128 | 133 | # TODO: Check if this is necessary or if this is handled by the trainer |
129 | 134 | # dm.setup() |
@@ -183,6 +188,63 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa |
183 | 188 | dirpath = os.path.join(fun_control["CHECKPOINT_PATH"], config_id) |
184 | 189 | callbacks.append(ModelCheckpoint(dirpath=dirpath, monitor=None, verbose=False, save_last=True)) # Save the last checkpoint |
185 | 190 |
|
| 191 | + if fun_control["hacky"]: |
| 192 | + verbose = fun_control["verbosity"] > 0 |
| 193 | + ds = fun_control["data_full_train"] |
| 194 | + indices = list(range(len(ds))) |
| 195 | + indice_results_val_loss = [] |
| 196 | + indice_results_hp_metric = [] |
| 197 | + for i in indices: |
| 198 | + print(f"train_model(): Hacky Implementation with Index {i}") |
| 199 | + test_indices = [indices[i]] |
| 200 | + train_indices = [index for index in indices if index != test_indices[0]] |
| 201 | + |
| 202 | + train_dataset = torch.utils.data.Subset(ds, train_indices) |
| 203 | + test_dataset = torch.utils.data.Subset(ds, test_indices) |
| 204 | + |
| 205 | + train_dl = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=PadSequenceManyToMany()) |
| 206 | + test_dl = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=PadSequenceManyToMany()) |
| 207 | + |
| 208 | + model = build_model_instance(config, fun_control) |
| 209 | + |
| 210 | + enable_progress_bar = fun_control["enable_progress_bar"] or False |
| 211 | + trainer = L.Trainer( |
| 212 | + # Where to save models |
| 213 | + default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id), |
| 214 | + max_epochs=model.hparams.epochs, |
| 215 | + accelerator=fun_control["accelerator"], |
| 216 | + devices=fun_control["devices"], |
| 217 | + strategy=fun_control["strategy"], |
| 218 | + num_nodes=fun_control["num_nodes"], |
| 219 | + precision=fun_control["precision"], |
| 220 | + logger=TensorBoardLogger(save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True, log_graph=fun_control["log_graph"], name=""), |
| 221 | + callbacks=callbacks, |
| 222 | + enable_progress_bar=enable_progress_bar, |
| 223 | + num_sanity_val_steps=fun_control["num_sanity_val_steps"], |
| 224 | + log_every_n_steps=fun_control["log_every_n_steps"], |
| 225 | + gradient_clip_val=None, |
| 226 | + gradient_clip_algorithm="norm", |
| 227 | + ) |
| 228 | + |
| 229 | + trainer.fit(model=model, train_dataloaders=train_dl, ckpt_path=None) |
| 230 | + result = trainer.validate(model=model, dataloaders=test_dl, ckpt_path=None, verbose=verbose) |
| 231 | + result = result[0] |
| 232 | + |
| 233 | + print(f"results_dict: {result}") |
| 234 | + |
| 235 | + indice_results_val_loss.append(result["val_loss"]) |
| 236 | + indice_results_hp_metric.append(result["hp_metric"]) |
| 237 | + |
| 238 | + mean_val_loss = np.mean(indice_results_val_loss) |
| 239 | + mean_hp_metric = np.mean(indice_results_hp_metric) |
| 240 | + |
| 241 | + print(f"train_model(): Mean Validation Loss: {mean_val_loss}") |
| 242 | + print(f"train_model(): Mean Hyperparameter Metric: {mean_hp_metric}") |
| 243 | + |
| 244 | + results_dict = {"val_loss": mean_val_loss, "hp_metric": mean_hp_metric} |
| 245 | + |
| 246 | + return results_dict["val_loss"] |
| 247 | + |
186 | 248 | # Tensorboard logger. The tensorboard is passed to the trainer. |
187 | 249 | # See: https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.TensorBoardLogger.html |
188 | 250 | # It uses the following arguments: |
|
0 commit comments