Skip to content

Commit d56eaca

Browse files
0.14.13
xai
1 parent ca60bea commit d56eaca

6 files changed

Lines changed: 301 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.11"
10+
version = "0.14.13"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/california.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
from sklearn.datasets import fetch_california_housing
4+
5+
6+
class CaliforniaHousing(Dataset):
7+
"""
8+
A PyTorch Dataset for regression. A toy data set from scikit-learn.
9+
Features:
10+
* MedInc median income in block group
11+
* HouseAge median house age in block group
12+
* AveRooms average number of rooms per household
13+
* AveBedrms average number of bedrooms per household
14+
* Population block group population
15+
* AveOccup average number of household members
16+
* Latitude block group latitude
17+
* Longitude block group longitude
18+
The target variable is the median house value for California districts,
19+
expressed in hundreds of thousands of Dollars ($100,000).
20+
Samples total: 20640, Dimensionality: 8, Features: real, Target: real 0.15 - 5.
21+
This dataset was derived from the 1990 U.S. census, using one row per census block group.
22+
A block group is the smallest geographical unit for which the U.S. Census Bureau publishes sample data
23+
(a block group typically has a population of 600 to 3,000 people).
24+
25+
Args:
26+
feature_type (torch.dtype): The data type of the features. Defaults to torch.float.
27+
target_type (torch.dtype): The data type of the targets. Defaults to torch.long.
28+
train (bool): Whether the dataset is for training or not. Defaults to True.
29+
n_samples (int): The number of samples of the dataset. Defaults to None, which means the entire dataset is used.
30+
31+
Attributes:
32+
data (Tensor): The data features.
33+
targets (Tensor): The data targets.
34+
35+
Examples:
36+
>>> from torch.utils.data import DataLoader
37+
from spotPython.data.diabetes import Diabetes
38+
import torch
39+
dataset = Diabetes(feature_type=torch.float32, target_type=torch.float32)
40+
# Set batch size for DataLoader
41+
batch_size = 5
42+
# Create DataLoader
43+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
44+
# Iterate over the data in the DataLoader
45+
for batch in dataloader:
46+
inputs, targets = batch
47+
print(f"Batch Size: {inputs.size(0)}")
48+
print("---------------")
49+
print(f"Inputs: {inputs}")
50+
print(f"Targets: {targets}")
51+
"""
52+
53+
def __init__(
54+
self,
55+
feature_type: torch.dtype = torch.float,
56+
target_type: torch.dtype = torch.float,
57+
train: bool = True,
58+
n_samples: int = None,
59+
) -> None:
60+
super().__init__()
61+
self.feature_type = feature_type
62+
self.target_type = target_type
63+
self.train = train
64+
self.n_samples = n_samples
65+
self.data, self.targets = self._load_data()
66+
67+
def _load_data(self) -> tuple:
68+
"""Loads the data from scikit-learn and returns the features and targets.
69+
70+
Returns:
71+
tuple: A tuple containing the features and targets.
72+
73+
Examples:
74+
>>> from spotPython.data.diabetes import Diabetes
75+
dataset = Diabetes()
76+
print(dataset.data.shape)
77+
print(dataset.targets.shape)
78+
torch.Size([442, 10])
79+
torch.Size([442])
80+
"""
81+
feature_df, target_df = fetch_california_housing(return_X_y=True, as_frame=True)
82+
if self.n_samples is not None:
83+
feature_df = feature_df[: self.n_samples]
84+
target_df = target_df[: self.n_samples]
85+
# Convert DataFrames to PyTorch tensors
86+
feature_tensor = torch.tensor(feature_df.values, dtype=self.feature_type)
87+
target_tensor = torch.tensor(target_df.values, dtype=self.target_type)
88+
89+
return feature_tensor, target_tensor
90+
91+
def __getitem__(self, idx: int) -> tuple:
92+
"""
93+
Returns the feature and target at the given index.
94+
95+
Args:
96+
idx (int): The index.
97+
98+
Returns:
99+
tuple: A tuple containing the feature and target at the given index.
100+
101+
Examples:
102+
>>> from spotPython.light.csvdataset import CSVDataset
103+
dataset = CSVDataset(filename='./data/spotPython/data.csv', target_column='prognosis')
104+
print(dataset.data.shape)
105+
print(dataset.targets.shape)
106+
torch.Size([11, 65])
107+
torch.Size([11])
108+
"""
109+
feature = self.data[idx]
110+
target = self.targets[idx]
111+
return feature, target
112+
113+
def __len__(self) -> int:
114+
"""
115+
Returns the length of the dataset.
116+
117+
Returns:
118+
int: The length of the dataset.
119+
120+
Examples:
121+
>>> from spotPython.light import CSVDataset
122+
>>> dataset = CSVDataset()
123+
>>> print(len(dataset))
124+
60000
125+
126+
"""
127+
return len(self.data)
128+
129+
def extra_repr(self) -> str:
130+
"""
131+
Returns a string representation of the dataset.
132+
133+
Returns:
134+
str: A string representation of the dataset.
135+
136+
Examples:
137+
>>> from spotPython.light import CSVDataset
138+
>>> dataset = CSVDataset()
139+
>>> print(dataset)
140+
Split: Train
141+
142+
"""
143+
split = "Train" if self.train else "Test"
144+
return f"Split: {split}"

