Skip to content

Commit 1d9d3b6

Browse files
committed
update xai parameters
1 parent b44525f commit 1d9d3b6

1 file changed

Lines changed: 41 additions & 14 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -654,13 +654,24 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
654654
model = trainer.model
655655
print("MODEL :", model)
656656

657-
# Get the validation dataloader from the LightningDataModule
657+
# Get the validation and train dataloader from the LightningDataModule
658+
train_dataloader: DataLoader = dm.train_dataloader() # Fetch train data loader
658659
val_dataloader: DataLoader = dm.val_dataloader() # Fetch validation data loader
659660

660661
# Collect all validation data
661662
X_val_list = []
662663
y_val_list = []
663664

665+
# Collect all train data
666+
X_train_list = []
667+
y_train_list = []
668+
669+
# Iterate over the train dataloader to gather all data
670+
for batch in train_dataloader:
671+
X_batch, y_batch = batch # Extract inputs and labels
672+
X_train_list.append(X_batch)
673+
y_train_list.append(y_batch)
674+
664675
# Iterate over the validation dataloader to gather all data
665676
for batch in val_dataloader:
666677
X_batch, y_batch = batch # Extract inputs and labels
@@ -679,38 +690,54 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
679690
if method not in valid_xai_methods:
680691
raise ValueError(f"Invalid XAI method: {method}. Valid methods are: {valid_xai_methods}")
681692

682-
# Ensure the model is in evaluation mode
683-
model.eval()
693+
X_train_tensor = torch.cat(X_train_list, dim=0).to(model.device)
694+
X_train_tensor.requires_grad_()
684695
X_val_tensor = torch.cat(X_val_list, dim=0).to(model.device)
685696
X_val_tensor.requires_grad_()
686697

687698
# Dictionary to store attributions
688699
attributions_dict = {}
689700

690701
if fun_control["xai_baseline"] is None:
691-
X_train_mean = X_val_tensor.mean(dim=0)
702+
X_train_mean = X_train_tensor.mean(dim=0)
692703
fun_control["xai_baseline"] = X_train_mean.unsqueeze(0)
693704
print("Baseline is None. Using training mean as baseline.")
694705
baseline = fun_control["xai_baseline"]
695706

696-
with torch.enable_grad():
697-
if "IntegratedGradients" in fun_control["xai_methods"]:
698-
attr_ig = IntegratedGradients(model)
699-
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
700-
vec = attribution_ig.detach().cpu().numpy().sum(axis=0)
701-
l2 = np.linalg.norm(vec)
702-
attributions_dict["IntegratedGradients"] = vec / l2 if l2 != 0 else vec
707+
# Use a subset of the validation data for attribution
708+
if fun_control["xai_subset_size"] is not None:
709+
N_total = X_val_tensor.size(0)
710+
N_attr = min(fun_control["xai_subset_size"], N_total)
711+
print(f"Using a subset of {N_attr} samples for attribution analysis out of {N_total} total samples.")
712+
g = torch.Generator(device=X_val_tensor.device)
713+
g.manual_seed(fun_control["seed"])
714+
perm = torch.randperm(N_total, generator=g,
715+
device=X_val_tensor.device)[:N_attr]
716+
X_val_tensor = X_val_tensor[perm]
703717

704-
if "KernelShap" in fun_control["xai_methods"]:
718+
# Ensure the model is in evaluation mode
719+
model.eval()
720+
721+
if "KernelShap" in fun_control["xai_methods"]:
705722
attr_ks = KernelShap(model)
706-
samples_ks = 100 * X_val_tensor.shape[1]
723+
n_features = X_val_tensor.shape[1]
724+
samples_ks = min(2000, 100 * n_features) # Adjust number of samples based on features, maximum 2000
707725
print("KernelShap: Using", samples_ks, "samples for attribution.")
708-
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline, n_samples=samples_ks, perturbations_per_eval=64)
726+
with torch.no_grad():
727+
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline, n_samples=samples_ks, perturbations_per_eval=64)
709728
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
710729
l2_norm = np.linalg.norm(ks_attr_test_sum)
711730
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
712731
attributions_dict["KernelShap"] = l2_normalized_ks
713732

733+
with torch.enable_grad():
734+
if "IntegratedGradients" in fun_control["xai_methods"]:
735+
attr_ig = IntegratedGradients(model)
736+
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
737+
vec = attribution_ig.detach().cpu().numpy().sum(axis=0)
738+
l2 = np.linalg.norm(vec)
739+
attributions_dict["IntegratedGradients"] = vec / l2 if l2 != 0 else vec
740+
714741
if "DeepLift" in fun_control["xai_methods"]:
715742
attr_dl = DeepLift(model)
716743
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline)

0 commit comments

Comments
 (0)