Skip to content

Commit 68efa82

Browse files
0.10.52
Skip Linear uses Normal Dist Weights init
1 parent 299e322 commit 68efa82

3 files changed

Lines changed: 72 additions & 17 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2852,6 +2852,65 @@
28522852
" result = result.reshape(-1, self.n_out)\n",
28532853
" return result\n"
28542854
]
2855+
},
2856+
{
2857+
"cell_type": "code",
2858+
"execution_count": null,
2859+
"metadata": {},
2860+
"outputs": [],
2861+
"source": [
2862+
"class SkipLinear(torch.nn.Module):\n",
2863+
"\n",
2864+
" class Core(torch.nn.Module):\n",
2865+
" \"\"\"A simple linear layer with n outputs.\"\"\"\n",
2866+
"\n",
2867+
" def __init__(self, n):\n",
2868+
" \"\"\"\n",
2869+
" Initialize the layer.\n",
2870+
"\n",
2871+
" Args:\n",
2872+
" n (int): The number of output nodes.\n",
2873+
" \"\"\"\n",
2874+
" super().__init__()\n",
2875+
" self.weights = torch.nn.Parameter(torch.zeros((n, 1), dtype=torch.float32))\n",
2876+
" self.biases = torch.nn.Parameter(torch.zeros(n, dtype=torch.float32))\n",
2877+
" lim = 0.01\n",
2878+
" torch.nn.init.uniform_(self.weights, -lim, lim)\n",
2879+
"\n",
2880+
" def forward(self, x) -> torch.Tensor:\n",
2881+
" \"\"\"\n",
2882+
" Forward pass through the layer.\n",
2883+
"\n",
2884+
" Args:\n",
2885+
" x (torch.Tensor): The input tensor.\n",
2886+
"\n",
2887+
" Returns:\n",
2888+
" torch.Tensor: The output of the layer.\n",
2889+
" \"\"\"\n",
2890+
" return x @ self.weights.t() + self.biases\n",
2891+
"\n",
2892+
" def __init__(self, n_in, n_out):\n",
2893+
" super().__init__()\n",
2894+
" self.n_in = n_in\n",
2895+
" self.n_out = n_out\n",
2896+
" if n_out % n_in != 0:\n",
2897+
" raise ValueError(\"n_out % n_in != 0\")\n",
2898+
" n = n_out // n_in # num nodes per input\n",
2899+
"\n",
2900+
" self.lst_modules = torch.nn.ModuleList([SkipLinear.Core(n) for i in range(n_in)])\n",
2901+
"\n",
2902+
" def forward(self, x):\n",
2903+
" lst_nodes = []\n",
2904+
" for i in range(self.n_in):\n",
2905+
" xi = x[:, i].reshape(-1, 1)\n",
2906+
" oupt = self.lst_modules[i](xi)\n",
2907+
" lst_nodes.append(oupt)\n",
2908+
" result = torch.cat((lst_nodes[0], lst_nodes[1]), 1)\n",
2909+
" for i in range(2, self.n_in):\n",
2910+
" result = torch.cat((result, lst_nodes[i]), 1)\n",
2911+
" result = result.reshape(-1, self.n_out)\n",
2912+
" return result"
2913+
]
28552914
}
28562915
],
28572916
"metadata": {

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.53"
10+
version = "0.10.54"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/transformer/skiplinear.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ def __init__(self, n):
6161
n (int): The number of output nodes.
6262
"""
6363
super().__init__()
64-
self.weights = torch.nn.Parameter(torch.zeros((n, 1), dtype=torch.float32))
65-
self.biases = torch.nn.Parameter(torch.zeros(n, dtype=torch.float32))
66-
lim = 0.01
67-
torch.nn.init.uniform_(self.weights, -lim, lim)
64+
# initialize with random weights using normal distribution
65+
self.weights = torch.nn.Parameter(torch.randn(1, n))
66+
# self.weights = torch.nn.Parameter(torch.rand(1, n) * 2 - 1)
67+
self.linear = torch.nn.Linear(1, n)
6868

6969
def forward(self, x) -> torch.Tensor:
7070
"""
@@ -76,7 +76,7 @@ def forward(self, x) -> torch.Tensor:
7676
Returns:
7777
torch.Tensor: The output of the layer.
7878
"""
79-
return x @ self.weights.t() + self.biases
79+
return self.linear(x)
8080

8181
def __init__(self, n_in, n_out):
8282
super().__init__()
@@ -86,16 +86,12 @@ def __init__(self, n_in, n_out):
8686
raise ValueError("n_out % n_in != 0")
8787
n = n_out // n_in # num nodes per input
8888

89-
self.lst_modules = torch.nn.ModuleList([SkipLinear.Core(n) for i in range(n_in)])
89+
self.lst_modules = torch.nn.ModuleList([SkipLinear.Core(n) for _ in range(n_in)])
9090

9191
def forward(self, x):
92-
lst_nodes = []
93-
for i in range(self.n_in):
94-
xi = x[:, i].reshape(-1, 1)
95-
oupt = self.lst_modules[i](xi)
96-
lst_nodes.append(oupt)
97-
result = torch.cat((lst_nodes[0], lst_nodes[1]), 1)
98-
for i in range(2, self.n_in):
99-
result = torch.cat((result, lst_nodes[i]), 1)
100-
result = result.reshape(-1, self.n_out)
101-
return result
92+
# We want to apply each module to a slice of the input tensor x and collect the outputs.
93+
# This applies the i-th module to the i-th column of x, reshaped as a column vector.
94+
# The result is a list of output tensors, which are then concatenated to form the final output.
95+
lst_nodes = [self.lst_modules[i](x[:, i].unsqueeze(1)) for i in range(self.n_in)]
96+
result = torch.cat(lst_nodes, dim=1)
97+
return result.reshape(-1, self.n_out)

0 commit comments

Comments
 (0)