Skip to content

Commit 6374424

Browse files
0.16.6
xai plot fixed
1 parent b0fe5d2 commit 6374424

2 files changed

Lines changed: 20 additions & 12 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.16.4"
10+
version = "0.16.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,10 @@ def get_activations(net, fun_control, batch_size, device="cpu", normalize=False)
115115
if isinstance(layer, nn.Linear):
116116
activations[layer_index] = inputs.view(-1).cpu().numpy()
117117
mean_activations[layer_index] = inputs.mean(dim=0).cpu().numpy()
118-
# Record the size of the activations
119-
layer_sizes[layer_index] = np.array(inputs.size())
118+
# Record the size of the activations and set the first dimension to 1
119+
layer_size = np.array(inputs.size())
120+
layer_size[0] = 1 # Set the first dimension to 1
121+
layer_sizes[layer_index] = layer_size
120122

121123
return activations, mean_activations, layer_sizes
122124

@@ -444,27 +446,26 @@ def plot_nn_values_scatter(
444446

445447
if absolute:
446448
reshaped_values = np.abs(values).reshape((height, width))
449+
# Mark padding values distinctly by setting them back to NaN
447450
reshaped_values[reshaped_values == np.abs(padding_marker)] = np.nan
448451
else:
449452
reshaped_values = values.reshape((height, width))
450453

451-
fig, ax = plt.subplots(figsize=figsize)
454+
_, ax = plt.subplots(figsize=figsize)
452455
cax = ax.imshow(reshaped_values, cmap=cmap, interpolation="nearest")
453456

454-
# Adjust the position and size of the colorbar
455-
cbar = fig.colorbar(cax, ax=ax, fraction=0.046, pad=0.04)
456-
457457
for i in range(height):
458458
for j in range(width):
459459
if np.isnan(reshaped_values[i, j]):
460460
ax.text(j, i, "P", ha="center", va="center", color="red")
461461

462+
plt.colorbar(cax, label="Value")
462463
plt.title(f"{nn_values_names} Plot for {layer}")
463464
if show:
464465
plt.show()
465466

467+
# Add reshaped_values to the dictionary res
466468
res[layer] = reshaped_values
467-
468469
if return_reshaped:
469470
return res
470471

@@ -534,7 +535,7 @@ def visualize_gradient_distributions(
534535
plot_nn_values_hist(grads, net, nn_values_names="Gradients", color=color, columns=columns)
535536

536537

537-
def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
538+
def visualize_mean_activations(mean_activations, layer_sizes, absolute=True, cmap="gray", figsize=(6, 6)) -> None:
538539
"""
539540
Scatter plots the mean activations of a neural network for each layer.
540541
means_activations is a dictionary with the mean activations of the neural network computed via
@@ -543,6 +544,8 @@ def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", fig
543544
Args:
544545
mean_activations (dict):
545546
A dictionary with the mean activations of the neural network.
547+
layer_sizes (dict):
548+
A dictionary with layer names as keys and their sizes as entries in NumPy array format.
546549
absolute (bool, optional):
547550
Whether to use the absolute values. Defaults to True.
548551
cmap (str, optional):
@@ -555,12 +558,17 @@ def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", fig
555558
556559
Examples:
557560
>>> from spotpython.plot.xai import get_activations
558-
activations, mean_activations, _ = get_activations(net, fun_control)
559-
visualize_mean_activations(mean_activations)
561+
activations, mean_activations, layer_sizes = get_activations(net, fun_control)
562+
visualize_mean_activations(mean_activations, layer_sizes)
560563
561564
"""
562565
plot_nn_values_scatter(
563-
nn_values=mean_activations, nn_values_names="Average Activations", absolute=absolute, cmap=cmap, figsize=figsize
566+
nn_values=mean_activations,
567+
layer_sizes=layer_sizes,
568+
nn_values_names="Average Activations",
569+
absolute=absolute,
570+
cmap=cmap,
571+
figsize=figsize,
564572
)
565573

566574

0 commit comments

Comments
 (0)