@@ -654,13 +654,24 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
654654 model = trainer .model
655655 print ("MODEL :" , model )
656656
657- # Get the validation dataloader from the LightningDataModule
657+ # Get the validation and train dataloader from the LightningDataModule
658+ train_dataloader : DataLoader = dm .train_dataloader () # Fetch train data loader
658659 val_dataloader : DataLoader = dm .val_dataloader () # Fetch validation data loader
659660
660661 # Collect all validation data
661662 X_val_list = []
662663 y_val_list = []
663664
665+ # Collect all train data
666+ X_train_list = []
667+ y_train_list = []
668+
669+ # Iterate over the train dataloader to gather all data
670+ for batch in train_dataloader :
671+ X_batch , y_batch = batch # Extract inputs and labels
672+ X_train_list .append (X_batch )
673+ y_train_list .append (y_batch )
674+
664675 # Iterate over the validation dataloader to gather all data
665676 for batch in val_dataloader :
666677 X_batch , y_batch = batch # Extract inputs and labels
@@ -679,38 +690,54 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
679690 if method not in valid_xai_methods :
680691 raise ValueError (f"Invalid XAI method: { method } . Valid methods are: { valid_xai_methods } " )
681692
682- # Ensure the model is in evaluation mode
683- model . eval ()
693+ X_train_tensor = torch . cat ( X_train_list , dim = 0 ). to ( model . device )
694+ X_train_tensor . requires_grad_ ()
684695 X_val_tensor = torch .cat (X_val_list , dim = 0 ).to (model .device )
685696 X_val_tensor .requires_grad_ ()
686697
687698 # Dictionary to store attributions
688699 attributions_dict = {}
689700
690701 if fun_control ["xai_baseline" ] is None :
691- X_train_mean = X_val_tensor .mean (dim = 0 )
702+ X_train_mean = X_train_tensor .mean (dim = 0 )
692703 fun_control ["xai_baseline" ] = X_train_mean .unsqueeze (0 )
693704 print ("Baseline is None. Using training mean as baseline." )
694705 baseline = fun_control ["xai_baseline" ]
695706
696- with torch .enable_grad ():
697- if "IntegratedGradients" in fun_control ["xai_methods" ]:
698- attr_ig = IntegratedGradients (model )
699- attribution_ig = attr_ig .attribute (X_val_tensor , baselines = baseline )
700- vec = attribution_ig .detach ().cpu ().numpy ().sum (axis = 0 )
701- l2 = np .linalg .norm (vec )
702- attributions_dict ["IntegratedGradients" ] = vec / l2 if l2 != 0 else vec
707+ # Use a subset of the validation data for attribution
708+ if fun_control ["xai_subset_size" ] is not None :
709+ N_total = X_val_tensor .size (0 )
710+ N_attr = min (fun_control ["xai_subset_size" ], N_total )
711+ print (f"Using a subset of { N_attr } samples for attribution analysis out of { N_total } total samples." )
712+ g = torch .Generator (device = X_val_tensor .device )
713+ g .manual_seed (fun_control ["seed" ])
714+ perm = torch .randperm (N_total , generator = g ,
715+ device = X_val_tensor .device )[:N_attr ]
716+ X_val_tensor = X_val_tensor [perm ]
703717
704- if "KernelShap" in fun_control ["xai_methods" ]:
718+ # Ensure the model is in evaluation mode
719+ model .eval ()
720+
721+ if "KernelShap" in fun_control ["xai_methods" ]:
705722 attr_ks = KernelShap (model )
706- samples_ks = 100 * X_val_tensor .shape [1 ]
723+ n_features = X_val_tensor .shape [1 ]
724+ samples_ks = min (2000 , 100 * n_features ) # Adjust number of samples based on features, maximum 2000
707725 print ("KernelShap: Using" , samples_ks , "samples for attribution." )
708- attribution_ks = attr_ks .attribute (X_val_tensor , baselines = baseline , n_samples = samples_ks , perturbations_per_eval = 64 )
726+ with torch .no_grad ():
727+ attribution_ks = attr_ks .attribute (X_val_tensor , baselines = baseline , n_samples = samples_ks , perturbations_per_eval = 64 )
709728 ks_attr_test_sum = attribution_ks .detach ().numpy ().sum (axis = 0 )
710729 l2_norm = np .linalg .norm (ks_attr_test_sum )
711730 l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
712731 attributions_dict ["KernelShap" ] = l2_normalized_ks
713732
733+ with torch .enable_grad ():
734+ if "IntegratedGradients" in fun_control ["xai_methods" ]:
735+ attr_ig = IntegratedGradients (model )
736+ attribution_ig = attr_ig .attribute (X_val_tensor , baselines = baseline )
737+ vec = attribution_ig .detach ().cpu ().numpy ().sum (axis = 0 )
738+ l2 = np .linalg .norm (vec )
739+ attributions_dict ["IntegratedGradients" ] = vec / l2 if l2 != 0 else vec
740+
714741 if "DeepLift" in fun_control ["xai_methods" ]:
715742 attr_dl = DeepLift (model )
716743 attribution_dl = attr_dl .attribute (X_val_tensor , baselines = baseline )
0 commit comments