@@ -275,7 +275,9 @@ def plot_nn_values_hist(nn_values, net, nn_values_names="", color="C0", columns=
275275 plt .show ()
276276
277277
278- def plot_nn_values_scatter (nn_values , nn_values_names = "" , absolute = True , cmap = "gray" , figsize = (6 , 6 )):
278+ def old_plot_nn_values_scatter (
279+ nn_values , nn_values_names = "" , absolute = True , cmap = "gray" , figsize = (6 , 6 ), return_reshaped = False
280+ ):
279281 """
280282 Plot the values of a neural network.
281283 Can be used to plot the weights, gradients, or activations of a neural network.
@@ -292,6 +294,8 @@ def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="g
292294 The colormap to use. Defaults to "gray".
293295 figsize (tuple, optional):
294296 The figure size. Defaults to (6, 6).
297+ return_reshaped (bool, optional):
298+ Whether to return the reshaped values. Defaults to False.
295299
296300 """
297301 if cmap == "gray" :
@@ -301,13 +305,19 @@ def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="g
301305 else :
302306 cmap = "viridis"
303307
308+ res = {}
304309 for layer , values in nn_values .items ():
305- n = int (math .sqrt (len (values )))
306- if n * n != len (values ): # if the length is not a perfect square
307- n += 1 # increase n by 1
310+ k = len (values )
311+ print (f"{ k } values in Layer { layer } ." )
312+ if is_square (k ):
313+ n = int (math .sqrt (k ))
314+ else :
315+ n = int (math .sqrt (len (values )) + 1 )
308316 padding = np .zeros (n * n - len (values )) # create a zero array for padding
317+ print (f"{ len (padding )} padding values added." )
309318 values = np .concatenate ((values , padding )) # append the padding to the values
310319
320+ print (f"{ len (values )} values in Layer { layer } ." )
311321 if absolute :
312322 reshaped_values = np .abs (values .reshape ((n , n )))
313323 else :
@@ -318,6 +328,81 @@ def plot_nn_values_scatter(nn_values, nn_values_names="", absolute=True, cmap="g
318328 plt .colorbar (label = "Value" )
319329 plt .title (f"{ nn_values_names } Plot for { layer } " )
320330 plt .show ()
331+ # add reshaped_values to the dictionary res
332+ res [layer ] = reshaped_values
333+ if return_reshaped :
334+ return res
335+
336+
337+ def plot_nn_values_scatter (
338+ nn_values , nn_values_names = "" , absolute = True , cmap = "gray" , figsize = (6 , 6 ), return_reshaped = False , show = True
339+ ) -> dict :
340+ """
341+ Plot the values of a neural network including a marker for padding values.
342+ For simplicity, this example will annotate 'P' directly on the plot for padding values
343+ using a unique marker value approach.
344+
345+ Args:
346+ nn_values (dict):
347+ A dictionary with the values of the neural network. For example,
348+ the weights, gradients, or activations.
349+ nn_values_names (str, optional):
350+ The name of the values. Defaults to "".
351+ absolute (bool, optional):
352+ Whether to use the absolute values. Defaults to True.
353+ cmap (str, optional):
354+ The colormap to use. Defaults to "gray".
355+ figsize (tuple, optional):
356+ The figure size. Defaults to (6, 6).
357+ return_reshaped (bool, optional):
358+ Whether to return the reshaped values. Defaults to False.
359+ show (bool, optional):
360+ Whether to show the plot. Defaults to True.
361+
362+ Returns:
363+ dict: A dictionary with the reshaped values.
364+ """
365+ if cmap == "gray" :
366+ cmap = "gray"
367+ elif cmap == "BlueWhiteRed" :
368+ cmap = colors .LinearSegmentedColormap .from_list ("" , ["blue" , "white" , "red" ])
369+ else :
370+ cmap = "viridis"
371+
372+ res = {}
373+ padding_marker = np .nan # Use NaN as a special marker for padding
374+ for layer , values in nn_values .items ():
375+ k = len (values )
376+ print (f"{ k } values in Layer { layer } ." )
377+ n = int (math .sqrt (k ))
378+ if n * n != k : # if the length is not a perfect square
379+ n += 1 # Adjust n for padding
380+ print (f"{ n * n - k } padding values added." )
381+ values = np .append (values , [padding_marker ] * (n * n - k )) # Append padding values
382+
383+ print (f"{ len (values )} values now in Layer { layer } ." )
384+
385+ if absolute :
386+ reshaped_values = np .abs (values ).reshape ((n , n ))
387+ # Mark padding values distinctly by setting them back to NaN
388+ reshaped_values [reshaped_values == np .abs (padding_marker )] = np .nan
389+ else :
390+ reshaped_values = values .reshape ((n , n ))
391+
392+ _ , ax = plt .figure (figsize = figsize ), plt .gca ()
393+ cax = ax .imshow (reshaped_values , cmap = cmap , interpolation = "nearest" )
394+ for i in range (n ):
395+ for j in range (n ):
396+ if np .isnan (reshaped_values [i , j ]):
397+ ax .text (j , i , "P" , ha = "center" , va = "center" , color = "red" )
398+ plt .colorbar (cax , label = "Value" )
399+ plt .title (f"{ nn_values_names } Plot for { layer } " )
400+ if show :
401+ plt .show ()
402+ # Add reshaped_values to the dictionary res
403+ res [layer ] = reshaped_values
404+ if return_reshaped :
405+ return res
321406
322407
323408def visualize_activations_distributions (net , fun_control , batch_size , device = "cpu" , color = "C0" , columns = 2 ) -> None :
@@ -595,3 +680,21 @@ def plot_attributions(df, attr_method="IntegratedGradients"):
595680 plt .xlabel (f"{ attr_method } Attribution Value" )
596681 plt .ylabel ("Feature" )
597682 plt .show ()
683+
684+
685+ def is_square (n ):
686+ """Check if a number is a square number.
687+
688+ Args:
689+ n (int): The number to check.
690+
691+ Returns:
692+ bool: True if the number is a square number, False otherwise.
693+
694+ Examples:
695+ >>> is_square(4)
696+ True
697+ >>> is_square(5)
698+ False
699+ """
700+ return n == int (math .sqrt (n )) ** 2
0 commit comments