Skip to content

Commit 78dc7f7

Browse files
committed
set up default xai baseline
1 parent 6fb2b1d commit 78dc7f7

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,9 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
687687
attributions_dict = {}
688688

689689
if fun_control["xai_baseline"] is None:
690-
fun_control["xai_baseline"] = torch.zeros_like(X_val_tensor)
691-
print("Baseline is None. Using zeros as baseline.")
690+
X_train_mean = X_val_tensor.mean(dim=0)
691+
fun_control["xai_baseline"] = X_train_mean.unsqueeze(0)
692+
print("Baseline is None. Using training mean as baseline.")
692693
baseline = fun_control["xai_baseline"]
693694

694695
if "IntegratedGradients" in fun_control["xai_methods"]:

0 commit comments

Comments
 (0)