Skip to content

Commit 90f0a37

Browse files
0.20.3
1 parent 987aad3 commit 90f0a37

2 files changed

Lines changed: 44 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.20.2"
10+
version = "0.20.3"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/hyperparameters/listgenerator.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,45 @@
11
class ListGenerator:
2+
"""
3+
Generates a list of hidden layer sizes based on the input/output layer sizes and specified network shape.
4+
5+
Args:
6+
hparams: An object containing network hyperparameters such as the layer sizes and the shape of the network.
7+
L_in (int): The size of the input layer.
8+
L_out (int): The size of the output layer.
9+
10+
Methods:
11+
_get_hidden_sizes() -> list:
12+
Generates and returns a list of hidden layer sizes based on the specified network shape (e.g., Funnel, Diamond, Hourglass, Wave, Block).
13+
14+
Attributes:
15+
hparams.nn_shape (str): The shape of the network. Options include "Funnel", "Diamond", "Hourglass", "Wave", "Block".
16+
hparams.l1 (int): The size of the first hidden layer.
17+
hparams.l_n (int): The total number of hidden layers.
18+
"""
19+
220
def __init__(self, hparams, L_in, L_out):
321
self.hparams = hparams
422
self._L_in = L_in
523
self._L_out = L_out
624

7-
def _get_hidden_sizes(self) -> list:
25+
def _get_hidden_sizes(self):
826
"""
9-
Generate the hidden layer sizes for the network based on nn_shape.
27+
Generate the hidden layer sizes for the network based on the specified shape.
1028
1129
Returns:
12-
list: A list of hidden layer sizes.
30+
list: A list of hidden layer sizes that defines the architecture of the neural network.
31+
32+
Raises:
33+
ValueError: If an unknown `nn_shape` is provided in the `hparams`.
1334
"""
14-
n_low = self._L_in // 4 # Minimum number of neurons
15-
# n_high = max(self.hparams.l1, 2 * n_low) # Maximum number of neurons
35+
36+
if self._L_in < 8:
37+
n_low = self._L_in # Minimum number of neurons
38+
elif self._L_in < 16:
39+
n_low = self._L_in // 2 # Minimum number of neurons
40+
else:
41+
n_low = self._L_in // 4 # Minimum number of neurons
42+
n_high = max(self.hparams.l1, 2 * n_low) # Maximum number of neurons
1643

1744
# TODO: Überlegen, wie rum es besser ist
1845
if self.hparams.l_n > self.hparams.l1:
@@ -25,18 +52,21 @@ def _get_hidden_sizes(self) -> list:
2552

2653
elif self.hparams.nn_shape == "Diamond":
2754
mid_point = (self.hparams.l_n + 1) // 2
55+
upper_limit = self.hparams.l1 * 2
56+
step_size_up = (upper_limit - self.hparams.l1) // (mid_point - 1)
57+
2858
increasing_part = [self.hparams.l1]
2959
for _ in range(1, mid_point):
30-
next_size = int(increasing_part[-1] * 1.2)
31-
increasing_part.append(next_size)
60+
next_size = increasing_part[-1] + step_size_up
61+
increasing_part.append(min(upper_limit, next_size))
3262

3363
remaining_layers = self.hparams.l_n - mid_point
34-
step_size = (increasing_part[-1] - self._L_out) // (remaining_layers + 1)
64+
step_size_down = (increasing_part[-1] - self._L_out) // (remaining_layers + 1)
3565

3666
decreasing_part = []
3767
current_size = increasing_part[-1]
3868
for _ in range(remaining_layers):
39-
current_size = max(self._L_out, current_size - step_size)
69+
current_size = max(self._L_out, current_size - step_size_down)
4070
decreasing_part.append(current_size)
4171

4272
hidden_sizes = increasing_part + decreasing_part
@@ -53,7 +83,7 @@ def _get_hidden_sizes(self) -> list:
5383
increasing_part = [decreasing_part[-1] + step_size]
5484
for _ in range(mid_point, self.hparams.l_n - 2):
5585
next_size = increasing_part[-1] + step_size
56-
increasing_part.append(min(self.hparams.l1, next_size))
86+
increasing_part.append(min(n_high, next_size))
5787

5888
last_step_size = (increasing_part[-1] - self._L_out) // 2
5989
decreasing_to_output = max(self._L_out, increasing_part[-1] - last_step_size)
@@ -72,7 +102,7 @@ def _get_hidden_sizes(self) -> list:
72102
increasing_part_1 = [decreasing_part_1[-1] + step_size]
73103
for _ in range(half_wave, 2 * half_wave - 1):
74104
next_size = increasing_part_1[-1] + step_size
75-
increasing_part_1.append(next_size)
105+
increasing_part_1.append(min(n_high, next_size))
76106

77107
decreasing_part_2 = [increasing_part_1[-1] - step_size]
78108
for _ in range(2 * half_wave, 3 * half_wave - 1):
@@ -82,15 +112,15 @@ def _get_hidden_sizes(self) -> list:
82112
increasing_part_2 = [decreasing_part_2[-1] + step_size]
83113
for _ in range(3 * half_wave, self.hparams.l_n - 2):
84114
next_size = increasing_part_2[-1] + step_size
85-
increasing_part_2.append(next_size)
115+
increasing_part_2.append(min(n_high, next_size))
86116

87117
last_step_size = (increasing_part_2[-1] - self._L_out) // 2
88118
decreasing_to_output = max(self._L_out, increasing_part_2[-1] - last_step_size)
89119

90120
hidden_sizes = decreasing_part_1 + increasing_part_1 + decreasing_part_2 + increasing_part_2 + [decreasing_to_output]
91121

92122
elif self.hparams.nn_shape == "Block":
93-
hidden_sizes = [self.hparams.l1] * self.hparams.l_n
123+
hidden_sizes = [min(n_high, self.hparams.l1)] * self.hparams.l_n
94124

95125
else:
96126
raise ValueError(f"Unknown nn_shape: {self.hparams.nn_shape}")

0 commit comments

Comments
 (0)