Skip to content

Commit c3e7df0

Browse files
committed
improve numerical stability of the attribution normalization process
1 parent 72d8d35 commit c3e7df0

1 file changed

Lines changed: 52 additions & 21 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -718,34 +718,65 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
718718
model.eval()
719719

720720
target = fun_control.get("xai_target", None)
721-
721+
722722
if "KernelShap" in fun_control["xai_methods"]:
723723
attr_ks = KernelShap(model)
724724
n_features = X_val_tensor.shape[1]
725-
samples_ks = min(2000, 100 * n_features) # Adjust number of samples based on features, maximum 2000
725+
samples_ks = min(2000, 100 * n_features)
726726
print("KernelShap: Using", samples_ks, "samples for attribution.")
727727
with torch.no_grad():
728-
attribution_ks = attr_ks.attribute(X_val_tensor, baselines=baseline, n_samples=samples_ks, perturbations_per_eval=64, target=target)
729-
ks_attr_test_sum = attribution_ks.detach().numpy().sum(axis=0)
730-
l2_norm = np.linalg.norm(ks_attr_test_sum)
731-
l2_normalized_ks = ks_attr_test_sum / l2_norm if l2_norm != 0 else ks_attr_test_sum
728+
attribution_ks = attr_ks.attribute(
729+
X_val_tensor,
730+
baselines=baseline,
731+
n_samples=samples_ks,
732+
perturbations_per_eval=64,,
733+
target=target,
734+
show_progress=False,
735+
)
736+
ks_sum = attribution_ks.detach().cpu().numpy().sum(axis=0)
737+
# Remove NaN and Inf values
738+
ks_sum = np.nan_to_num(ks_sum, nan=0.0, posinf=0.0, neginf=0.0)
739+
# Safe normalization (no division by ~0)
740+
l2_norm = np.linalg.norm(ks_sum)
741+
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
742+
l2_normalized_ks = np.zeros_like(ks_sum)
743+
else:
744+
l2_normalized_ks = ks_sum / l2_norm
732745
attributions_dict["KernelShap"] = l2_normalized_ks
733746

734-
with torch.enable_grad():
735-
if "IntegratedGradients" in fun_control["xai_methods"]:
736-
attr_ig = IntegratedGradients(model)
737-
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline, target=target)
738-
vec = attribution_ig.detach().cpu().numpy().sum(axis=0)
739-
l2 = np.linalg.norm(vec)
740-
attributions_dict["IntegratedGradients"] = vec / l2 if l2 != 0 else vec
741-
742-
if "DeepLift" in fun_control["xai_methods"]:
743-
attr_dl = DeepLift(model)
744-
attribution_dl = attr_dl.attribute(X_val_tensor, baselines=baseline, target=target)
745-
dl_attr_test_sum = attribution_dl.detach().numpy().sum(axis=0)
746-
l2_norm = np.linalg.norm(dl_attr_test_sum)
747-
l2_normalized_dl = dl_attr_test_sum / l2_norm if l2_norm != 0 else dl_attr_test_sum
748-
attributions_dict["DeepLift"] = l2_normalized_dl
747+
if "IntegratedGradients" in fun_control["xai_methods"]:
748+
attr_ig = IntegratedGradients(model)
749+
# IG braucht Gradienten ⇒ enable_grad
750+
with torch.enable_grad():
751+
attribution_ig = attr_ig.attribute(
752+
X_val_tensor,
753+
baselines=baseline,
754+
target=target,
755+
)
756+
ig_sum = attribution_ig.detach().cpu().numpy().sum(axis=0)
757+
ig_sum = np.nan_to_num(ig_sum, nan=0.0, posinf=0.0, neginf=0.0)
758+
l2 = np.linalg.norm(ig_sum)
759+
if not np.isfinite(l2) or l2 < 1e-12:
760+
attributions_dict["IntegratedGradients"] = np.zeros_like(ig_sum)
761+
else:
762+
attributions_dict["IntegratedGradients"] = ig_sum / l2
763+
764+
if "DeepLift" in fun_control["xai_methods"]:
765+
attr_dl = DeepLift(model)
766+
with torch.enable_grad(): # DeepLIFT via Backprop
767+
attribution_dl = attr_dl.attribute(
768+
X_val_tensor,
769+
baselines=baseline,
770+
target=target,
771+
)
772+
dl_sum = attribution_dl.detach().cpu().numpy().sum(axis=0)
773+
dl_sum = np.nan_to_num(dl_sum, nan=0.0, posinf=0.0, neginf=0.0)
774+
l2_norm = np.linalg.norm(dl_sum)
775+
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
776+
l2_normalized_dl = np.zeros_like(dl_sum)
777+
else:
778+
l2_normalized_dl = dl_sum / l2_norm
779+
attributions_dict["DeepLift"] = l2_normalized_dl
749780

750781
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
751782
attributions = np.stack(attributions_list, axis=0)

0 commit comments

Comments
 (0)