Skip to content

Commit 43218bf

Browse files
0.12.9
max_surrogate_points
1 parent cdffd63 commit 43218bf

5 files changed

Lines changed: 126 additions & 4 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3558,6 +3558,74 @@
35583558
"outputs": [],
35593559
"source": []
35603560
},
3561+
{
3562+
"cell_type": "markdown",
3563+
"metadata": {},
3564+
"source": [
3565+
"# Subset Select"
3566+
]
3567+
},
3568+
{
3569+
"cell_type": "code",
3570+
"execution_count": 3,
3571+
"metadata": {},
3572+
"outputs": [],
3573+
"source": [
3574+
"import numpy as np\n",
3575+
"from sklearn.cluster import KMeans\n",
3576+
"\n",
3577+
"def select_distant_points(X, y, k):\n",
3578+
" \"\"\"\n",
3579+
" Selects k points that are distant from each other using a clustering approach.\n",
3580+
" \n",
3581+
" :param X: np.array of shape (n, k), with n points in k-dimensional space.\n",
3582+
" :param y: np.array of length n, with values corresponding to each point in X.\n",
3583+
" :param k: The number of distant points to select.\n",
3584+
" :return: Selected k points from X and their corresponding y values.\n",
3585+
" \"\"\"\n",
3586+
" # Perform k-means clustering to find k clusters\n",
3587+
" kmeans = KMeans(n_clusters=k, random_state=0, n_init=\"auto\").fit(X)\n",
3588+
" \n",
3589+
" # Find the closest point in X to each cluster center\n",
3590+
" selected_points = np.array([X[np.argmin(np.linalg.norm(X - center, axis=1))] for center in kmeans.cluster_centers_])\n",
3591+
" \n",
3592+
" # Find indices of the selected points in the original X array\n",
3593+
" indices = np.array([np.where(np.all(X==point, axis=1))[0][0] for point in selected_points])\n",
3594+
" \n",
3595+
" # Select the corresponding y values\n",
3596+
" selected_y = y[indices]\n",
3597+
" \n",
3598+
" return selected_points, selected_y\n"
3599+
]
3600+
},
3601+
{
3602+
"cell_type": "code",
3603+
"execution_count": 4,
3604+
"metadata": {},
3605+
"outputs": [
3606+
{
3607+
"name": "stdout",
3608+
"output_type": "stream",
3609+
"text": [
3610+
"Selected Points: [[0.77482755 0.11776665]\n",
3611+
" [0.1600672 0.5466571 ]\n",
3612+
" [0.87752562 0.66913902]\n",
3613+
" [0.37216814 0.33013892]\n",
3614+
" [0.37977024 0.83643457]]\n",
3615+
"Corresponding y values: [0.79945132 0.63677214 0.17382713 0.97910053 0.26962361]\n"
3616+
]
3617+
}
3618+
],
3619+
"source": [
3620+
"X = np.random.rand(100, 2) # Generate some random points\n",
3621+
"y = np.random.rand(100) # Random corresponding y values\n",
3622+
"k = 5\n",
3623+
"\n",
3624+
"selected_points, selected_y = select_distant_points(X, y, k)\n",
3625+
"print(\"Selected Points:\", selected_points)\n",
3626+
"print(\"Corresponding y values:\", selected_y)"
3627+
]
3628+
},
35613629
{
35623630
"cell_type": "code",
35633631
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.12.8"
10+
version = "0.12.9"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from numpy import min, max
2323
from spotPython.utils.init import fun_control_init, optimizer_control_init, surrogate_control_init, design_control_init
2424
from spotPython.utils.compare import selectNew
25-
from spotPython.utils.aggregate import aggregate_mean_var
25+
from spotPython.utils.aggregate import aggregate_mean_var, select_distant_points
2626
from spotPython.utils.repair import remove_nan
2727
from spotPython.budget.ocba import get_ocba_X
2828
import logging
@@ -247,6 +247,7 @@ def __init__(
247247
self.show_progress = self.fun_control["show_progress"]
248248
self.infill_criterion = self.fun_control["infill_criterion"]
249249
self.n_points = self.fun_control["n_points"]
250+
self.max_surrogate_points = self.fun_control["max_surrogate_points"]
250251

251252
# if the key "spot_writer" is not in the dictionary fun_control,
252253
# set self.spot_writer to None else to the value of the key "spot_writer"
@@ -912,8 +913,15 @@ def fit_surrogate(self) -> None:
912913
logger.debug("In fit_surrogate(): self.y: %s", self.y)
913914
logger.debug("In fit_surrogate(): self.X.shape: %s", self.X.shape)
914915
logger.debug("In fit_surrogate(): self.y.shape: %s", self.y.shape)
915-
if self.X.shape[0] == self.y.shape[0]:
916-
self.surrogate.fit(self.X, self.y)
916+
X_points = self.X.shape[0]
917+
y_points = self.y.shape[0]
918+
if X_points == y_points:
919+
if X_points > self.max_surrogate_points:
920+
X_S, y_S = select_distant_points(X=self.X, y=self.y, k=self.max_surrogate_points)
921+
else:
922+
X_S = self.X
923+
y_S = self.y
924+
self.surrogate.fit(X_S, y_S)
917925
else:
918926
logger.warning("X and y have different sizes. Surrogate not fitted.")
919927
if self.show_models:

src/spotPython/utils/aggregate.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pandas as pd
22
import numpy as np
3+
from sklearn.cluster import KMeans
34

45

56
def aggregate_mean_var(X, y, sort=False) -> (np.ndarray, np.ndarray, np.ndarray):
@@ -72,3 +73,42 @@ def get_ranks(x):
7273
ranks = np.empty_like(ts)
7374
ranks[ts] = np.arange(len(x))
7475
return ranks
76+
77+
78+
def select_distant_points(X, y, k):
79+
"""
80+
Selects k points that are distant from each other using a clustering approach.
81+
82+
Args:
83+
X (numpy.ndarray): X array, shape `(n, k)`.
84+
y (numpy.ndarray): values, shape `(n,)`.
85+
k (int): number of points to select.
86+
87+
Returns:
88+
(numpy.ndarray):
89+
selected `X` values, shape `(k, k)`.
90+
(numpy.ndarray):
91+
selected `y` values, shape `(k,)`.
92+
93+
Examples:
94+
>>> from spotPython.utils.aggregate import select_distant_points
95+
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
96+
y = np.array([1, 2, 3, 4, 5])
97+
selected_points, selected_y = select_distant_points(X, y, 3)
98+
print(selected_points)
99+
[[1 2]
100+
[7 8]
101+
[9 10]]
102+
print(selected_y)
103+
[1 4 5]
104+
105+
"""
106+
# Perform k-means clustering to find k clusters
107+
kmeans = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(X)
108+
# Find the closest point in X to each cluster center
109+
selected_points = np.array([X[np.argmin(np.linalg.norm(X - center, axis=1))] for center in kmeans.cluster_centers_])
110+
# Find indices of the selected points in the original X array
111+
indices = np.array([np.where(np.all(X == point, axis=1))[0][0] for point in selected_points])
112+
# Select the corresponding y values
113+
selected_y = y[indices]
114+
return selected_points, selected_y

src/spotPython/utils/init.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import datetime
77
from dateutil.tz import tzlocal
88
from torch.utils.tensorboard import SummaryWriter
9+
from math import inf
910

1011

1112
def fun_control_init(
@@ -36,6 +37,7 @@ def fun_control_init(
3637
log_level=50,
3738
lower=None,
3839
max_time=1,
40+
max_surrogate_points=inf,
3941
metric_sklearn=None,
4042
noise=False,
4143
n_points=1,
@@ -132,6 +134,8 @@ def fun_control_init(
132134
lower bound
133135
max_time (int):
134136
The maximum time in minutes.
137+
max_surrogate_points (int):
138+
The maximum number of points in the surrogate model. Default is inf.
135139
metric_sklearn (object):
136140
The metric object from the scikit-learn library. Default is None.
137141
noise (bool):
@@ -234,6 +238,7 @@ def fun_control_init(
234238
'k_folds': None,
235239
'loss_function': None,
236240
'lower': None,
241+
'max_surrogate_points': 100,
237242
'metric_river': None,
238243
'metric_sklearn': None,
239244
'metric_torch': None,
@@ -333,6 +338,7 @@ def fun_control_init(
333338
"loss_function": None,
334339
"lower": lower,
335340
"max_time": max_time,
341+
"max_surrogate_points": max_surrogate_points,
336342
"metric_river": None,
337343
"metric_sklearn": metric_sklearn,
338344
"metric_torch": None,

0 commit comments

Comments
 (0)