|
3 | 3 | from torch import nn |
4 | 4 | from spotpython.hyperparameters.optimizer import optimizer_handler |
5 | 5 | import torchmetrics.functional.regression |
| 6 | +import torch.optim as optim |
6 | 7 |
|
7 | 8 |
|
8 | 9 | class NNFunnelRegressor(L.LightningModule): |
@@ -117,10 +118,15 @@ def __init__( |
117 | 118 |
|
118 | 119 | for i in range(self.hparams.num_layers): |
119 | 120 | out_features = max(hidden_size // 2, 8) # Enforce minimum of 8 units |
120 | | - layers += [ |
121 | | - nn.Linear(in_features, hidden_size), |
122 | | - self.hparams.act_fn, |
123 | | - nn.Dropout(self.hparams.dropout_prob),] |
| 121 | + |
| 122 | + layers.append(nn.Linear(in_features, hidden_size)) |
| 123 | + |
| 124 | + if self.hparams.batch_norm: |
| 125 | + layers.append(nn.BatchNorm1d(hidden_size)) # Add BatchNorm if enabled |
| 126 | + |
| 127 | + layers.append(self.hparams.act_fn) |
| 128 | + layers.append(nn.Dropout(self.hparams.dropout_prob)) |
| 129 | + |
124 | 130 | in_features = hidden_size |
125 | 131 | hidden_size = out_features |
126 | 132 |
|
@@ -258,4 +264,22 @@ def configure_optimizers(self) -> torch.optim.Optimizer: |
258 | 264 | """ |
259 | 265 | # optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) |
260 | 266 | optimizer = optimizer_handler(optimizer_name=self.hparams.optimizer, params=self.parameters(), lr_mult=self.hparams.lr_mult) |
261 | | - return optimizer |
| 267 | + |
| 268 | + # If the lr_sched hyperparameter is set to True, we will use a learning rate scheduler. |
| 269 | + if self.hparams.lr_sched: |
| 270 | + num_milestones = 3 # Number of milestones to divide the epochs |
| 271 | + milestones = [int(self.hparams.epochs / (num_milestones + 1) * (i + 1)) for i in range(num_milestones)] |
| 272 | + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) # Decay factor |
| 273 | + |
| 274 | + lr_scheduler_config = { |
| 275 | + "scheduler": scheduler, |
| 276 | + "interval": "epoch", |
| 277 | + "frequency": 1, |
| 278 | + } |
| 279 | + return { |
| 280 | + "optimizer": optimizer, |
| 281 | + "lr_scheduler": lr_scheduler_config, |
| 282 | + } |
| 283 | + # If the lr_sched hyperparameter is not set to True, we return the optimizer only. |
| 284 | + else: |
| 285 | + return optimizer |
0 commit comments