Skip to content

Commit 398d8b1

Browse files
conversion to light
1 parent 1fcbfbe commit 398d8b1

8 files changed

Lines changed: 184 additions & 81 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.51"
10+
version = "0.2.52"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/crossvalidationdatamodule.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def __init__(
1313
k: int = 1,
1414
split_seed: int = 42,
1515
num_splits: int = 10,
16-
data_dir: str = "./data",
16+
DATASET_PATH: str = "./data",
1717
num_workers: int = 0,
1818
pin_memory: bool = False,
1919
):
2020
super().__init__()
2121
self.batch_size = batch_size
22-
self.data_dir = data_dir
22+
self.DATASET_PATH = DATASET_PATH
2323
self.num_workers = num_workers
2424
self.k = k
2525
self.split_seed = split_seed
@@ -35,6 +35,7 @@ def __init__(
3535
self.data_val: Optional[Dataset] = None
3636

3737
def prepare_data(self):
38+
# download
3839
pass
3940

4041
def setup(self, stage=None):

src/spotPython/light/csvdatamodule.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66

77
class CSVDataModule(L.LightningDataModule):
8-
def __init__(self, batch_size, data_dir: str = "./data", num_workers: int = 0):
8+
def __init__(self, batch_size, DATASET_PATH: str = "./data", num_workers: int = 0):
99
super().__init__()
1010
self.batch_size = batch_size
11-
self.data_dir = data_dir
1211
self.num_workers = num_workers
1312

1413
def prepare_data(self):

src/spotPython/light/netlightbase.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,38 @@ def __init__(
1818
optimizer,
1919
dropout_prob,
2020
lr_mult,
21-
patience=3,
22-
_L_in=64,
23-
_L_out=11,
21+
patience,
22+
_L_in,
23+
_L_out,
2424
):
2525
super().__init__()
26-
27-
# Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing.
28-
# It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`
29-
self.save_hyperparameters(ignore=["act_fn"])
26+
# Attribute 'act_fn' is an instance of `nn.Module` and is already saved during
27+
# checkpointing. It is recommended to ignore them
28+
# using `self.save_hyperparameters(ignore=['act_fn'])`
29+
# self.save_hyperparameters(ignore=["act_fn"])
30+
#
31+
self._L_in = _L_in
3032
self._L_out = _L_out
31-
if l1 < 4:
33+
# _L_in and _L_out are not hyperparameters, but are needed to create the network
34+
self.save_hyperparameters(ignore=["_L_in", "_L_out"])
35+
if self.hparams.l1 < 4:
3236
raise ValueError("l1 must be at least 4")
33-
self.l1 = l1
34-
hidden_sizes = [l1, l1 // 2, l1 // 2, l1 // 4]
35-
self.epochs = epochs
36-
self.patience = patience
37-
self.batch_size = batch_size
38-
self.initialization = initialization
39-
self.act_fn = act_fn
40-
self.optimizer = optimizer
41-
self.dropout_prob = dropout_prob
42-
self.lr_mult = lr_mult
37+
38+
hidden_sizes = [self.hparams.l1, self.hparams.l1 // 2, self.hparams.l1 // 2, self.hparams.l1 // 4]
4339
self.train_mapk = MAPK(k=3)
4440
self.valid_mapk = MAPK(k=3)
4541
self.test_mapk = MAPK(k=3)
4642

4743
# Create the network based on the specified hidden sizes
4844
layers = []
49-
layer_sizes = [_L_in] + hidden_sizes
45+
layer_sizes = [self._L_in] + hidden_sizes
5046
layer_size_last = layer_sizes[0]
5147
for layer_size in layer_sizes[1:]:
52-
layers += [nn.Linear(layer_size_last, layer_size), act_fn, nn.Dropout(self.dropout_prob)]
48+
layers += [
49+
nn.Linear(layer_size_last, layer_size),
50+
self.hparams.act_fn,
51+
nn.Dropout(self.hparams.dropout_prob),
52+
]
5353
layer_size_last = layer_size
5454
layers += [nn.Linear(layer_sizes[-1], self._L_out)]
5555
# nn.Sequential summarizes a list of modules into a single module, applying them in sequence
@@ -98,5 +98,7 @@ def test_step(self, batch, batch_idx, prog_bar=False):
9898

9999
def configure_optimizers(self):
100100
# optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
101-
optimizer = optimizer_handler(optimizer_name=self.optimizer, params=self.parameters(), lr_mult=self.lr_mult)
101+
optimizer = optimizer_handler(
102+
optimizer_name=self.hparams.optimizer, params=self.parameters(), lr_mult=self.hparams.lr_mult
103+
)
102104
return optimizer

src/spotPython/light/traintest.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,23 @@
33
from spotPython.light.crossvalidationdatamodule import CrossValidationDataModule
44
from spotPython.utils.eda import generate_config_id
55
from pytorch_lightning.loggers import TensorBoardLogger
6+
from lightning.pytorch.callbacks import ModelCheckpoint
67
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
78
from spotPython.torch.initialization import kaiming_init, xavier_init
9+
import os
810

911

1012
def train_model(config, fun_control):
13+
_L_in = fun_control["_L_in"]
14+
_L_out = fun_control["_L_out"]
15+
print(f"_L_in: {_L_in}")
16+
print(f"_L_out: {_L_out}")
1117
if fun_control["enable_progress_bar"] is None:
1218
enable_progress_bar = False
1319
else:
1420
enable_progress_bar = fun_control["enable_progress_bar"]
1521
config_id = generate_config_id(config)
16-
# Init DataModule
17-
dm = CSVDataModule(
18-
batch_size=config["batch_size"], num_workers=fun_control["num_workers"], data_dir=fun_control["data_dir"]
19-
)
20-
# Init model from datamodule's attributes
21-
# model = LitModel(*dm.dims, dm.num_classes)
22-
model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
22+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
2323
initialization = config["initialization"]
2424
if initialization == "Xavier":
2525
xavier_init(model)
@@ -28,12 +28,22 @@ def train_model(config, fun_control):
2828
else:
2929
pass
3030
print(f"model: {model}")
31+
32+
# Init DataModule
33+
dm = CSVDataModule(
34+
batch_size=config["batch_size"],
35+
num_workers=fun_control["num_workers"],
36+
DATASET_PATH=fun_control["DATASET_PATH"],
37+
)
38+
3139
# Init trainer
3240
trainer = L.Trainer(
33-
max_epochs=model.epochs,
41+
# Where to save models
42+
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
43+
max_epochs=model.hparams.epochs,
3444
accelerator="auto",
3545
devices=1,
36-
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id, default_hp_metric=True),
46+
logger=TensorBoardLogger(save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True),
3747
callbacks=[
3848
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
3949
],
@@ -51,18 +61,22 @@ def train_model(config, fun_control):
5161

5262

5363
def test_model(config, fun_control):
64+
_L_in = fun_control["_L_in"]
65+
_L_out = fun_control["_L_out"]
5466
if fun_control["enable_progress_bar"] is None:
5567
enable_progress_bar = False
5668
else:
5769
enable_progress_bar = fun_control["enable_progress_bar"]
58-
config_id = generate_config_id(config)
70+
# Add "TEST" postfix to config_id
71+
config_id = generate_config_id(config) + "_TEST"
5972
# Init DataModule
6073
dm = CSVDataModule(
61-
batch_size=config["batch_size"], num_workers=fun_control["num_workers"], data_dir=fun_control["data_dir"]
74+
batch_size=config["batch_size"],
75+
num_workers=fun_control["num_workers"],
76+
DATASET_PATH=fun_control["DATASET_PATH"],
6277
)
6378
# Init model from datamodule's attributes
64-
# model = LitModel(*dm.dims, dm.num_classes)
65-
model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
79+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
6680
initialization = config["initialization"]
6781
if initialization == "Xavier":
6882
xavier_init(model)
@@ -73,12 +87,15 @@ def test_model(config, fun_control):
7387
print(f"model: {model}")
7488
# Init trainer
7589
trainer = L.Trainer(
76-
max_epochs=model.epochs,
90+
# Where to save models
91+
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
92+
max_epochs=model.hparams.epochs,
7793
accelerator="auto",
7894
devices=1,
79-
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id, default_hp_metric=True),
95+
logger=TensorBoardLogger(save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True),
8096
callbacks=[
81-
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
97+
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False),
98+
ModelCheckpoint(save_last=True), # Save the last checkpoint
8299
],
83100
enable_progress_bar=enable_progress_bar,
84101
)
@@ -91,15 +108,18 @@ def test_model(config, fun_control):
91108

92109

93110
def cv_model(config, fun_control):
94-
config_id = generate_config_id(config)
111+
_L_in = fun_control["_L_in"]
112+
_L_out = fun_control["_L_out"]
95113
if fun_control["enable_progress_bar"] is None:
96114
enable_progress_bar = False
97115
else:
98116
enable_progress_bar = fun_control["enable_progress_bar"]
117+
# Add "CV" postfix to config_id
118+
config_id = generate_config_id(config) + "_CV"
99119
results = []
100120
num_folds = 10
101121
split_seed = 12345
102-
model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
122+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
103123
initialization = config["initialization"]
104124
if initialization == "Xavier":
105125
xavier_init(model)
@@ -116,7 +136,7 @@ def cv_model(config, fun_control):
116136
num_splits=num_folds,
117137
split_seed=split_seed,
118138
batch_size=config["batch_size"],
119-
data_dir=fun_control["data_dir"],
139+
DATASET_PATH=fun_control["DATASET_PATH"],
120140
)
121141
dm.prepare_data()
122142
dm.setup()
@@ -125,11 +145,13 @@ def cv_model(config, fun_control):
125145
print(f"model: {model}")
126146
# Init trainer
127147
trainer = L.Trainer(
128-
max_epochs=model.epochs,
148+
# Where to save models
149+
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
150+
max_epochs=model.hparams.epochs,
129151
accelerator="auto",
130152
devices=1,
131153
logger=TensorBoardLogger(
132-
save_dir=fun_control["tensorboard_path"], version=config_id, default_hp_metric=True
154+
save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True
133155
),
134156
callbacks=[
135157
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
@@ -150,3 +172,16 @@ def cv_model(config, fun_control):
150172
mapk_score = sum(results) / num_folds
151173
print(f"cv_model mapk result: {mapk_score}")
152174
return mapk_score
175+
176+
177+
def load_light_from_checkpoint(config, fun_control, postfix="_TEST"):
178+
config_id = generate_config_id(config) + postfix
179+
default_root_dir = fun_control["TENSORBOARD_PATH"] + "lightning_logs/" + config_id + "/checkpoints/last.ckpt"
180+
# default_root_dir = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)
181+
print(f"Loading model from {default_root_dir}")
182+
model = fun_control["core_model"].load_from_checkpoint(
183+
default_root_dir, _L_in=fun_control["_L_in"], _L_out=fun_control["_L_out"]
184+
)
185+
# disable randomness, dropout, etc...
186+
model.eval()
187+
return model

src/spotPython/light/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from spotPython.hyperparameters.values import get_one_config_from_X
2+
3+
4+
def get_tuned_architecture(spot_tuner, fun_control):
5+
X = spot_tuner.to_all_dim(spot_tuner.min_X.reshape(1, -1))
6+
config = get_one_config_from_X(X, fun_control)
7+
return config

src/spotPython/utils/eda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def visualize_activations(net, device="cpu", color="C0"):
167167
sns.histplot(data=activations[key], bins=50, ax=key_ax, color=color, kde=True, stat="density")
168168
key_ax.set_title(f"Layer {key} - {net.layers[key].__class__.__name__}")
169169
fig_index += 1
170-
fig.suptitle(f"Activation distribution for activation function {net.act_fn}", fontsize=14)
170+
fig.suptitle(f"Activation distribution for activation function {net.hparams.act_fn}", fontsize=14)
171171
fig.subplots_adjust(hspace=0.4, wspace=0.4)
172172
plt.show()
173173
plt.close()

0 commit comments

Comments
 (0)