Skip to content

Commit 8a8fced

Browse files
committed
new xai consitency metric
1 parent 4c681a1 commit 8a8fced

2 files changed

Lines changed: 47 additions & 65 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 20 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import lightning as L
22
from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany
33
from spotpython.utils.eda import generate_config_id
4+
from spotpython.utils.metrics import calculate_xai_consistency
45
from pytorch_lightning.loggers import TensorBoardLogger
56
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
67
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -694,87 +695,41 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
694695
attr_ig = IntegratedGradients(model)
695696
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
696697
ig_attr_test_sum = attribution_ig.detach().numpy().sum(0)
697-
ig_attr_test_norm_sum = ig_attr_test_sum / np.linalg.norm(ig_attr_test_sum, ord=1)
698-
attributions_dict["IntegratedGradients"] = ig_attr_test_norm_sum
698+
row_sum_ig = np.sum(ig_attr_test_sum, axis=0)
699+
if row_sum_ig == 0:
700+
row_sum_ig += 1e-10
701+
scaled_attribution_ig = ig_attr_test_sum / row_sum_ig
702+
attributions_dict["IntegratedGradients"] = scaled_attribution_ig
699703

700704
if "KernelShap" in fun_control["xai_methods"]:
701705
attr_ks = KernelShap(model)
702706
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline)
703707
ks_attr_test_sum = attribution_ks.detach().numpy().sum(0)
704-
ks_attr_test_norm_sum = ks_attr_test_sum / np.linalg.norm(ks_attr_test_sum, ord=1)
705-
attributions_dict["KernelShap"] = ks_attr_test_norm_sum
708+
row_sum_ks = np.sum(ks_attr_test_sum, axis=0)
709+
if row_sum_ks == 0:
710+
row_sum_ks += 1e-10
711+
scaled_attribution_ks = ks_attr_test_sum / row_sum_ks
712+
attributions_dict["KernelShap"] = scaled_attribution_ks
706713

707714
if "DeepLift" in fun_control["xai_methods"]:
708715
attr_dl = DeepLift(model)
709716
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline)
710717
dl_attr_test_sum = attribution_dl.detach().numpy().sum(0)
711-
dl_attr_test_norm_sum = dl_attr_test_sum / np.linalg.norm(dl_attr_test_sum, ord=1)
712-
attributions_dict["DeepLift"] = dl_attr_test_norm_sum
718+
row_sum_dl = np.sum(dl_attr_test_sum, axis=0)
719+
if row_sum_dl == 0:
720+
row_sum_dl += 1e-10
721+
scaled_attribution_dl = dl_attr_test_sum / row_sum_dl
722+
attributions_dict["DeepLift"] = scaled_attribution_dl
713723

714724
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
715-
attributions = np.stack(attributions_list, axis=1)
725+
attributions = np.stack(attributions_list, axis=0)
716726

717-
if fun_control["xai_metric"] not in {"max_diff", "variance", "spearman", "spearman+variance"}:
718-
print("Invalid or missing xai_metric. Setting it to 'max_diff'.")
719-
fun_control["xai_metric"] = "max_diff"
727+
result_xai = calculate_xai_consistency(attributions)
720728

721-
if fun_control["xai_metric"] == "max_diff":
722-
# Compute the max difference of the attribution methods for each feature
723-
result_xai = np.max(attributions, axis=1) - np.min(attributions, axis=1)
724-
print("Maximum differences of feature attribution methods:", result_xai)
725-
result_xai = result_xai.sum()
726729

727-
if fun_control["xai_metric"] == "variance":
728-
result_xai = np.var(attributions, axis=1)
729-
print("Variance of feature attribution methods:", result_xai)
730-
result_xai = result_xai.sum()
731-
732-
if fun_control["xai_metric"] == "spearman":
733-
num_methods = attributions.shape[1]
734-
spearman_matrix = np.zeros((num_methods, num_methods)) # Store correlation values
735-
736-
for i in range(num_methods):
737-
for j in range(i + 1, num_methods): # Only compute upper triangle
738-
corr, _ = spearmanr(attributions[:, i], attributions[:, j]) # Compute Spearman correlation
739-
spearman_matrix[i, j] = corr
740-
spearman_matrix[j, i] = corr # Mirror value in symmetric matrix
741-
742-
# Extract upper triangular values (excluding diagonal)
743-
upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)]
744-
745-
# Compute mean correlation as the consistency score
746-
# Negative sign to use the result as loss of the objective function for minimization
747-
result_xai = -np.mean(upper_triangle_values)
748-
749-
print("Spearman rank correlation matrix:\n", spearman_matrix)
750-
print("Consistency Score (Mean Spearman Correlation):", -result_xai)
751-
752-
if fun_control["xai_metric"] == "spearman+variance":
753-
# Compute Spearman mean
754-
num_methods = attributions.shape[1]
755-
spearman_matrix = np.zeros((num_methods, num_methods))
756-
757-
for i in range(num_methods):
758-
for j in range(i + 1, num_methods):
759-
corr, _ = spearmanr(attributions[:, i], attributions[:, j])
760-
spearman_matrix[i, j] = corr
761-
spearman_matrix[j, i] = corr
762-
763-
upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)]
764-
mean_spearman = np.mean(upper_triangle_values)
765-
766-
# Compute attribution variance across methods for each feature
767-
variance = np.var(attributions, axis=1).mean() # mean over features
768-
769-
# Combine both (λ is a trade-off hyperparameter you define)
770-
lambda_variance = fun_control["lambda_variance"] if "lambda_variance" in fun_control else 1.0
771-
result_xai = -mean_spearman + lambda_variance * variance
772-
773-
print("Mean Spearman correlation:", mean_spearman)
774-
print("Mean Variance:", variance)
775-
print("Variance Weight:", lambda_variance)
776-
print("Combined XAI loss: ", result_xai)
777730

778731
# -------------------------------------------------------------------------------------------------------------------
779732

780733
return result["val_loss"], result_xai
734+
735+

src/spotpython/utils/metrics.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,30 @@ def get_metric_sign(metric_name):
196196
return +1
197197
else:
198198
raise ValueError(f"Metric '{metric_name}' not found.")
199+
200+
201+
202+
def calculate_xai_consistency(attributions):
203+
"""
204+
Calculates the consistency of XAI methods by computing the mean of the upper triangle
205+
of the correlation matrix of the provided attributions.
206+
207+
Args:
208+
attributions (np.ndarray): Array of shape (n_methods, n_features) containing
209+
the attributions from different XAI methods.
210+
211+
Returns:
212+
float: Mean value of the upper triangle of the correlation matrix.
213+
"""
214+
global_attr_np = np.array(attributions)
215+
corr_matrix = np.corrcoef(global_attr_np)
216+
print("Attribution Correlation Matrix:")
217+
print(corr_matrix)
218+
219+
# Calculate the mean of the upper triangle of the correlation matrix
220+
upper_triangle_indices = np.triu_indices_from(corr_matrix, k=1)
221+
upper_triangle_values = corr_matrix[upper_triangle_indices]
222+
result_xai = upper_triangle_values.mean()
223+
print("XAI Consistency (mean of upper triangle of correlation matrix):")
224+
print(result_xai)
225+
return result_xai

0 commit comments

Comments
 (0)