11class ListGenerator :
2+ """
3+ Generates a list of hidden layer sizes based on the input/output layer sizes and specified network shape.
4+
5+ Args:
6+ hparams: An object containing network hyperparameters such as the layer sizes and the shape of the network.
7+ L_in (int): The size of the input layer.
8+ L_out (int): The size of the output layer.
9+
10+ Methods:
11+ _get_hidden_sizes() -> list:
12+ Generates and returns a list of hidden layer sizes based on the specified network shape (e.g., Funnel, Diamond, Hourglass, Wave, Block).
13+
14+ Attributes:
15+ hparams.nn_shape (str): The shape of the network. Options include "Funnel", "Diamond", "Hourglass", "Wave", "Block".
16+ hparams.l1 (int): The size of the first hidden layer.
17+ hparams.l_n (int): The total number of hidden layers.
18+ """
19+
220 def __init__ (self , hparams , L_in , L_out ):
321 self .hparams = hparams
422 self ._L_in = L_in
523 self ._L_out = L_out
624
7- def _get_hidden_sizes (self ) -> list :
25+ def _get_hidden_sizes (self ):
826 """
9- Generate the hidden layer sizes for the network based on nn_shape .
27+ Generate the hidden layer sizes for the network based on the specified shape .
1028
1129 Returns:
12- list: A list of hidden layer sizes.
30+ list: A list of hidden layer sizes that defines the architecture of the neural network.
31+
32+ Raises:
33+ ValueError: If an unknown `nn_shape` is provided in the `hparams`.
1334 """
14- n_low = self ._L_in // 4 # Minimum number of neurons
15- # n_high = max(self.hparams.l1, 2 * n_low) # Maximum number of neurons
35+
36+ if self ._L_in < 8 :
37+ n_low = self ._L_in # Minimum number of neurons
38+ elif self ._L_in < 16 :
39+ n_low = self ._L_in // 2 # Minimum number of neurons
40+ else :
41+ n_low = self ._L_in // 4 # Minimum number of neurons
42+ n_high = max (self .hparams .l1 , 2 * n_low ) # Maximum number of neurons
1643
1744 # TODO: Überlegen, wie rum es besser ist
1845 if self .hparams .l_n > self .hparams .l1 :
@@ -25,18 +52,21 @@ def _get_hidden_sizes(self) -> list:
2552
2653 elif self .hparams .nn_shape == "Diamond" :
2754 mid_point = (self .hparams .l_n + 1 ) // 2
55+ upper_limit = self .hparams .l1 * 2
56+ step_size_up = (upper_limit - self .hparams .l1 ) // (mid_point - 1 )
57+
2858 increasing_part = [self .hparams .l1 ]
2959 for _ in range (1 , mid_point ):
30- next_size = int ( increasing_part [- 1 ] * 1.2 )
31- increasing_part .append (next_size )
60+ next_size = increasing_part [- 1 ] + step_size_up
61+ increasing_part .append (min ( upper_limit , next_size ) )
3262
3363 remaining_layers = self .hparams .l_n - mid_point
34- step_size = (increasing_part [- 1 ] - self ._L_out ) // (remaining_layers + 1 )
64+ step_size_down = (increasing_part [- 1 ] - self ._L_out ) // (remaining_layers + 1 )
3565
3666 decreasing_part = []
3767 current_size = increasing_part [- 1 ]
3868 for _ in range (remaining_layers ):
39- current_size = max (self ._L_out , current_size - step_size )
69+ current_size = max (self ._L_out , current_size - step_size_down )
4070 decreasing_part .append (current_size )
4171
4272 hidden_sizes = increasing_part + decreasing_part
@@ -53,7 +83,7 @@ def _get_hidden_sizes(self) -> list:
5383 increasing_part = [decreasing_part [- 1 ] + step_size ]
5484 for _ in range (mid_point , self .hparams .l_n - 2 ):
5585 next_size = increasing_part [- 1 ] + step_size
56- increasing_part .append (min (self . hparams . l1 , next_size ))
86+ increasing_part .append (min (n_high , next_size ))
5787
5888 last_step_size = (increasing_part [- 1 ] - self ._L_out ) // 2
5989 decreasing_to_output = max (self ._L_out , increasing_part [- 1 ] - last_step_size )
@@ -72,7 +102,7 @@ def _get_hidden_sizes(self) -> list:
72102 increasing_part_1 = [decreasing_part_1 [- 1 ] + step_size ]
73103 for _ in range (half_wave , 2 * half_wave - 1 ):
74104 next_size = increasing_part_1 [- 1 ] + step_size
75- increasing_part_1 .append (next_size )
105+ increasing_part_1 .append (min ( n_high , next_size ) )
76106
77107 decreasing_part_2 = [increasing_part_1 [- 1 ] - step_size ]
78108 for _ in range (2 * half_wave , 3 * half_wave - 1 ):
@@ -82,15 +112,15 @@ def _get_hidden_sizes(self) -> list:
82112 increasing_part_2 = [decreasing_part_2 [- 1 ] + step_size ]
83113 for _ in range (3 * half_wave , self .hparams .l_n - 2 ):
84114 next_size = increasing_part_2 [- 1 ] + step_size
85- increasing_part_2 .append (next_size )
115+ increasing_part_2 .append (min ( n_high , next_size ) )
86116
87117 last_step_size = (increasing_part_2 [- 1 ] - self ._L_out ) // 2
88118 decreasing_to_output = max (self ._L_out , increasing_part_2 [- 1 ] - last_step_size )
89119
90120 hidden_sizes = decreasing_part_1 + increasing_part_1 + decreasing_part_2 + increasing_part_2 + [decreasing_to_output ]
91121
92122 elif self .hparams .nn_shape == "Block" :
93- hidden_sizes = [self .hparams .l1 ] * self .hparams .l_n
123+ hidden_sizes = [min ( n_high , self .hparams .l1 ) ] * self .hparams .l_n
94124
95125 else :
96126 raise ValueError (f"Unknown nn_shape: { self .hparams .nn_shape } " )
0 commit comments