@@ -363,11 +363,10 @@ def plot_nn_values_scatter(
363363 figsize = (6 , 6 ),
364364 return_reshaped = False ,
365365 show = True ,
366+ colorbar_orientation = "auto" ,
366367) -> dict :
367368 """
368369 Plot the values of a neural network including a marker for padding values.
369- For simplicity, this example will annotate 'P' directly on the plot for padding values
370- using a unique marker value approach.
371370
372371 Args:
373372 nn_values (dict):
@@ -387,35 +386,14 @@ def plot_nn_values_scatter(
387386 Whether to return the reshaped values. Defaults to False.
388387 show (bool, optional):
389388 Whether to show the plot. Defaults to True.
389+ colorbar_orientation (str, optional):
390+ The orientation of the colorbar. Can be "auto", "horizontal", "vertical", or "none".
391+ "auto" will choose the orientation based on the geometry of the plot.
392+ "none" will not show the colorbar.
393+ Defaults to "auto".
390394
391395 Returns:
392396 dict: A dictionary with the reshaped values.
393-
394- Examples:
395- >>> from spotpython.utils.init import fun_control_init
396- from spotpython.data.diabetes import Diabetes
397- from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
398- from spotpython.hyperdict.light_hyper_dict import LightHyperDict
399- from spotpython.hyperparameters.values import (
400- get_default_hyperparameters_as_array, get_one_config_from_X)
401- from spotpython.hyperdict.light_hyper_dict import LightHyperDict
402- # from spotpython.plot.xai import get_gradients
403- fun_control = fun_control_init(
404- _L_in=10, # 10: diabetes
405- _L_out=1,
406- _torchmetric="mean_squared_error",
407- data_set=Diabetes(),
408- core_model=NNLinearRegressor,
409- hyperdict=LightHyperDict)
410- X = get_default_hyperparameters_as_array(fun_control)
411- config = get_one_config_from_X(X, fun_control)
412- _L_in = fun_control["_L_in"]
413- _L_out = fun_control["_L_out"]
414- _torchmetric = fun_control["_torchmetric"]
415- batch_size = 16
416- model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
417- gradients, layer_sizes = get_gradients(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
418- plot_nn_values_scatter(nn_values=gradients, layer_sizes=layer_sizes, nn_values_names="Weights")
419397 """
420398 if cmap == "gray" :
421399 cmap = "gray"
@@ -459,13 +437,21 @@ def plot_nn_values_scatter(
459437 if np .isnan (reshaped_values [i , j ]):
460438 ax .text (j , i , "P" , ha = "center" , va = "center" , color = "red" )
461439
462- plt .colorbar (cax , label = "Value" )
440+ if colorbar_orientation == "auto" :
441+ if height < width :
442+ plt .colorbar (cax , orientation = "horizontal" , label = "Value" )
443+ else :
444+ plt .colorbar (cax , orientation = "vertical" , label = "Value" )
445+
446+ if colorbar_orientation in ["horizontal" , "vertical" ]:
447+ plt .colorbar (cax , orientation = colorbar_orientation , label = "Value" )
463448 plt .title (f"{ nn_values_names } Plot for { layer } " )
464449 if show :
465450 plt .show ()
466451
467452 # Add reshaped_values to the dictionary res
468453 res [layer ] = reshaped_values
454+
469455 if return_reshaped :
470456 return res
471457
0 commit comments