@@ -17,6 +17,8 @@ def plot_mo(
1717 pareto_label : bool = False ,
1818 y_rf_color = "blue" ,
1919 y_best_color = "red" ,
20+ x_axis_transformation : str = "id" , # New argument for x-axis transformation
21+ y_axis_transformation : str = "id" , # New argument for y-axis transformation
2022) -> None :
2123 """
2224 Generates scatter plots for each combination of two targets from a multi-output prediction while highlighting Pareto optimal points.
@@ -34,6 +36,8 @@ def plot_mo(
3436 pareto_label (bool): If True, label Pareto points with their index. Defaults to False.
3537 y_rf_color (str): The color of the predicted points. Defaults to "blue".
3638 y_best_color (str): The color of the best point. Defaults to "red".
39+ x_axis_transformation (str): Transformation for the x-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
40+ y_axis_transformation (str): Transformation for the y-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
3741
3842 Returns:
3943 None: Displays the plot.
@@ -66,58 +70,48 @@ def plot_mo(
6670
6771 # Plot original data if provided
6872 if y_orig is not None :
69- # Determine Pareto optimal points for original data
7073 minimize = pareto == "min"
7174 pareto_mask_orig = is_pareto_efficient (y_orig [:, [i , j ]], minimize )
72-
73- # Plot all original points
7475 plt .scatter (y_orig [:, i ], y_orig [:, j ], edgecolor = "w" , c = "gray" , s = s , marker = "o" , alpha = a , label = "Original Points" )
75-
76- # Highlight Pareto points for original data
7776 plt .scatter (y_orig [pareto_mask_orig , i ], y_orig [pareto_mask_orig , j ], edgecolor = "k" , c = "gray" , s = pareto_size , marker = "o" , alpha = a , label = "Original Pareto" )
78-
79- # Label Pareto points for original data if requested
8077 if pareto_label :
8178 for idx in np .where (pareto_mask_orig )[0 ]:
8279 plt .text (y_orig [idx , i ], y_orig [idx , j ], str (idx ), color = "black" , fontsize = 8 , ha = "center" , va = "center" )
83-
84- # Draw Pareto front for original data if requested
8580 if pareto_front_orig :
8681 sorted_indices_orig = np .argsort (y_orig [pareto_mask_orig , i ])
8782 plt .plot (y_orig [pareto_mask_orig , i ][sorted_indices_orig ], y_orig [pareto_mask_orig , j ][sorted_indices_orig ], "k-" , alpha = a , label = "Original Pareto Front" )
8883
8984 if y_rf is not None :
90- # Determine Pareto optimal points for predicted data
9185 minimize = pareto == "min"
9286 pareto_mask = is_pareto_efficient (y_rf [:, [i , j ]], minimize )
93-
94- # Plot all predicted points
9587 plt .scatter (y_rf [:, i ], y_rf [:, j ], edgecolor = "w" , c = y_rf_color , s = s , marker = "^" , alpha = a , label = "Predicted Points" )
96-
97- # Highlight Pareto points for predicted data
9888 plt .scatter (y_rf [pareto_mask , i ], y_rf [pareto_mask , j ], edgecolor = "k" , c = y_rf_color , s = pareto_size , marker = "s" , alpha = a , label = "Predicted Pareto" )
99-
100- # Label Pareto points for predicted data if requested
10189 if pareto_label :
10290 for idx in np .where (pareto_mask )[0 ]:
10391 plt .text (y_rf [idx , i ], y_rf [idx , j ], str (idx ), color = "black" , fontsize = 8 , ha = "center" , va = "center" )
104-
105- # Draw Pareto front for predicted data if requested
10692 if pareto_front :
10793 sorted_indices = np .argsort (y_rf [pareto_mask , i ])
10894 plt .plot (
10995 y_rf [pareto_mask , i ][sorted_indices ],
11096 y_rf [pareto_mask , j ][sorted_indices ],
111- linestyle = "-" , # Specify the line style
112- color = y_rf_color , # Use the color specified by y_rf_color
97+ linestyle = "-" ,
98+ color = y_rf_color ,
11399 alpha = a ,
114100 label = "Predicted Pareto Front" ,
115101 )
116102
117- # Plot the best point, if provided
118103 if y_best is not None :
119104 plt .scatter (y_best [:, i ], y_best [:, j ], edgecolor = "k" , c = y_best_color , s = s , marker = "D" , alpha = 1 , label = "Best" )
120105
106+ # Apply axis transformations
107+ if x_axis_transformation == "log" :
108+ plt .xscale ("log" )
109+ if y_axis_transformation == "log" :
110+ plt .yscale ("log" )
111+ if x_axis_transformation == "loglog" or y_axis_transformation == "loglog" :
112+ plt .xscale ("log" )
113+ plt .yscale ("log" )
114+
121115 plt .xlabel (target_names [i ])
122116 plt .ylabel (target_names [j ])
123117 plt .grid ()
0 commit comments