Skip to content

Commit 75d87a9

Browse files
committed
implementation of spearman correlation
1 parent a2b759e commit 75d87a9

2 files changed

Lines changed: 36 additions & 3 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import lightning as L
22
from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany
33
from spotpython.utils.eda import generate_config_id
4-
from spotpython.utils.metrics import calculate_xai_consistency_corr, calculate_xai_consistency_cosine, calculate_xai_consistency_euclidean
4+
from spotpython.utils.metrics import calculate_xai_consistency_corr, calculate_xai_consistency_cosine, calculate_xai_consistency_euclidean, calculate_xai_consistency_spearman
55
from pytorch_lightning.loggers import TensorBoardLogger
66
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
77
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -723,13 +723,15 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
723723
attributions = np.stack(attributions_list, axis=0)
724724

725725
# Calculate corr:
726-
if fun_control["xai_metric"] not in ["corr", "cosine", "euclidean"]:
727-
raise ValueError(f"Invalid xai_metric: {fun_control['xai_metric']}. Valid metrics are: 'corr', 'cosine', 'euclidean'")
726+
if fun_control["xai_metric"] not in ["corr", "cosine", "euclidean", "spearman"]:
727+
raise ValueError(f"Invalid xai_metric: {fun_control['xai_metric']}. Valid metrics are: 'corr', 'cosine', 'euclidean', 'spearman'")
728728
if fun_control["xai_metric"] == "corr":
729729
result_xai = calculate_xai_consistency_corr(attributions)
730730
elif fun_control["xai_metric"] == "cosine":
731731
result_xai = calculate_xai_consistency_cosine(attributions)
732732
elif fun_control["xai_metric"] == "euclidean":
733733
result_xai = calculate_xai_consistency_euclidean(attributions)
734+
elif fun_control["xai_metric"] == "spearman":
735+
result_xai = calculate_xai_consistency_spearman(attributions)
734736

735737
return result["val_loss"], result_xai

src/spotpython/utils/metrics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import numpy as np
2828
from spotpython.utils.convert import series_to_array
2929
from sklearn.metrics.pairwise import euclidean_distances
30+
from scipy.stats import spearmanr
31+
3032

3133

3234
def apk(actual, predicted, k=10):
@@ -275,3 +277,32 @@ def calculate_xai_consistency_euclidean(attributions):
275277
print("XAI Consistency (mean of upper triangle of Euclidean distance matrix):")
276278
print(result_xai)
277279
return result_xai
280+
281+
def calculate_xai_consistency_spearman(attributions):
282+
"""
283+
Calculates the consistency of XAI methods using Spearman rank correlation.
284+
285+
Args:
286+
attributions (np.ndarray): shape (n_methods, n_features)
287+
288+
Returns:
289+
float: Mean of upper triangle of Spearman correlation matrix (excluding diagonal)
290+
"""
291+
attributions = np.array(attributions)
292+
n_methods = attributions.shape[0]
293+
294+
spearman_corr_matrix = np.zeros((n_methods, n_methods))
295+
for i in range(n_methods):
296+
for j in range(n_methods):
297+
corr = spearmanr(attributions[i], attributions[j]).correlation
298+
if np.isnan(corr):
299+
corr = 0.0
300+
spearman_corr_matrix[i, j] = corr
301+
302+
upper_triangle_values = spearman_corr_matrix[np.triu_indices(n_methods, k=1)]
303+
print("Attribution Spearman Correlation Matrix:")
304+
print(spearman_corr_matrix)
305+
306+
print("XAI Consistency (mean of upper triangle of Spearman correlation matrix):")
307+
print(upper_triangle_values.mean())
308+
return upper_triangle_values.mean()

0 commit comments

Comments
 (0)