Skip to content

Commit 54e1368

Browse files
0.16.9
1 parent f9dc086 commit 54e1368

7 files changed

Lines changed: 371 additions & 193 deletions

File tree

.gitignore

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,13 @@ notebooks/spot_000_experiment.pickle
327327
notebooks/spot_100_experiment.pickle
328328
notebooks/spot_TEST_experiment.pickle
329329
notebooks/spot_TEST_SAVE_experiment.pickle
330+
notebooks/model_architecture
331+
notebooks/model_architecture.png
332+
notebooks/model_architecture1
333+
notebooks/model_architecture1.png
334+
notebooks/model_architecture2
335+
notebooks/model_architecture2.png
336+
notebooks/model_architecture3
337+
notebooks/model_architecture3.png
338+
notebooks/model_architecture4
339+
notebooks/model_architecture4.png

RELEASE_NOTES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
spotpython-0.16.9:
2+
3+
- xai.py: add new function viz_net to visualize the network architecture (linear nets)
4+
- dimensions.py: add new function extract_linear_dims that extracts the input and output dimensions of the Linear layers in a PyTorch model.
5+
16
spotpython-0.16.8:
27

38
- xai.py: automatically handle the orientation of the colorbar in the plot_nn_values_scatter function

notebooks/00_spotPython_tests.ipynb

Lines changed: 191 additions & 192 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.16.8"
10+
version = "0.16.9"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from torchviz import make_dot
23
from torch.utils.data import DataLoader
34
import matplotlib.pyplot as plt
45
import math
@@ -16,6 +17,7 @@
1617
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
1718
from matplotlib.ticker import MaxNLocator
1819
from spotpython.data.lightdatamodule import LightDataModule
20+
from spotpython.torch.dimensions import extract_linear_dims
1921

2022

2123
def check_for_nans(data, layer_index) -> bool:
@@ -1017,3 +1019,101 @@ def sort_layers(data_dict) -> dict:
10171019
# Create a new dictionary from the sorted items
10181020
sorted_dict = dict(sorted_items)
10191021
return sorted_dict
1022+
1023+
1024+
def viz_net(
1025+
net,
1026+
device="cpu",
1027+
show_attrs=False,
1028+
show_saved=False,
1029+
max_attr_chars=50,
1030+
filename="model_architecture",
1031+
format="png",
1032+
) -> None:
1033+
"""
1034+
Visualize the architecture of a linear neural network.
1035+
Produces Graphviz representation of PyTorch autograd graph.
1036+
If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green:
1037+
- Blue: reachable leaf tensors that requires grad (tensors whose .grad fields will be populated during .backward())
1038+
- Orange: saved tensors of custom autograd functions as well as those saved by built-in backward nodes
1039+
- Green: tensor passed in as outputs
1040+
- Dark green: if any output is a view, we represent its base tensor with a dark green node.
1041+
If `show_attrs`=True and `show_saved`=True it is shown what autograd saves for the backward pass.
1042+
1043+
Args:
1044+
net (nn.Module):
1045+
The neural network model.
1046+
device (str, optional):
1047+
The device to use. Defaults to "cpu".
1048+
show_attrs (bool, optional):
1049+
whether to display non-tensor attributes of backward nodes (Requires PyTorch version >= 1.9)
1050+
show_saved (bool, optional):
1051+
whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed. (Requires PyTorch version >= 1.9)
1052+
max_attr_chars (int, optional):
1053+
if show_attrs is True, sets max number of characters to display for any given attribute. Defaults to 50.
1054+
filename (str, optional):
1055+
The filename. Defaults to "model_architecture".
1056+
format (str, optional):
1057+
The output format. Defaults to "png".
1058+
1059+
Returns:
1060+
None
1061+
1062+
Raises:
1063+
ValueError: If the model does not have a linear layer.
1064+
TypeError: If the network structure or parameters are invalid.
1065+
RuntimeError: If an unexpected error occurs.
1066+
1067+
Examples:
1068+
>>> from spotpython.plot.xai import viz_net
1069+
from spotpython.utils.init import fun_control_init
1070+
from spotpython.data.diabetes import Diabetes
1071+
from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
1072+
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
1073+
from spotpython.hyperparameters.values import (
1074+
get_default_hyperparameters_as_array, get_one_config_from_X)
1075+
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
1076+
_L_in=10
1077+
_L_out=1
1078+
_torchmetric="mean_squared_error"
1079+
fun_control = fun_control_init(
1080+
_L_in=_L_in,
1081+
_L_out=_L_out,
1082+
_torchmetric=_torchmetric,
1083+
data_set=Diabetes(),
1084+
core_model=NNLinearRegressor,
1085+
hyperdict=LightHyperDict)
1086+
X = get_default_hyperparameters_as_array(fun_control)
1087+
config = get_one_config_from_X(X, fun_control)
1088+
# _L_in = fun_control["_L_in"]
1089+
# _L_out = fun_control["_L_out"]
1090+
# _torchmetric = fun_control["_torchmetric"]
1091+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
1092+
viz_net(net=model, device="cpu", show_attrs=True, show_saved=True, filename="model_architecture3", format="png")
1093+
1094+
"""
1095+
try:
1096+
dim = extract_linear_dims(net)
1097+
except ValueError as ve:
1098+
error_message = "The model does not have a linear layer: " + str(ve)
1099+
raise ValueError(error_message)
1100+
except TypeError as te:
1101+
error_message = "Invalid network structure or parameters: " + str(te)
1102+
raise TypeError(error_message)
1103+
except Exception as e:
1104+
# Catch any other unforeseen exceptions and log them for debugging purposes
1105+
error_message = "An unexpected error occurred: " + str(e)
1106+
raise RuntimeError(error_message)
1107+
1108+
# Proceed with the rest of the logic if dimensions were extracted successfully
1109+
x = torch.randn(1, dim[0]).requires_grad_(True)
1110+
x = x.to(device)
1111+
output = net(x)
1112+
dot = make_dot(
1113+
output,
1114+
params=dict(net.named_parameters()),
1115+
show_attrs=show_attrs,
1116+
show_saved=show_saved,
1117+
max_attr_chars=max_attr_chars,
1118+
)
1119+
dot.render(filename, format=format)

