Skip to content

Commit ecbc76c

Browse files
0.31.6
1 parent 1973616 commit ecbc76c

3 files changed

Lines changed: 108 additions & 58 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 7 additions & 7 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.31.5"
10+
version = "0.31.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/surrogate/plot.py

Lines changed: 100 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,67 @@ def error_color(z_actual: float, z_predicted: float, eps: float = 1e-4, max_erro
134134
return f"#{grey:02x}{grey:02x}{grey:02x}"
135135

136136

137-
def plot_3d_surface(
138-
ax: "matplotlib.axes.Axes",
137+
def plot_error_points(
138+
ax,
139139
X: np.ndarray,
140140
y: np.ndarray,
141141
model,
142142
i: int,
143143
j: int,
144+
eps: float = 1e-4,
145+
max_error: float = 1e-3,
146+
var_names: Optional[List[str]] = None,
147+
title: Optional[str] = None,
148+
z_mode: str = "actual", # "actual", "error", or None
149+
) -> None:
150+
"""
151+
Scatter input points colored by prediction error.
152+
153+
Args:
154+
ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
155+
X (np.ndarray): Input data, shape (n_samples, k).
156+
y (np.ndarray): Target values, shape (n_samples,).
157+
model (object): Fitted model with predict().
158+
i (int): Index of first varied dimension.
159+
j (int): Index of second varied dimension.
160+
eps (float): Tolerance for coloring points based on prediction error.
161+
max_error (float): Maximum error for color scaling.
162+
var_names (list of str or None): List of axis labels or None.
163+
title (str or None): Title for the plot.
164+
z_mode (str): "actual" for z_actual (for 3D), "error" for abs error (for 3D error surface), or None (for 2D).
165+
"""
166+
n, k = X.shape
167+
check_ij(i, j, k)
168+
for idx in range(n):
169+
x_point = X[idx, i]
170+
y_point_ = X[idx, j]
171+
z_actual = y[idx]
172+
z_predicted = model.predict(X[idx].reshape(1, -1))[0]
173+
color = error_color(z_actual, z_predicted, eps, max_error)
174+
if z_mode == "actual":
175+
ax.scatter(x_point, y_point_, z_actual, color=color, s=50, edgecolor="black")
176+
elif z_mode == "error":
177+
ax.scatter(x_point, y_point_, abs(z_actual - z_predicted), color=color, s=50, edgecolor="black")
178+
else:
179+
ax.scatter(x_point, y_point_, color=color, s=50, edgecolor="black")
180+
if title is not None:
181+
ax.set_title(title)
182+
if var_names is not None:
183+
ax.set_xlabel(var_names[0])
184+
ax.set_ylabel(var_names[1])
185+
else:
186+
ax.set_xlabel(f"Dimension {i}")
187+
ax.set_ylabel(f"Dimension {j}")
188+
189+
190+
def plot_3d_surface(
144191
Z,
192+
i: int,
193+
j: int,
194+
ax: "matplotlib.axes.Axes",
195+
X: np.ndarray = None,
196+
y: np.ndarray = None,
197+
model=None,
145198
surface_label: str = "Prediction Surface",
146199
zlabel: str = "Prediction",
147200
var_names: Optional[List[str]] = None,
@@ -150,18 +203,19 @@ def plot_3d_surface(
150203
max_error: float = 1e-3,
151204
cmap: str = "jet",
152205
error_surface: bool = False,
206+
add_points: bool = False,
153207
) -> None:
154208
"""
155209
Plot a 3D surface and scatter input points, colored by prediction error.
156210
157211
Args:
212+
Z (tuple or np.ndarray): Surface values to plot, shape matching meshgrid.
213+
i (int): Index of first varied dimension.
214+
j (int): Index of second varied dimension.
158215
ax (matplotlib.axes.Axes): Matplotlib 3D axis.
159216
X (np.ndarray): Input data, shape (n_samples, k).
160217
y (np.ndarray): Target values, shape (n_samples,).
161218
model (object): Fitted model with predict().
162-
i (int): Index of first varied dimension.
163-
j (int): Index of second varied dimension.
164-
Z (tuple or np.ndarray): Surface values to plot, shape matching meshgrid.
165219
surface_label (str): Title for the surface.
166220
zlabel (str): Label for the z-axis.
167221
var_names (list of str or None): List of axis labels or None.
@@ -170,6 +224,7 @@ def plot_3d_surface(
170224
max_error (float): Maximum error for color scaling.
171225
cmap (str): Colormap for the surface.
172226
error_surface (bool): If True, scatter z is abs(y_actual - y_predicted).
227+
add_points (bool): If True, adds scatter points to the surface plot.
173228
174229
Returns:
175230
None
@@ -179,71 +234,70 @@ def plot_3d_surface(
179234
ax.set_xlabel(var_names[0] if var_names else f"Dimension {i}")
180235
ax.set_ylabel(var_names[1] if var_names else f"Dimension {j}")
181236
ax.set_zlabel(var_names[2] if var_names else zlabel)
182-
for idx in range(X.shape[0]):
183-
x_point = X[idx, i]
184-
y_point = X[idx, j]
185-
z_actual = y[idx]
186-
z_predicted = model.predict(X[idx].reshape(1, -1))[0]
187-
if error_surface:
188-
z_scatter = abs(z_actual - z_predicted)
189-
else:
190-
z_scatter = z_actual
191-
color = error_color(z_actual, z_predicted, eps, max_error)
192-
ax.scatter(x_point, y_point, z_scatter, color=color, s=50, edgecolor="black")
237+
if add_points and X is not None and y is not None and model is not None:
238+
plot_error_points(ax, X, y, model, i, j, eps, max_error, var_names, z_mode="error" if error_surface else "actual")
193239

194240

195241
def plot_contour(
196-
ax,
197242
X_i: np.ndarray,
198243
X_j: np.ndarray,
199244
Z: np.ndarray,
200-
X: np.ndarray,
201-
y: np.ndarray,
202-
model,
203245
i: int,
204246
j: int,
247+
ax: "matplotlib.axes.Axes",
248+
X: np.ndarray = None,
249+
y: np.ndarray = None,
250+
model=None,
205251
eps: float = 1e-4,
206252
max_error: float = 1e-3,
207253
var_names: Optional[List[str]] = None,
208254
cmap: str = "jet",
209255
levels: int = 30,
210256
title: str = "Prediction Contour",
257+
add_points: bool = False,
211258
) -> None:
212259
"""
213260
Plot a filled contour plot with scatter points colored by prediction error.
214261
215262
Args:
216-
ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
217263
X_i (np.ndarray): Meshgrid for the i-th dimension.
218264
X_j (np.ndarray): Meshgrid for the j-th dimension.
219265
Z (np.ndarray): Contour values (predicted or error), shape matching meshgrid.
266+
i (int): Index of first varied dimension.
267+
j (int): Index of second varied dimension.
268+
ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
220269
X (np.ndarray): Input data, shape (n_samples, k).
221270
y (np.ndarray): Target values, shape (n_samples,).
222271
model (object): Fitted model with predict().
223-
i (int): Index of first varied dimension.
224-
j (int): Index of second varied dimension.
225272
eps (float): Tolerance for coloring points based on prediction error.
226273
max_error (float): Maximum error for color scaling.
227274
var_names (list of str or None): List of axis labels or None.
228275
cmap (str): Colormap for the contour plot.
229276
levels (int): Number of contour levels.
230277
title (str): Title for the plot.
278+
add_points (bool): If True, adds scatter points to the contour plot.
231279
232280
Returns:
233281
None
234282
"""
235283
contour = ax.contourf(X_i, X_j, Z, cmap=cmap, levels=levels)
236284
plt.colorbar(contour, ax=ax)
237-
for idx in range(X.shape[0]):
238-
x_point = X[idx, i]
239-
y_point = X[idx, j]
240-
z_actual = y[idx]
241-
z_predicted = model.predict(X[idx].reshape(1, -1))[0]
242-
color = error_color(z_actual, z_predicted, eps, max_error)
243-
ax.scatter(x_point, y_point, color=color, s=50, edgecolor="black")
244-
ax.set_title(title)
245-
ax.set_xlabel(var_names[0] if var_names else f"Dimension {i}")
246-
ax.set_ylabel(var_names[1] if var_names else f"Dimension {j}")
285+
if add_points and X is not None and y is not None and model is not None:
286+
plot_error_points(ax, X, y, model, i, j, eps, max_error, var_names, title, z_mode=None)
287+
288+
289+
def check_ij(i: int, j: int, k: int) -> None:
290+
"""
291+
Check if indices i and j are valid for the number of features k.
292+
Args:
293+
i (int): Index of the first dimension.
294+
j (int): Index of the second dimension.
295+
k (int): Total number of features.
296+
"""
297+
if i >= k or j >= k:
298+
raise ValueError(f"Dimensions i and j must be less than the number of features (k={k}).")
299+
if i == j:
300+
raise ValueError("Dimensions i and j must be different.")
247301

248302

249303
def plotkd(
@@ -291,11 +345,7 @@ def plotkd(
291345
292346
"""
293347
k = X.shape[1]
294-
if i >= k or j >= k:
295-
raise ValueError(f"Dimensions i and j must be less than the number of features (k={k}).")
296-
if i == j:
297-
raise ValueError("Dimensions i and j must be different.")
298-
348+
check_ij(i, j, k)
299349
X_i, X_j, grid_points = generate_mesh_grid(X, i, j, n_grid)
300350

301351
# Predict the values and standard deviations
@@ -308,13 +358,13 @@ def plotkd(
308358
# Plot predicted values
309359
ax1 = fig.add_subplot(221, projection="3d")
310360
plot_3d_surface(
361+
(X_i, X_j, Z_pred),
362+
i,
363+
j,
311364
ax1,
312365
X,
313366
y,
314367
model,
315-
i,
316-
j,
317-
(X_i, X_j, Z_pred),
318368
surface_label="Prediction Surface",
319369
zlabel="Prediction",
320370
var_names=var_names,
@@ -328,13 +378,13 @@ def plotkd(
328378
# Plot prediction error
329379
ax2 = fig.add_subplot(222, projection="3d")
330380
plot_3d_surface(
381+
(X_i, X_j, Z_std),
382+
i,
383+
j,
331384
ax2,
332385
X,
333386
y,
334387
model,
335-
i,
336-
j,
337-
(X_i, X_j, Z_std),
338388
surface_label="Prediction Error Surface",
339389
zlabel="Error",
340390
var_names=var_names,
@@ -348,15 +398,15 @@ def plotkd(
348398
# Contour plot of predicted values
349399
ax3 = fig.add_subplot(223)
350400
plot_contour(
351-
ax3,
352401
X_i,
353402
X_j,
354403
Z_pred,
404+
i,
405+
j,
406+
ax3,
355407
X,
356408
y,
357409
model,
358-
i,
359-
j,
360410
eps=eps,
361411
max_error=max_error,
362412
var_names=var_names,
@@ -368,15 +418,15 @@ def plotkd(
368418
# Contour plot of prediction error
369419
ax4 = fig.add_subplot(224)
370420
plot_contour(
371-
ax4,
372421
X_i,
373422
X_j,
374423
Z_std,
424+
i,
425+
j,
426+
ax4,
375427
X,
376428
y,
377429
model,
378-
i,
379-
j,
380430
eps=eps,
381431
max_error=max_error,
382432
var_names=var_names,

0 commit comments

Comments
 (0)