@@ -714,7 +714,7 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
714714 attributions_list = [attributions_dict [method ] for method in fun_control ["xai_methods" ]]
715715 attributions = np .stack (attributions_list , axis = 1 )
716716
717- if fun_control ["xai_metric" ] not in {"max_diff" , "variance" , "spearman" }:
717+ if fun_control ["xai_metric" ] not in {"max_diff" , "variance" , "spearman" , "spearman+variance" }:
718718 print ("Invalid or missing xai_metric. Setting it to 'max_diff'." )
719719 fun_control ["xai_metric" ] = "max_diff"
720720
@@ -749,6 +749,32 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
749749 print ("Spearman rank correlation matrix:\n " , spearman_matrix )
750750 print ("Consistency Score (Mean Spearman Correlation):" , - result_xai )
751751
752+ if fun_control ["xai_metric" ] == "spearman+variance" :
753+ # Compute Spearman mean
754+ num_methods = attributions .shape [1 ]
755+ spearman_matrix = np .zeros ((num_methods , num_methods ))
756+
757+ for i in range (num_methods ):
758+ for j in range (i + 1 , num_methods ):
759+ corr , _ = spearmanr (attributions [:, i ], attributions [:, j ])
760+ spearman_matrix [i , j ] = corr
761+ spearman_matrix [j , i ] = corr
762+
763+ upper_triangle_values = spearman_matrix [np .triu_indices (num_methods , k = 1 )]
764+ mean_spearman = np .mean (upper_triangle_values )
765+
766+ # Compute attribution variance across methods for each feature
767+ variance = np .var (attributions , axis = 1 ).mean () # mean over features
768+
769+ # Combine both (λ is a trade-off hyperparameter you define)
770+ lambda_variance = fun_control ["lambda_variance" ] if "lambda_variance" in fun_control else 1.0
771+ result_xai = - mean_spearman + lambda_variance * variance
772+
773+ print ("Mean Spearman correlation:" , mean_spearman )
774+ print ("Mean Variance:" , variance )
775+ print ("Variance Weight:" , lambda_variance )
776+ print ("Combined XAI loss: " , result_xai )
777+
752778 # -------------------------------------------------------------------------------------------------------------------
753779
754780 return result ["val_loss" ], result_xai
0 commit comments