Skip to content

Commit 52e7ccf

Browse files
0.28.11
plot mo options
1 parent 59c9b65 commit 52e7ccf

2 files changed

Lines changed: 16 additions & 22 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.28.10"
10+
version = "0.28.11"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/mo/plot.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def plot_mo(
1717
pareto_label: bool = False,
1818
y_rf_color="blue",
1919
y_best_color="red",
20+
x_axis_transformation: str = "id", # New argument for x-axis transformation
21+
y_axis_transformation: str = "id", # New argument for y-axis transformation
2022
) -> None:
2123
"""
2224
Generates scatter plots for each combination of two targets from a multi-output prediction while highlighting Pareto optimal points.
@@ -34,6 +36,8 @@ def plot_mo(
3436
pareto_label (bool): If True, label Pareto points with their index. Defaults to False.
3537
y_rf_color (str): The color of the predicted points. Defaults to "blue".
3638
y_best_color (str): The color of the best point. Defaults to "red".
39+
x_axis_transformation (str): Transformation for the x-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
40+
y_axis_transformation (str): Transformation for the y-axis. Options are "id" (linear), "log" (logarithmic), and "loglog" (log-log). Defaults to "id".
3741
3842
Returns:
3943
None: Displays the plot.
@@ -66,58 +70,48 @@ def plot_mo(
6670

6771
# Plot original data if provided
6872
if y_orig is not None:
69-
# Determine Pareto optimal points for original data
7073
minimize = pareto == "min"
7174
pareto_mask_orig = is_pareto_efficient(y_orig[:, [i, j]], minimize)
72-
73-
# Plot all original points
7475
plt.scatter(y_orig[:, i], y_orig[:, j], edgecolor="w", c="gray", s=s, marker="o", alpha=a, label="Original Points")
75-
76-
# Highlight Pareto points for original data
7776
plt.scatter(y_orig[pareto_mask_orig, i], y_orig[pareto_mask_orig, j], edgecolor="k", c="gray", s=pareto_size, marker="o", alpha=a, label="Original Pareto")
78-
79-
# Label Pareto points for original data if requested
8077
if pareto_label:
8178
for idx in np.where(pareto_mask_orig)[0]:
8279
plt.text(y_orig[idx, i], y_orig[idx, j], str(idx), color="black", fontsize=8, ha="center", va="center")
83-
84-
# Draw Pareto front for original data if requested
8580
if pareto_front_orig:
8681
sorted_indices_orig = np.argsort(y_orig[pareto_mask_orig, i])
8782
plt.plot(y_orig[pareto_mask_orig, i][sorted_indices_orig], y_orig[pareto_mask_orig, j][sorted_indices_orig], "k-", alpha=a, label="Original Pareto Front")
8883

8984
if y_rf is not None:
90-
# Determine Pareto optimal points for predicted data
9185
minimize = pareto == "min"
9286
pareto_mask = is_pareto_efficient(y_rf[:, [i, j]], minimize)
93-
94-
# Plot all predicted points
9587
plt.scatter(y_rf[:, i], y_rf[:, j], edgecolor="w", c=y_rf_color, s=s, marker="^", alpha=a, label="Predicted Points")
96-
97-
# Highlight Pareto points for predicted data
9888
plt.scatter(y_rf[pareto_mask, i], y_rf[pareto_mask, j], edgecolor="k", c=y_rf_color, s=pareto_size, marker="s", alpha=a, label="Predicted Pareto")
99-
100-
# Label Pareto points for predicted data if requested
10189
if pareto_label:
10290
for idx in np.where(pareto_mask)[0]:
10391
plt.text(y_rf[idx, i], y_rf[idx, j], str(idx), color="black", fontsize=8, ha="center", va="center")
104-
105-
# Draw Pareto front for predicted data if requested
10692
if pareto_front:
10793
sorted_indices = np.argsort(y_rf[pareto_mask, i])
10894
plt.plot(
10995
y_rf[pareto_mask, i][sorted_indices],
11096
y_rf[pareto_mask, j][sorted_indices],
111-
linestyle="-", # Specify the line style
112-
color=y_rf_color, # Use the color specified by y_rf_color
97+
linestyle="-",
98+
color=y_rf_color,
11399
alpha=a,
114100
label="Predicted Pareto Front",
115101
)
116102

117-
# Plot the best point, if provided
118103
if y_best is not None:
119104
plt.scatter(y_best[:, i], y_best[:, j], edgecolor="k", c=y_best_color, s=s, marker="D", alpha=1, label="Best")
120105

106+
# Apply axis transformations
107+
if x_axis_transformation == "log":
108+
plt.xscale("log")
109+
if y_axis_transformation == "log":
110+
plt.yscale("log")
111+
if x_axis_transformation == "loglog" or y_axis_transformation == "loglog":
112+
plt.xscale("log")
113+
plt.yscale("log")
114+
121115
plt.xlabel(target_names[i])
122116
plt.ylabel(target_names[j])
123117
plt.grid()

0 commit comments

Comments
 (0)