Skip to content

Commit c1fac44

Browse files
committed
add local xai consensus option
1 parent 3b178e2 commit c1fac44

2 files changed

Lines changed: 99 additions & 60 deletions

File tree

src/spotpython/light/trainmodel.py

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -715,12 +715,13 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
715715
model.eval()
716716

717717
target = fun_control.get("xai_target", None)
718+
xai_mode = fun_control.get("xai_mode", None)
719+
metric = fun_control.get("xai_metric", "corr")
718720

719721
if "KernelShap" in fun_control["xai_methods"]:
720722
attr_ks = KernelShap(model)
721723
n_features = X_val_tensor.shape[1]
722724
samples_ks = min(2000, 100 * n_features)
723-
print("KernelShap: Using", samples_ks, "samples for attribution.")
724725
with torch.no_grad():
725726
attribution_ks = attr_ks.attribute(
726727
X_val_tensor,
@@ -730,16 +731,27 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
730731
target=target,
731732
show_progress=False,
732733
)
733-
ks_sum = attribution_ks.detach().cpu().numpy().sum(axis=0)
734-
# Remove NaN and Inf values
735-
ks_sum = np.nan_to_num(ks_sum, nan=0.0, posinf=0.0, neginf=0.0)
736-
# Safe normalization (no division by ~0)
737-
l2_norm = np.linalg.norm(ks_sum)
738-
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
739-
l2_normalized_ks = np.zeros_like(ks_sum)
740-
else:
741-
l2_normalized_ks = ks_sum / l2_norm
742-
attributions_dict["KernelShap"] = l2_normalized_ks
734+
ks_tensor = attribution_ks.detach().cpu().numpy() # Shape: N_samples x N_features
735+
ks_tensor = np.nan_to_num(ks_tensor, nan=0.0, posinf=0.0, neginf=0.0)
736+
737+
if xai_mode == "local":
738+
print("local consensus calculation")
739+
# normalize each sample individually
740+
norms = np.linalg.norm(ks_tensor, axis=1, keepdims=True)
741+
norms[norms < 1e-12] = 1.0 # avoid division by zero
742+
ks_normalized = ks_tensor / norms
743+
# attributions_dict now stores per-sample normalized attributions
744+
attributions_dict["KernelShap"] = ks_normalized
745+
746+
else: # global attributions
747+
print("global consensus calculation")
748+
ks_sum = ks_tensor.sum(axis=0)
749+
l2_norm = np.linalg.norm(ks_sum)
750+
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
751+
ks_normalized = np.zeros_like(ks_sum)
752+
else:
753+
ks_normalized = ks_sum / l2_norm
754+
attributions_dict["KernelShap"] = ks_normalized
743755

744756
if "IntegratedGradients" in fun_control["xai_methods"]:
745757
attr_ig = IntegratedGradients(model)
@@ -750,13 +762,26 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
750762
baselines=baseline,
751763
target=target,
752764
)
753-
ig_sum = attribution_ig.detach().cpu().numpy().sum(axis=0)
754-
ig_sum = np.nan_to_num(ig_sum, nan=0.0, posinf=0.0, neginf=0.0)
755-
l2 = np.linalg.norm(ig_sum)
756-
if not np.isfinite(l2) or l2 < 1e-12:
757-
attributions_dict["IntegratedGradients"] = np.zeros_like(ig_sum)
758-
else:
759-
attributions_dict["IntegratedGradients"] = ig_sum / l2
765+
ig_tensor = attribution_ig.detach().cpu().numpy() # Shape: N_samples x N_features
766+
ig_tensor = np.nan_to_num(ig_tensor, nan=0.0, posinf=0.0, neginf=0.0)
767+
768+
if xai_mode == "local":
769+
# normalize each sample individually
770+
norms = np.linalg.norm(ig_tensor, axis=1, keepdims=True)
771+
norms[norms < 1e-12] = 1.0 # avoid division by zero
772+
ig_normalized = ig_tensor / norms
773+
# attributions_dict now stores per-sample normalized attributions
774+
attributions_dict["IntegratedGradients"] = ig_normalized
775+
776+
else: # global attributions
777+
# alte Variante: Summe über Samples
778+
ig_sum = ig_tensor.sum(axis=0)
779+
l2_norm = np.linalg.norm(ig_sum)
780+
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
781+
ig_normalized = np.zeros_like(ig_sum)
782+
else:
783+
ig_normalized = ig_sum / l2_norm
784+
attributions_dict["IntegratedGradients"] = ig_normalized
760785

