@@ -134,14 +134,67 @@ def error_color(z_actual: float, z_predicted: float, eps: float = 1e-4, max_erro
134134 return f"#{ grey :02x} { grey :02x} { grey :02x} "
135135
136136
137- def plot_3d_surface (
138- ax : "matplotlib.axes.Axes" ,
137+ def plot_error_points (
138+ ax ,
139139 X : np .ndarray ,
140140 y : np .ndarray ,
141141 model ,
142142 i : int ,
143143 j : int ,
144+ eps : float = 1e-4 ,
145+ max_error : float = 1e-3 ,
146+ var_names : Optional [List [str ]] = None ,
147+ title : Optional [str ] = None ,
148+ z_mode : str = "actual" , # "actual", "error", or None
149+ ) -> None :
150+ """
151+ Scatter input points colored by prediction error.
152+
153+ Args:
154+ ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
155+ X (np.ndarray): Input data, shape (n_samples, k).
156+ y (np.ndarray): Target values, shape (n_samples,).
157+ model (object): Fitted model with predict().
158+ i (int): Index of first varied dimension.
159+ j (int): Index of second varied dimension.
160+ eps (float): Tolerance for coloring points based on prediction error.
161+ max_error (float): Maximum error for color scaling.
162+ var_names (list of str or None): List of axis labels or None.
163+ title (str or None): Title for the plot.
164+ z_mode (str): "actual" for z_actual (for 3D), "error" for abs error (for 3D error surface), or None (for 2D).
165+ """
166+ n , k = X .shape
167+ check_ij (i , j , k )
168+ for idx in range (n ):
169+ x_point = X [idx , i ]
170+ y_point_ = X [idx , j ]
171+ z_actual = y [idx ]
172+ z_predicted = model .predict (X [idx ].reshape (1 , - 1 ))[0 ]
173+ color = error_color (z_actual , z_predicted , eps , max_error )
174+ if z_mode == "actual" :
175+ ax .scatter (x_point , y_point_ , z_actual , color = color , s = 50 , edgecolor = "black" )
176+ elif z_mode == "error" :
177+ ax .scatter (x_point , y_point_ , abs (z_actual - z_predicted ), color = color , s = 50 , edgecolor = "black" )
178+ else :
179+ ax .scatter (x_point , y_point_ , color = color , s = 50 , edgecolor = "black" )
180+ if title is not None :
181+ ax .set_title (title )
182+ if var_names is not None :
183+ ax .set_xlabel (var_names [0 ])
184+ ax .set_ylabel (var_names [1 ])
185+ else :
186+ ax .set_xlabel (f"Dimension { i } " )
187+ ax .set_ylabel (f"Dimension { j } " )
188+
189+
190+ def plot_3d_surface (
144191 Z ,
192+ i : int ,
193+ j : int ,
194+ ax : "matplotlib.axes.Axes" ,
195+ X : np .ndarray = None ,
196+ y : np .ndarray = None ,
197+ model = None ,
145198 surface_label : str = "Prediction Surface" ,
146199 zlabel : str = "Prediction" ,
147200 var_names : Optional [List [str ]] = None ,
@@ -150,18 +203,19 @@ def plot_3d_surface(
150203 max_error : float = 1e-3 ,
151204 cmap : str = "jet" ,
152205 error_surface : bool = False ,
206+ add_points : bool = False ,
153207) -> None :
154208 """
155209 Plot a 3D surface and scatter input points, colored by prediction error.
156210
157211 Args:
212+ Z (tuple or np.ndarray): Surface values to plot, shape matching meshgrid.
213+ i (int): Index of first varied dimension.
214+ j (int): Index of second varied dimension.
158215 ax (matplotlib.axes.Axes): Matplotlib 3D axis.
159216 X (np.ndarray): Input data, shape (n_samples, k).
160217 y (np.ndarray): Target values, shape (n_samples,).
161218 model (object): Fitted model with predict().
162- i (int): Index of first varied dimension.
163- j (int): Index of second varied dimension.
164- Z (tuple or np.ndarray): Surface values to plot, shape matching meshgrid.
165219 surface_label (str): Title for the surface.
166220 zlabel (str): Label for the z-axis.
167221 var_names (list of str or None): List of axis labels or None.
@@ -170,6 +224,7 @@ def plot_3d_surface(
170224 max_error (float): Maximum error for color scaling.
171225 cmap (str): Colormap for the surface.
172226 error_surface (bool): If True, scatter z is abs(y_actual - y_predicted).
227+ add_points (bool): If True, adds scatter points to the surface plot.
173228
174229 Returns:
175230 None
@@ -179,71 +234,70 @@ def plot_3d_surface(
179234 ax .set_xlabel (var_names [0 ] if var_names else f"Dimension { i } " )
180235 ax .set_ylabel (var_names [1 ] if var_names else f"Dimension { j } " )
181236 ax .set_zlabel (var_names [2 ] if var_names else zlabel )
182- for idx in range (X .shape [0 ]):
183- x_point = X [idx , i ]
184- y_point = X [idx , j ]
185- z_actual = y [idx ]
186- z_predicted = model .predict (X [idx ].reshape (1 , - 1 ))[0 ]
187- if error_surface :
188- z_scatter = abs (z_actual - z_predicted )
189- else :
190- z_scatter = z_actual
191- color = error_color (z_actual , z_predicted , eps , max_error )
192- ax .scatter (x_point , y_point , z_scatter , color = color , s = 50 , edgecolor = "black" )
237+ if add_points and X is not None and y is not None and model is not None :
238+ plot_error_points (ax , X , y , model , i , j , eps , max_error , var_names , z_mode = "error" if error_surface else "actual" )
193239
194240
195241def plot_contour (
196- ax ,
197242 X_i : np .ndarray ,
198243 X_j : np .ndarray ,
199244 Z : np .ndarray ,
200- X : np .ndarray ,
201- y : np .ndarray ,
202- model ,
203245 i : int ,
204246 j : int ,
247+ ax : "matplotlib.axes.Axes" ,
248+ X : np .ndarray = None ,
249+ y : np .ndarray = None ,
250+ model = None ,
205251 eps : float = 1e-4 ,
206252 max_error : float = 1e-3 ,
207253 var_names : Optional [List [str ]] = None ,
208254 cmap : str = "jet" ,
209255 levels : int = 30 ,
210256 title : str = "Prediction Contour" ,
257+ add_points : bool = False ,
211258) -> None :
212259 """
213260 Plot a filled contour plot with scatter points colored by prediction error.
214261
215262 Args:
216- ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
217263 X_i (np.ndarray): Meshgrid for the i-th dimension.
218264 X_j (np.ndarray): Meshgrid for the j-th dimension.
219265 Z (np.ndarray): Contour values (predicted or error), shape matching meshgrid.
266+ i (int): Index of first varied dimension.
267+ j (int): Index of second varied dimension.
268+ ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
220269 X (np.ndarray): Input data, shape (n_samples, k).
221270 y (np.ndarray): Target values, shape (n_samples,).
222271 model (object): Fitted model with predict().
223- i (int): Index of first varied dimension.
224- j (int): Index of second varied dimension.
225272 eps (float): Tolerance for coloring points based on prediction error.
226273 max_error (float): Maximum error for color scaling.
227274 var_names (list of str or None): List of axis labels or None.
228275 cmap (str): Colormap for the contour plot.
229276 levels (int): Number of contour levels.
230277 title (str): Title for the plot.
278+ add_points (bool): If True, adds scatter points to the contour plot.
231279
232280 Returns:
233281 None
234282 """
235283 contour = ax .contourf (X_i , X_j , Z , cmap = cmap , levels = levels )
236284 plt .colorbar (contour , ax = ax )
237- for idx in range (X .shape [0 ]):
238- x_point = X [idx , i ]
239- y_point = X [idx , j ]
240- z_actual = y [idx ]
241- z_predicted = model .predict (X [idx ].reshape (1 , - 1 ))[0 ]
242- color = error_color (z_actual , z_predicted , eps , max_error )
243- ax .scatter (x_point , y_point , color = color , s = 50 , edgecolor = "black" )
244- ax .set_title (title )
245- ax .set_xlabel (var_names [0 ] if var_names else f"Dimension { i } " )
246- ax .set_ylabel (var_names [1 ] if var_names else f"Dimension { j } " )
285+ if add_points and X is not None and y is not None and model is not None :
286+ plot_error_points (ax , X , y , model , i , j , eps , max_error , var_names , title , z_mode = None )
287+
288+
289+ def check_ij (i : int , j : int , k : int ) -> None :
290+ """
291+ Check if indices i and j are valid for the number of features k.
292+ Args:
293+ i (int): Index of the first dimension.
294+ j (int): Index of the second dimension.
295+ k (int): Total number of features.
296+ """
297+ if i >= k or j >= k :
298+ raise ValueError (f"Dimensions i and j must be less than the number of features (k={ k } )." )
299+ if i == j :
300+ raise ValueError ("Dimensions i and j must be different." )
247301
248302
249303def plotkd (
@@ -291,11 +345,7 @@ def plotkd(
291345
292346 """
293347 k = X .shape [1 ]
294- if i >= k or j >= k :
295- raise ValueError (f"Dimensions i and j must be less than the number of features (k={ k } )." )
296- if i == j :
297- raise ValueError ("Dimensions i and j must be different." )
298-
348+ check_ij (i , j , k )
299349 X_i , X_j , grid_points = generate_mesh_grid (X , i , j , n_grid )
300350
301351 # Predict the values and standard deviations
@@ -308,13 +358,13 @@ def plotkd(
308358 # Plot predicted values
309359 ax1 = fig .add_subplot (221 , projection = "3d" )
310360 plot_3d_surface (
361+ (X_i , X_j , Z_pred ),
362+ i ,
363+ j ,
311364 ax1 ,
312365 X ,
313366 y ,
314367 model ,
315- i ,
316- j ,
317- (X_i , X_j , Z_pred ),
318368 surface_label = "Prediction Surface" ,
319369 zlabel = "Prediction" ,
320370 var_names = var_names ,
@@ -328,13 +378,13 @@ def plotkd(
328378 # Plot prediction error
329379 ax2 = fig .add_subplot (222 , projection = "3d" )
330380 plot_3d_surface (
381+ (X_i , X_j , Z_std ),
382+ i ,
383+ j ,
331384 ax2 ,
332385 X ,
333386 y ,
334387 model ,
335- i ,
336- j ,
337- (X_i , X_j , Z_std ),
338388 surface_label = "Prediction Error Surface" ,
339389 zlabel = "Error" ,
340390 var_names = var_names ,
@@ -348,15 +398,15 @@ def plotkd(
348398 # Contour plot of predicted values
349399 ax3 = fig .add_subplot (223 )
350400 plot_contour (
351- ax3 ,
352401 X_i ,
353402 X_j ,
354403 Z_pred ,
404+ i ,
405+ j ,
406+ ax3 ,
355407 X ,
356408 y ,
357409 model ,
358- i ,
359- j ,
360410 eps = eps ,
361411 max_error = max_error ,
362412 var_names = var_names ,
@@ -368,15 +418,15 @@ def plotkd(
368418 # Contour plot of prediction error
369419 ax4 = fig .add_subplot (224 )
370420 plot_contour (
371- ax4 ,
372421 X_i ,
373422 X_j ,
374423 Z_std ,
424+ i ,
425+ j ,
426+ ax4 ,
375427 X ,
376428 y ,
377429 model ,
378- i ,
379- j ,
380430 eps = eps ,
381431 max_error = max_error ,
382432 var_names = var_names ,
0 commit comments