Skip to content

Commit 5516035

Browse files
v0.6.6
prepare general data modules
1 parent b4d1409 commit 5516035

6 files changed

Lines changed: 333 additions & 24 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.5"
10+
version = "0.6.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/fun/hyperlight.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
119119
# extract parameters like epochs, batch_size, lr, etc. from config
120120
# config_id = generate_config_id(config)
121121
try:
122+
print("fun: Calling train_model")
122123
df_eval = train_model(config, self.fun_control)
124+
print("fun: train_model returned")
123125
except Exception as err:
124126
logger.error(f"Error in fun(). Call to train_model failed. {err=}, {type(err)=}")
125127
logger.error("Setting df_eval to np.nan")

src/spotPython/light/cifar10datamodule.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def setup(self, stage: Optional[str] = None) -> None:
4848
data_full = CIFAR10(root=self.data_dir, train=True, transform=transform)
4949
# self.data_train, self.data_val = random_split(daata_full, [45000, 5000])
5050
test_abs = int(len(data_full) * 0.6)
51-
print("test_abs", test_abs)
51+
print("dm.setup(): test_abs", test_abs)
5252
self.data_train, self.data_val = random_split(data_full, [test_abs, len(data_full) - test_abs])
5353

5454
# Assign test dataset for use in dataloader(s)
@@ -66,7 +66,7 @@ def train_dataloader(self) -> DataLoader:
6666
DataLoader: The training dataloader.
6767
6868
"""
69-
print("self.batch_size", self.batch_size)
69+
print("train_dataloader: self.batch_size", self.batch_size)
7070
return DataLoader(self.data_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
7171

7272
def val_dataloader(self) -> DataLoader:
@@ -78,6 +78,7 @@ def val_dataloader(self) -> DataLoader:
7878
7979
8080
"""
81+
print("val_dataloader: self.batch_size", self.batch_size)
8182
return DataLoader(self.data_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
8283

8384
def test_dataloader(self) -> DataLoader:
@@ -89,4 +90,5 @@ def test_dataloader(self) -> DataLoader:
8990
9091
9192
"""
93+
print("train_data_loader: self.batch_size", self.batch_size)
9294
return DataLoader(self.data_test, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

src/spotPython/light/netlinearbase.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import lightning as L
22
import torch
33
from torch import nn
4-
import torchmetrics
4+
5+
# import torchmetrics
56
import torch.nn.functional as F
67
from torchmetrics.functional import accuracy
78
from spotPython.hyperparameters.optimizer import optimizer_handler
@@ -77,8 +78,8 @@ def __init__(
7778
patience: int,
7879
_L_in: int,
7980
_L_out: int,
80-
_metric: torchmetrics.functional = accuracy,
81-
_loss: torch.nn.functional = F.cross_entropy,
81+
# _metric: torchmetrics.functional = accuracy,
82+
# _loss: torch.nn.functional = F.cross_entropy,
8283
):
8384
"""
8485
Initializes the NetLightBase object.
@@ -131,6 +132,7 @@ def __init__(
131132
132133
"""
133134
super().__init__()
135+
print("NetLinearBase.__init__(): l1", l1)
134136
# Attribute 'act_fn' is an instance of `nn.Module` and is already saved during
135137
# checkpointing. It is recommended to ignore them
136138
# using `self.save_hyperparameters(ignore=['act_fn'])`
@@ -139,8 +141,8 @@ def __init__(
139141
self._L_in = _L_in
140142
self._L_out = _L_out
141143
# _L_in and _L_out are not hyperparameters, but are needed to create the network
142-
self._metric = _metric
143-
self._loss = _loss
144+
self._metric = accuracy
145+
self._loss = F.cross_entropy
144146
self.save_hyperparameters(ignore=["_L_in", "_L_out", "_metric", "_loss"])
145147
if self.hparams.l1 < 4:
146148
raise ValueError("l1 must be at least 4")
@@ -161,6 +163,7 @@ def __init__(
161163
layers += [nn.Linear(layer_sizes[-1], self._L_out)]
162164
# nn.Sequential summarizes a list of modules into a single module, applying them in sequence
163165
self.layers = nn.Sequential(*layers)
166+
print("Leaving NetLinearBase.__init__()")
164167

165168
def forward(self, x: torch.Tensor) -> torch.Tensor:
166169
"""
@@ -185,6 +188,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
185188
patience=5)
186189
187190
"""
191+
print("Entering NetLinearBase.forward()")
188192
x = self.layers(x)
189193
return F.softmax(x, dim=1)
190194

@@ -213,13 +217,17 @@ def training_step(self, batch: tuple) -> torch.Tensor:
213217
>>> trainer.fit(net_light_base, train_loader)
214218
215219
"""
216-
x, y = batch
217-
logits = self(x)
218-
# compute loss (default: cross entropy loss) from logits and y
219-
loss = self._loss(logits, y)
220-
# self.train_mapk(logits, y)
221-
# self.log("train_mapk", self.train_mapk, on_step=True, on_epoch=False)
222-
return loss
220+
print("Entering NetLinearBase.training_step()")
221+
# x, y = batch
222+
# print("NetLinearBase.training_step(): batch")
223+
# logits = self(x)
224+
# print("NetLinearBase.training_step(): logits")
225+
# # compute loss (default: cross entropy loss) from logits and y
226+
# loss = self._loss(logits, y)
227+
# # self.train_mapk(logits, y)
228+
# # self.log("train_mapk", self.train_mapk, on_step=True, on_epoch=False)
229+
# return loss
230+
return 0.1234
223231

224232
def validation_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False):
225233
"""
@@ -248,6 +256,7 @@ def validation_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False):
248256
>>> trainer.fit(net_light_base, val_loader)
249257
250258
"""
259+
print("Entering NetLinearBase.validation_step()")
251260
x, y = batch
252261
logits = self(x)
253262
# compute cross entropy loss from logits and y
@@ -271,6 +280,7 @@ def test_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> tup
271280
Returns:
272281
tuple: A tuple containing the loss and accuracy for this batch.
273282
"""
283+
print("Entering NetLinearBase.test_step()")
274284
x, y = batch
275285
logits = self(x)
276286
# compute cross entropy loss from logits and y
@@ -291,7 +301,9 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
291301
292302
"""
293303
# optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
304+
print("Entering NetLinearBase.configure_optimizers()")
294305
optimizer = optimizer_handler(
295306
optimizer_name=self.hparams.optimizer, params=self.parameters(), lr_mult=self.hparams.lr_mult
296307
)
308+
print("Leaving NetLinearBase.configure_optimizers()")
297309
return optimizer
Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,15 @@ def train_model(config: dict, fun_control: dict) -> float:
6060
kaiming_init(model)
6161
else:
6262
pass
63-
print(f"model: {model}")
63+
# print(f"model: {model}")
6464

6565
# Init DataModule
6666
dm = CIFAR10DataModule(
6767
batch_size=config["batch_size"], data_dir=fun_control["DATASET_PATH"], num_workers=fun_control["num_workers"]
6868
)
6969
dm.prepare_data()
7070
dm.setup()
71-
72-
dataiter = iter(dm)
73-
images, labels = next(dataiter)
74-
batch_size = 3
75-
# print labels
76-
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
77-
print(" ".join(f"{classes[labels[j]]:5s}" for j in range(batch_size)))
78-
71+
print("Leaving dm.setup()")
7972
# Init trainer
8073
trainer = L.Trainer(
8174
# Where to save models
@@ -90,7 +83,9 @@ def train_model(config: dict, fun_control: dict) -> float:
9083
enable_progress_bar=enable_progress_bar,
9184
)
9285
# Pass the datamodule as arg to trainer.fit to override model hooks :)
86+
print("train.model: Entering trainer.fit()")
9387
trainer.fit(model=model, datamodule=dm)
88+
print("train.model: Leaving trainer.fit()")
9489
# Test best model on validation and test set
9590
# result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
9691
result = trainer.validate(model=model, datamodule=dm)

0 commit comments

Comments
 (0)