Skip to content

Commit f9dc086

Browse files
0.16.8
1 parent 3476c1a commit f9dc086

4 files changed

Lines changed: 31 additions & 31 deletions

File tree

RELEASE_NOTES.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
spotpython-0.16.8:
2+
3+
- xai.py: automatically handle the orientation of the colorbar in the plot_nn_values_scatter function

notebooks/00_spotPython_tests.ipynb

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5520,7 +5520,7 @@
55205520
},
55215521
{
55225522
"cell_type": "code",
5523-
"execution_count": 4,
5523+
"execution_count": null,
55245524
"metadata": {},
55255525
"outputs": [
55265526
{
@@ -5579,6 +5579,17 @@
55795579
"execution_count": 4,
55805580
"metadata": {},
55815581
"output_type": "execute_result"
5582+
},
5583+
{
5584+
"ename": "",
5585+
"evalue": "",
5586+
"output_type": "error",
5587+
"traceback": [
5588+
"\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
5589+
"\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
5590+
"\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
5591+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
5592+
]
55825593
}
55835594
],
55845595
"source": [

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.16.7"
10+
version = "0.16.8"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/plot/xai.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -363,11 +363,10 @@ def plot_nn_values_scatter(
363363
figsize=(6, 6),
364364
return_reshaped=False,
365365
show=True,
366+
colorbar_orientation="auto",
366367
) -> dict:
367368
"""
368369
Plot the values of a neural network including a marker for padding values.
369-
For simplicity, this example will annotate 'P' directly on the plot for padding values
370-
using a unique marker value approach.
371370
372371
Args:
373372
nn_values (dict):
@@ -387,35 +386,14 @@ def plot_nn_values_scatter(
387386
Whether to return the reshaped values. Defaults to False.
388387
show (bool, optional):
389388
Whether to show the plot. Defaults to True.
389+
colorbar_orientation (str, optional):
390+
The orientation of the colorbar. Can be "auto", "horizontal", "vertical", or "none".
391+
"auto" will choose the orientation based on the geometry of the plot.
392+
"none" will not show the colorbar.
393+
Defaults to "auto".
390394
391395
Returns:
392396
dict: A dictionary with the reshaped values.
393-
394-
Examples:
395-
>>> from spotpython.utils.init import fun_control_init
396-
from spotpython.data.diabetes import Diabetes
397-
from spotpython.light.regression.nn_linear_regressor import NNLinearRegressor
398-
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
399-
from spotpython.hyperparameters.values import (
400-
get_default_hyperparameters_as_array, get_one_config_from_X)
401-
from spotpython.hyperdict.light_hyper_dict import LightHyperDict
402-
# from spotpython.plot.xai import get_gradients
403-
fun_control = fun_control_init(
404-
_L_in=10, # 10: diabetes
405-
_L_out=1,
406-
_torchmetric="mean_squared_error",
407-
data_set=Diabetes(),
408-
core_model=NNLinearRegressor,
409-
hyperdict=LightHyperDict)
410-
X = get_default_hyperparameters_as_array(fun_control)
411-
config = get_one_config_from_X(X, fun_control)
412-
_L_in = fun_control["_L_in"]
413-
_L_out = fun_control["_L_out"]
414-
_torchmetric = fun_control["_torchmetric"]
415-
batch_size = 16
416-
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
417-
gradients, layer_sizes = get_gradients(model, fun_control=fun_control, batch_size=batch_size, device = "cpu")
418-
plot_nn_values_scatter(nn_values=gradients, layer_sizes=layer_sizes, nn_values_names="Weights")
419397
"""
420398
if cmap == "gray":
421399
cmap = "gray"
@@ -459,13 +437,21 @@ def plot_nn_values_scatter(
459437
if np.isnan(reshaped_values[i, j]):
460438
ax.text(j, i, "P", ha="center", va="center", color="red")
461439

462-
plt.colorbar(cax, label="Value")
440+
if colorbar_orientation == "auto":
441+
if height < width:
442+
plt.colorbar(cax, orientation="horizontal", label="Value")
443+
else:
444+
plt.colorbar(cax, orientation="vertical", label="Value")
445+
446+
if colorbar_orientation in ["horizontal", "vertical"]:
447+
plt.colorbar(cax, orientation=colorbar_orientation, label="Value")
463448
plt.title(f"{nn_values_names} Plot for {layer}")
464449
if show:
465450
plt.show()
466451

467452
# Add reshaped_values to the dictionary res
468453
res[layer] = reshaped_values
454+
469455
if return_reshaped:
470456
return res
471457

0 commit comments

Comments
 (0)