Skip to content

Commit a7e4602

Browse files
0.15.31
1 parent 26d4248 commit a7e4602

2 files changed

Lines changed: 203 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.30"
10+
version = "0.15.31"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
@@ -53,7 +53,8 @@ dependencies = [
5353
"torch",
5454
"torch-tb-profiler",
5555
"torchmetrics",
56-
"torchvision"
56+
"torchvision",
57+
"torchviz",
5758
]
5859
# dynamic = ["version"]
5960

src/spotpython/plot/xai.py

Lines changed: 200 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
1616
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
1717
from matplotlib.ticker import MaxNLocator
18+
from spotpython.data.lightdatamodule import LightDataModule
1819

1920

2021
def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
@@ -72,7 +73,7 @@ def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
7273
"""
7374
activations = {}
7475
net.eval()
75-
print(f"net: {net}")
76+
# print(f"net: {net}")
7677
dataset = fun_control["data_set"]
7778
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
7879
inputs, _ = next(iter(dataloader))
@@ -549,7 +550,7 @@ def visualize_weights(net, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
549550
plot_nn_values_scatter(nn_values=weights, nn_values_names="Weights", absolute=absolute, cmap=cmap, figsize=figsize)
550551

551552

552-
def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
553+
def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray", figsize=(6, 6), device="cpu") -> None:
553554
"""
554555
Scatter plots the gradients of a neural network.
555556
@@ -566,6 +567,8 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
566567
The colormap to use. Defaults to "gray".
567568
figsize (tuple, optional):
568569
The figure size. Defaults to (6, 6).
570+
device (str, optional):
571+
The device to use. Defaults to "cpu".
569572
570573
Returns:
571574
None
@@ -574,6 +577,7 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
574577
net,
575578
fun_control,
576579
batch_size=batch_size,
580+
device=device,
577581
)
578582
plot_nn_values_scatter(nn_values=grads, nn_values_names="Gradients", absolute=absolute, cmap=cmap, figsize=figsize)
579583

