Skip to content

Commit 8f62f62

Browse files
0.29.10
figsize added to plot.py
1 parent d1ad54d commit 8f62f62

2 files changed

Lines changed: 4 additions & 28 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.29.9"
10+
version = "0.29.10"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/uc/plot.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def plot_predictionintervals(
1212
y_test_pred_low,
1313
y_test_pred_high,
1414
suptitle: str,
15+
figsize: tuple = (10, 10), # Default figsize added
1516
) -> None:
1617
"""
1718
Plots prediction intervals for training and testing data.
@@ -31,37 +32,12 @@ def plot_predictionintervals(
3132
y_test_pred_low (array-like): Lower bounds of prediction intervals for the testing set.
3233
y_test_pred_high (array-like): Upper bounds of prediction intervals for the testing set.
3334
suptitle (str): The title for the entire figure.
35+
figsize (tuple, optional): Size of the figure. Default is (10, 10).
3436
3537
Returns:
3638
None: The function displays the plots but does not return any value.
37-
38-
Notes:
39-
- The first subplot compares true and predicted values with error bars for both training
40-
and testing data.
41-
- The second subplot visualizes the width of prediction intervals as a function of true values.
42-
- The third subplot orders the prediction interval widths and displays them for both
43-
training and testing data.
44-
- The fourth subplot shows histograms of the interval widths for training and testing data.
45-
46-
References:
47-
Function adapted from: https://github.com/scikit-learn-contrib/MAPIE/blob/master/notebooks/regression/exoplanets.ipynb
48-
49-
Examples:
50-
>>> import numpy as np
51-
>>> from spotpython.uc.plot import plot_predictionintervals
52-
>>> y_train = np.array([1, 2, 3, 4, 5])
53-
>>> y_train_pred = np.array([1.1, 2.2, 3.3, 4.4, 5.5])
54-
>>> y_train_pred_low = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
55-
>>> y_train_pred_high = np.array([1.2, 2.4, 3.6, 4.8, 6.0])
56-
>>> y_test = np.array([6, 7, 8])
57-
>>> y_test_pred = np.array([6.1, 7.2, 8.3])
58-
>>> y_test_pred_low = np.array([6.0, 7.0, 8.0])
59-
>>> y_test_pred_high = np.array([6.2, 7.4, 8.6])
60-
>>> suptitle = "Prediction Intervals"
61-
>>> plot_predictionintervals(y_train, y_train_pred, y_train_pred_low, y_train_pred_high, y_test, y_test_pred, y_test_pred_low, y_test_pred_high, suptitle)
6239
"""
63-
64-
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))
40+
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize) # Use figsize parameter
6541

6642
ax1.errorbar(
6743
x=y_train,

0 commit comments

Comments
 (0)