Skip to content

Commit 861b413

Browse files
committed
add batch normalization and lr scheduler to the funnel regression model
1 parent 94e08b9 commit 861b413

2 files changed

Lines changed: 53 additions & 5 deletions

File tree

src/spotpython/hyperdict/light_hyper_dict.json

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,30 @@
794794
"core_model_parameter_type": "str",
795795
"lower": 0,
796796
"upper": 2
797+
},
798+
"batch_norm": {
799+
"levels": [
800+
0,
801+
1
802+
],
803+
"type": "factor",
804+
"default": 0,
805+
"transform": "None",
806+
"core_model_parameter_type": "bool",
807+
"lower": 0,
808+
"upper": 1
809+
},
810+
"lr_sched": {
811+
"levels": [
812+
0,
813+
1
814+
],
815+
"type": "factor",
816+
"default": 0,
817+
"transform": "None",
818+
"core_model_parameter_type": "bool",
819+
"lower": 0,
820+
"upper": 1
797821
}
798822
},
799823
"NNLinearRegressor": {

src/spotpython/light/regression/nn_funnel_regressor.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch import nn
44
from spotpython.hyperparameters.optimizer import optimizer_handler
55
import torchmetrics.functional.regression
6+
import torch.optim as optim
67

78

89
class NNFunnelRegressor(L.LightningModule):
@@ -117,10 +118,15 @@ def __init__(
117118

118119
for i in range(self.hparams.num_layers):
119120
out_features = max(hidden_size // 2, 8) # Enforce minimum of 8 units
120-
layers += [
121-
nn.Linear(in_features, hidden_size),
122-
self.hparams.act_fn,
123-
nn.Dropout(self.hparams.dropout_prob),]
121+
122+
layers.append(nn.Linear(in_features, hidden_size))
123+
124+
if self.hparams.batch_norm:
125+
layers.append(nn.BatchNorm1d(hidden_size)) # Add BatchNorm if enabled
126+
127+
layers.append(self.hparams.act_fn)
128+
layers.append(nn.Dropout(self.hparams.dropout_prob))
129+
124130
in_features = hidden_size
125131
hidden_size = out_features
126132

@@ -258,4 +264,22 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
258264
"""
259265
# optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
260266
optimizer = optimizer_handler(optimizer_name=self.hparams.optimizer, params=self.parameters(), lr_mult=self.hparams.lr_mult)
261-
return optimizer
267+
268+
# If the lr_sched hyperparameter is set to True, we will use a learning rate scheduler.
269+
if self.hparams.lr_sched:
270+
num_milestones = 3 # Number of milestones to divide the epochs
271+
milestones = [int(self.hparams.epochs / (num_milestones + 1) * (i + 1)) for i in range(num_milestones)]
272+
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) # Decay factor
273+
274+
lr_scheduler_config = {
275+
"scheduler": scheduler,
276+
"interval": "epoch",
277+
"frequency": 1,
278+
}
279+
return {
280+
"optimizer": optimizer,
281+
"lr_scheduler": lr_scheduler_config,
282+
}
283+
# If the lr_sched hyperparameter is not set to True, we return the optimizer only.
284+
else:
285+
return optimizer

0 commit comments

Comments
 (0)