@@ -585,6 +589,7 @@ def get_attributions(
585589
baseline=None,
586590
abs_attr=True,
587591
n_rel=5,
592+
device="cpu",
588593
) -> pd.DataFrame:
589594
"""Get the attributions of a neural network.
590595
@@ -601,6 +606,8 @@ def get_attributions(
601606
Whether the method should sort by the absolute attribution values. Defaults to True.
602607
n_rel (int, optional):
603608
The number of relevant features. Defaults to 5.
609+
device (str, optional):
610+
The device to use. Defaults to "cpu".
604611
605612
Returns:
606613
pd.DataFrame (object): A DataFrame with the attributions.
@@ -615,7 +622,121 @@ def get_attributions(
615622
train_model(config, fun_control, timestamp=False)
616623
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
617624
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
618-
model = model.to("cpu")
625+
model = model.to(device)
626+
model.eval()
627+
# get feature names
628+
dataset = fun_control["data_set"]
629+
try:
630+
n_features = dataset.data.shape[1]
631+
except AttributeError:
632+
n_features = dataset.tensors[0].shape[1]
633+
if feature_names is None:
634+
feature_names = [f"x{i}" for i in range(n_features)]
635+
# get batch size
636+
batch_size = config["batch_size"]
637+
# test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
638+
639+
data_module = LightDataModule(
640+
dataset=dataset,
641+
batch_size=batch_size,
642+
test_size=fun_control["test_size"],
643+
scaler=fun_control["scaler"],
644+
verbosity=10
645+
)
646+
data_module.setup(stage="test")
647+
test_loader = data_module.test_dataloader()
648+
649+
if attr_method == "IntegratedGradients":
650+
attr = IntegratedGradients(model)
651+
elif attr_method == "DeepLift":
652+
attr = DeepLift(model)
653+
elif attr_method == "GradientShap": # Todo: would need a baseline
654+
if baseline is None:
655+
raise ValueError("baseline cannot be 'None' for GradientShap")
656+
attr = GradientShap(model)
657+
elif attr_method == "FeatureAblation":
658+
attr = FeatureAblation(model)
659+
else:
660+
raise ValueError(
661+
"""
662+
Unsupported attribution method.
663+
Please choose from 'IntegratedGradients', 'DeepLift', 'GradientShap', or 'FeatureAblation'.
664+
"""
665+
)
666+
for inputs, _ in test_loader:
667+
inputs.requires_grad_()
668+
attributions = attr.attribute(inputs, return_convergence_delta=False, baselines=baseline)
669+
if total_attributions is None:
670+
total_attributions = attributions
671+
else:
672+
if len(attributions) == len(total_attributions):
673+
total_attributions += attributions
674+
675+
# Calculation of average attribution across all batches
676+
avg_attributions = total_attributions.mean(dim=0).detach().numpy()
677+
678+
# Transformation to the absolute attribution values if abs_attr is True
679+
# Get indices of the n most important features
680+
if abs_attr is True:
681+
abs_avg_attributions = abs(avg_attributions)
682+
top_n_indices = abs_avg_attributions.argsort()[-n_rel:][::-1]
683+
else:
684+
top_n_indices = avg_attributions.argsort()[-n_rel:][::-1]
685+
686+
# Get the importance values for the top n features
687+
top_n_importances = avg_attributions[top_n_indices]
688+
689+
df = pd.DataFrame(
690+
{
691+
"Feature Index": top_n_indices,
692+
"Feature": [feature_names[i] for i in top_n_indices],
693+
attr_method + "Attribution": top_n_importances,
694+
}
695+
)
696+
return df
697+
698+
699+
def get_attributions_old(
700+
spot_tuner,
701+
fun_control,
702+
attr_method="IntegratedGradients",
703+
baseline=None,
704+
abs_attr=True,
705+
n_rel=5,
706+
device="cpu",
707+
) -> pd.DataFrame:
708+
"""Get the attributions of a neural network.
709+
710+
Args:
711+
spot_tuner (object):
712+
The spot tuner object.
713+
fun_control (dict):
714+
A dictionary with the function control.
715+
attr_method (str, optional):
716+
The attribution method. Defaults to "IntegratedGradients".
717+
baseline (torch.Tensor, optional):
718+
The baseline for the attribution methods. Defaults to None.
719+
abs_attr (bool, optional):
720+
Whether the method should sort by the absolute attribution values. Defaults to True.
721+
n_rel (int, optional):
722+
The number of relevant features. Defaults to 5.
723+
device (str, optional):
724+
The device to use. Defaults to "cpu".
725+
726+
Returns:
727+
pd.DataFrame (object): A DataFrame with the attributions.
728+
"""
729+
try:
730+
fun_control["data_set"].names
731+
except AttributeError:
732+
fun_control["data_set"].names = None
733+
feature_names = fun_control["data_set"].names
734+
total_attributions = None
735+
config = get_tuned_architecture(spot_tuner, fun_control)
736+
train_model(config, fun_control, timestamp=False)
737+
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
738+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
739+
model = model.to(device)
619740
model.eval()
620741
dataset = fun_control["data_set"]
621742
try:
@@ -626,7 +747,7 @@ def get_attributions(
626747
feature_names = [f"x{i}" for i in range(n_features)]
627748
batch_size = config["batch_size"]
628749
# train_loader = DataLoader(dataset, batch_size=batch_size)
629-
test_loader = DataLoader(dataset, batch_size=batch_size)
750+
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
630751
if attr_method == "IntegratedGradients":
631752
attr = IntegratedGradients(model)
632753
elif attr_method == "DeepLift":
@@ -717,7 +838,7 @@ def is_square(n) -> bool:
717838
return n == int(math.sqrt(n)) ** 2
718839

719840

720-
def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
841+
def get_layer_conductance(spot_tuner, fun_control, layer_idx, device="cpu") -> np.ndarray:
721842
"""
722843
Compute the average layer conductance attributions for a specified layer in the model.
723844
@@ -728,6 +849,8 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
728849
The fun_control dictionary containing the hyperparameters used to train the model.
729850
layer_idx (int):
730851
Index of the layer for which to compute layer conductance attributions.
852+
device (str, optional):
853+
The device to use. Defaults to "cpu".
731854
732855
Returns:
733856
numpy.ndarray:
@@ -744,7 +867,7 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
744867
train_model(config, fun_control, timestamp=False)
745868
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
746869
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
747-
model = model.to("cpu")
870+
model = model.to(device)
748871
model.eval()
749872

750873
dataset = fun_control["data_set"]
@@ -776,15 +899,23 @@ def get_layer_conductance(spot_tuner, fun_control, layer_idx) -> np.ndarray:
776899
return avg_layer_attributions
777900

778901

