Skip to content

Commit 391293d

Browse files
0.14.16
plotXai
1 parent fff19ad commit 391293d

4 files changed

Lines changed: 334 additions & 54 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 201 additions & 49 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.14"
10+
version = "0.14.16"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/plot/xai.py

Lines changed: 107 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ def plot_nn_values_hist(nn_values, net, nn_values_names="", color="C0", columns=
275275
plt.show()
276276

277277

278-
def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="gray", figsize=(6, 6)):
278+
def old_plot_nn_values_scatter(
279+
nn_values, nn_values_names="", absolute=True, cmap="gray", figsize=(6, 6), return_reshaped=False
280+
):
279281
"""
280282
Plot the values of a neural network.
281283
Can be used to plot the weights, gradients, or activations of a neural network.
@@ -292,6 +294,8 @@ def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="g
292294
The colormap to use. Defaults to "gray".
293295
figsize (tuple, optional):
294296
The figure size. Defaults to (6, 6).
297+
return_reshaped (bool, optional):
298+
Whether to return the reshaped values. Defaults to False.
295299
296300
"""
297301
if cmap == "gray":
@@ -301,13 +305,19 @@ def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="g
301305
else:
302306
cmap = "viridis"
303307

308+
res = {}
304309
for layer, values in nn_values.items():
305-
n = int(math.sqrt(len(values)))
306-
if n * n != len(values): # if the length is not a perfect square
307-
n += 1 # increase n by 1
310+
k = len(values)
311+
print(f"{k} values in Layer {layer}.")
312+
if is_square(k):
313+
n = int(math.sqrt(k))
314+
else:
315+
n = int(math.sqrt(len(values)) + 1)
308316
padding = np.zeros(n * n - len(values)) # create a zero array for padding
317+
print(f"{len(padding)} padding values added.")
309318
values = np.concatenate((values, padding)) # append the padding to the values
310319

320+
print(f"{len(values)} values in Layer {layer}.")
311321
if absolute:
312322
reshaped_values = np.abs(values.reshape((n, n)))
313323
else:
@@ -318,6 +328,81 @@ def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="g
318328
plt.colorbar(label="Value")
319329
plt.title(f"{nn_values_names} Plot for {layer}")
320330
plt.show()
331+
# add reshaped_values to the dictionary res
332+
res[layer] = reshaped_values
333+
if return_reshaped:
334+
return res
335+
336+
337+
def plot_nn_values_scatter(
338+
nn_values, nn_values_names="", absolute=True, cmap="gray", figsize=(6, 6), return_reshaped=False, show=True
339+
) -> dict:
340+
"""
341+
Plot the values of a neural network including a marker for padding values.
342+
For simplicity, this example will annotate 'P' directly on the plot for padding values
343+
using a unique marker value approach.
344+
345+
Args:
346+
nn_values (dict):
347+
A dictionary with the values of the neural network. For example,
348+
the weights, gradients, or activations.
349+
nn_values_names (str, optional):
350+
The name of the values. Defaults to "".
351+
absolute (bool, optional):
352+
Whether to use the absolute values. Defaults to True.
353+
cmap (str, optional):
354+
The colormap to use. Defaults to "gray".
355+
figsize (tuple, optional):
356+
The figure size. Defaults to (6, 6).
357+
return_reshaped (bool, optional):
358+
Whether to return the reshaped values. Defaults to False.
359+
show (bool, optional):
360+
Whether to show the plot. Defaults to True.
361+
362+
Returns:
363+
dict: A dictionary with the reshaped values.
364+
"""
365+
if cmap == "gray":
366+
cmap = "gray"
367+
elif cmap == "BlueWhiteRed":
368+
cmap = colors.LinearSegmentedColormap.from_list("", ["blue", "white", "red"])
369+
else:
370+
cmap = "viridis"
371+
372+
res = {}
373+
padding_marker = np.nan # Use NaN as a special marker for padding
374+
for layer, values in nn_values.items():
375+
k = len(values)
376+
print(f"{k} values in Layer {layer}.")
377+
n = int(math.sqrt(k))
378+
if n * n != k: # if the length is not a perfect square
379+
n += 1 # Adjust n for padding
380+
print(f"{n*n-k} padding values added.")
381+
values = np.append(values, [padding_marker] * (n * n - k)) # Append padding values
382+
383+
print(f"{len(values)} values now in Layer {layer}.")
384+
385+
if absolute:
386+
reshaped_values = np.abs(values).reshape((n, n))
387+
# Mark padding values distinctly by setting them back to NaN
388+
reshaped_values[reshaped_values == np.abs(padding_marker)] = np.nan
389+
else:
390+
reshaped_values = values.reshape((n, n))
391+
392+
_, ax = plt.figure(figsize=figsize), plt.gca()
393+
cax = ax.imshow(reshaped_values, cmap=cmap, interpolation="nearest")
394+
for i in range(n):
395+
for j in range(n):
396+
if np.isnan(reshaped_values[i, j]):
397+
ax.text(j, i, "P", ha="center", va="center", color="red")
398+
plt.colorbar(cax, label="Value")
399+
plt.title(f"{nn_values_names} Plot for {layer}")
400+
if show:
401+
plt.show()
402+
# Add reshaped_values to the dictionary res
403+
res[layer] = reshaped_values
404+
if return_reshaped:
405+
return res
321406

322407

323408
def visualize_activations_distributions(net, fun_control, batch_size, device="cpu", color="C0", columns=2) -> None:
@@ -595,3 +680,21 @@ def plot_attributions(df, attr_method="IntegratedGradients"):
595680
plt.xlabel(f"{attr_method} Attribution Value")
596681
plt.ylabel("Feature")
597682
plt.show()
683+
684+
685+
def is_square(n):
686+
"""Check if a number is a square number.
687+
688+
Args:
689+
n (int): The number to check.
690+
691+
Returns:
692+
bool: True if the number is a square number, False otherwise.
693+
694+
Examples:
695+
>>> is_square(4)
696+
True
697+
>>> is_square(5)
698+
False
699+
"""
700+
return n == int(math.sqrt(n)) ** 2

test/test_xai.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
import pprint
3+
from spotPython.plot.xai import plot_nn_values_scatter
4+
5+
6+
def test_plot_nn_values_scatter_reshaped_values():
7+
# Mock data for testing
8+
nn_values = {
9+
'layer1': np.random.rand(16), # 16 values suggesting a perfect square (4x4)
10+
'layer2': np.random.rand(18), # 18 values suggesting padding will be required for a 5x5 shape
11+
}
12+
13+
# Use the modified function that returns reshaped_values for testing
14+
reshaped_values = plot_nn_values_scatter(nn_values, 'Test Layer1', return_reshaped=True, show=False)
15+
16+
pprint.pprint(nn_values)
17+
pprint.pprint(reshaped_values)
18+
# Assert for layer1: Checks if reshaping is correct for perfect square
19+
assert reshaped_values['layer1'].shape == (4, 4)
20+
# Assert for layer2: Checks if reshaping is correct for non-square
21+
assert reshaped_values['layer2'].shape == (5, 5)
22+
23+
24+
if __name__ == "__main__":
25+
pytest.main(["-v", __file__])

0 commit comments

Comments
 (0)