|
1 | 1 | import lightning as L |
2 | 2 | from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany |
3 | 3 | from spotpython.utils.eda import generate_config_id |
4 | | -from spotpython.utils.metrics import calculate_xai_consistency_corr, calculate_xai_consistency_cosine, calculate_xai_consistency_euclidean |
| 4 | +from spotpython.utils.metrics import calculate_xai_consistency_corr, calculate_xai_consistency_cosine, calculate_xai_consistency_euclidean, calculate_xai_consistency_spearman |
5 | 5 | from pytorch_lightning.loggers import TensorBoardLogger |
6 | 6 | from lightning.pytorch.callbacks.early_stopping import EarlyStopping |
7 | 7 | from lightning.pytorch.callbacks import ModelCheckpoint |
@@ -723,13 +723,15 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) -> |
723 | 723 | attributions = np.stack(attributions_list, axis=0) |
724 | 724 |
|
725 | 725 | # Calculate corr: |
726 | | - if fun_control["xai_metric"] not in ["corr", "cosine", "euclidean"]: |
727 | | - raise ValueError(f"Invalid xai_metric: {fun_control['xai_metric']}. Valid metrics are: 'corr', 'cosine', 'euclidean'") |
| 726 | + if fun_control["xai_metric"] not in ["corr", "cosine", "euclidean", "spearman"]: |
| 727 | + raise ValueError(f"Invalid xai_metric: {fun_control['xai_metric']}. Valid metrics are: 'corr', 'cosine', 'euclidean', 'spearman'") |
728 | 728 | if fun_control["xai_metric"] == "corr": |
729 | 729 | result_xai = calculate_xai_consistency_corr(attributions) |
730 | 730 | elif fun_control["xai_metric"] == "cosine": |
731 | 731 | result_xai = calculate_xai_consistency_cosine(attributions) |
732 | 732 | elif fun_control["xai_metric"] == "euclidean": |
733 | 733 | result_xai = calculate_xai_consistency_euclidean(attributions) |
| 734 | + elif fun_control["xai_metric"] == "spearman": |
| 735 | + result_xai = calculate_xai_consistency_spearman(attributions) |
734 | 736 |
|
735 | 737 | return result["val_loss"], result_xai |
0 commit comments