779-
def get_weights_conductance_last_layer(spot_tuner, fun_control):
902+
def get_weights_conductance_last_layer(spot_tuner, fun_control, device="cpu") -> tuple:
780903
"""
781904
Get the weights and the conductance of the last layer.
905+
906+
Args:
907+
spot_tuner (object):
908+
The spot tuner object.
909+
fun_control (dict):
910+
A dictionary with the function control.
911+
device (str, optional):
912+
The device to use. Defaults to "cpu".
782913
"""
783914
config = get_tuned_architecture(spot_tuner, fun_control)
784915
train_model(config, fun_control, timestamp=False)
785916
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
786917
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
787-
model = model.to("cpu")
918+
model = model.to(device)
788919
model.eval()
789920

790921
weights, index = get_weights(model, return_index=True)
@@ -796,11 +927,28 @@ def get_weights_conductance_last_layer(spot_tuner, fun_control):
796927
return weights_last, layer_conductance_last
797928

798929

799-
def plot_conductance_last_layer(weights_last, layer_conductance_last, show=True):
930+
def plot_conductance_last_layer(weights_last, layer_conductance_last, figsize=(12, 6), show=True) -> None:
800931
"""
801932
Plot the conductance of the last layer.
933+
934+
Args:
935+
weights_last (np.ndarray):
936+
The weights of the last layer.
937+
layer_conductance_last (np.ndarray):
938+
The conductance of the last layer.
939+
figsize (tuple, optional):
940+
The figure size. Defaults to (12, 6).
941+
show (bool, optional):
942+
Whether to show the plot. Defaults
943+
944+
Examples:
945+
>>> import numpy as np
946+
from spotpython.plot.xai import plot_conductance_last_layer
947+
weights_last = np.random.rand(10)
948+
layer_conductance_last = np.random.rand(10)
949+
plot_conductance_last_layer(weights_last, layer_conductance_last, show=True)
802950
"""
803-
fig, ax = plt.subplots(figsize=(12, 6))
951+
fig, ax = plt.subplots(figsize=figsize)
804952
ax.bar(range(len(weights_last)), weights_last / weights_last.max(), label="Weights", alpha=0.5)
805953
ax.bar(
806954
range(len(layer_conductance_last)),
@@ -817,15 +965,55 @@ def plot_conductance_last_layer(weights_last, layer_conductance_last, show=True)
817965
plt.show()
818966

819967

820-
def get_all_layers_conductance(spot_tuner, fun_control):
968+
def get_all_layers_conductance(spot_tuner, fun_control, device="cpu") -> dict:
969+
"""
970+
Get the conductance of all layers.
971+
972+
Args:
973+
spot_tuner (object):
974+
The spot tuner object.
975+
fun_control (dict):
976+
A dictionary with the function control.
977+
device (str, optional):
978+
The device to use. Defaults to "cpu".
979+
"""
821980
config = get_tuned_architecture(spot_tuner, fun_control)
822981
train_model(config, fun_control, timestamp=False)
823982
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
824983
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
825-
model = model.to("cpu")
984+
model = model.to(device)
826985
model.eval()
827986
_, index = get_weights(model, return_index=True)
828987
layer_conductance = {}
829988
for i in index:
830989
layer_conductance[i] = get_layer_conductance(spot_tuner, fun_control, layer_idx=i)
831990
return layer_conductance
991+
992+
993+
def sort_layers(data_dict) -> dict:
994+
"""
995+
Sorts a dictionary with keys in the format "Layer X" based on the numerical value X.
996+
997+
Args:
998+
data_dict (dict): A dictionary with keys in the format "Layer X".
999+
1000+
Returns:
1001+
dict: A dictionary with the keys sorted based on the numerical value X.
1002+
1003+
Examples:
1004+
>>> data_dict = {
1005+
... "Layer 1": [1, 2, 3],
1006+
... "Layer 3": [4, 5, 6],
1007+
... "Layer 2": [7, 8, 9]
1008+
... }
1009+
>>> sort_layers(data_dict)
1010+
{'Layer 1': [1, 2, 3], 'Layer 2': [7, 8, 9], 'Layer 3': [4,
1011+
1012+
"""
1013+
# Use a lambda function to extract the number X from "Layer X" and sort based on that number
1014+
sorted_items = sorted(data_dict.items(), key=lambda item: int(item[0].split()[1]))
1015+
1016+
# Create a new dictionary from the sorted items
1017+
sorted_dict = dict(sorted_items)
1018+
1019+
return sorted_dict

0 commit comments

Comments
 (0)