Skip to content

Commit 71cbda3

Browse files
v0.8.14
1 parent b727a09 commit 71cbda3

7 files changed

Lines changed: 248 additions & 68 deletions

File tree

notebooks/testKriging.ipynb

Lines changed: 118 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,65 +1205,9 @@
12051205
},
12061206
{
12071207
"cell_type": "code",
1208-
"execution_count": 8,
1208+
"execution_count": null,
12091209
"metadata": {},
1210-
"outputs": [
1211-
{
1212-
"name": "stderr",
1213-
"output_type": "stream",
1214-
"text": [
1215-
"Seed set to 42\n"
1216-
]
1217-
},
1218-
{
1219-
"name": "stdout",
1220-
"output_type": "stream",
1221-
"text": [
1222-
"S.X: [[ 0. 1. ]\n",
1223-
" [ 1. 0. ]\n",
1224-
" [ 1. 1. ]\n",
1225-
" [ 1. 1. ]\n",
1226-
" [ 0.54509876 -0.36921401]\n",
1227-
" [ 0.54509876 -0.36921401]\n",
1228-
" [ 0.18642675 0.87708546]\n",
1229-
" [ 0.18642675 0.87708546]\n",
1230-
" [-0.45060393 -0.208063 ]\n",
1231-
" [-0.45060393 -0.208063 ]]\n",
1232-
"S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n",
1233-
" 0.79130456 0.81487288 0.24000221 0.23988634]\n",
1234-
"X_shape_before: (10, 2)\n",
1235-
"y_size_before: 10\n",
1236-
"S.X: [[ 0. 1. ]\n",
1237-
" [ 1. 0. ]\n",
1238-
" [ 1. 1. ]\n",
1239-
" [ 1. 1. ]\n",
1240-
" [ 0.54509876 -0.36921401]\n",
1241-
" [ 0.54509876 -0.36921401]\n",
1242-
" [ 0.18642675 0.87708546]\n",
1243-
" [ 0.18642675 0.87708546]\n",
1244-
" [-0.45060393 -0.208063 ]\n",
1245-
" [-0.45060393 -0.208063 ]\n",
1246-
" [-0.45060393 -0.208063 ]\n",
1247-
" [-0.39841465 -0.21105872]\n",
1248-
" [-0.39841465 -0.21105872]]\n",
1249-
"S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n",
1250-
" 0.79130456 0.81487288 0.24000221 0.23988634 0.22655169 0.19592429\n",
1251-
" 0.22903853]\n",
1252-
"S.n_points: 1\n",
1253-
"S.ocba_delta: 1\n",
1254-
"X_shape_after: (13, 2)\n",
1255-
"y_size_after: 13\n"
1256-
]
1257-
},
1258-
{
1259-
"name": "stderr",
1260-
"output_type": "stream",
1261-
"text": [
1262-
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/budget/ocba.py:65: RuntimeWarning: invalid value encountered in cast\n",
1263-
" add_budget = around(add_budget).astype(int)\n"
1264-
]
1265-
}
1266-
],
1210+
"outputs": [],
12671211
"source": [
12681212
"import numpy as np\n",
12691213
"from spotPython.fun.objectivefunctions import analytical\n",
@@ -1315,6 +1259,122 @@
13151259
"assert y_size_before + S.n_points * S.fun_repeats + S.ocba_delta == S.y.size"
13161260
]
13171261
},
1262+
{
1263+
"cell_type": "markdown",
1264+
"metadata": {},
1265+
"source": [
1266+
"## test get_new_X0()"
1267+
]
1268+
},
1269+
{
1270+
"cell_type": "code",
1271+
"execution_count": 5,
1272+
"metadata": {},
1273+
"outputs": [
1274+
{
1275+
"name": "stderr",
1276+
"output_type": "stream",
1277+
"text": [
1278+
"Seed set to 123\n"
1279+
]
1280+
},
1281+
{
1282+
"name": "stdout",
1283+
"output_type": "stream",
1284+
"text": [
1285+
"X0: [[-0.44554771693476863 -0.20875926895365957]\n",
1286+
" [-0.4455476317976184 -0.20875838912156508]\n",
1287+
" [-0.4455472980490104 -0.20875947484788715]\n",
1288+
" [-0.4455471641839423 -0.2087604928416959 ]\n",
1289+
" [-0.44554694753382335 -0.20875893366783724]\n",
1290+
" [-0.44554687059608705 -0.2087595746675645 ]\n",
1291+
" [-0.44554687053984077 -0.20875957380178856]\n",
1292+
" [-0.44554635573095563 -0.2087586627530108 ]\n",
1293+
" [-0.44554621783755394 -0.2087589280357544 ]\n",
1294+
" [-0.4455458637237855 -0.20875981800729412]]\n"
1295+
]
1296+
}
1297+
],
1298+
"source": [
1299+
"import numpy as np\n",
1300+
"from spotPython.fun.objectivefunctions import analytical\n",
1301+
"from spotPython.spot import spot\n",
1302+
"from spotPython.utils.init import fun_control_init\n",
1303+
"# number of initial points:\n",
1304+
"ni = 3\n",
1305+
"X_start = np.array([[0, 1], [1, 0], [1, 1], [1, 1]])\n",
1306+
"\n",
1307+
"fun = analytical().fun_sphere\n",
1308+
"fun_control = fun_control_init(\n",
1309+
" sigma=0.0,\n",
1310+
" seed=123,)\n",
1311+
"lower = np.array([-1, -1])\n",
1312+
"upper = np.array([1, 1])\n",
1313+
"design_control={\"init_size\": ni,\n",
1314+
" \"repeats\": 1}\n",
1315+
"\n",
1316+
"S = spot.Spot(fun=fun,\n",
1317+
" noise=False,\n",
1318+
" fun_repeats=1,\n",
1319+
" n_points=10,\n",
1320+
" ocba_delta=0,\n",
1321+
" lower = lower,\n",
1322+
" upper= upper,\n",
1323+
" show_progress=True,\n",
1324+
" design_control=design_control,\n",
1325+
" fun_control=fun_control\n",
1326+
")\n",
1327+
"S.initialize_design(X_start=X_start)\n",
1328+
"S.update_stats()\n",
1329+
"S.fit_surrogate()\n",
1330+
"X_ocba = None\n",
1331+
"X0 = S.get_new_X0()\n",
1332+
"assert X0.shape[0] == S.n_points\n",
1333+
"assert X0.shape[1] == S.lower.size\n",
1334+
"# assert new points are in the interval [lower, upper]\n",
1335+
"assert np.all(X0 >= S.lower)\n",
1336+
"assert np.all(X0 <= S.upper)\n",
1337+
"# print using 20 digits precision\n",
1338+
"np.set_printoptions(precision=20)\n",
1339+
"print(f\"X0: {X0}\")\n",
1340+
"\n"
1341+
]
1342+
},
1343+
{
1344+
"cell_type": "markdown",
1345+
"metadata": {},
1346+
"source": [
1347+
"## test selectNew()"
1348+
]
1349+
},
1350+
{
1351+
"cell_type": "code",
1352+
"execution_count": 19,
1353+
"metadata": {},
1354+
"outputs": [],
1355+
"source": [
1356+
"from spotPython.utils.compare import selectNew\n",
1357+
"import numpy as np\n",
1358+
"A = np.array([[1,2,3],[4,5,6]])\n",
1359+
"X = np.array([[1,2,3],[4,5,6]])\n",
1360+
"B, ind = selectNew(A, X)\n",
1361+
"assert B.shape[0] == 0\n",
1362+
"assert np.equal(ind, np.array([False, False])).all()"
1363+
]
1364+
},
1365+
{
1366+
"cell_type": "code",
1367+
"execution_count": 21,
1368+
"metadata": {},
1369+
"outputs": [],
1370+
"source": [
1371+
"A = np.array([[1,2,3],[4,5,7]])\n",
1372+
"X = np.array([[1,2,3],[4,5,6]])\n",
1373+
"B, ind = selectNew(A, X)\n",
1374+
"assert B.shape[0] == 1\n",
1375+
"assert np.equal(ind, np.array([False, True])).all()"
1376+
]
1377+
},
13181378
{
13191379
"cell_type": "code",
13201380
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.8.12"
10+
version = "0.8.14"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,63 @@ def to_all_dim_if_needed(self, X) -> np.array:
443443
return X
444444

445445
def get_new_X0(self) -> np.array:
446+
"""
447+
Get new design points.
448+
Calls `suggest_new_X()` and repairs the new design points, e.g.,
449+
by `repair_non_numeric()` and `selectNew()`.
450+
451+
Args:
452+
self (object): Spot object
453+
454+
Returns:
455+
(numpy.ndarray): new design points
456+
457+
Notes:
458+
* self.design (object): an experimental design is used to generate new design points
459+
if no new design points are found, a new experimental design is generated.
460+
461+
Examples:
462+
>>> import numpy as np
463+
from spotPython.fun.objectivefunctions import analytical
464+
from spotPython.spot import spot
465+
from spotPython.utils.init import fun_control_init
466+
# number of initial points:
467+
ni = 3
468+
X_start = np.array([[0, 1], [1, 0], [1, 1], [1, 1]])
469+
fun = analytical().fun_sphere
470+
fun_control = fun_control_init(
471+
sigma=0.0,
472+
seed=123,)
473+
lower = np.array([-1, -1])
474+
upper = np.array([1, 1])
475+
design_control={"init_size": ni,
476+
"repeats": 1}
477+
S = spot.Spot(fun=fun,
478+
noise=False,
479+
fun_repeats=1,
480+
n_points=10,
481+
ocba_delta=0,
482+
lower = lower,
483+
upper= upper,
484+
show_progress=True,
485+
design_control=design_control,
486+
fun_control=fun_control
487+
)
488+
S.initialize_design(X_start=X_start)
489+
S.update_stats()
490+
S.fit_surrogate()
491+
X_ocba = None
492+
X0 = S.get_new_X0()
493+
assert X0.shape[0] == S.n_points
494+
assert X0.shape[1] == S.lower.size
495+
# assert new points are in the interval [lower, upper]
496+
assert np.all(X0 >= S.lower)
497+
assert np.all(X0 <= S.upper)
498+
# print using 20 digits precision
499+
np.set_printoptions(precision=20)
500+
print(f"X0: {X0}")
501+
502+
"""
446503
X0 = self.suggest_new_X()
447504
X0 = repair_non_numeric(X0, self.var_type)
448505
# (S-16) Duplicate Handling:
@@ -851,6 +908,7 @@ def update_writer(self) -> None:
851908
def suggest_new_X(self) -> np.array:
852909
"""
853910
Compute `n_points` new infill points in natural units.
911+
These diffrent points are computed by the optimizer using increasing seed.
854912
The optimizer searches in the ranges from `lower_j` to `upper_j`.
855913
The method `infill()` is used as the objective function.
856914
@@ -876,11 +934,11 @@ def suggest_new_X(self) -> np.array:
876934
"basinhopping": lambda: self.optimizer(func=self.infill, x0=self.min_X),
877935
"default": lambda: self.optimizer(func=self.infill, bounds=self.de_bounds),
878936
}
879-
880937
for i in range(self.n_points):
938+
self.optimizer_control["seed"] = self.optimizer_control["seed"] + i
881939
result = optimizers.get(optimizer_name, optimizers["default"])()
882940
new_X[i][:] = result.x
883-
return new_X
941+
return np.unique(new_X, axis=0)
884942

885943
def infill(self, x) -> float:
886944
"""

src/spotPython/utils/compare.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,19 @@ def selectNew(A: np.ndarray, X: np.ndarray, tolerance: float = 0) -> Tuple[np.nd
1818
(numpy.ndarray): array with `True` if value is new, otherwise `False`.
1919
2020
Examples:
21-
>>> from spotPython.utils.compare import selectNew
22-
A = np.array([[1,2,3],[4,5,6]])
23-
X = np.array([[1,2,3],[4,5,6]])
24-
selectNew(A, X)
25-
(array([], shape=(0, 3), dtype=int64), array([], dtype=bool))
26-
21+
>>> from spotPython.utils.compare import selectNew
22+
import numpy as np
23+
A = np.array([[1,2,3],[4,5,6]])
24+
X = np.array([[1,2,3],[4,5,6]])
25+
B, ind = selectNew(A, X)
26+
assert B.shape[0] == 0
27+
assert np.equal(ind, np.array([False, False])).all()
28+
>>> from spotPython.utils.compare import selectNew
29+
A = np.array([[1,2,3],[4,5,7]])
30+
X = np.array([[1,2,3],[4,5,6]])
31+
B, ind = selectNew(A, X)
32+
assert B.shape[0] == 1
33+
assert np.equal(ind, np.array([False, True])).all()
2734
"""
2835
B = np.abs(A[:, None] - X)
2936
ind = np.any(np.all(B <= tolerance, axis=2), axis=1)

src/spotPython/utils/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def fun_control_init(
8282
'writer': None}
8383
"""
8484
# Setting the seed
85-
L.seed_everything(42)
85+
L.seed_everything(seed)
8686

8787
# Path to the folder where the pretrained models are saved
8888
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "runs/saved_models/")

test/test_get_new_X0.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import numpy as np
2+
from spotPython.fun.objectivefunctions import analytical
3+
from spotPython.spot import spot
4+
from spotPython.utils.init import fun_control_init
5+
6+
def test_get_new_X0():
7+
# number of initial points:
8+
ni = 3
9+
X_start = np.array([[0, 1], [1, 0], [1, 1], [1, 1]])
10+
11+
fun = analytical().fun_sphere
12+
fun_control = fun_control_init(
13+
sigma=0.0,
14+
seed=123,)
15+
lower = np.array([-1, -1])
16+
upper = np.array([1, 1])
17+
design_control={"init_size": ni,
18+
"repeats": 1}
19+
20+
S = spot.Spot(fun=fun,
21+
noise=False,
22+
fun_repeats=1,
23+
n_points=10,
24+
ocba_delta=0,
25+
lower = lower,
26+
upper= upper,
27+
show_progress=True,
28+
design_control=design_control,
29+
fun_control=fun_control
30+
)
31+
S.initialize_design(X_start=X_start)
32+
S.update_stats()
33+
S.fit_surrogate()
34+
X0 = S.get_new_X0()
35+
assert X0.shape[0] == S.n_points
36+
assert X0.shape[1] == S.lower.size
37+
# assert new points are in the interval [lower, upper]
38+
assert np.all(X0 >= S.lower)
39+
assert np.all(X0 <= S.upper)

test/test_selectNew().py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from spotPython.utils.compare import selectNew
2+
import numpy as np
3+
4+
def test_selectNew_All_Equal():
5+
A = np.array([[1,2,3],[4,5,6]])
6+
X = np.array([[1,2,3],[4,5,6]])
7+
B, ind = selectNew(A, X)
8+
assert B.shape[0] == 0
9+
assert np.equal(ind, np.array([False, False])).all()
10+
11+
def test_selectNew_One_Not_Equal():
12+
A = np.array([[1,2,3],[4,5,6]])
13+
X = np.array([[1,2,3],[4,5,7]])
14+
B, ind = selectNew(A, X)
15+
assert B.shape[0] == 1
16+
assert np.equal(ind, np.array([False, True])).all()

0 commit comments

Comments
 (0)