761786
if "DeepLift" in fun_control["xai_methods"]:
762787
attr_dl = DeepLift(model)
@@ -766,28 +791,59 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
766791
baselines=baseline,
767792
target=target,
768793
)
769-
dl_sum = attribution_dl.detach().cpu().numpy().sum(axis=0)
770-
dl_sum = np.nan_to_num(dl_sum, nan=0.0, posinf=0.0, neginf=0.0)
771-
l2_norm = np.linalg.norm(dl_sum)
772-
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
773-
l2_normalized_dl = np.zeros_like(dl_sum)
774-
else:
775-
l2_normalized_dl = dl_sum / l2_norm
776-
attributions_dict["DeepLift"] = l2_normalized_dl
777-
778-
attributions_list = [attributions_dict[method] for method in fun_control["xai_methods"]]
779-
attributions = np.stack(attributions_list, axis=0)
780-
781-
# Calculate corr:
782-
if fun_control["xai_metric"] not in ["corr", "cosine", "euclidean", "spearman"]:
783-
raise ValueError(f"Invalid xai_metric: {fun_control['xai_metric']}. Valid metrics are: 'corr', 'cosine', 'euclidean', 'spearman'")
784-
if fun_control["xai_metric"] == "corr":
785-
result_xai = calculate_xai_consistency_corr(attributions)
786-
elif fun_control["xai_metric"] == "cosine":
787-
result_xai = calculate_xai_consistency_cosine(attributions)
788-
elif fun_control["xai_metric"] == "euclidean":
789-
result_xai = calculate_xai_consistency_euclidean(attributions)
790-
elif fun_control["xai_metric"] == "spearman":
791-
result_xai = calculate_xai_consistency_spearman(attributions)
792-
793-
return result["val_loss"], result_xai
794+
dl_tensor = attribution_dl.detach().cpu().numpy() # Shape: N_samples x N_features
795+
dl_tensor = np.nan_to_num(dl_tensor, nan=0.0, posinf=0.0, neginf=0.0)
796+
797+
if xai_mode == "local":
798+
# normalize each sample individually
799+
norms = np.linalg.norm(dl_tensor, axis=1, keepdims=True)
800+
norms[norms < 1e-12] = 1.0 # avoid division by zero
801+
dl_normalized = dl_tensor / norms
802+
# attributions_dict now stores per-sample normalized attributions
803+
attributions_dict["DeepLift"] = dl_normalized
804+
805+
else: # global attributions
806+
# alte Variante: Summe über Samples
807+
dl_sum = dl_tensor.sum(axis=0)
808+
l2_norm = np.linalg.norm(dl_sum)
809+
if not np.isfinite(l2_norm) or l2_norm < 1e-12:
810+
dl_normalized = np.zeros_like(dl_sum)
811+
else:
812+
dl_normalized = dl_sum / l2_norm
813+
attributions_dict["DeepLift"] = dl_normalized
814+
815+
816+
if xai_mode == "local":
817+
# Konsens pro Sample
818+
N_samples = attributions_dict[fun_control["xai_methods"][0]].shape[0]
819+
per_sample_consensus = []
820+
for i in range(N_samples):
821+
sample_attrs = np.stack([attributions_dict[m][i] for m in fun_control["xai_methods"]], axis=0)
822+
# sample_attrs: n_methods x n_features
823+
if metric == "corr":
824+
sample_consensus = calculate_xai_consistency_corr(sample_attrs)
825+
elif metric == "cosine":
826+
sample_consensus = calculate_xai_consistency_cosine(sample_attrs)
827+
elif metric == "euclidean":
828+
sample_consensus = calculate_xai_consistency_euclidean(sample_attrs)
829+
elif metric == "spearman":
830+
sample_consensus = calculate_xai_consistency_spearman(sample_attrs)
831+
per_sample_consensus.append(sample_consensus)
832+
result_xai = np.mean(per_sample_consensus)
833+
print("aggregated local consensus:", result_xai)
834+
835+
else:
836+
# global modus
837+
attributions_list = [attributions_dict[m] for m in fun_control["xai_methods"]]
838+
attributions = np.stack(attributions_list, axis=0) # n_methods x n_features
839+
if metric == "corr":
840+
result_xai = calculate_xai_consistency_corr(attributions)
841+
elif metric == "cosine":
842+
result_xai = calculate_xai_consistency_cosine(attributions)
843+
elif metric == "euclidean":
844+
result_xai = calculate_xai_consistency_euclidean(attributions)
845+
elif metric == "spearman":
846+
result_xai = calculate_xai_consistency_spearman(attributions)
847+
print("global consensus:", result_xai)
848+
849+
return result["val_loss"], result_xai

