@@ -12,6 +12,7 @@ def plot_predictionintervals(
1212 y_test_pred_low ,
1313 y_test_pred_high ,
1414 suptitle : str ,
15+ figsize : tuple = (10 , 10 ), # Default figsize added
1516) -> None :
1617 """
1718 Plots prediction intervals for training and testing data.
@@ -31,37 +32,12 @@ def plot_predictionintervals(
3132 y_test_pred_low (array-like): Lower bounds of prediction intervals for the testing set.
3233 y_test_pred_high (array-like): Upper bounds of prediction intervals for the testing set.
3334 suptitle (str): The title for the entire figure.
35+ figsize (tuple, optional): Size of the figure. Default is (10, 10).
3436
3537 Returns:
3638 None: The function displays the plots but does not return any value.
37-
38- Notes:
39- - The first subplot compares true and predicted values with error bars for both training
40- and testing data.
41- - The second subplot visualizes the width of prediction intervals as a function of true values.
42- - The third subplot orders the prediction interval widths and displays them for both
43- training and testing data.
44- - The fourth subplot shows histograms of the interval widths for training and testing data.
45-
46- References:
47- Function adapted from: https://github.com/scikit-learn-contrib/MAPIE/blob/master/notebooks/regression/exoplanets.ipynb
48-
49- Examples:
50- >>> import numpy as np
51- >>> from spotpython.uc.plot import plot_predictionintervals
52- >>> y_train = np.array([1, 2, 3, 4, 5])
53- >>> y_train_pred = np.array([1.1, 2.2, 3.3, 4.4, 5.5])
54- >>> y_train_pred_low = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
55- >>> y_train_pred_high = np.array([1.2, 2.4, 3.6, 4.8, 6.0])
56- >>> y_test = np.array([6, 7, 8])
57- >>> y_test_pred = np.array([6.1, 7.2, 8.3])
58- >>> y_test_pred_low = np.array([6.0, 7.0, 8.0])
59- >>> y_test_pred_high = np.array([6.2, 7.4, 8.6])
60- >>> suptitle = "Prediction Intervals"
61- >>> plot_predictionintervals(y_train, y_train_pred, y_train_pred_low, y_train_pred_high, y_test, y_test_pred, y_test_pred_low, y_test_pred_high, suptitle)
6239 """
63-
64- fig , ((ax1 , ax2 ), (ax3 , ax4 )) = plt .subplots (2 , 2 , figsize = (12 , 10 ))
40+ fig , ((ax1 , ax2 ), (ax3 , ax4 )) = plt .subplots (2 , 2 , figsize = figsize ) # Use figsize parameter
6541
6642 ax1 .errorbar (
6743 x = y_train ,
0 commit comments