Skip to content

Commit b44525f

Browse files
committed
update xai parameter
1 parent 75d87a9 commit b44525f

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/spotpython/light/trainmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
696696
with torch.enable_grad():
697697
if "IntegratedGradients" in fun_control["xai_methods"]:
698698
attr_ig = IntegratedGradients(model)
699-
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline, n_steps=100, internal_batch_size=64)
699+
attribution_ig = attr_ig.attribute(X_val_tensor, baselines=baseline)
700700
vec = attribution_ig.detach().cpu().numpy().sum(axis=0)
701701
l2 = np.linalg.norm(vec)
702702
attributions_dict["IntegratedGradients"] = vec / l2 if l2 != 0 else vec

0 commit comments

Comments
 (0)