Skip to content

Commit 25d5335

Browse files
v0.2.50
initialization
1 parent f69ec00 commit 25d5335

6 files changed

Lines changed: 133 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.49"
10+
version = "0.2.50"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/light_hyper_dict.json

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"class_name": "spotPython.torch.activation",
3333
"core_model_parameter_type": "instance()",
3434
"lower": 0,
35-
"upper": 2},
35+
"upper": 5},
3636
"optimizer": {
3737
"levels": ["Adadelta",
3838
"Adagrad",
@@ -71,7 +71,15 @@
7171
"transform": "transform_power_2_int",
7272
"lower": 2,
7373
"upper": 6
74-
}
74+
},
75+
"initialization": {
76+
"levels": ["Default", "Kaiming", "Xavier"],
77+
"type": "factor",
78+
"default": "Default",
79+
"transform": "None",
80+
"core_model_parameter_type": "str",
81+
"lower": 0,
82+
"upper": 2}
7583
},
7684
"LitModel":
7785
{

src/spotPython/light/netlightbase.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,24 @@
99

1010
class NetLightBase(L.LightningModule):
1111
def __init__(
12-
self, l1, epochs, batch_size, act_fn, optimizer, dropout_prob, lr_mult, patience=3, _L_in=64, _L_out=11
12+
self,
13+
l1,
14+
epochs,
15+
batch_size,
16+
initialization,
17+
act_fn,
18+
optimizer,
19+
dropout_prob,
20+
lr_mult,
21+
patience=3,
22+
_L_in=64,
23+
_L_out=11,
1324
):
1425
super().__init__()
15-
self.save_hyperparameters()
26+
27+
# Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing.
28+
# It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`
29+
self.save_hyperparameters(ignore=["act_fn"])
1630
self._L_out = _L_out
1731
if l1 < 4:
1832
raise ValueError("l1 must be at least 4")
@@ -21,6 +35,7 @@ def __init__(
2135
self.epochs = epochs
2236
self.patience = patience
2337
self.batch_size = batch_size
38+
self.initialization = initialization
2439
self.act_fn = act_fn
2540
self.optimizer = optimizer
2641
self.dropout_prob = dropout_prob
@@ -38,10 +53,10 @@ def __init__(
3853
layer_size_last = layer_size
3954
layers += [nn.Linear(layer_sizes[-1], self._L_out)]
4055
# nn.Sequential summarizes a list of modules into a single module, applying them in sequence
41-
self.model = nn.Sequential(*layers)
56+
self.layers = nn.Sequential(*layers)
4257

4358
def forward(self, x):
44-
x = self.model(x)
59+
x = self.layers(x)
4560
return F.softmax(x, dim=1)
4661

4762
def training_step(self, batch):

src/spotPython/light/traintest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from spotPython.utils.eda import generate_config_id
55
from pytorch_lightning.loggers import TensorBoardLogger
66
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
7+
from spotPython.torch.initialization import kaiming_init, xavier_init
78

89

910
def train_model(config, fun_control):
@@ -19,6 +20,13 @@ def train_model(config, fun_control):
1920
# Init model from datamodule's attributes
2021
# model = LitModel(*dm.dims, dm.num_classes)
2122
model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
23+
initialization = config["initialization"]
24+
if initialization == "Xavier":
25+
xavier_init(model)
26+
elif initialization == "Kaiming":
27+
kaiming_init(model)
28+
else:
29+
pass
2230
print(f"model: {model}")
2331
# Init trainer
2432
trainer = L.Trainer(
@@ -55,6 +63,13 @@ def test_model(config, fun_control):
5563
# Init model from datamodule's attributes
5664
# model = LitModel(*dm.dims, dm.num_classes)
5765
model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
66+
initialization = config["initialization"]
67+
if initialization == "Xavier":
68+
xavier_init(model)
69+
elif initialization == "Kaiming":
70+
kaiming_init(model)
71+
else:
72+
pass
5873
print(f"model: {model}")
5974
# Init trainer
6075
trainer = L.Trainer(
@@ -85,6 +100,13 @@ def cv_model(config, fun_control):
85100
num_folds = 10
86101
split_seed = 12345
87102
model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
103+
initialization = config["initialization"]
104+
if initialization == "Xavier":
105+
xavier_init(model)
106+
elif initialization == "Kaiming":
107+
kaiming_init(model)
108+
else:
109+
pass
88110
print(f"model: {model}")
89111

90112
for k in range(num_folds):
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import math
2+
3+
4+
def kaiming_init(model):
5+
for name, param in model.named_parameters():
6+
if name.endswith(".bias"):
7+
param.data.fill_(0)
8+
elif name.startswith("layers.0"): # The first layer does not have ReLU applied on its input
9+
param.data.normal_(0, 1 / math.sqrt(param.shape[1]))
10+
else:
11+
param.data.normal_(0, math.sqrt(2) / math.sqrt(param.shape[1]))
12+
13+
14+
def xavier_init(model):
15+
for name, param in model.named_parameters():
16+
if name.endswith(".bias"):
17+
param.data.fill_(0)
18+
else:
19+
bound = math.sqrt(6) / math.sqrt(param.shape[0] + param.shape[1])
20+
param.data.uniform_(-bound, bound)

src/spotPython/utils/eda.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@
66
get_var_type,
77
get_transform,
88
)
9+
import torch
10+
from spotPython.light.csvdataset import CSVDataset
11+
from torch.utils.data import DataLoader
12+
import matplotlib.pyplot as plt
13+
import math
14+
import seaborn as sns
915

1016

1117
def get_stars(input_list) -> list:
@@ -110,3 +116,58 @@ def generate_config_id(config):
110116
for key in config:
111117
config_id += str(config[key]) + "_"
112118
return config_id[:-1]
119+
120+
121+
def visualize_activations(net, device="cpu", color="C0"):
122+
"""Visualizes the activations of a neural network.
123+
Code is based on:
124+
PyTorch Lightning TUTORIAL 2: ACTIVATION FUNCTIONS,
125+
Author: Phillip Lippe,
126+
License: CC BY-SA.
127+
128+
Args:
129+
net (object): A neural network.
130+
device (str, optional): The device to use. Defaults to "cpu".
131+
color (str, optional): The color to use. Defaults to "C0".
132+
Example:
133+
>>> from spotPython.hyperparameters.values import get_one_config_from_X
134+
>>> X = spot_tuner.to_all_dim(spot_tuner.min_X.reshape(1,-1))
135+
>>> config = get_one_config_from_X(X, fun_control)
136+
>>> model = fun_control["core_model"](**config, _L_in=64, _L_out=11)
137+
>>> visualize_activations(model, device="cpu", color=f"C{0}")
138+
"""
139+
activations = {}
140+
net.eval()
141+
# Create an instance of CSVDataset
142+
dataset = CSVDataset(csv_file="./data/VBDP/train.csv", train=True)
143+
# Set batch size for DataLoader
144+
batch_size = 128
145+
# Create DataLoader
146+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
147+
# for batch in dataloader:
148+
# inputs, targets = batch
149+
# small_loader = data.DataLoader(train_set, batch_size=1024)
150+
inputs, _ = next(iter(dataloader))
151+
with torch.no_grad():
152+
layer_index = 0
153+
inputs = inputs.to(device)
154+
inputs = inputs.view(inputs.size(0), -1)
155+
# We need to manually loop through the layers to save all activations
156+
for layer_index, layer in enumerate(net.layers[:-1]):
157+
inputs = layer(inputs)
158+
activations[layer_index] = inputs.view(-1).cpu().numpy()
159+
160+
# Plotting
161+
columns = 4
162+
rows = math.ceil(len(activations) / columns)
163+
fig, ax = plt.subplots(rows, columns, figsize=(columns * 2.7, rows * 2.5))
164+
fig_index = 0
165+
for key in activations:
166+
key_ax = ax[fig_index // columns][fig_index % columns]
167+
sns.histplot(data=activations[key], bins=50, ax=key_ax, color=color, kde=True, stat="density")
168+
key_ax.set_title(f"Layer {key} - {net.layers[key].__class__.__name__}")
169+
fig_index += 1
170+
fig.suptitle(f"Activation distribution for activation function {net.act_fn}", fontsize=14)
171+
fig.subplots_adjust(hspace=0.4, wspace=0.4)
172+
plt.show()
173+
plt.close()

0 commit comments

Comments
 (0)