11import lightning as L
22import torch
33from torch import nn
4- import torchmetrics
4+
5+ # import torchmetrics
56import torch .nn .functional as F
67from torchmetrics .functional import accuracy
78from spotPython .hyperparameters .optimizer import optimizer_handler
@@ -77,8 +78,8 @@ def __init__(
7778 patience : int ,
7879 _L_in : int ,
7980 _L_out : int ,
80- _metric : torchmetrics .functional = accuracy ,
81- _loss : torch .nn .functional = F .cross_entropy ,
81+ # _metric: torchmetrics.functional = accuracy,
82+ # _loss: torch.nn.functional = F.cross_entropy,
8283 ):
8384 """
8485 Initializes the NetLightBase object.
@@ -131,6 +132,7 @@ def __init__(
131132
132133 """
133134 super ().__init__ ()
135+ print ("NetLinearBase.__init__(): l1" , l1 )
134136 # Attribute 'act_fn' is an instance of `nn.Module` and is already saved during
135137 # checkpointing. It is recommended to ignore them
136138 # using `self.save_hyperparameters(ignore=['act_fn'])`
@@ -139,8 +141,8 @@ def __init__(
139141 self ._L_in = _L_in
140142 self ._L_out = _L_out
141143 # _L_in and _L_out are not hyperparameters, but are needed to create the network
142- self ._metric = _metric
143- self ._loss = _loss
144+ self ._metric = accuracy
145+ self ._loss = F . cross_entropy
144146 self .save_hyperparameters (ignore = ["_L_in" , "_L_out" , "_metric" , "_loss" ])
145147 if self .hparams .l1 < 4 :
146148 raise ValueError ("l1 must be at least 4" )
@@ -161,6 +163,7 @@ def __init__(
161163 layers += [nn .Linear (layer_sizes [- 1 ], self ._L_out )]
162164 # nn.Sequential summarizes a list of modules into a single module, applying them in sequence
163165 self .layers = nn .Sequential (* layers )
166+ print ("Leaving NetLinearBase.__init__()" )
164167
165168 def forward (self , x : torch .Tensor ) -> torch .Tensor :
166169 """
@@ -185,6 +188,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
185188 patience=5)
186189
187190 """
191+ print ("Entering NetLinearBase.forward()" )
188192 x = self .layers (x )
189193 return F .softmax (x , dim = 1 )
190194
@@ -213,13 +217,17 @@ def training_step(self, batch: tuple) -> torch.Tensor:
213217 >>> trainer.fit(net_light_base, train_loader)
214218
215219 """
216- x , y = batch
217- logits = self (x )
218- # compute loss (default: cross entropy loss) from logits and y
219- loss = self ._loss (logits , y )
220- # self.train_mapk(logits, y)
221- # self.log("train_mapk", self.train_mapk, on_step=True, on_epoch=False)
222- return loss
220+ print ("Entering NetLinearBase.training_step()" )
221+ # x, y = batch
222+ # print("NetLinearBase.training_step(): batch")
223+ # logits = self(x)
224+ # print("NetLinearBase.training_step(): logits")
225+ # # compute loss (default: cross entropy loss) from logits and y
226+ # loss = self._loss(logits, y)
227+ # # self.train_mapk(logits, y)
228+ # # self.log("train_mapk", self.train_mapk, on_step=True, on_epoch=False)
229+ # return loss
230+ return 0.1234
223231
224232 def validation_step (self , batch : tuple , batch_idx : int , prog_bar : bool = False ):
225233 """
@@ -248,6 +256,7 @@ def validation_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False):
248256 >>> trainer.fit(net_light_base, val_loader)
249257
250258 """
259+ print ("Entering NetLinearBase.validation_step()" )
251260 x , y = batch
252261 logits = self (x )
253262 # compute cross entropy loss from logits and y
@@ -271,6 +280,7 @@ def test_step(self, batch: tuple, batch_idx: int, prog_bar: bool = False) -> tup
271280 Returns:
272281 tuple: A tuple containing the loss and accuracy for this batch.
273282 """
283+ print ("Entering NetLinearBase.test_step()" )
274284 x , y = batch
275285 logits = self (x )
276286 # compute cross entropy loss from logits and y
@@ -291,7 +301,9 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
291301
292302 """
293303 # optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
304+ print ("Entering NetLinearBase.configure_optimizers()" )
294305 optimizer = optimizer_handler (
295306 optimizer_name = self .hparams .optimizer , params = self .parameters (), lr_mult = self .hparams .lr_mult
296307 )
308+ print ("Leaving NetLinearBase.configure_optimizers()" )
297309 return optimizer
0 commit comments