Skip to content

Commit 2319ff4

Browse files
committed
implementations of new xai consistency metrics
1 parent 3b94632 commit 2319ff4

3 files changed

Lines changed: 171 additions & 77 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import lightning as L
22
from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany
33
from spotpython.utils.eda import generate_config_id
4-
from spotpython.utils.metrics import calculate_xai_consistency
4+
from spotpython.utils.metrics import calculate_xai_consistency_corr, calculate_xai_consistency_cosine, calculate_xai_consistency_euclidean
55
from pytorch_lightning.loggers import TensorBoardLogger
66
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
77
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -694,38 +694,38 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
694694
if "IntegratedGradients" in fun_control["xai_methods"]:
695695
attr_ig = IntegratedGradients(model)
696696
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
697-
ig_attr_test_sum = attribution_ig.detach().numpy().sum(0)
698-
row_sum_ig = np.sum(ig_attr_test_sum, axis=0)
699-
if row_sum_ig == 0:
700-
row_sum_ig += 1e-10
701-
scaled_attribution_ig = ig_attr_test_sum / row_sum_ig
702-
attributions_dict["IntegratedGradients"] = scaled_attribution_ig
697+
ig_attr_test_sum = attribution_ig.detach().numpy().sum(axis=0)
698+
l2_norm = np.linalg.norm(ig_attr_test_sum)
699+
l2_normalized_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
700+
attributions_dict["IntegratedGradients"] = l2_normalized_ig
703701

704702
if "KernelShap" in fun_control["xai_methods"]:
705703
attr_ks = KernelShap(model)
706704
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline)
707-
ks_attr_test_sum = attribution_ks.detach().numpy().sum(0)
708-
row_sum_ks = np.sum(ks_attr_test_sum, axis=0)
709-
if row_sum_ks == 0:
710-
row_sum_ks += 1e-10
711-
scaled_attribution_ks = ks_attr_test_sum / row_sum_ks
712-
attributions_dict["KernelShap"] = scaled_attribution_ks
705+
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
706+
l2_norm = np.linalg.norm(ks_attr_test_sum)
707+
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
708+
attributions_dict["KernelShap"] = l2_normalized_ks
713709

714710
if "DeepLift" in fun_control["xai_methods"]:
715711
attr_dl = DeepLift(model)
716712
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline)
717-
dl_attr_test_sum = attribution_dl.detach().numpy().sum(0)
718-
row_sum_dl = np.sum(dl_attr_test_sum, axis=0)
719-
if row_sum_dl == 0:
720-
row_sum_dl += 1e-10
721-
scaled_attribution_dl = dl_attr_test_sum / row_sum_dl
722-
attributions_dict["DeepLift"] = scaled_attribution_dl
713+
dl_attr_test_sum = attribution_dl.detach().numpy().sum(axis=0)
714+
l2_norm = np.linalg.norm(dl_attr_test_sum)
715+
l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
716+
attributions_dict["DeepLift"] = l2_normalized_dl
723717

724718
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
725719
attributions = np.stack(attributions_list, axis=0)
726720

727-
result_xai = calculate_xai_consistency(attributions)
728-
729-
# -------------------------------------------------------------------------------------------------------------------
721+
# Calculate corr:
722+
if fun_control["xai_metric"] not in ["corr", "cosine", "euclidean"]:
723+
raise ValueError(f"Invalid xai_metric: {fun_control['xai_metric']}. Valid metrics are: 'corr', 'cosine', 'euclidean'")
724+
if fun_control["xai_metric"] == "corr":
725+
result_xai = calculate_xai_consistency_corr(attributions)
726+
elif fun_control["xai_metric"] == "cosine":
727+
result_xai = calculate_xai_consistency_cosine(attributions)
728+
elif fun_control["xai_metric"] == "euclidean":
729+
result_xai = calculate_xai_consistency_euclidean(attributions)
730730

731731
return result["val_loss"], result_xai

