|
| 1 | +import math |
| 2 | +import random |
| 3 | + |
| 4 | + |
| 5 | +class FriedmanDriftDataset: |
| 6 | + """Friedman Drift Dataset.""" |
| 7 | + |
| 8 | + def __init__(self, n_samples=100, change_point1=50, change_point2=75, seed=None, constant=False): |
| 9 | + """Constructor for the Friedman Drift Dataset. |
| 10 | +
|
| 11 | + Args: |
| 12 | + n_samples (int): The number of samples to generate. |
| 13 | + change_point1 (int): The index of the first change point. |
| 14 | + change_point2 (int): The index of the second change point. |
| 15 | + seed (int): The seed for the random number generator. |
| 16 | + constant (bool): If True, only the first feature is set to 1 and all others are set to 0. |
| 17 | +
|
| 18 | + Returns: |
| 19 | + None |
| 20 | +
|
| 21 | + Examples: |
| 22 | + >>> from spotPython.data.friedman import FriedmanDriftDataset |
| 23 | + data_generator = FriedmanDriftDataset(n_samples=100, |
| 24 | + seed=42, change_point1=50, change_point2=75, constant=False) |
| 25 | + data = [data for data in data_generator] |
| 26 | + indices = [i for _, _, i in data] |
| 27 | + values = {f"x{i}": [] for i in range(5)} |
| 28 | + values["y"] = [] |
| 29 | + for x, y, _ in data: |
| 30 | + for i in range(5): |
| 31 | + values[f"x{i}"].append(x[i]) |
| 32 | + values["y"].append(y) |
| 33 | + plt.figure(figsize=(10, 6)) |
| 34 | + for label, series in values.items(): |
| 35 | + plt.plot(indices, series, label=label) |
| 36 | + plt.xlabel('Index') |
| 37 | + plt.ylabel('Value') |
| 38 | + plt.title('') |
| 39 | + plt.axvline(x=50, color='k', linestyle='--', label='Drift Point 1') |
| 40 | + plt.axvline(x=75, color='r', linestyle='--', label='Drift Point 2') |
| 41 | + plt.legend() |
| 42 | + plt.grid(True) |
| 43 | + plt.show() |
| 44 | + """ |
| 45 | + self.n_samples = n_samples |
| 46 | + self._change_point1 = change_point1 |
| 47 | + self._change_point2 = change_point2 |
| 48 | + self.seed = seed |
| 49 | + self.index = 0 |
| 50 | + self.rng = random.Random(self.seed) |
| 51 | + self.constant = constant |
| 52 | + |
| 53 | + def __iter__(self): |
| 54 | + return self |
| 55 | + |
| 56 | + def __next__(self): |
| 57 | + if self.index >= self.n_samples: # Specifying end of generation |
| 58 | + raise StopIteration |
| 59 | + if self.constant: |
| 60 | + # x[0] is set to 1, all others to 0 |
| 61 | + x = {0: 1} |
| 62 | + x.update({i: 0 for i in range(1, 10)}) # All x[i] are 0 for i > 0 |
| 63 | + else: |
| 64 | + x = {i: self.rng.uniform(a=0, b=1) for i in range(10)} |
| 65 | + y = self._global_recurring_abrupt_gen(x, self.index) + self.rng.gauss(mu=0, sigma=1) |
| 66 | + result = (x, y, self.index) |
| 67 | + self.index += 1 |
| 68 | + return result |
| 69 | + |
| 70 | + def _global_recurring_abrupt_gen(self, x, index): |
| 71 | + if index < self._change_point1 or index >= self._change_point2: |
| 72 | + return 10 * math.sin(math.pi * x[0] * x[1]) + 20 * (x[2] - 0.5) ** 2 + 10 * x[3] + 5 * x[4] |
| 73 | + else: |
| 74 | + return 10 * math.sin(math.pi * x[3] * x[5]) + 20 * (x[1] - 0.5) ** 2 + 10 * x[0] + 5 * x[2] |
| 75 | + |
| 76 | + def __len__(self) -> int: |
| 77 | + """ |
| 78 | + Returns the length of the dataset. |
| 79 | +
|
| 80 | + Returns: |
| 81 | + int: The length of the dataset. |
| 82 | +
|
| 83 | +
|
| 84 | + """ |
| 85 | + return self.n_samples |
0 commit comments