@@ -115,8 +115,10 @@ def get_activations(net, fun_control, batch_size, device="cpu", normalize=False)
115115 if isinstance (layer , nn .Linear ):
116116 activations [layer_index ] = inputs .view (- 1 ).cpu ().numpy ()
117117 mean_activations [layer_index ] = inputs .mean (dim = 0 ).cpu ().numpy ()
118- # Record the size of the activations
119- layer_sizes [layer_index ] = np .array (inputs .size ())
118+ # Record the size of the activations and set the first dimension to 1
119+ layer_size = np .array (inputs .size ())
120+ layer_size [0 ] = 1 # Set the first dimension to 1
121+ layer_sizes [layer_index ] = layer_size
120122
121123 return activations , mean_activations , layer_sizes
122124
@@ -444,27 +446,26 @@ def plot_nn_values_scatter(
444446
445447 if absolute :
446448 reshaped_values = np .abs (values ).reshape ((height , width ))
449+ # Mark padding values distinctly by setting them back to NaN
447450 reshaped_values [reshaped_values == np .abs (padding_marker )] = np .nan
448451 else :
449452 reshaped_values = values .reshape ((height , width ))
450453
451- fig , ax = plt .subplots (figsize = figsize )
454+ _ , ax = plt .subplots (figsize = figsize )
452455 cax = ax .imshow (reshaped_values , cmap = cmap , interpolation = "nearest" )
453456
454- # Adjust the position and size of the colorbar
455- cbar = fig .colorbar (cax , ax = ax , fraction = 0.046 , pad = 0.04 )
456-
457457 for i in range (height ):
458458 for j in range (width ):
459459 if np .isnan (reshaped_values [i , j ]):
460460 ax .text (j , i , "P" , ha = "center" , va = "center" , color = "red" )
461461
462+ plt .colorbar (cax , label = "Value" )
462463 plt .title (f"{ nn_values_names } Plot for { layer } " )
463464 if show :
464465 plt .show ()
465466
467+ # Add reshaped_values to the dictionary res
466468 res [layer ] = reshaped_values
467-
468469 if return_reshaped :
469470 return res
470471
@@ -534,7 +535,7 @@ def visualize_gradient_distributions(
534535 plot_nn_values_hist (grads , net , nn_values_names = "Gradients" , color = color , columns = columns )
535536
536537
537- def visualize_mean_activations (mean_activations , absolute = True , cmap = "gray" , figsize = (6 , 6 )) -> None :
538+ def visualize_mean_activations (mean_activations , layer_sizes , absolute = True , cmap = "gray" , figsize = (6 , 6 )) -> None :
538539 """
539540 Scatter plots the mean activations of a neural network for each layer.
540541 means_activations is a dictionary with the mean activations of the neural network computed via
@@ -543,6 +544,8 @@ def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", fig
543544 Args:
544545 mean_activations (dict):
545546 A dictionary with the mean activations of the neural network.
547+ layer_sizes (dict):
548+ A dictionary with layer names as keys and their sizes as entries in NumPy array format.
546549 absolute (bool, optional):
547550 Whether to use the absolute values. Defaults to True.
548551 cmap (str, optional):
@@ -555,12 +558,17 @@ def visualize_mean_activations(mean_activations, absolute=True, cmap="gray", fig
555558
556559 Examples:
557560 >>> from spotpython.plot.xai import get_activations
558- activations, mean_activations, _ = get_activations(net, fun_control)
559- visualize_mean_activations(mean_activations)
561+ activations, mean_activations, layer_sizes = get_activations(net, fun_control)
562+ visualize_mean_activations(mean_activations, layer_sizes )
560563
561564 """
562565 plot_nn_values_scatter (
563- nn_values = mean_activations , nn_values_names = "Average Activations" , absolute = absolute , cmap = cmap , figsize = figsize
566+ nn_values = mean_activations ,
567+ layer_sizes = layer_sizes ,
568+ nn_values_names = "Average Activations" ,
569+ absolute = absolute ,
570+ cmap = cmap ,
571+ figsize = figsize ,
564572 )
565573
566574
0 commit comments