src/spotPython/light/loadmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def load_light_from_checkpoint(config: dict, fun_control: dict, postfix: str = "
2121
A dictionary containing the function control parameters.
2222
postfix (str):
2323
The postfix to append to the configuration ID when generating the checkpoint path.
24+
Default is "_TEST". Can be set to "_TRAIN" for training checkpoints.
2425
2526
Returns:
2627
Any: The loaded model.

src/spotPython/light/trainmodel.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,22 @@
44
from pytorch_lightning.loggers import TensorBoardLogger
55
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
66
from spotPython.torch.initialization import kaiming_init, xavier_init
7+
from lightning.pytorch.callbacks import ModelCheckpoint
78
import os
89

910

10-
def train_model(config: dict, fun_control: dict) -> float:
11+
def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> float:
1112
"""
1213
Trains a model using the given configuration and function control parameters.
1314
1415
Args:
15-
config (dict): A dictionary containing the configuration parameters for the model.
16-
fun_control (dict): A dictionary containing the function control parameters.
16+
config (dict):
17+
A dictionary containing the configuration parameters for the model.
18+
fun_control (dict):
19+
A dictionary containing the function control parameters.
20+
timestamp (bool):
21+
A boolean value indicating whether to include a timestamp in the config id. Default is True.
22+
If False, the string "_TRAIN" is appended to the config id.
1723
1824
Returns:
1925
float: The validation loss of the trained model.
@@ -72,9 +78,15 @@ def train_model(config: dict, fun_control: dict) -> float:
7278
enable_progress_bar = False
7379
else:
7480
enable_progress_bar = fun_control["enable_progress_bar"]
75-
# config id is unique. Since the model is not loaded from a checkpoint,
76-
# the config id is generated here with a timestamp.
77-
config_id = generate_config_id(config, timestamp=True)
81+
if timestamp:
82+
# config id is unique. Since the model is not loaded from a checkpoint,
83+
# the config id is generated here with a timestamp.
84+
config_id = generate_config_id(config, timestamp=True)
85+
else:
86+
# config id is not time-dependent and therefore unique,
87+
# so that the model can be loaded from a checkpoint,
88+
# the config id is generated here without a timestamp.
89+
config_id = generate_config_id(config, timestamp=False) + "_TRAIN"
7890
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
7991
initialization = config["initialization"]
8092
if initialization == "Xavier":
@@ -97,6 +109,16 @@ def train_model(config: dict, fun_control: dict) -> float:
97109
# print(f"train_model(): Train set size: {len(dm.data_train)}")
98110
# print(f"train_model(): Batch size: {config['batch_size']}")
99111

112+
# Callbacks
113+
callbacks = [
114+
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
115+
]
116+
if not timestamp:
117+
# add ModelCheckpoint only if timestamp is False
118+
callbacks.append(
119+
ModelCheckpoint(dirpath=os.path.join(fun_control["CHECKPOINT_PATH"], config_id), save_last=True)
120+
) # Save the last checkpoint
121+
100122
# Init trainer
101123
trainer = L.Trainer(
102124
# Where to save models
@@ -110,9 +132,7 @@ def train_model(config: dict, fun_control: dict) -> float:
110132
default_hp_metric=True,
111133
log_graph=fun_control["log_graph"],
112134
),
113-
callbacks=[
114-
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
115-
],
135+
callbacks=callbacks,
116136
enable_progress_bar=enable_progress_bar,
117137
)
118138
# Pass the datamodule as arg to trainer.fit to override model hooks :)

src/spotPython/plot/xai.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99
import matplotlib.colors as colors
10+
from spotPython.hyperparameters.values import get_tuned_architecture
11+
from spotPython.light.trainmodel import train_model
12+
from spotPython.light.loadmodel import load_light_from_checkpoint
13+
from spotPython.utils.classes import get_removed_attributes_and_base_net
14+
import pandas as pd
15+
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
16+
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
1017

1118

1219
def get_activations(net, fun_control, batch_size, device="cpu") -> dict:
@@ -472,3 +479,119 @@ def visualize_gradients(net, fun_control, batch_size, absolute=True, cmap="gray"
472479
batch_size=batch_size,
473480
)
474481
plot_nn_values_scatter(nn_values=grads, nn_values_names="Gradients", absolute=absolute, cmap=cmap, figsize=figsize)
482+
483+
484+
def get_attributions(
485+
spot_tuner,
486+
fun_control,
487+
attr_method="IntegratedGradients",
488+
baseline=None,
489+
abs_attr=True,
490+
n_rel=5,
491+
feature_names=None,
492+
):
493+
"""Get the attributions of a neural network.
494+
495+
Args:
496+
spot_tuner (object):
497+
The spot tuner object.
498+
fun_control (dict):
499+
A dictionary with the function control.
500+
attr_method (str, optional):
501+
The attribution method. Defaults to "IntegratedGradients".
502+
baseline (torch.Tensor, optional):
503+
The baseline for the attribution methods. Defaults to None.
504+
abs_attr (bool, optional):
505+
Whether the method should sort by the absolute attribution values. Defaults to True.
506+
n_rel (int, optional):
507+
The number of relevant features. Defaults to 5.
508+
feature_names (list, optional):
509+
The feature names. Defaults to None.
510+
511+
Returns:
512+
pd.DataFrame: A DataFrame with the attributions.
513+
"""
514+
total_attributions = None
515+
config = get_tuned_architecture(spot_tuner, fun_control)
516+
train_model(config, fun_control, timestamp=False)
517+
model_loaded = load_light_from_checkpoint(config, fun_control, postfix="_TRAIN")
518+
removed_attributes, model = get_removed_attributes_and_base_net(net=model_loaded)
519+
model = model.to("cpu")
520+
model.eval()
521+
dataset = fun_control["data_set"]
522+
n_features = dataset.data.shape[1]
523+
if feature_names is None:
524+
feature_names = [f"x{i}" for i in range(n_features)]
525+
batch_size = config["batch_size"]
526+
# train_loader = DataLoader(dataset, batch_size=batch_size)
527+
test_loader = DataLoader(dataset, batch_size=batch_size)
528+
if attr_method == "IntegratedGradients":
529+
attr = IntegratedGradients(model)
530+
elif attr_method == "DeepLift":
531+
attr = DeepLift(model)
532+
elif attr_method == "GradientShap": # Todo: would need a baseline
533+
if baseline is None:
534+
raise ValueError("baseline cannot be 'None' for GradientShap")
535+
attr = GradientShap(model)
536+
elif attr_method == "FeatureAblation":
537+
attr = FeatureAblation(model)
538+
else:
539+
raise ValueError(
540+
"""
541+
Unsupported attribution method.
542+
Please choose from 'IntegratedGradients', 'DeepLift', 'GradientShap', or 'FeatureAblation'.
543+
"""
544+
)
545+
for inputs, labels in test_loader:
546+
attributions = attr.attribute(inputs, return_convergence_delta=False, baselines=baseline)
547+
if total_attributions is None:
548+
total_attributions = attributions
549+
else:
550+
if len(attributions) == len(total_attributions):
551+
total_attributions += attributions
552+
553+
# Calculation of average attribution across all batches
554+
avg_attributions = total_attributions.mean(dim=0).detach().numpy()
555+
556+
# Transformation to the absolute attribution values if abs_attr is True
557+
# Get indices of the n most important features
558+
if abs_attr is True:
559+
abs_avg_attributions = abs(avg_attributions)
560+
top_n_indices = abs_avg_attributions.argsort()[-n_rel:][::-1]
561+
else:
562+
top_n_indices = avg_attributions.argsort()[-n_rel:][::-1]
563+
564+
# Get the importance values for the top n features
565+
top_n_importances = avg_attributions[top_n_indices]
566+
567+
df = pd.DataFrame(
568+
{
569+
"Feature Index": top_n_indices,
570+
"Feature": [feature_names[i] for i in top_n_indices],
571+
attr_method + "Attribution": top_n_importances,
572+
}
573+
)
574+
return df
575+
576+
577+
def plot_attributions(df, attr_method="IntegratedGradients"):
578+
"""
579+
Plot the attributions of a neural network.
580+
581+
Args:
582+
df (pd.DataFrame):
583+
A DataFrame with the attributions.
584+
attr_method (str, optional):
585+
The attribution method. Defaults to "IntegratedGradients".
586+
587+
Returns:
588+
None
589+
590+
"""
591+
sns.set_theme(style="whitegrid")
592+
plt.figure(figsize=(10, 6))
593+
sns.barplot(x=attr_method + "Attribution", y="Feature", data=df, palette="viridis", hue="Feature")
594+
plt.title(f"Top {df.shape[0]} Features by {attr_method} Attribution")
595+
plt.xlabel(f"{attr_method} Attribution Value")
596+
plt.ylabel("Feature")
597+
plt.show()

src/spotPython/spot/spot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,9 @@ def write_db_dict(self) -> None:
708708
# Generate a description of the results:
709709
# if spot_tuner_control['min_y'] exists:
710710
try:
711-
result = f"Results for {ident}: Finally, the best value is {spot_tuner_control['min_y']} at {spot_tuner_control['min_X']}."
711+
result = f"""
712+
Results for {ident}: Finally, the best value is {spot_tuner_control['min_y']}
713+
at {spot_tuner_control['min_X']}."""
712714
#
713715
db_dict = {
714716
"data": {

0 commit comments

Comments
 (0)