@@ -718,34 +718,65 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
718718 model .eval ()
719719
720720 target = fun_control .get ("xai_target" , None )
721-
721+
722722 if "KernelShap" in fun_control ["xai_methods" ]:
723723 attr_ks = KernelShap (model )
724724 n_features = X_val_tensor .shape [1 ]
725- samples_ks = min (2000 , 100 * n_features ) # Adjust number of samples based on features, maximum 2000
725+ samples_ks = min (2000 , 100 * n_features )
726726 print ("KernelShap: Using" , samples_ks , "samples for attribution." )
727727 with torch .no_grad ():
728- attribution_ks = attr_ks .attribute (X_val_tensor , baselines = baseline , n_samples = samples_ks , perturbations_per_eval = 64 , target = target )
729- ks_attr_test_sum = attribution_ks .detach ().numpy ().sum (axis = 0 )
730- l2_norm = np .linalg .norm (ks_attr_test_sum )
731- l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
728+ attribution_ks = attr_ks .attribute (
729+ X_val_tensor ,
730+ baselines = baseline ,
731+ n_samples = samples_ks ,
732+ perturbations_per_eval = 64 , ,
733+ target = target ,
734+ show_progress = False ,
735+ )
736+ ks_sum = attribution_ks .detach ().cpu ().numpy ().sum (axis = 0 )
737+ # Remove NaN and Inf values
738+ ks_sum = np .nan_to_num (ks_sum , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
739+ # Safe normalization (no division by ~0)
740+ l2_norm = np .linalg .norm (ks_sum )
741+ if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
742+ l2_normalized_ks = np .zeros_like (ks_sum )
743+ else :
744+ l2_normalized_ks = ks_sum / l2_norm
732745 attributions_dict ["KernelShap" ] = l2_normalized_ks
733746
734- with torch .enable_grad ():
735- if "IntegratedGradients" in fun_control ["xai_methods" ]:
736- attr_ig = IntegratedGradients (model )
737- attribution_ig = attr_ig .attribute (X_val_tensor , baselines = baseline , target = target )
738- vec = attribution_ig .detach ().cpu ().numpy ().sum (axis = 0 )
739- l2 = np .linalg .norm (vec )
740- attributions_dict ["IntegratedGradients" ] = vec / l2 if l2 != 0 else vec
741-
742- if "DeepLift" in fun_control ["xai_methods" ]:
743- attr_dl = DeepLift (model )
744- attribution_dl = attr_dl .attribute (X_val_tensor , baselines = baseline , target = target )
745- dl_attr_test_sum = attribution_dl .detach ().numpy ().sum (axis = 0 )
746- l2_norm = np .linalg .norm (dl_attr_test_sum )
747- l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
748- attributions_dict ["DeepLift" ] = l2_normalized_dl
747+ if "IntegratedGradients" in fun_control ["xai_methods" ]:
748+ attr_ig = IntegratedGradients (model )
749+ # IG braucht Gradienten ⇒ enable_grad
750+ with torch .enable_grad ():
751+ attribution_ig = attr_ig .attribute (
752+ X_val_tensor ,
753+ baselines = baseline ,
754+ target = target ,
755+ )
756+ ig_sum = attribution_ig .detach ().cpu ().numpy ().sum (axis = 0 )
757+ ig_sum = np .nan_to_num (ig_sum , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
758+ l2 = np .linalg .norm (ig_sum )
759+ if not np .isfinite (l2 ) or l2 < 1e-12 :
760+ attributions_dict ["IntegratedGradients" ] = np .zeros_like (ig_sum )
761+ else :
762+ attributions_dict ["IntegratedGradients" ] = ig_sum / l2
763+
764+ if "DeepLift" in fun_control ["xai_methods" ]:
765+ attr_dl = DeepLift (model )
766+ with torch .enable_grad (): # DeepLIFT via Backprop
767+ attribution_dl = attr_dl .attribute (
768+ X_val_tensor ,
769+ baselines = baseline ,
770+ target = target ,
771+ )
772+ dl_sum = attribution_dl .detach ().cpu ().numpy ().sum (axis = 0 )
773+ dl_sum = np .nan_to_num (dl_sum , nan = 0.0 , posinf = 0.0 , neginf = 0.0 )
774+ l2_norm = np .linalg .norm (dl_sum )
775+ if not np .isfinite (l2_norm ) or l2_norm < 1e-12 :
776+ l2_normalized_dl = np .zeros_like (dl_sum )
777+ else :
778+ l2_normalized_dl = dl_sum / l2_norm
779+ attributions_dict ["DeepLift" ] = l2_normalized_dl
749780
750781 attributions_list = [attributions_dict [method ] for method in fun_control ["xai_methods" ]]
751782 attributions = np .stack (attributions_list , axis = 0 )
0 commit comments