Skip to content

Commit e8d5a12

Browse files
v0.2.49
1 parent 256491d commit e8d5a12

5 files changed

Lines changed: 80 additions & 27 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.48"
10+
version = "0.2.49"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/light_hyper_dict.json

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,32 @@
2020
"lower": 1,
2121
"upper": 4},
2222
"act_fn": {
23-
"levels": ["Sigmoid", "Tanh", "ReLU", "LeakyReLU", "ELU", "Swish"],
23+
"levels": ["Sigmoid",
24+
"Tanh",
25+
"ReLU",
26+
"LeakyReLU",
27+
"ELU",
28+
"Swish"],
2429
"type": "factor",
2530
"default": "ReLU",
2631
"transform": "None",
27-
"class_name": "spotPython.torch.activation",
32+
"class_name": "spotPython.torch.activation",
2833
"core_model_parameter_type": "instance()",
2934
"lower": 0,
3035
"upper": 2},
3136
"optimizer": {
32-
"levels": ["Adadelta", "Adagrad", "Adam", "AdamW", "SparseAdam", "Adamax", "ASGD", "NAdam", "RAdam", "RMSprop", "Rprop", "SGD"],
37+
"levels": ["Adadelta",
38+
"Adagrad",
39+
"Adam",
40+
"AdamW",
41+
"SparseAdam",
42+
"Adamax",
43+
"ASGD",
44+
"NAdam",
45+
"RAdam",
46+
"RMSprop",
47+
"Rprop",
48+
"SGD"],
3349
"type": "factor",
3450
"default": "SGD",
3551
"transform": "None",
@@ -48,7 +64,14 @@
4864
"default": 1.0,
4965
"transform": "None",
5066
"lower": 0.1,
51-
"upper": 10.0}
67+
"upper": 10.0},
68+
"patience": {
69+
"type": "int",
70+
"default": 2,
71+
"transform": "transform_power_2_int",
72+
"lower": 2,
73+
"upper": 6
74+
}
5275
},
5376
"LitModel":
5477
{
@@ -87,6 +110,6 @@
87110
"class_name": "torch.optim",
88111
"core_model_parameter_type": "str",
89112
"lower": 0,
90-
"upper": 12}
113+
"upper": 11}
91114
}
92115
}

src/spotPython/light/netlightbase.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,18 @@
88

99

1010
class NetLightBase(L.LightningModule):
11-
def __init__(self, l1, epochs, batch_size, act_fn, optimizer, dropout_prob, lr_mult, _L_in=64, _L_out=11):
11+
def __init__(
12+
self, l1, epochs, batch_size, act_fn, optimizer, dropout_prob, lr_mult, patience=3, _L_in=64, _L_out=11
13+
):
1214
super().__init__()
13-
14-
# We take in input dimensions as parameters and use those to dynamically build model.
15+
self.save_hyperparameters()
1516
self._L_out = _L_out
1617
if l1 < 4:
1718
raise ValueError("l1 must be at least 4")
1819
self.l1 = l1
1920
hidden_sizes = [l1, l1 // 2, l1 // 2, l1 // 4]
2021
self.epochs = epochs
22+
self.patience = patience
2123
self.batch_size = batch_size
2224
self.act_fn = act_fn
2325
self.optimizer = optimizer
@@ -63,6 +65,7 @@ def validation_step(self, batch, batch_idx, prog_bar=False):
6365
self.log("valid_mapk", self.valid_mapk, on_step=False, on_epoch=True, prog_bar=prog_bar)
6466
self.log("val_loss", loss, prog_bar=prog_bar)
6567
self.log("val_acc", acc, prog_bar=prog_bar)
68+
self.log("hp_metric", loss)
6669

6770
def test_step(self, batch, batch_idx, prog_bar=False):
6871
x, y = batch
@@ -75,6 +78,7 @@ def test_step(self, batch, batch_idx, prog_bar=False):
7578
self.log("test_mapk", self.test_mapk, on_step=True, on_epoch=True, prog_bar=prog_bar)
7679
self.log("val_loss", loss, prog_bar=prog_bar)
7780
self.log("val_acc", acc, prog_bar=prog_bar)
81+
# self.log("hp_metric", loss, prog_bar=prog_bar)
7882
return loss, acc
7983

8084
def configure_optimizers(self):

src/spotPython/light/traintest.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88

99
def train_model(config, fun_control):
10+
if fun_control["enable_progress_bar"] is None:
11+
enable_progress_bar = False
12+
else:
13+
enable_progress_bar = fun_control["enable_progress_bar"]
1014
config_id = generate_config_id(config)
1115
# Init DataModule
1216
dm = CSVDataModule(
@@ -21,8 +25,11 @@ def train_model(config, fun_control):
2125
max_epochs=model.epochs,
2226
accelerator="auto",
2327
devices=1,
24-
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id),
25-
callbacks=[EarlyStopping(monitor="val_loss", patience=3, mode="min", strict=False, verbose=False)],
28+
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id, default_hp_metric=True),
29+
callbacks=[
30+
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
31+
],
32+
enable_progress_bar=enable_progress_bar,
2633
)
2734
# Pass the datamodule as arg to trainer.fit to override model hooks :)
2835
trainer.fit(model=model, datamodule=dm)
@@ -36,6 +43,10 @@ def train_model(config, fun_control):
3643

