Skip to content

Commit b164e56

Browse files
committed
combined xai consistency metric
1 parent 2f4c111 commit b164e56

1 file changed

Lines changed: 27 additions & 1 deletion

File tree

src/spotpython/light/trainmodel.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
714714
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
715715
attributions = np.stack(attributions_list, axis=1)
716716

717-
if fun_control["xai_metric"] not in {"max_diff", "variance", "spearman"}:
717+
if fun_control["xai_metric"] not in {"max_diff", "variance", "spearman", "spearman+variance"}:
718718
print("Invalid or missing xai_metric. Setting it to 'max_diff'.")
719719
fun_control["xai_metric"] = "max_diff"
720720

@@ -749,6 +749,32 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
749749
print("Spearman rank correlation matrix:\n", spearman_matrix)
750750
print("Consistency Score (Mean Spearman Correlation):", -result_xai)
751751

752+
if fun_control["xai_metric"] == "spearman+variance":
753+
# Compute Spearman mean
754+
num_methods = attributions.shape[1]
755+
spearman_matrix = np.zeros((num_methods, num_methods))
756+
757+
for i in range(num_methods):
758+
for j in range(i + 1, num_methods):
759+
corr, _ = spearmanr(attributions[:, i], attributions[:, j])
760+
spearman_matrix[i, j] = corr
761+
spearman_matrix[j, i] = corr
762+
763+
upper_triangle_values = spearman_matrix[np.triu_indices(num_methods, k=1)]
764+
mean_spearman = np.mean(upper_triangle_values)
765+
766+
# Compute attribution variance across methods for each feature
767+
variance = np.var(attributions, axis=1).mean() # mean over features
768+
769+
# Combine both (λ is a trade-off hyperparameter you define)
770+
lambda_variance = fun_control["lambda_variance"] if "lambda_variance" in fun_control else 1.0
771+
result_xai = -mean_spearman + lambda_variance * variance
772+
773+
print("Mean Spearman correlation:", mean_spearman)
774+
print("Mean Variance:", variance)
775+
print("Variance Weight:", lambda_variance)
776+
print("Combined XAI loss: ", result_xai)
777+
752778
# -------------------------------------------------------------------------------------------------------------------
753779

754780
return result["val_loss"], result_xai

0 commit comments

Comments
 (0)