Skip to content

Commit 8265f57

Browse files
committed
tests xai consistency
1 parent 1e0b451 commit 8265f57

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

test/test_xai_consistency.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
from spotpython.utils.metrics import calculate_xai_consistency
3+
4+
5+
def test_xai_consistency():
6+
# Mock data for testing
7+
8+
dl_attr_test_sum = [1, 2, 3, 4, 5]
9+
row_sum_dl = np.sum(dl_attr_test_sum, axis=0)
10+
if row_sum_dl == 0:
11+
row_sum_dl += 1e-10
12+
scaled_attribution_dl = dl_attr_test_sum / row_sum_dl
13+
14+
ig_attr_test_sum = [1, 2, 3, 4, 5]
15+
row_sum_ig = np.sum(ig_attr_test_sum, axis=0)
16+
if row_sum_ig == 0:
17+
row_sum_ig += 1e-10
18+
scaled_attribution_ig = ig_attr_test_sum / row_sum_ig
19+
20+
attributions = [scaled_attribution_dl, scaled_attribution_ig]
21+
result = calculate_xai_consistency(attributions)
22+
print("XAI Consistency Result:")
23+
print(result)
24+
25+
# Assert that the result is 1
26+
assert abs(result - 1) < 1e-10
27+
28+
29+
def test_xai_consistency_negative():
30+
# Mock data for testing negative consistency
31+
32+
dl_attr_test_sum = [1, 2, 3, 4, 5]
33+
row_sum_dl = np.sum(dl_attr_test_sum, axis=0)
34+
if row_sum_dl == 0:
35+
row_sum_dl += 1e-10
36+
scaled_attribution_dl = dl_attr_test_sum / row_sum_dl
37+
38+
ig_attr_test_sum = [-1, -2, -3, -4, -5]
39+
row_sum_ig = np.sum(np.abs(ig_attr_test_sum), axis=0)
40+
if row_sum_ig == 0:
41+
row_sum_ig += 1e-10
42+
scaled_attribution_ig = ig_attr_test_sum / row_sum_ig
43+
44+
attributions = [scaled_attribution_dl, scaled_attribution_ig]
45+
result = calculate_xai_consistency(attributions)
46+
print("XAI Consistency Result (Negative):")
47+
print(result)
48+
49+
# Assert that the result is -1
50+
assert abs(result + 1) < 1e-10

0 commit comments

Comments
 (0)