src/spotpython/torch/dimensions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numpy as np
2+
import torch.nn as nn
3+
4+
5+
def extract_linear_dims(model) -> np.array:
6+
"""Extracts the input and output dimensions of the Linear layers in a PyTorch model.
7+
8+
Args:
9+
model (nn.Module): PyTorch model.
10+
11+
Returns:
12+
np.array: Array with the input and output dimensions of the Linear layers.
13+
14+
Examples:
15+
>>> from spotpython.torch.dimensions import extract_linear_dims
16+
>>> net = NNLinearRegressor()
17+
>>> result = extract_linear_dims(net)
18+
19+
"""
20+
dims = []
21+
for layer in model.layers:
22+
if isinstance(layer, nn.Linear):
23+
# Append input and output features of the Linear layer
24+
dims.append(layer.in_features)
25+
dims.append(layer.out_features)
26+
return np.array(dims)

test/test_extract_linear_dims.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import numpy as np
2+
import torch.nn as nn
3+
import pytest
4+
from spotpython.torch.dimensions import extract_linear_dims
5+
6+
class NNLinearRegressor(nn.Module):
7+
def __init__(self):
8+
super(NNLinearRegressor, self).__init__()
9+
self.layers = nn.Sequential(
10+
nn.Linear(10, 8),
11+
nn.ReLU(),
12+
nn.Dropout(0.01),
13+
nn.Linear(8, 4),
14+
nn.ReLU(),
15+
nn.Dropout(0.01),
16+
nn.Linear(4, 4),
17+
nn.ReLU(),
18+
nn.Dropout(0.01),
19+
nn.Linear(4, 2),
20+
nn.ReLU(),
21+
nn.Dropout(0.01),
22+
nn.Linear(2, 2),
23+
nn.ReLU(),
24+
nn.Dropout(0.01),
25+
nn.Linear(2, 2),
26+
nn.ReLU(),
27+
nn.Dropout(0.01),
28+
nn.Linear(2, 1),
29+
)
30+
31+
def test_extract_linear_dims():
32+
net = NNLinearRegressor()
33+
expected_dims = np.array([10, 8, 8, 4, 4, 4, 4, 2, 2, 2, 2, 2, 2, 1])
34+
result = extract_linear_dims(net)
35+
assert np.array_equal(result, expected_dims), f"Expected {expected_dims}, but got {result}"
36+
37+
if __name__ == "__main__":
38+
pytest.main([__file__])

0 commit comments

Comments
 (0)