@@ -715,12 +715,13 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
715715 model .eval ()
716716
717717 target = fun_control .get ("xai_target" , None )
718+ xai_mode = fun_control .get ("xai_mode" , None )
719+ metric = fun_control .get ("xai_metric" , "corr" )
718720
719721 if "KernelShap" in fun_control ["xai_methods" ]:
720722 attr_ks = KernelShap (model )
721723 n_features = X_val_tensor .shape [1 ]
722724 samples_ks = min (2000 , 100 * n_features )
723- print ("KernelShap: Using" , samples_ks , "samples for attribution." )
724725 with torch .no_grad ():
725726 attribution_ks = attr_ks .attribute (
726727 X_val_tensor ,
@@ -730,16 +731,27 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
730731 target = target ,
731732 show_progress = False ,
732733 )
733- ks_sum = attribution_ks .detach ().cpu ().numpy ().sum (axis = 0 )
734- # Remove NaN and Inf values
735- ks_sum = np .nan_to_num (ks_sum , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
736- # Safe normalization (no division by ~0)
737- l2_norm = np .linalg .norm (ks_sum )
738- if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
739- l2_normalized_ks = np .zeros_like (ks_sum )
740- else :
741- l2_normalized_ks = ks_sum / l2_norm
742- attributions_dict ["KernelShap" ] = l2_normalized_ks
734+ ks_tensor = attribution_ks .detach ().cpu ().numpy () # Shape: N_samples x N_features
735+ ks_tensor = np .nan_to_num (ks_tensor , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
736+
737+ if xai_mode == "local" :
738+ print ("local consensus calculation" )
739+ # normalize each sample individually
740+ norms = np .linalg .norm (ks_tensor , axis = 1 , keepdims = True )
741+ norms [norms < 1e-12 ] = 1.0 # avoid division by zero
742+ ks_normalized = ks_tensor / norms
743+ # attributions_dict now stores per-sample normalized attributions
744+ attributions_dict ["KernelShap" ] = ks_normalized
745+
746+ else : # global attributions
747+ print ("global consensus calculation" )
748+ ks_sum = ks_tensor .sum (axis = 0 )
749+ l2_norm = np .linalg .norm (ks_sum )
750+ if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
751+ ks_normalized = np .zeros_like (ks_sum )
752+ else :
753+ ks_normalized = ks_sum / l2_norm
754+ attributions_dict ["KernelShap" ] = ks_normalized
743755
744756 if "IntegratedGradients" in fun_control ["xai_methods" ]:
745757 attr_ig = IntegratedGradients (model )
@@ -750,13 +762,26 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
750762 baselines = baseline ,
751763 target = target ,
752764 )
753- ig_sum = attribution_ig .detach ().cpu ().numpy ().sum (axis = 0 )
754- ig_sum = np .nan_to_num (ig_sum , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
755- l2 = np .linalg .norm (ig_sum )
756- if not np .isfinite (l2 ) or l2 < 1e-12 :
757- attributions_dict ["IntegratedGradients" ] = np .zeros_like (ig_sum )
758- else :
759- attributions_dict ["IntegratedGradients" ] = ig_sum / l2
765+ ig_tensor = attribution_ig .detach ().cpu ().numpy () # Shape: N_samples x N_features
766+ ig_tensor = np .nan_to_num (ig_tensor , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
767+
768+ if xai_mode == "local" :
769+ # normalize each sample individually
770+ norms = np .linalg .norm (ig_tensor , axis = 1 , keepdims = True )
771+ norms [norms < 1e-12 ] = 1.0 # avoid division by zero
772+ ig_normalized = ig_tensor / norms
773+ # attributions_dict now stores per-sample normalized attributions
774+ attributions_dict ["IntegratedGradients" ] = ig_normalized
775+
776+ else : # global attributions
777+ # alte Variante: Summe über Samples
778+ ig_sum = ig_tensor .sum (axis = 0 )
779+ l2_norm = np .linalg .norm (ig_sum )
780+ if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
781+ ig_normalized = np .zeros_like (ig_sum )
782+ else :
783+ ig_normalized = ig_sum / l2_norm
784+ attributions_dict ["IntegratedGradients" ] = ig_normalized
760785
761786 if "DeepLift" in fun_control ["xai_methods" ]:
762787 attr_dl = DeepLift (model )
@@ -766,28 +791,59 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
766791 baselines = baseline ,
767792 target = target ,
768793 )
769- dl_sum = attribution_dl .detach ().cpu ().numpy ().sum (axis = 0 )
770- dl_sum = np .nan_to_num (dl_sum , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
771- l2_norm = np .linalg .norm (dl_sum )
772- if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
773- l2_normalized_dl = np .zeros_like (dl_sum )
774- else :
775- l2_normalized_dl = dl_sum / l2_norm
776- attributions_dict ["DeepLift" ] = l2_normalized_dl
777-
778- attributions_list = [attributions_dict [method ] for method in fun_control ["xai_methods" ]]
779- attributions = np .stack (attributions_list , axis = 0 )
780-
781- # Calculate corr:
782- if fun_control ["xai_metric" ] not in ["corr" , "cosine" , "euclidean" , "spearman" ]:
783- raise ValueError (f"Invalid xai_metric: { fun_control ['xai_metric' ]} . Valid metrics are: 'corr', 'cosine', 'euclidean', 'spearman'" )
784- if fun_control ["xai_metric" ] == "corr" :
785- result_xai = calculate_xai_consistency_corr (attributions )
786- elif fun_control ["xai_metric" ] == "cosine" :
787- result_xai = calculate_xai_consistency_cosine (attributions )
788- elif fun_control ["xai_metric" ] == "euclidean" :
789- result_xai = calculate_xai_consistency_euclidean (attributions )
790- elif fun_control ["xai_metric" ] == "spearman" :
791- result_xai = calculate_xai_consistency_spearman (attributions )
792-
793- return result ["val_loss" ], result_xai
794+ dl_tensor = attribution_dl .detach ().cpu ().numpy () # Shape: N_samples x N_features
795+ dl_tensor = np .nan_to_num (dl_tensor , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
796+
797+ if xai_mode == "local" :
798+ # normalize each sample individually
799+ norms = np .linalg .norm (dl_tensor , axis = 1 , keepdims = True )
800+ norms [norms < 1e-12 ] = 1.0 # avoid division by zero
801+ dl_normalized = dl_tensor / norms
802+ # attributions_dict now stores per-sample normalized attributions
803+ attributions_dict ["DeepLift" ] = dl_normalized
804+
805+ else : # global attributions
806+ # alte Variante: Summe über Samples
807+ dl_sum = dl_tensor .sum (axis = 0 )
808+ l2_norm = np .linalg .norm (dl_sum )
809+ if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
810+ dl_normalized = np .zeros_like (dl_sum )
811+ else :
812+ dl_normalized = dl_sum / l2_norm
813+ attributions_dict ["DeepLift" ] = dl_normalized
814+
815+
816+ if xai_mode == "local" :
817+ # Konsens pro Sample
818+ N_samples = attributions_dict [fun_control ["xai_methods" ][0 ]].shape [0 ]
819+ per_sample_consensus = []
820+ for i in range (N_samples ):
821+ sample_attrs = np .stack ([attributions_dict [m ][i ] for m in fun_control ["xai_methods" ]], axis = 0 )
822+ # sample_attrs: n_methods x n_features
823+ if metric == "corr" :
824+ sample_consensus = calculate_xai_consistency_corr (sample_attrs )
825+ elif metric == "cosine" :
826+ sample_consensus = calculate_xai_consistency_cosine (sample_attrs )
827+ elif metric == "euclidean" :
828+ sample_consensus = calculate_xai_consistency_euclidean (sample_attrs )
829+ elif metric == "spearman" :
830+ sample_consensus = calculate_xai_consistency_spearman (sample_attrs )
831+ per_sample_consensus .append (sample_consensus )
832+ result_xai = np .mean (per_sample_consensus )
833+ print ("aggregated local consensus:" , result_xai )
834+
835+ else :
836+ # global modus
837+ attributions_list = [attributions_dict [m ] for m in fun_control ["xai_methods" ]]
838+ attributions = np .stack (attributions_list , axis = 0 ) # n_methods x n_features
839+ if metric == "corr" :
840+ result_xai = calculate_xai_consistency_corr (attributions )
841+ elif metric == "cosine" :
842+ result_xai = calculate_xai_consistency_cosine (attributions )
843+ elif metric == "euclidean" :
844+ result_xai = calculate_xai_consistency_euclidean (attributions )
845+ elif metric == "spearman" :
846+ result_xai = calculate_xai_consistency_spearman (attributions )
847+ print ("global consensus:" , result_xai )
848+
849+ return result ["val_loss" ], result_xai
0 commit comments