Skip to content

Commit c2ca9c5

Browse files
0.14.57
Improved plots for oml-book
1 parent 3280be0 commit c2ca9c5

4 files changed

Lines changed: 71 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.56"
10+
version = "0.14.57"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/plot/ts.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from spotPython.data.friedman import FriedmanDriftDataset
2+
import matplotlib.pyplot as plt
3+
4+
5+
def plot_friedman_drift_data(
6+
n_samples, seed, change_point1, change_point2, constant=True, show=True, filename=None
7+
) -> None:
8+
"""Plot the Friedman dataset with drifts at change_point1 and change_point2.
9+
10+
Args:
11+
n_samples (int):
12+
Number of samples to generate.
13+
seed (int):
14+
Seed for the random number generator.
15+
change_point1 (int):
16+
Index of the first drift point.
17+
change_point2 (int):
18+
Index of the second drift point.
19+
constant (bool, optional):
20+
If True, the drifts are constant. Defaults to True.
21+
filename (str, optional):
22+
Name of the file to save the plot. Defaults to None.
23+
24+
Returns:
25+
None
26+
27+
Examples:
28+
>>> from spotPython.plot.ts import plot_friedman_drift_data
29+
>>> plot_friedman_drift_data(n_samples=100, seed=42, change_point1=50, change_point2=75, constant=False)
30+
>>> plot_friedman_drift_data(n_samples=100, seed=42, change_point1=50, change_point2=75, constant=True)
31+
"""
32+
data_generator = FriedmanDriftDataset(
33+
n_samples=n_samples, seed=seed, change_point1=change_point1, change_point2=change_point2, constant=constant
34+
)
35+
data = [data for data in data_generator]
36+
indices = [i for _, _, i in data]
37+
values = {f"x{i}": [] for i in range(5)}
38+
values["y"] = []
39+
for x, y, _ in data:
40+
for i in range(5):
41+
values[f"x{i}"].append(x[i])
42+
values["y"].append(y)
43+
44+
plt.figure(figsize=(10, 6))
45+
for label, series in values.items():
46+
plt.plot(indices, series, label=label)
47+
plt.xlabel("Index")
48+
plt.ylabel("Value")
49+
plt.axvline(x=change_point1, color="k", linestyle="--", label="Drift Point 1")
50+
plt.axvline(x=change_point2, color="r", linestyle="--", label="Drift Point 2")
51+
plt.legend()
52+
plt.grid(True)
53+
if filename is not None:
54+
plt.savefig(filename)
55+
if show:
56+
plt.show()

src/spotPython/plot/validation.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,14 +239,20 @@ def plot_confusion_matrix(
239239
plt.show()
240240

241241

242-
def plot_actual_vs_predicted(y_test, y_pred, title=None, show=True) -> None:
242+
def plot_actual_vs_predicted(y_test, y_pred, title=None, show=True, filename=None) -> None:
243243
"""Plot actual vs. predicted values.
244244
245245
Args:
246-
y_test (np.ndarray): True values.
247-
y_pred (np.ndarray): Predicted values.
248-
title (str, optional): Title of the plot. Defaults to None.
249-
show (bool, optional): If True, the plot is shown. Defaults to True.
246+
y_test (np.ndarray):
247+
True values.
248+
y_pred (np.ndarray):
249+
Predicted values.
250+
title (str, optional):
251+
Title of the plot. Defaults to None.
252+
show (bool, optional):
253+
If True, the plot is shown. Defaults to True.
254+
filename (str, optional):
255+
Name of the file to save the plot. Defaults to None.
250256
251257
Returns:
252258
(NoneType): None
@@ -284,5 +290,7 @@ def plot_actual_vs_predicted(y_test, y_pred, title=None, show=True) -> None:
284290
if title is not None:
285291
fig.suptitle(title)
286292
plt.tight_layout()
293+
if filename is not None:
294+
plt.savefig(filename)
287295
if show:
288296
plt.show()

src/spotPython/plot/xy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def plot_y_vs_X(X, y, nrows=5, ncols=2, figsize=(30, 20), ylabel="y", feature_na
2222
feature_names (list of str, optional):
2323
List of feature names. Defaults to None. If None, generates feature names as x0, x1, etc.
2424
25-
Example:
25+
Examples:
2626
>>> from sklearn.datasets import load_diabetes
2727
>>> from spotPython.plot.xy import plot_y_vs_X
2828
>>> data = load_diabetes()

0 commit comments

Comments
 (0)