Skip to content

Commit 2a43204

Browse files
0.31.7
Before changing the Spot.plot() method
1 parent ecbc76c commit 2a43204

3 files changed

Lines changed: 50 additions & 74 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 20 additions & 68 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.31.6"
10+
version = "0.31.7"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/surrogate/plot.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def plot_3d_surface(
204204
cmap: str = "jet",
205205
error_surface: bool = False,
206206
add_points: bool = False,
207+
vmin: float = None,
208+
vmax: float = None,
207209
) -> None:
208210
"""
209211
Plot a 3D surface and scatter input points, colored by prediction error.
@@ -229,7 +231,7 @@ def plot_3d_surface(
229231
Returns:
230232
None
231233
"""
232-
ax.plot_surface(*Z[:2], Z[2], cmap=cmap, alpha=alpha) if isinstance(Z, tuple) else ax.plot_surface(Z[0], Z[1], Z[2], cmap=cmap, alpha=alpha)
234+
ax.plot_surface(*Z[:2], Z[2], cmap=cmap, alpha=alpha, vmin=vmin, vmax=vmax) if isinstance(Z, tuple) else ax.plot_surface(Z[0], Z[1], Z[2], cmap=cmap, alpha=alpha, vmin=vmin, vmax=vmax)
233235
ax.set_title(surface_label)
234236
ax.set_xlabel(var_names[0] if var_names else f"Dimension {i}")
235237
ax.set_ylabel(var_names[1] if var_names else f"Dimension {j}")
@@ -238,7 +240,7 @@ def plot_3d_surface(
238240
plot_error_points(ax, X, y, model, i, j, eps, max_error, var_names, z_mode="error" if error_surface else "actual")
239241

240242

241-
def plot_contour(
243+
def plot_contour_and_err(
242244
X_i: np.ndarray,
243245
X_j: np.ndarray,
244246
Z: np.ndarray,
@@ -255,6 +257,8 @@ def plot_contour(
255257
levels: int = 30,
256258
title: str = "Prediction Contour",
257259
add_points: bool = False,
260+
vmin: float = None,
261+
vmax: float = None,
258262
) -> None:
259263
"""
260264
Plot a filled contour plot with scatter points colored by prediction error.
@@ -276,11 +280,13 @@ def plot_contour(
276280
levels (int): Number of contour levels.
277281
title (str): Title for the plot.
278282
add_points (bool): If True, adds scatter points to the contour plot.
283+
vmin (float): Minimum value for color scaling.
284+
vmax (float): Maximum value for color scaling.
279285
280286
Returns:
281287
None
282288
"""
283-
contour = ax.contourf(X_i, X_j, Z, cmap=cmap, levels=levels)
289+
contour = ax.contourf(X_i, X_j, Z, cmap=cmap, levels=levels, vmin=vmin, vmax=vmax)
284290
plt.colorbar(contour, ax=ax)
285291
if add_points and X is not None and y is not None and model is not None:
286292
plot_error_points(ax, X, y, model, i, j, eps, max_error, var_names, title, z_mode=None)
@@ -313,6 +319,9 @@ def plotkd(
313319
var_names: Optional[List[str]] = None,
314320
cmap: str = "jet",
315321
n_grid: int = 100,
322+
vmin: Optional[float] = None,
323+
vmax: Optional[float] = None,
324+
add_points: bool = False,
316325
) -> None:
317326
"""
318327
Plots the Kriging surrogate model for k-dimensional input data by varying two dimensions (i, j).
@@ -330,6 +339,9 @@ def plotkd(
330339
var_names (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
331340
cmap (str): Colormap for the surface and contour plots. Default is "jet".
332341
n_grid (int): Number of grid points per dimension for the mesh grid. Default is 100.
342+
vmin (float, optional): Minimum value for the color scale. If None, determined from predictions.
343+
vmax (float, optional): Maximum value for the color scale. If None, determined from predictions.
344+
add_points (bool): If True, adds scatter points to the surface and contour plots. Default is False.
333345
334346
Examples:
335347
>>> import numpy as np
@@ -373,6 +385,9 @@ def plotkd(
373385
max_error=max_error,
374386
cmap=cmap,
375387
error_surface=False,
388+
vmin=vmin,
389+
vmax=vmax,
390+
add_points=add_points,
376391
)
377392

378393
# Plot prediction error
@@ -393,11 +408,14 @@ def plotkd(
393408
max_error=max_error,
394409
cmap=cmap,
395410
error_surface=True,
411+
vmin=vmin,
412+
vmax=vmax,
413+
add_points=add_points,
396414
)
397415

398416
# Contour plot of predicted values
399417
ax3 = fig.add_subplot(223)
400-
plot_contour(
418+
plot_contour_and_err(
401419
X_i,
402420
X_j,
403421
Z_pred,
@@ -413,11 +431,14 @@ def plotkd(
413431
cmap=cmap,
414432
levels=30,
415433
title="Prediction Contour",
434+
vmin=vmin,
435+
vmax=vmax,
436+
add_points=add_points,
416437
)
417438

418439
# Contour plot of prediction error
419440
ax4 = fig.add_subplot(224)
420-
plot_contour(
441+
plot_contour_and_err(
421442
X_i,
422443
X_j,
423444
Z_std,
@@ -433,6 +454,9 @@ def plotkd(
433454
cmap=cmap,
434455
levels=30,
435456
title="Error Contour",
457+
vmin=vmin,
458+
vmax=vmax,
459+
add_points=add_points,
436460
)
437461

438462
if show:

0 commit comments

Comments
 (0)