|
1 | 1 | import torch |
| 2 | +from torchviz import make_dot |
2 | 3 | from torch.utils.data import DataLoader |
3 | 4 | import matplotlib.pyplot as plt |
4 | 5 | import math |
|
16 | 17 | from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation |
17 | 18 | from matplotlib.ticker import MaxNLocator |
18 | 19 | from spotpython.data.lightdatamodule import LightDataModule |
| 20 | +from spotpython.torch.dimensions import extract_linear_dims |
19 | 21 |
|
20 | 22 |
|
21 | 23 | def check_for_nans(data, layer_index) -> bool: |
@@ -1017,3 +1019,101 @@ def sort_layers(data_dict) -> dict: |
1017 | 1019 | # Create a new dictionary from the sorted items |
1018 | 1020 | sorted_dict = dict(sorted_items) |
1019 | 1021 | return sorted_dict |
| 1022 | + |
| 1023 | + |
| 1024 | +def viz_net( |
| 1025 | + net, |
| 1026 | + device="cpu", |
| 1027 | + show_attrs=False, |
| 1028 | + show_saved=False, |
| 1029 | + max_attr_chars=50, |
| 1030 | + filename="model_architecture", |
| 1031 | + format="png", |
| 1032 | +) -> None: |
| 1033 | + """ |
| 1034 | + Visualize the architecture of a linear neural network. |
| 1035 | + Produces Graphviz representation of PyTorch autograd graph. |
| 1036 | + If a node represents a backward function, it is gray. Otherwise, the node represents a tensor and is either blue, orange, or green: |
| 1037 | + - Blue: reachable leaf tensors that requires grad (tensors whose .grad fields will be populated during .backward()) |
| 1038 | + - Orange: saved tensors of custom autograd functions as well as those saved by built-in backward nodes |
| 1039 | + - Green: tensor passed in as outputs |
| 1040 | + - Dark green: if any output is a view, we represent its base tensor with a dark green node. |
| 1041 | + If `show_attrs`=True and `show_saved`=True it is shown what autograd saves for the backward pass. |
| 1042 | +
|
| 1043 | + Args: |
| 1044 | + net (nn.Module): |
| 1045 | + The neural network model. |
| 1046 | + device (str, optional): |
| 1047 | + The device to use. Defaults to "cpu". |
| 1048 | + show_attrs (bool, optional): |
| 1049 | + whether to display non-tensor attributes of backward nodes (Requires PyTorch version >= 1.9) |
| 1050 | + show_saved (bool, optional): |
| 1051 | + whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed. (Requires PyTorch version >= 1.9) |
| 1052 | + max_attr_chars (int, optional): |
| 1053 | + if show_attrs is True, sets max number of characters to display for any given attribute. Defaults to 50. |
| 1054 | + filename (str, optional): |
| 1055 | + The filename. Defaults to "model_architecture". |
| 1056 | + format (str, optional): |
| 1057 | + The output format. Defaults to "png". |
| 1058 | +
|
| 1059 | + Returns: |
| 1060 | + None |
| 1061 | +
|
| 1062 | + Raises: |
| 1063 | + ValueError: If the model does not have a linear layer. |
| 1064 | + TypeError: If the network structure or parameters are invalid. |
| 1065 | + RuntimeError: If an unexpected error occurs. |
| 1066 | +
|
| 1067 | + Examples: |
| 1068 | + >>> from spotpython.plot.xai import viz_net |
| 1069 | + from spotpython.utils.init import fun_control_init |
| 1070 | + from spotpython.data.diabetes import Diabetes |
| 1071 | + from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor |
| 1072 | + from spotpython.hyperdict.light_hyper_dict import LightHyperDict |
| 1073 | + from spotpython.hyperparameters.values import ( |
| 1074 | + get_default_hyperparameters_as_array, get_one_config_from_X) |
| 1075 | + from spotpython.hyperdict.light_hyper_dict import LightHyperDict |
| 1076 | + _L_in=10 |
| 1077 | + _L_out=1 |
| 1078 | + _torchmetric="mean_squared_error" |
| 1079 | + fun_control = fun_control_init( |
| 1080 | + _L_in=_L_in, |
| 1081 | + _L_out=_L_out, |
| 1082 | + _torchmetric=_torchmetric, |
| 1083 | + data_set=Diabetes(), |
| 1084 | + core_model=NNLinearRegressor, |
| 1085 | + hyperdict=LightHyperDict) |
| 1086 | + X = get_default_hyperparameters_as_array(fun_control) |
| 1087 | + config = get_one_config_from_X(X, fun_control) |
| 1088 | + # _L_in = fun_control["_L_in"] |
| 1089 | + # _L_out = fun_control["_L_out"] |
| 1090 | + # _torchmetric = fun_control["_torchmetric"] |
| 1091 | + model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric) |
| 1092 | + viz_net(net=model, device="cpu", show_attrs=True, show_saved=True, filename="model_architecture3", format="png") |
| 1093 | +
|
| 1094 | + """ |
| 1095 | + try: |
| 1096 | + dim = extract_linear_dims(net) |
| 1097 | + except ValueError as ve: |
| 1098 | + error_message = "The model does not have a linear layer: " + str(ve) |
| 1099 | + raise ValueError(error_message) |
| 1100 | + except TypeError as te: |
| 1101 | + error_message = "Invalid network structure or parameters: " + str(te) |
| 1102 | + raise TypeError(error_message) |
| 1103 | + except Exception as e: |
| 1104 | + # Catch any other unforeseen exceptions and log them for debugging purposes |
| 1105 | + error_message = "An unexpected error occurred: " + str(e) |
| 1106 | + raise RuntimeError(error_message) |
| 1107 | + |
| 1108 | + # Proceed with the rest of the logic if dimensions were extracted successfully |
| 1109 | + x = torch.randn(1, dim[0]).requires_grad_(True) |
| 1110 | + x = x.to(device) |
| 1111 | + output = net(x) |
| 1112 | + dot = make_dot( |
| 1113 | + output, |
| 1114 | + params=dict(net.named_parameters()), |
| 1115 | + show_attrs=show_attrs, |
| 1116 | + show_saved=show_saved, |
| 1117 | + max_attr_chars=max_attr_chars, |
| 1118 | + ) |
| 1119 | + dot.render(filename, format=format) |
0 commit comments