Skip to content

Commit c51ca8e

Browse files
0.15.38
1 parent d2cc635 commit c51ca8e

6 files changed

Lines changed: 299 additions & 198 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 187 additions & 164 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.37"
10+
version = "0.15.38"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/regression/nn_linear_regressor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class NNLinearRegressor(L.LightningModule):
8686
8787
| Name | Type | Params | Mode | In sizes | Out sizes
8888
----------------------------------------------------------------------
89-
0 | layers | Sequential | 20.8 K | train | [128, 10] | [128, 1]
89+
0 | layers | Sequential | 20.8 K | train | [128, 10] | [128, 1]
9090
----------------------------------------------------------------------
9191
20.8 K Trainable params
9292
0 Non-trainable params

src/spotpython/plot/xai.py

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def check_for_nans(data, layer_index) -> bool:
3434
return False
3535

3636

37-
def get_activations(net, fun_control, batch_size, device="cpu", normalize=True) -> tuple:
37+
def get_activations(net, fun_control, batch_size, device="cpu", normalize=False) -> tuple:
3838
"""Computes the activations for each layer of the network and
3939
the mean activations for each layer. Both are returned as a dictionary.
4040
@@ -43,7 +43,7 @@ def get_activations(net, fun_control, batch_size, device="cpu", normalize=True)
4343
fun_control (dict): A dictionary containing the dataset.
4444
device (str): The device to run the model on. Defaults to "cpu".
4545
batch_size (int): The batch size for the data loader.
46-
normalize (bool): Whether to normalize the input data. Defaults to True.
46+
normalize (bool): Whether to normalize the input data. Defaults to False.
4747
4848
Returns:
4949
tuple: A tuple containing the activations and mean activations for each layer.
@@ -63,9 +63,9 @@ def get_activations(net, fun_control, batch_size, device="cpu", normalize=True)
6363
scaler=fun_control["scaler"],
6464
verbosity=10,
6565
)
66-
data_module.setup(stage="test")
67-
test_loader = data_module.test_dataloader()
68-
inputs, _ = next(iter(test_loader))
66+
data_module.setup(stage="fit")
67+
train_loader = data_module.train_dataloader()
68+
inputs, _ = next(iter(train_loader))
6969
inputs = inputs.to(device)
7070
if normalize:
7171
inputs = (inputs - inputs.mean()) / inputs.std()
@@ -199,7 +199,7 @@ def get_weights(net, return_index=False) -> dict:
199199
return weights
200200

201201

202-
def get_gradients(net, fun_control, batch_size, device="cpu", normalize=True) -> dict:
202+
def get_gradients(net, fun_control, batch_size, device="cpu", normalize=False) -> dict:
203203
"""
204204
Get the gradients of a neural network.
205205
@@ -213,46 +213,67 @@ def get_gradients(net, fun_control, batch_size, device="cpu", normalize=True) ->
213213
device (str, optional):
214214
The device to use. Defaults to "cpu".
215215
normalize (bool, optional):
216-
Whether to normalize the input data. Defaults to True.
216+
Whether to normalize the input data. Defaults to False.
217217
218218
Returns:
219219
dict: A dictionary with the gradients of the neural network.
220220
221221
Examples:
222-
>>> from torch.utils.data import DataLoader
223-
from spotpython.utils.init import fun_control_init
224-
from spotpython.hyperparameters.values import set_control_key_value
222+
>>> from spotpython.utils.init import fun_control_init
225223
from spotpython.data.diabetes import Diabetes
226-
from spotpython.light.regression.netlightregression import NetLightRegression
224+
from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
227225
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
228-
from spotpython.hyperparameters.values import add_core_model_to_fun_control
229226
from spotpython.hyperparameters.values import (
230227
get_default_hyperparameters_as_array, get_one_config_from_X)
231-
from spotpython.hyperparameters.values import set_control_key_value
232-
from spotpython.plot.xai import get_activations
228+
from spotpython.plot.xai import get_gradients
233229
fun_control = fun_control_init(
234230
_L_in=10, # 10: diabetes
235231
_L_out=1,
236-
)
237-
dataset = Diabetes()
238-
set_control_key_value(control_dict=fun_control,
239-
key="data_set",
240-
value=dataset,
241-
replace=True)
242-
add_core_model_to_fun_control(fun_control=fun_control,
243-
core_model=NetLightRegression,
244-
hyper_dict=LightHyperDict)
232+
_torchmetric="mean_squared_error",
233+
data_set=Diabetes(),
234+
core_model=NNLinearRegressor,
235+
hyperdict=LightHyperDict)
245236
X = get_default_hyperparameters_as_array(fun_control)
246237
config = get_one_config_from_X(X, fun_control)
247238
_L_in = fun_control["_L_in"]
248239
_L_out = fun_control["_L_out"]
249-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
250-
batch_size= config["batch_size"]
251-
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
240+
_torchmetric = fun_control["_torchmetric"]
241+
batch_size = 16
242+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
252243
get_gradients(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
253-
{'layers.0.weight': array([ 0.10417588, -0.04161512, 0.10597267, 0.02180895, 0.12001498,
254-
0.02890352, 0.0114617 , 0.08183316, 0.2495192 , 0.5108763 ,
255-
0.14668094, -0.07902834, 0.00912531, 0.02640062, 0.14108546, ...
244+
{'layers.0.weight': array([-18.91906 , -15.034285 , -9.014692 , -11.67453 , -17.93505 ,
245+
-18.900719 , 3.181451 , -7.079934 , -8.781589 , -19.415773 ,
246+
-31.762537 , -25.240526 , -15.134445 , -19.59995 , -30.110514 ,
247+
-31.731745 , 5.3412247, -11.88625 , -14.743096 , -32.596447 ,
248+
-16.250072 , -19.540495 , -12.840339 , -12.497604 , -24.44074 ,
249+
-26.738008 , 7.0891356, -14.540221 , -12.63131 , -20.33385 ,
250+
-16.617418 , -19.537054 , -12.366335 , -11.95286 , -22.170914 ,
251+
-24.224556 , 7.333409 , -13.811482 , -12.374348 , -19.54898 ,
252+
-12.489107 , -14.683411 , -9.294134 , -8.983377 , -16.662935 ,
253+
-18.20638 , 5.5115504, -10.380258 , -9.300154 , -14.692373 ,
254+
-10.237142 , -12.03578 , -7.618267 , -7.363545 , -13.658367 ,
255+
-14.92351 , 4.517738 , -8.508549 , -7.6232023, -12.043128 ,
256+
-20.709038 , -26.502258 , -16.64915 , -14.087446 , -28.602673 ,
257+
-31.098864 , 8.91061 , -17.756905 , -15.304844 , -24.48614 ,
258+
-31.866945 , -25.78516 , -15.80128 , -16.71967 , -30.365 ,
259+
-30.903124 , 1.2193708, -10.1255665, -12.155798 , -31.34386 ],
260+
dtype=float32),
261+
'layers.3.weight': array([-33.59704 , -30.819086, -28.372812, -27.846645, -34.799633,
262+
-31.002586, -30.067335, -39.82912 , -54.281433, -49.7932 ,
263+
-45.840855, -44.99075 , -56.22442 , -50.089676, -48.578625,
264+
-64.350365, -26.605227, -24.405384, -22.4682 , -22.051537,
265+
-27.557549, -24.550695, -23.810078, -31.540358, -31.579184,
266+
-28.968073, -26.668724, -26.174164, -32.709553, -29.140554,
267+
-28.261475, -37.436962], dtype=float32),
268+
'layers.6.weight': array([ -68.05522 , -74.10879 , -77.15874 , -43.79848 , -102.948906,
269+
-112.10627 , -116.72002 , -66.25509 , -75.09263 , -81.77218 ,
270+
-85.13751 , -48.327564, -48.758083, -53.09515 , -55.280285,
271+
-31.379366], dtype=float32),
272+
'layers.9.weight': array([-104.8834 , -129.18658 , -136.66594 , -120.37764 , -92.72068 ,
273+
-114.20557 , -120.817566, -106.41813 ], dtype=float32),
274+
'layers.12.weight': array([-424.17743, -439.30273, -263.17206, -272.55627], dtype=float32),
275+
'layers.15.weight': array([-516.952 , -194.23613 , -154.84027 , -58.178665], dtype=float32),
276+
'layers.18.weight': array([-489.72256, -405.1883 ], dtype=float32)}
256277
"""
257278
net.eval()
258279
dataset = fun_control["data_set"]
@@ -263,9 +284,9 @@ def get_gradients(net, fun_control, batch_size, device="cpu", normalize=True) ->
263284
scaler=fun_control["scaler"],
264285
verbosity=10,
265286
)
266-
data_module.setup(stage="test")
267-
test_loader = data_module.test_dataloader()
268-
inputs, targets = next(iter(test_loader))
287+
data_module.setup(stage="fit")
288+
train_loader = data_module.train_dataloader()
289+
inputs, targets = next(iter(train_loader))
269290
if normalize:
270291
inputs = (inputs - inputs.mean()) / inputs.std()
271292
inputs, targets = inputs.to(device), targets.to(device)

src/spotpython/utils/init.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,13 @@ def fun_control_init(
466466
hyper_dict=hyperdict,
467467
filename=None,
468468
)
469+
if hyperdict is not None and core_model is not None:
470+
add_core_model_to_fun_control(
471+
core_model=core_model,
472+
fun_control=fun_control,
473+
hyper_dict=hyperdict,
474+
filename=None,
475+
)
469476
return fun_control
470477

471478

test/test_get_gradients.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
import numpy as np
3+
from spotpython.utils.init import fun_control_init
4+
from spotpython.data.diabetes import Diabetes
5+
from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
6+
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
7+
from spotpython.hyperparameters.values import (
8+
get_default_hyperparameters_as_array, get_one_config_from_X)
9+
from spotpython.plot.xai import get_gradients
10+
11+
def test_gradients_computation():
12+
# Initialize the control function
13+
fun_control = fun_control_init(
14+
_L_in=10, # 10: diabetes
15+
_L_out=1,
16+
_torchmetric="mean_squared_error",
17+
data_set=Diabetes(),
18+
core_model=NNLinearRegressor,
19+
hyperdict=LightHyperDict
20+
)
21+
22+
# Get hyperparameters and model configuration
23+
X = get_default_hyperparameters_as_array(fun_control)
24+
config = get_one_config_from_X(X, fun_control)
25+
26+
# Retrieve specific parameters from the control
27+
_L_in = fun_control["_L_in"]
28+
_L_out = fun_control["_L_out"]
29+
_torchmetric = fun_control["_torchmetric"]
30+
batch_size = 16
31+
32+
# Instantiate the core model with the setup configuration
33+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
34+
35+
# Compute gradients using the defined function
36+
gradients = get_gradients(
37+
model,
38+
fun_control=fun_control,
39+
batch_size=batch_size,
40+
device="cpu"
41+
)
42+
43+
# Conduct necessary assertions to validate gradient results
44+
assert isinstance(gradients, dict), "Gradients should be a dictionary."
45+
# Checking that all keys in gradients dictionary contain the string 'layers'
46+
assert all('layers' in key for key in gradients.keys()), \
47+
"All keys should include 'layers' in their description."
48+
# Ensuring all values within the gradients are numpy arrays
49+
assert all(isinstance(value, np.ndarray) for value in gradients.values()), \
50+
"All gradient values should be numpy arrays."

0 commit comments

Comments
 (0)