3744

3845
def test_model(config, fun_control):
46+
if fun_control["enable_progress_bar"] is None:
47+
enable_progress_bar = False
48+
else:
49+
enable_progress_bar = fun_control["enable_progress_bar"]
3950
config_id = generate_config_id(config)
4051
# Init DataModule
4152
dm = CSVDataModule(
@@ -50,7 +61,11 @@ def test_model(config, fun_control):
5061
max_epochs=model.epochs,
5162
accelerator="auto",
5263
devices=1,
53-
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id),
64+
logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id, default_hp_metric=True),
65+
callbacks=[
66+
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
67+
],
68+
enable_progress_bar=enable_progress_bar,
5469
)
5570
# Pass the datamodule as arg to trainer.fit to override model hooks :)
5671
trainer.fit(model=model, datamodule=dm)
@@ -61,7 +76,11 @@ def test_model(config, fun_control):
6176

6277

6378
def cv_model(config, fun_control):
64-
# config_id = generate_config_id(config)
79+
config_id = generate_config_id(config)
80+
if fun_control["enable_progress_bar"] is None:
81+
enable_progress_bar = False
82+
else:
83+
enable_progress_bar = fun_control["enable_progress_bar"]
6584
results = []
6685
num_folds = 10
6786
split_seed = 12345
@@ -87,7 +106,13 @@ def cv_model(config, fun_control):
87106
max_epochs=model.epochs,
88107
accelerator="auto",
89108
devices=1,
90-
# logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id),
109+
logger=TensorBoardLogger(
110+
save_dir=fun_control["tensorboard_path"], version=config_id, default_hp_metric=True
111+
),
112+
callbacks=[
113+
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
114+
],
115+
enable_progress_bar=enable_progress_bar,
91116
)
92117
# Pass the datamodule as arg to trainer.fit to override model hooks :)
93118
trainer.fit(model=model, datamodule=dm)

src/spotPython/utils/init.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.utils.tensorboard import SummaryWriter
55

66

7-
def fun_control_init(task, tensorboard_path=None, num_workers=0, device=None):
7+
def fun_control_init(task, enable_progress_bar=False, tensorboard_path=None, num_workers=0, device=None):
88
"""Initialize fun_control dictionary.
99
Args:
1010
None
@@ -52,27 +52,28 @@ def fun_control_init(task, tensorboard_path=None, num_workers=0, device=None):
5252
fun_control = {
5353
"data": None,
5454
"data_dir": "./data",
55-
"train": None,
56-
"test": None,
55+
"device": device,
56+
"enable_progress_bar": enable_progress_bar,
57+
"eval": None,
58+
"k_folds": None,
5759
"loss_function": None,
58-
"metric_sklearn": None,
5960
"metric_river": None,
61+
"metric_sklearn": None,
6062
"metric_torch": None,
6163
"metric_params": {},
62-
"num_workers": num_workers,
63-
"prep_model": None,
6464
"n_samples": None,
65-
"target_column": None,
66-
"shuffle": None,
67-
"eval": None,
68-
"k_folds": None,
65+
"num_workers": num_workers,
6966
"optimizer": None,
70-
"device": device,
71-
"show_batch_interval": 1_000_000,
7267
"path": None,
68+
"prep_model": None,
69+
"save_model": False,
70+
"show_batch_interval": 1_000_000,
71+
"shuffle": None,
72+
"target_column": None,
73+
"train": None,
74+
"test": None,
7375
"task": task,
7476
"tensorboard_path": tensorboard_path,
75-
"save_model": False,
7677
"weights": 1.0,
7778
"writer": writer,
7879
}

0 commit comments

Comments
 (0)