@@ -653,7 +653,6 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
653653 # Perform feature attribution analysis
654654 model = trainer .model
655655 print ("MODEL :" , model )
656- model .eval ()
657656
658657 # Get the validation dataloader from the LightningDataModule
659658 val_dataloader : DataLoader = dm .val_dataloader () # Fetch validation data loader
@@ -668,9 +667,6 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
668667 X_val_list .append (X_batch )
669668 y_val_list .append (y_batch )
670669
671- # Concatenate all batches into single tensors
672- X_val_tensor = torch .cat (X_val_list , dim = 0 ).to (model .device )
673-
674670 # Perform feature attribution analysis
675671
676672 # Check if at least 2 elements are in list fun_control["xai_methods"]
@@ -683,6 +679,11 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
683679 if method not in valid_xai_methods :
684680 raise ValueError (f"Invalid XAI method: { method } . Valid methods are: { valid_xai_methods } " )
685681
682+ # Ensure the model is in evaluation mode
683+ model .eval ()
684+ X_val_tensor = torch .cat (X_val_list , dim = 0 ).to (model .device )
685+ X_val_tensor .requires_grad_ ()
686+
686687 # Dictionary to store attributions
687688 attributions_dict = {}
688689
@@ -692,29 +693,31 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
692693 print ("Baseline is None. Using training mean as baseline." )
693694 baseline = fun_control ["xai_baseline" ]
694695
695- if "IntegratedGradients" in fun_control ["xai_methods" ]:
696- attr_ig = IntegratedGradients (model )
697- attribution_ig = attr_ig .attribute (X_val_tensor , baselines = baseline )
698- ig_attr_test_sum = attribution_ig .detach ().numpy ().sum (axis = 0 )
699- l2_norm = np .linalg .norm (ig_attr_test_sum )
700- l2_normalized_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
701- attributions_dict ["IntegratedGradients" ] = l2_normalized_ig
702-
703- if "KernelShap" in fun_control ["xai_methods" ]:
704- attr_ks = KernelShap (model )
705- attribution_ks = attr_ks .attribute (X_val_tensor , baselines = baseline )
706- ks_attr_test_sum = attribution_ks .detach ().numpy ().sum (axis = 0 )
707- l2_norm = np .linalg .norm (ks_attr_test_sum )
708- l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
709- attributions_dict ["KernelShap" ] = l2_normalized_ks
710-
711- if "DeepLift" in fun_control ["xai_methods" ]:
712- attr_dl = DeepLift (model )
713- attribution_dl = attr_dl .attribute (X_val_tensor , baselines = baseline )
714- dl_attr_test_sum = attribution_dl .detach ().numpy ().sum (axis = 0 )
715- l2_norm = np .linalg .norm (dl_attr_test_sum )
716- l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
717- attributions_dict ["DeepLift" ] = l2_normalized_dl
696+ with torch .enable_grad ():
697+ if "IntegratedGradients" in fun_control ["xai_methods" ]:
698+ attr_ig = IntegratedGradients (model )
699+ attribution_ig = attr_ig .attribute (X_val_tensor , baselines = baseline , n_steps = 100 , internal_batch_size = 64 )
700+ vec = attribution_ig .detach ().cpu ().numpy ().sum (axis = 0 )
701+ l2 = np .linalg .norm (vec )
702+ attributions_dict ["IntegratedGradients" ] = vec / l2 if l2 != 0 else vec
703+
704+ if "KernelShap" in fun_control ["xai_methods" ]:
705+ attr_ks = KernelShap (model )
706+ samples_ks = 100 * X_val_tensor .shape [1 ]
707+ print ("KernelShap: Using" , samples_ks , "samples for attribution." )
708+ attribution_ks = attr_ks .attribute (X_val_tensor , baselines = baseline , n_samples = samples_ks , perturbations_per_eval = 64 )
709+ ks_attr_test_sum = attribution_ks .detach ().numpy ().sum (axis = 0 )
710+ l2_norm = np .linalg .norm (ks_attr_test_sum )
711+ l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
712+ attributions_dict ["KernelShap" ] = l2_normalized_ks
713+
714+ if "DeepLift" in fun_control ["xai_methods" ]:
715+ attr_dl = DeepLift (model )
716+ attribution_dl = attr_dl .attribute (X_val_tensor , baselines = baseline )
717+ dl_attr_test_sum = attribution_dl .detach ().numpy ().sum (axis = 0 )
718+ l2_norm = np .linalg .norm (dl_attr_test_sum )
719+ l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
720+ attributions_dict ["DeepLift" ] = l2_normalized_dl
718721
719722 attributions_list = [attributions_dict [method ] for method in fun_control ["xai_methods" ]]
720723 attributions = np .stack (attributions_list , axis = 0 )
0 commit comments