src/spotpython/utils/metrics.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import numpy as np
2828
from spotpython.utils.convert import series_to_array
29+
from sklearn.metrics.pairwise import euclidean_distances
2930

3031

3132
def apk(actual, predicted, k=10):
@@ -198,43 +199,17 @@ def get_metric_sign(metric_name):
198199
raise ValueError(f"Metric '{metric_name}' not found.")
199200

200201

201-
def calculate_xai_consistency(attributions) -> float:
202-
"""Calculate the consistency between different XAI methods.
203-
Computes the pairwise correlation between different XAI methods' attributions
204-
and returns their mean correlation as a measure of consistency. A higher value
205-
indicates greater agreement between different XAI methods.
202+
def calculate_xai_consistency_corr(attributions):
203+
"""
204+
Calculates the consistency of XAI methods by computing the mean of the upper triangle
205+
of the correlation matrix of the provided attributions.
206206
207207
Args:
208208
attributions (np.ndarray): Array of shape (n_methods, n_features) containing
209-
feature importance scores from different XAI methods. Each row represents
210-
a different XAI method's attributions, and each column represents a feature.
209+
the attributions from different XAI methods.
211210
212211
Returns:
213-
float: Mean correlation between XAI methods, ranging from -1 to 1.
214-
- 1: Perfect consistency between methods
215-
- 0: No consistency between methods
216-
- -1: Perfect negative consistency between methods
217-
218-
Examples:
219-
>>> import numpy as np
220-
>>> # Three XAI methods' attributions for four features
221-
>>> attributions = np.array([
222-
... [0.1, 0.2, 0.3, 0.4], # Method 1
223-
... [0.2, 0.3, 0.4, 0.5], # Method 2
224-
... [0.0, 0.1, 0.2, 0.3] # Method 3
225-
... ])
226-
>>> consistency = calculate_xai_consistency(attributions)
227-
>>> print(f"XAI Consistency: {consistency:.2f}")
228-
Attribution Correlation Matrix:
229-
[[ 1. 0.97 0.98]
230-
[ 0.97 1. 0.99]
231-
[ 0.98 0.99 1. ]]
232-
XAI Consistency: 0.98
233-
234-
Note:
235-
The correlation matrix is computed using numpy's corrcoef function, which
236-
calculates Pearson correlation coefficients. Only the upper triangle of
237-
the correlation matrix is used to avoid counting correlations twice.
212+
float: Mean value of the upper triangle of the correlation matrix.
238213
"""
239214
global_attr_np = np.array(attributions)
240215
corr_matrix = np.corrcoef(global_attr_np)
@@ -248,3 +223,55 @@ def calculate_xai_consistency(attributions) -> float:
248223
print("XAI Consistency (mean of upper triangle of correlation matrix):")
249224
print(result_xai)
250225
return result_xai
226+
227+
228+
def calculate_xai_consistency_cosine(attributions):
229+
"""
230+
Calculates the consistency of XAI methods by computing the mean of the upper triangle
231+
of the cosine similarity matrix of the provided attributions.
232+
233+
Args:
234+
attributions (np.ndarray): Array of shape (n_methods, n_features) containing
235+
the attributions from different XAI methods.
236+
237+
Returns:
238+
float: Mean value of the upper triangle of the cosine similarity matrix.
239+
"""
240+
global_attr_np = np.array(attributions)
241+
cosine_sim_matrix = np.array([[np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) for b in global_attr_np] for a in global_attr_np])
242+
print("Attribution Cosine Similarity Matrix:")
243+
print(cosine_sim_matrix)
244+
245+
# Calculate the mean of the upper triangle of the cosine similarity matrix
246+
upper_triangle_indices = np.triu_indices_from(cosine_sim_matrix, k=1)
247+
upper_triangle_values = cosine_sim_matrix[upper_triangle_indices]
248+
result_xai = upper_triangle_values.mean()
249+
print("XAI Consistency (mean of upper triangle of cosine similarity matrix):")
250+
print(result_xai)
251+
return result_xai
252+
253+
254+
def calculate_xai_consistency_euclidean(attributions):
255+
"""
256+
Calculates the consistency of XAI methods by computing the mean of the upper triangle
257+
of the Euclidean distance matrix of the provided attributions.
258+
259+
Args:
260+
attributions (np.ndarray): Array of shape (n_methods, n_features) containing
261+
the attributions from different XAI methods.
262+
263+
Returns:
264+
float: Mean value of the upper triangle of the Euclidean distance matrix.
265+
"""
266+
global_attr_np = np.array(attributions)
267+
euclidean_dist_matrix = euclidean_distances(global_attr_np)
268+
print("Attribution Euclidean Distance Matrix:")
269+
print(euclidean_dist_matrix)
270+
271+
# Calculate the mean of the upper triangle of the Euclidean distance matrix
272+
upper_triangle_indices = np.triu_indices_from(euclidean_dist_matrix, k=1)
273+
upper_triangle_values = euclidean_dist_matrix[upper_triangle_indices]
274+
result_xai = upper_triangle_values.mean()
275+
print("XAI Consistency (mean of upper triangle of Euclidean distance matrix):")
276+
print(result_xai)
277+
return result_xai

test/test_xai_consistency.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,117 @@
11
import numpy as np
2-
from spotpython.utils.metrics import calculate_xai_consistency
2+
from spotpython.utils.metrics import calculate_xai_consistency_corr, calculate_xai_consistency_cosine, calculate_xai_consistency_euclidean
33

44

5-
def test_xai_consistency():
5+
def test_xai_consistency_corr():
66
# Mock data for testing
77

88
dl_attr_test_sum = [1, 2, 3, 4, 5]
9-
row_sum_dl = np.sum(dl_attr_test_sum, axis=0)
10-
if row_sum_dl == 0:
11-
row_sum_dl += 1e-10
12-
scaled_attribution_dl = dl_attr_test_sum / row_sum_dl
9+
l2_norm = np.linalg.norm(dl_attr_test_sum)
10+
scaled_attribution_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
1311

1412
ig_attr_test_sum = [1, 2, 3, 4, 5]
15-
row_sum_ig = np.sum(ig_attr_test_sum, axis=0)
16-
if row_sum_ig == 0:
17-
row_sum_ig += 1e-10
18-
scaled_attribution_ig = ig_attr_test_sum / row_sum_ig
13+
l2_norm = np.linalg.norm(ig_attr_test_sum)
14+
scaled_attribution_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
1915

2016
attributions = [scaled_attribution_dl, scaled_attribution_ig]
21-
result = calculate_xai_consistency(attributions)
17+
result = calculate_xai_consistency_corr(attributions)
2218
print("XAI Consistency Result:")
2319
print(result)
2420

2521
# Assert that the result is 1
2622
assert abs(result - 1) < 1e-10
2723

2824

29-
def test_xai_consistency_negative():
25+
def test_xai_consistency_negative_corr():
3026
# Mock data for testing negative consistency
3127

3228
dl_attr_test_sum = [1, 2, 3, 4, 5]
33-
row_sum_dl = np.sum(dl_attr_test_sum, axis=0)
34-
if row_sum_dl == 0:
35-
row_sum_dl += 1e-10
36-
scaled_attribution_dl = dl_attr_test_sum / row_sum_dl
29+
l2_norm = np.linalg.norm(dl_attr_test_sum)
30+
scaled_attribution_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
3731

38-
ig_attr_test_sum = [-1, -2, -3, -4, -5]
39-
row_sum_ig = np.sum(np.abs(ig_attr_test_sum), axis=0)
40-
if row_sum_ig == 0:
41-
row_sum_ig += 1e-10
42-
scaled_attribution_ig = ig_attr_test_sum / row_sum_ig
32+
ig_attr_test_sum = [-2, -3, -4, -5, -6]
33+
l2_norm = np.linalg.norm(ig_attr_test_sum)
34+
scaled_attribution_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
4335

4436
attributions = [scaled_attribution_dl, scaled_attribution_ig]
45-
result = calculate_xai_consistency(attributions)
37+
result = calculate_xai_consistency_corr(attributions)
4638
print("XAI Consistency Result (Negative):")
4739
print(result)
4840

4941
# Assert that the result is -1
50-
assert abs(result + 1) < 1e-10
42+
assert abs(result + 1) < 1e-10
43+
44+
45+
def test_xai_consistency_cosine():
46+
# Mock data for testing cosine consistency
47+
48+
dl_attr_test_sum = [1, 2, 3, 4, 5]
49+
l2_norm = np.linalg.norm(dl_attr_test_sum)
50+
scaled_attribution_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
51+
ig_attr_test_sum = [1, 2, 3, 4, 5]
52+
l2_norm = np.linalg.norm(ig_attr_test_sum)
53+
scaled_attribution_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
54+
55+
attributions = [scaled_attribution_dl, scaled_attribution_ig]
56+
result = calculate_xai_consistency_cosine(attributions)
57+
print("XAI Consistency Cosine Result:")
58+
print(result)
59+
# Assert that the result is 1
60+
assert abs(result - 1) < 1e-10
61+
62+
63+
def test_xai_consistency_negative_cosine():
64+
# Mock data for testing negative cosine consistency
65+
66+
dl_attr_test_sum = [1, 2, 3, 4, 5]
67+
l2_norm = np.linalg.norm(dl_attr_test_sum)
68+
scaled_attribution_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
69+
ig_attr_test_sum = [-1, -2, -3, -4, -5]
70+
l2_norm = np.linalg.norm(ig_attr_test_sum)
71+
scaled_attribution_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
72+
73+
attributions = [scaled_attribution_dl, scaled_attribution_ig]
74+
result = calculate_xai_consistency_cosine(attributions)
75+
print("XAI Consistency Cosine Result (Negative):")
76+
print(result)
77+
78+
# Assert that the result is -1
79+
assert abs(result + 1) < 1e-10
80+
81+
82+
def test_xai_consistency_euclidean():
83+
# Mock data for testing Euclidean consistency
84+
85+
dl_attr_test_sum = [1, 2, 3, 4, 5]
86+
l2_norm = np.linalg.norm(dl_attr_test_sum)
87+
scaled_attribution_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
88+
ig_attr_test_sum = [1, 2, 3, 4, 5]
89+
l2_norm = np.linalg.norm(ig_attr_test_sum)
90+
scaled_attribution_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
91+
92+
attributions = [scaled_attribution_dl, scaled_attribution_ig]
93+
result = calculate_xai_consistency_euclidean(attributions)
94+
print("XAI Consistency Euclidean Result:")
95+
print(result)
96+
97+
# Assert that the result is close to zero
98+
assert abs(result) < 1e-10
99+
100+
101+
def test_xai_consistency_negative_euclidean():
102+
# Mock data for testing negative Euclidean consistency
103+
104+
dl_attr_test_sum = [1, 2, 3, 4, 5]
105+
l2_norm = np.linalg.norm(dl_attr_test_sum)
106+
scaled_attribution_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
107+
ig_attr_test_sum = [-1, -2, -3, -4, -5]
108+
l2_norm = np.linalg.norm(ig_attr_test_sum)
109+
scaled_attribution_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
110+
111+
attributions = [scaled_attribution_dl, scaled_attribution_ig]
112+
result = calculate_xai_consistency_euclidean(attributions)
113+
print("XAI Consistency Euclidean Result (Negative):")
114+
print(result)
115+
116+
# Assert that the result is close to two
117+
assert abs(result - 2) < 1e-10

0 commit comments

Comments
 (0)