Skip to content

Commit 001860f

Browse files
committed
update xai metrics
1 parent 4094650 commit 001860f

2 files changed

Lines changed: 32 additions & 5 deletions

File tree

src/spotpython/fun/xai_hyperlight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
169169
# Multiply results by the weights. Positive weights mean that the result is to be minimized.
170170
# Negative weights mean that the result is to be maximized, e.g., accuracy.
171171
z_val = fun_control["weights"] * df_eval
172-
print("Attribution : ", xai_attr)
173172
xai_incons = fun_control["xai_weight"] * xai_attr
174173

175174
# Append, since several configurations can be evaluated at once.

src/spotpython/light/trainmodel.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from captum.attr import IntegratedGradients, DeepLift, KernelShap
99
import torch
1010
import os
11+
from scipy.stats import spearmanr
12+
1113

1214
import numpy as np
1315

@@ -693,11 +695,37 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
693695
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
694696
attributions = np.stack(attributions_list, axis=1)
695697

696-
# Compute the max difference of the attribution methods for each feature
697-
max_diffs = np.max(attributions, axis=1) - np.min(attributions, axis=1)
698+
if fun_control["xai_metric"] == "max_diff":
699+
# Compute the max difference of the attribution methods for each feature
700+
result_xai = np.max(attributions, axis=1) - np.min(attributions, axis=1)
701+
print("Maximum differences of feature attribution methods:", result_xai)
702+
result_xai = result_xai.sum()
703+
704+
if fun_control["xai_metric"] == "variance":
705+
result_xai = np.var(attributions, axis=1)
706+
print("Variance of feature attribution methods:", result_xai)
707+
result_xai = result_xai.sum()
708+
709+
if fun_control["xai_metric"] == "spearman":
710+
num_methods = attributions.shape[1]
711+
spearman_matrix = np.zeros((num_methods, num_methods)) # Store correlation values
712+
713+
for i in range(num_methods):
714+
for j in range(i + 1, num_methods): # Only compute upper triangle
715+
corr, _ = spearmanr(attributions[:, i], attributions[:, j]) # Compute Spearman correlation
716+
spearman_matrix[i, j] = corr
717+
spearman_matrix[j, i] = corr # Mirror value in symmetric matrix
718+
719+
# Extract upper triangular values (excluding diagonal)
720+
upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)]
721+
722+
# Compute mean correlation as the consistency score
723+
# Negative sign to use the result as loss of the objective function for minimization
724+
result_xai = -np.mean(upper_triangle_values)
698725

699-
print("MAX DIFFS:", max_diffs)
726+
print("Spearman rank correlation matrix:\n", spearman_matrix)
727+
print("Consistency Score (Mean Spearman Correlation):", -result_xai)
700728

701729
# -------------------------------------------------------------------------------------------------------------------
702730

703-
return result["val_loss"], max_diffs.sum()
731+
return result["val_loss"], result_xai

0 commit comments

Comments
 (0)