|
1 | 1 | import lightning as L |
2 | 2 | from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany |
3 | 3 | from spotpython.utils.eda import generate_config_id |
| 4 | +from spotpython.utils.metrics import calculate_xai_consistency |
4 | 5 | from pytorch_lightning.loggers import TensorBoardLogger |
5 | 6 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
6 | 7 | from lightning.pytorch.callbacks import ModelCheckpoint |
@@ -694,87 +695,41 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) -> |
694 | 695 | attr_ig = IntegratedGradients(model) |
695 | 696 | attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline) |
696 | 697 | ig_attr_test_sum = attribution_ig.detach().numpy().sum(0) |
697 | | - ig_attr_test_norm_sum = ig_attr_test_sum / np.linalg.norm(ig_attr_test_sum, ord=1) |
698 | | - attributions_dict["IntegratedGradients"] = ig_attr_test_norm_sum |
| 698 | + row_sum_ig = np.sum(ig_attr_test_sum, axis=0) |
| 699 | + if row_sum_ig == 0: |
| 700 | + row_sum_ig += 1e-10 |
| 701 | + scaled_attribution_ig = ig_attr_test_sum / row_sum_ig |
| 702 | + attributions_dict["IntegratedGradients"] = scaled_attribution_ig |
699 | 703 |
|
700 | 704 | if "KernelShap" in fun_control["xai_methods"]: |
701 | 705 | attr_ks = KernelShap(model) |
702 | 706 | attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline) |
703 | 707 | ks_attr_test_sum = attribution_ks.detach().numpy().sum(0) |
704 | | - ks_attr_test_norm_sum = ks_attr_test_sum / np.linalg.norm(ks_attr_test_sum, ord=1) |
705 | | - attributions_dict["KernelShap"] = ks_attr_test_norm_sum |
| 708 | + row_sum_ks = np.sum(ks_attr_test_sum, axis=0) |
| 709 | + if row_sum_ks == 0: |
| 710 | + row_sum_ks += 1e-10 |
| 711 | + scaled_attribution_ks = ks_attr_test_sum / row_sum_ks |
| 712 | + attributions_dict["KernelShap"] = scaled_attribution_ks |
706 | 713 |
|
707 | 714 | if "DeepLift" in fun_control["xai_methods"]: |
708 | 715 | attr_dl = DeepLift(model) |
709 | 716 | attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline) |
710 | 717 | dl_attr_test_sum = attribution_dl.detach().numpy().sum(0) |
711 | | - dl_attr_test_norm_sum = dl_attr_test_sum / np.linalg.norm(dl_attr_test_sum, ord=1) |
712 | | - attributions_dict["DeepLift"] = dl_attr_test_norm_sum |
| 718 | + row_sum_dl = np.sum(dl_attr_test_sum, axis=0) |
| 719 | + if row_sum_dl == 0: |
| 720 | + row_sum_dl += 1e-10 |
| 721 | + scaled_attribution_dl = dl_attr_test_sum / row_sum_dl |
| 722 | + attributions_dict["DeepLift"] = scaled_attribution_dl |
713 | 723 |
|
714 | 724 | attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]] |
715 | | - attributions = np.stack(attributions_list, axis=1) |
| 725 | + attributions = np.stack(attributions_list, axis=0) |
716 | 726 |
|
717 | | - if fun_control["xai_metric"] not in {"max_diff", "variance", "spearman", "spearman+variance"}: |
718 | | - print("Invalid or missing xai_metric. Setting it to 'max_diff'.") |
719 | | - fun_control["xai_metric"] = "max_diff" |
| 727 | + result_xai = calculate_xai_consistency(attributions) |
720 | 728 |
|
721 | | - if fun_control["xai_metric"] == "max_diff": |
722 | | - # Compute the max difference of the attribution methods for each feature |
723 | | - result_xai = np.max(attributions, axis=1) - np.min(attributions, axis=1) |
724 | | - print("Maximum differences of feature attribution methods:", result_xai) |
725 | | - result_xai = result_xai.sum() |
726 | 729 |
|
727 | | - if fun_control["xai_metric"] == "variance": |
728 | | - result_xai = np.var(attributions, axis=1) |
729 | | - print("Variance of feature attribution methods:", result_xai) |
730 | | - result_xai = result_xai.sum() |
731 | | - |
732 | | - if fun_control["xai_metric"] == "spearman": |
733 | | - num_methods = attributions.shape[1] |
734 | | - spearman_matrix = np.zeros((num_methods, num_methods)) # Store correlation values |
735 | | - |
736 | | - for i in range(num_methods): |
737 | | - for j in range(i + 1, num_methods): # Only compute upper triangle |
738 | | - corr, _ = spearmanr(attributions[:, i], attributions[:, j]) # Compute Spearman correlation |
739 | | - spearman_matrix[i, j] = corr |
740 | | - spearman_matrix[j, i] = corr # Mirror value in symmetric matrix |
741 | | - |
742 | | - # Extract upper triangular values (excluding diagonal) |
743 | | - upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)] |
744 | | - |
745 | | - # Compute mean correlation as the consistency score |
746 | | - # Negative sign to use the result as loss of the objective function for minimization |
747 | | - result_xai = -np.mean(upper_triangle_values) |
748 | | - |
749 | | - print("Spearman rank correlation matrix:\n", spearman_matrix) |
750 | | - print("Consistency Score (Mean Spearman Correlation):", -result_xai) |
751 | | - |
752 | | - if fun_control["xai_metric"] == "spearman+variance": |
753 | | - # Compute Spearman mean |
754 | | - num_methods = attributions.shape[1] |
755 | | - spearman_matrix = np.zeros((num_methods, num_methods)) |
756 | | - |
757 | | - for i in range(num_methods): |
758 | | - for j in range(i + 1, num_methods): |
759 | | - corr, _ = spearmanr(attributions[:, i], attributions[:, j]) |
760 | | - spearman_matrix[i, j] = corr |
761 | | - spearman_matrix[j, i] = corr |
762 | | - |
763 | | - upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)] |
764 | | - mean_spearman = np.mean(upper_triangle_values) |
765 | | - |
766 | | - # Compute attribution variance across methods for each feature |
767 | | - variance = np.var(attributions, axis=1).mean() # mean over features |
768 | | - |
769 | | - # Combine both (λ is a trade-off hyperparameter you define) |
770 | | - lambda_variance = fun_control["lambda_variance"] if "lambda_variance" in fun_control else 1.0 |
771 | | - result_xai = -mean_spearman + lambda_variance * variance |
772 | | - |
773 | | - print("Mean Spearman correlation:", mean_spearman) |
774 | | - print("Mean Variance:", variance) |
775 | | - print("Variance Weight:", lambda_variance) |
776 | | - print("Combined XAI loss: ", result_xai) |
777 | 730 |
|
778 | 731 | # ------------------------------------------------------------------------------------------------------------------- |
779 | 732 |
|
780 | 733 | return result["val_loss"], result_xai |
| 734 | + |
| 735 | + |
0 commit comments