Skip to content

Commit df97531

Browse files
optimizer act_fn
1 parent 7915bf1 commit df97531

2 files changed

Lines changed: 16 additions & 10 deletions

File tree

src/spotPython/data/light_hyper_dict.json

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
"lower": 1,
2121
"upper": 4},
2222
"act_fn": {
23-
"levels": ["ReLU"],
23+
"levels": ["Sigmoid", "Tanh", "ReLU", "LeakyReLU", "ELU", "Swish"],
2424
"type": "factor",
2525
"default": "ReLU",
2626
"transform": "None",
27-
"class_name": "torch.nn",
27+
"class_name": "spotPython.torch.activation",
2828
"core_model_parameter_type": "instance()",
2929
"lower": 0,
30-
"upper": 0},
30+
"upper": 2},
3131
"optimizer": {
3232
"levels": ["Adadelta", "Adagrad", "Adam", "AdamW", "SparseAdam", "Adamax", "ASGD", "NAdam", "RAdam", "RMSprop", "Rprop", "SGD"],
3333
"type": "factor",
@@ -36,13 +36,19 @@
3636
"class_name": "torch.optim",
3737
"core_model_parameter_type": "str",
3838
"lower": 0,
39-
"upper": 12},
39+
"upper": 11},
4040
"dropout_prob": {
4141
"type": "float",
4242
"default": 0.01,
4343
"transform": "None",
4444
"lower": 0.0,
45-
"upper": 0.1}
45+
"upper": 0.1},
46+
"lr_mult": {
47+
"type": "float",
48+
"default": 1.0,
49+
"transform": "None",
50+
"lower": 0.1,
51+
"upper": 10.0}
4652
},
4753
"LitModel":
4854
{

src/spotPython/light/netlightbase.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from torch import nn
55
from torchmetrics.functional import accuracy
66
from spotPython.torch.mapk import MAPK
7+
from spotPython.hyperparameters.optimizer import optimizer_handler
78

89

910
class NetLightBase(L.LightningModule):
10-
def __init__(
11-
self, l1, epochs, batch_size, act_fn, optimizer, dropout_prob, learning_rate=2e-4, _L_in=64, _L_out=11
12-
):
11+
def __init__(self, l1, epochs, batch_size, act_fn, optimizer, dropout_prob, lr_mult, _L_in=64, _L_out=11):
1312
super().__init__()
1413

1514
# We take in input dimensions as parameters and use those to dynamically build model.
@@ -23,7 +22,7 @@ def __init__(
2322
self.act_fn = act_fn
2423
self.optimizer = optimizer
2524
self.dropout_prob = dropout_prob
26-
self.learning_rate = learning_rate
25+
self.lr_mult = lr_mult
2726
self.train_mapk = MAPK(k=3)
2827
self.valid_mapk = MAPK(k=3)
2928
self.test_mapk = MAPK(k=3)
@@ -79,5 +78,6 @@ def test_step(self, batch, batch_idx, prog_bar=False):
7978
return loss, acc
8079

8180
def configure_optimizers(self):
82-
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
81+
# optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
82+
optimizer = optimizer_handler(optimizer_name=self.optimizer, params=self.parameters(), lr_mult=self.lr_mult)
8383
return optimizer

0 commit comments

Comments
 (0)