Skip to content

Commit a0d4d90

Browse files
0.14.17
attributions
1 parent 391293d commit a0d4d90

3 files changed

Lines changed: 133 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.16"
10+
version = "0.14.17"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/diabetes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,4 @@ def get_names(self) -> list:
152152
print(dataset.get_names())
153153
["age", "sex", "bmi", "bp", "tc", "ldl", "hdl", "tch", "ltg", "glu"]
154154
"""
155-
return ["age", "sex", "bmi", "bp", "tc", "ldl", "hdl", "tch", "ltg", "glu"]
155+
return ["age", "sex", "bmi", "bp", "s1_tc", "s2_ldl", "s3_hdl", "s4_tch", "s5_ltg", "s6_glu"]

src/spotPython/plot/xai.py

Lines changed: 131 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import pandas as pd
1515
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
1616
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
17+
from matplotlib.ticker import MaxNLocator
1718

1819

1920
def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
@@ -87,16 +88,21 @@ def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
8788
return activations
8889

8990

90-
def get_weights(net) -> dict:
91+
def get_weights(net, return_index=False) -> dict:
9192
"""
9293
Get the weights of a neural network.
9394
9495
Args:
9596
net (object):
9697
A neural network.
98+
return_index (bool, optional):
99+
Whether to return the index. Defaults to False.
97100
98101
Returns:
99-
dict: A dictionary with the weights of the neural network.
102+
dict:
103+
A dictionary with the weights of the neural network.
104+
index (list):
105+
The layer index list.
100106
101107
Examples:
102108
>>> from torch.utils.data import DataLoader
@@ -150,13 +156,19 @@ def get_weights(net) -> dict:
150156
151157
"""
152158
weights = {}
159+
index = []
153160
for name, param in net.named_parameters():
154161
if name.endswith(".bias"):
155162
continue
163+
# add (int(name.split(".")[1])) to the index list
164+
index.append(int(name.split(".")[1]))
156165
key_name = f"Layer {name.split('.')[1]}"
157166
weights[key_name] = param.detach().view(-1).cpu().numpy()
158167
# print(f"weights: {weights}")
159-
return weights
168+
if return_index:
169+
return weights, index
170+
else:
171+
return weights
160172

161173

162174
def get_gradients(net, fun_control, batch_size, device="cpu") -> dict:
@@ -573,7 +585,6 @@ def get_attributions(
573585
baseline=None,
574586
abs_attr=True,
575587
n_rel=5,
576-
feature_names=None,
577588
):
578589
"""Get the attributions of a neural network.
579590
@@ -590,12 +601,15 @@ def get_attributions(
590601
Whether the method should sort by the absolute attribution values. Defaults to True.
591602
n_rel (int, optional):
592603
The number of relevant features. Defaults to 5.
593-
feature_names (list, optional):
594-
The feature names. Defaults to None.
595604
596605
Returns:
597606
pd.DataFrame: A DataFrame with the attributions.
598607
"""
608+
try:
609+
fun_control["data_set"].names
610+
except AttributeError:
611+
fun_control["data_set"].names = None
612+
feature_names = fun_control["data_set"].names
599613
total_attributions = None
600614
config = get_tuned_architecture(spot_tuner, fun_control)
601615
train_model(config, fun_control, timestamp=False)
@@ -698,3 +712,114 @@ def is_square(n):
698712
False
699713
"""
700714
return n == int(math.sqrt(n)) ** 2
715+
716+
717+
def get_layer_conductance(spot_tuner, fun_control, layer_idx):
718+
"""
719+
Compute the average layer conductance attributions for a specified layer in the model.
720+
721+
Args:
722+
spot_tuner (spot.Spot):
723+
The spot tuner object containing the trained model.
724+
fun_control (dict):
725+
The fun_control dictionary containing the hyperparameters used to train the model.
726+
layer_idx (int):
727+
Index of the layer for which to compute layer conductance attributions.
728+
729+
Returns:
730+
numpy.ndarray:
731+
An array containing the average layer conductance attributions for the specified layer.
732+
The shape of the array corresponds to the shape of the attributions.
733+
"""
734+
try:
735+
fun_control["data_set"].names
736+
except AttributeError:
737+
fun_control["data_set"].names = None
738+
feature_names = fun_control["data_set"].names
739+
740+
config = get_tuned_architecture(spot_tuner, fun_control)
741+
train_model(config, fun_control, timestamp=False)
742+
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
743+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
744+
model = model.to("cpu")
745+
model.eval()
746+
747+
dataset = fun_control["data_set"]
748+
n_features = dataset.data.shape[1]
749+
if feature_names is None:
750+
feature_names = [f"x{i}" for i in range(n_features)]
751+
batch_size = config["batch_size"]
752+
# train_loader = DataLoader(dataset, batch_size=batch_size)
753+
test_loader = DataLoader(dataset, batch_size=batch_size)
754+
755+
total_layer_attributions = None
756+
layers = model.layers
757+
print("Conductance analysis for layer: ", layers[layer_idx])
758+
lc = LayerConductance(model, layers[layer_idx])
759+
760+
for inputs, labels in test_loader:
761+
lc_attr_test = lc.attribute(inputs, n_steps=10, attribute_to_layer_input=True)
762+
if total_layer_attributions is None:
763+
total_layer_attributions = lc_attr_test
764+
else:
765+
if len(lc_attr_test) == len(total_layer_attributions):
766+
total_layer_attributions += lc_attr_test
767+
768+
avg_layer_attributions = total_layer_attributions.mean(dim=0).detach().numpy()
769+
770+
return avg_layer_attributions
771+
772+
773+
def get_weights_conductance_last_layer(spot_tuner, fun_control):
774+
"""
775+
Get the weights and the conductance of the last layer.
776+
"""
777+
config = get_tuned_architecture(spot_tuner, fun_control)
778+
train_model(config, fun_control, timestamp=False)
779+
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
780+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
781+
model = model.to("cpu")
782+
model.eval()
783+
784+
weights, index = get_weights(model, return_index=True)
785+
layer_idx = index[-1]
786+
weights_last = weights[f"Layer {layer_idx}"]
787+
weights_last
788+
layer_conductance_last = get_layer_conductance(spot_tuner, fun_control, layer_idx=layer_idx)
789+
790+
return weights_last, layer_conductance_last
791+
792+
793+
def plot_conductance_last_layer(weights_last, layer_conductance_last, show=True):
794+
"""
795+
Plot the conductance of the last layer.
796+
"""
797+
fig, ax = plt.subplots(figsize=(12, 6))
798+
ax.bar(range(len(weights_last)), weights_last / weights_last.max(), label="Weights", alpha=0.5)
799+
ax.bar(
800+
range(len(layer_conductance_last)),
801+
layer_conductance_last / layer_conductance_last.max(),
802+
label="Layer Conductance",
803+
alpha=0.5,
804+
)
805+
ax.set_xlabel("Weight Index")
806+
ax.set_ylabel("Normalized Value")
807+
ax.set_title("Layer Conductance vs. Weights")
808+
ax.legend()
809+
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
810+
if show:
811+
plt.show()
812+
813+
814+
def get_all_layers_conductance(spot_tuner, fun_control):
815+
config = get_tuned_architecture(spot_tuner, fun_control)
816+
train_model(config, fun_control, timestamp=False)
817+
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
818+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
819+
model = model.to("cpu")
820+
model.eval()
821+
_, index = get_weights(model, return_index=True)
822+
layer_conductance = {}
823+
for i in index:
824+
layer_conductance[i] = get_layer_conductance(spot_tuner, fun_control, layer_idx=i)
825+
return layer_conductance

0 commit comments

Comments
 (0)