Skip to content

Commit 736d3ac

Browse files
committed
update xai metric
1 parent 001860f commit 736d3ac

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,10 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
695695
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
696696
attributions = np.stack(attributions_list, axis=1)
697697

698+
if fun_control["xai_metric"] not in {"max_diff", "variance", "spearman"}:
699+
print("Invalid or missing xai_metric. Setting it to 'max_diff'.")
700+
fun_control["xai_metric"] = "max_diff"
701+
698702
if fun_control["xai_metric"] == "max_diff":
699703
# Compute the max difference of the attribution methods for each feature
700704
result_xai = np.max(attributions, axis=1) - np.min(attributions, axis=1)

0 commit comments

Comments
 (0)