Skip to content

Commit c607929

Browse files
0.18.13
1 parent 11c2ab9 commit c607929

4 files changed

Lines changed: 105 additions & 2 deletions

File tree

RELEASE_NOTES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
spotpackage 0.18.13:
2+
3+
- listgenerator.py:
4+
- New class class ListGenerator:
5+
16
spotpython 0.18.11:
27

38
- testmodel, predictmodel, and cvmodel functions updated, so that they can handle DataModules specified by the user in fun_control.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.12"
10+
version = "0.18.13"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
class ListGenerator:
2+
def __init__(self, hparams, L_in, L_out):
3+
self.hparams = hparams
4+
self._L_in = L_in
5+
self._L_out = L_out
6+
7+
def _get_hidden_sizes(self) -> list:
8+
"""
9+
Generate the hidden layer sizes for the network based on nn_shape.
10+
11+
Returns:
12+
list: A list of hidden layer sizes.
13+
"""
14+
n_low = self._L_in // 4 # Minimum number of neurons
15+
# n_high = max(self.hparams.l1, 2 * n_low) # Maximum number of neurons
16+
17+
# TODO: Überlegen, wie rum es besser ist
18+
if self.hparams.l_n > self.hparams.l1:
19+
self.hparams.l1 = self.hparams.l_n
20+
# raise ValueError("l_n must be bigger than l1")
21+
22+
if self.hparams.nn_shape == "Funnel":
23+
step_size = (self.hparams.l1 - self._L_out) // self.hparams.l_n
24+
hidden_sizes = list(range(self.hparams.l1, self._L_out, -step_size))
25+
26+
elif self.hparams.nn_shape == "Diamond":
27+
mid_point = (self.hparams.l_n + 1) // 2
28+
increasing_part = [self.hparams.l1]
29+
for _ in range(1, mid_point):
30+
next_size = int(increasing_part[-1] * 1.2)
31+
increasing_part.append(next_size)
32+
33+
remaining_layers = self.hparams.l_n - mid_point
34+
step_size = (increasing_part[-1] - self._L_out) // (remaining_layers + 1)
35+
36+
decreasing_part = []
37+
current_size = increasing_part[-1]
38+
for _ in range(remaining_layers):
39+
current_size = max(self._L_out, current_size - step_size)
40+
decreasing_part.append(current_size)
41+
42+
hidden_sizes = increasing_part + decreasing_part
43+
44+
elif self.hparams.nn_shape == "Hourglass":
45+
mid_point = (self.hparams.l_n) // 2
46+
step_size = (self.hparams.l1 - n_low) // (mid_point - 1)
47+
48+
decreasing_part = [self.hparams.l1]
49+
for _ in range(1, mid_point):
50+
next_size = decreasing_part[-1] - step_size
51+
decreasing_part.append(max(n_low, next_size))
52+
53+
increasing_part = [decreasing_part[-1] + step_size]
54+
for _ in range(mid_point, self.hparams.l_n - 2):
55+
next_size = increasing_part[-1] + step_size
56+
increasing_part.append(min(self.hparams.l1, next_size))
57+
58+
last_step_size = (increasing_part[-1] - self._L_out) // 2
59+
decreasing_to_output = max(self._L_out, increasing_part[-1] - last_step_size)
60+
61+
hidden_sizes = decreasing_part + increasing_part + [decreasing_to_output]
62+
63+
elif self.hparams.nn_shape == "Wave":
64+
half_wave = (self.hparams.l_n) // 4
65+
step_size = (self.hparams.l1 - n_low) // (half_wave - 1)
66+
67+
decreasing_part_1 = [self.hparams.l1]
68+
for _ in range(1, half_wave):
69+
next_size = decreasing_part_1[-1] - step_size
70+
decreasing_part_1.append(max(n_low, next_size))
71+
72+
increasing_part_1 = [decreasing_part_1[-1] + step_size]
73+
for _ in range(half_wave, 2 * half_wave - 1):
74+
next_size = increasing_part_1[-1] + step_size
75+
increasing_part_1.append(next_size)
76+
77+
decreasing_part_2 = [increasing_part_1[-1] - step_size]
78+
for _ in range(2 * half_wave, 3 * half_wave - 1):
79+
next_size = decreasing_part_2[-1] - step_size
80+
decreasing_part_2.append(max(n_low, next_size))
81+
82+
increasing_part_2 = [decreasing_part_2[-1] + step_size]
83+
for _ in range(3 * half_wave, self.hparams.l_n - 2):
84+
next_size = increasing_part_2[-1] + step_size
85+
increasing_part_2.append(next_size)
86+
87+
last_step_size = (increasing_part_2[-1] - self._L_out) // 2
88+
decreasing_to_output = max(self._L_out, increasing_part_2[-1] - last_step_size)
89+
90+
hidden_sizes = decreasing_part_1 + increasing_part_1 + decreasing_part_2 + increasing_part_2 + [decreasing_to_output]
91+
92+
elif self.hparams.nn_shape == "Block":
93+
hidden_sizes = [self.hparams.l1] * self.hparams.l_n
94+
95+
else:
96+
raise ValueError(f"Unknown nn_shape: {self.hparams.nn_shape}")
97+
98+
return hidden_sizes

src/spotpython/light/testmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
107107
)
108108
# Pass the datamodule as arg to trainer.fit to override model hooks :)
109109
trainer.fit(model=model, datamodule=dm)
110-
110+
111111
# Load the last checkpoint
112112
test_result = trainer.test(datamodule=dm, ckpt_path="last")
113113
test_result = test_result[0]

0 commit comments

Comments
 (0)