Skip to content

Commit a2b759e

Browse files
committed
parameter update for feature attribution methods
1 parent 78dc7f7 commit a2b759e

1 file changed

Lines changed: 30 additions & 27 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,6 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
653653
# Perform feature attribution analysis
654654
model = trainer.model
655655
print("MODEL :", model)
656-
model.eval()
657656

658657
# Get the validation dataloader from the LightningDataModule
659658
val_dataloader: DataLoader = dm.val_dataloader() # Fetch validation data loader
@@ -668,9 +667,6 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
668667
X_val_list.append(X_batch)
669668
y_val_list.append(y_batch)
670669

671-
# Concatenate all batches into single tensors
672-
X_val_tensor = torch.cat(X_val_list, dim=0).to(model.device)
673-
674670
# Perform feature attribution analysis
675671

676672
# Check if at least 2 elements are in list fun_control["xai_methods"]
@@ -683,6 +679,11 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
683679
if method not in valid_xai_methods:
684680
raise ValueError(f"Invalid XAI method: {method}. Valid methods are: {valid_xai_methods}")
685681

682+
# Ensure the model is in evaluation mode
683+
model.eval()
684+
X_val_tensor = torch.cat(X_val_list, dim=0).to(model.device)
685+
X_val_tensor.requires_grad_()
686+
686687
# Dictionary to store attributions
687688
attributions_dict = {}
688689

@@ -692,29 +693,31 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
692693
print("Baseline is None. Using training mean as baseline.")
693694
baseline = fun_control["xai_baseline"]
694695

695-
if "IntegratedGradients" in fun_control["xai_methods"]:
696-
attr_ig = IntegratedGradients(model)
697-
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
698-
ig_attr_test_sum = attribution_ig.detach().numpy().sum(axis=0)
699-
l2_norm = np.linalg.norm(ig_attr_test_sum)
700-
l2_normalized_ig = ig_attr_test_sum / l2_norm if l2_norm != 0 else ig_attr_test_sum
701-
attributions_dict["IntegratedGradients"] = l2_normalized_ig
702-
703-
if "KernelShap" in fun_control["xai_methods"]:
704-
attr_ks = KernelShap(model)
705-
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline)
706-
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
707-
l2_norm = np.linalg.norm(ks_attr_test_sum)
708-
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
709-
attributions_dict["KernelShap"] = l2_normalized_ks
710-
711-
if "DeepLift" in fun_control["xai_methods"]:
712-
attr_dl = DeepLift(model)
713-
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline)
714-
dl_attr_test_sum = attribution_dl.detach().numpy().sum(axis=0)
715-
l2_norm = np.linalg.norm(dl_attr_test_sum)
716-
l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
717-
attributions_dict["DeepLift"] = l2_normalized_dl
696+
with torch.enable_grad():
697+
if "IntegratedGradients" in fun_control["xai_methods"]:
698+
attr_ig = IntegratedGradients(model)
699+
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline, n_steps=100, internal_batch_size=64)
700+
vec = attribution_ig.detach().cpu().numpy().sum(axis=0)
701+
l2 = np.linalg.norm(vec)
702+
attributions_dict["IntegratedGradients"] = vec / l2 if l2 != 0 else vec
703+
704+
if "KernelShap" in fun_control["xai_methods"]:
705+
attr_ks = KernelShap(model)
706+
samples_ks = 100 * X_val_tensor.shape[1]
707+
print("KernelShap: Using", samples_ks, "samples for attribution.")
708+
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline, n_samples=samples_ks, perturbations_per_eval=64)
709+
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
710+
l2_norm = np.linalg.norm(ks_attr_test_sum)
711+
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
712+
attributions_dict["KernelShap"] = l2_normalized_ks
713+
714+
if "DeepLift" in fun_control["xai_methods"]:
715+
attr_dl = DeepLift(model)
716+
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline)
717+
dl_attr_test_sum = attribution_dl.detach().numpy().sum(axis=0)
718+
l2_norm = np.linalg.norm(dl_attr_test_sum)
719+
l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
720+
attributions_dict["DeepLift"] = l2_normalized_dl
718721

719722
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
720723
attributions = np.stack(attributions_list, axis=0)

0 commit comments

Comments
 (0)