|
8 | 8 | from captum.attr import IntegratedGradients, DeepLift, KernelShap |
9 | 9 | import torch |
10 | 10 | import os |
| 11 | +from scipy.stats import spearmanr |
| 12 | + |
11 | 13 |
|
12 | 14 | import numpy as np |
13 | 15 |
|
@@ -693,11 +695,41 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) -> |
693 | 695 | attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]] |
694 | 696 | attributions = np.stack(attributions_list, axis=1) |
695 | 697 |
|
696 | | - # Compute the max difference of the attribution methods for each feature |
697 | | - max_diffs = np.max(attributions, axis=1) - np.min(attributions, axis=1) |
| 698 | + if fun_control["xai_metric"] not in {"max_diff", "variance", "spearman"}: |
| 699 | + print("Invalid or missing xai_metric. Setting it to 'max_diff'.") |
| 700 | + fun_control["xai_metric"] = "max_diff" |
| 701 | + |
| 702 | + if fun_control["xai_metric"] == "max_diff": |
| 703 | + # Compute the max difference of the attribution methods for each feature |
| 704 | + result_xai = np.max(attributions, axis=1) - np.min(attributions, axis=1) |
| 705 | + print("Maximum differences of feature attribution methods:", result_xai) |
| 706 | + result_xai = result_xai.sum() |
| 707 | + |
| 708 | + if fun_control["xai_metric"] == "variance": |
| 709 | + result_xai = np.var(attributions, axis=1) |
| 710 | + print("Variance of feature attribution methods:", result_xai) |
| 711 | + result_xai = result_xai.sum() |
| 712 | + |
| 713 | + if fun_control["xai_metric"] == "spearman": |
| 714 | + num_methods = attributions.shape[1] |
| 715 | + spearman_matrix = np.zeros((num_methods, num_methods)) # Store correlation values |
| 716 | + |
| 717 | + for i in range(num_methods): |
| 718 | + for j in range(i + 1, num_methods): # Only compute upper triangle |
| 719 | + corr, _ = spearmanr(attributions[:, i], attributions[:, j]) # Compute Spearman correlation |
| 720 | + spearman_matrix[i, j] = corr |
| 721 | + spearman_matrix[j, i] = corr # Mirror value in symmetric matrix |
| 722 | + |
| 723 | + # Extract upper triangular values (excluding diagonal) |
| 724 | + upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)] |
| 725 | + |
| 726 | + # Compute mean correlation as the consistency score |
| 727 | + # Negative sign to use the result as loss of the objective function for minimization |
| 728 | + result_xai = -np.mean(upper_triangle_values) |
698 | 729 |
|
699 | | - print("MAX DIFFS:", max_diffs) |
| 730 | + print("Spearman rank correlation matrix:\n", spearman_matrix) |
| 731 | + print("Consistency Score (Mean Spearman Correlation):", -result_xai) |
700 | 732 |
|
701 | 733 | # ------------------------------------------------------------------------------------------------------------------- |
702 | 734 |
|
703 | | - return result["val_loss"], max_diffs.sum() |
| 735 | + return result["val_loss"], result_xai |
0 commit comments