src/spotpython/utils/metrics.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,11 @@ def calculate_xai_consistency_corr(attributions) -> float:
214214
"""
215215
global_attr_np = np.array(attributions)
216216
corr_matrix = np.corrcoef(global_attr_np)
217-
print("Attribution Correlation Matrix:")
218-
print(corr_matrix)
219217

220218
# Calculate the mean of the upper triangle of the correlation matrix
221219
upper_triangle_indices = np.triu_indices_from(corr_matrix, k=1)
222220
upper_triangle_values = corr_matrix[upper_triangle_indices]
223221
result_xai = upper_triangle_values.mean()
224-
print("XAI Consistency (mean of upper triangle of correlation matrix):")
225-
print(result_xai)
226222
return result_xai
227223

228224

@@ -240,15 +236,11 @@ def calculate_xai_consistency_cosine(attributions) -> float:
240236
"""
241237
global_attr_np = np.array(attributions)
242238
cosine_sim_matrix = np.array([[np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) for b in global_attr_np] for a in global_attr_np])
243-
print("Attribution Cosine Similarity Matrix:")
244-
print(cosine_sim_matrix)
245239

246240
# Calculate the mean of the upper triangle of the cosine similarity matrix
247241
upper_triangle_indices = np.triu_indices_from(cosine_sim_matrix, k=1)
248242
upper_triangle_values = cosine_sim_matrix[upper_triangle_indices]
249243
result_xai = upper_triangle_values.mean()
250-
print("XAI Consistency (mean of upper triangle of cosine similarity matrix):")
251-
print(result_xai)
252244
return result_xai
253245

254246

@@ -266,15 +258,11 @@ def calculate_xai_consistency_euclidean(attributions) -> float:
266258
"""
267259
global_attr_np = np.array(attributions)
268260
euclidean_dist_matrix = euclidean_distances(global_attr_np)
269-
print("Attribution Euclidean Distance Matrix:")
270-
print(euclidean_dist_matrix)
271261

272262
# Calculate the mean of the upper triangle of the Euclidean distance matrix
273263
upper_triangle_indices = np.triu_indices_from(euclidean_dist_matrix, k=1)
274264
upper_triangle_values = euclidean_dist_matrix[upper_triangle_indices]
275265
result_xai = upper_triangle_values.mean()
276-
print("XAI Consistency (mean of upper triangle of Euclidean distance matrix):")
277-
print(result_xai)
278266
return result_xai
279267

280268

@@ -300,9 +288,4 @@ def calculate_xai_consistency_spearman(attributions) -> float:
300288
spearman_corr_matrix[i, j] = corr
301289

302290
upper_triangle_values = spearman_corr_matrix[np.triu_indices(n_methods, k=1)]
303-
print("Attribution Spearman Correlation Matrix:")
304-
print(spearman_corr_matrix)
305-
306-
print("XAI Consistency (mean of upper triangle of Spearman correlation matrix):")
307-
print(upper_triangle_values.mean())
308291
return upper_triangle_values.mean()

0 commit comments

Comments
 (0)