|
1205 | 1205 | }, |
1206 | 1206 | { |
1207 | 1207 | "cell_type": "code", |
1208 | | - "execution_count": 8, |
| 1208 | + "execution_count": null, |
1209 | 1209 | "metadata": {}, |
1210 | | - "outputs": [ |
1211 | | - { |
1212 | | - "name": "stderr", |
1213 | | - "output_type": "stream", |
1214 | | - "text": [ |
1215 | | - "Seed set to 42\n" |
1216 | | - ] |
1217 | | - }, |
1218 | | - { |
1219 | | - "name": "stdout", |
1220 | | - "output_type": "stream", |
1221 | | - "text": [ |
1222 | | - "S.X: [[ 0. 1. ]\n", |
1223 | | - " [ 1. 0. ]\n", |
1224 | | - " [ 1. 1. ]\n", |
1225 | | - " [ 1. 1. ]\n", |
1226 | | - " [ 0.54509876 -0.36921401]\n", |
1227 | | - " [ 0.54509876 -0.36921401]\n", |
1228 | | - " [ 0.18642675 0.87708546]\n", |
1229 | | - " [ 0.18642675 0.87708546]\n", |
1230 | | - " [-0.45060393 -0.208063 ]\n", |
1231 | | - " [-0.45060393 -0.208063 ]]\n", |
1232 | | - "S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n", |
1233 | | - " 0.79130456 0.81487288 0.24000221 0.23988634]\n", |
1234 | | - "X_shape_before: (10, 2)\n", |
1235 | | - "y_size_before: 10\n", |
1236 | | - "S.X: [[ 0. 1. ]\n", |
1237 | | - " [ 1. 0. ]\n", |
1238 | | - " [ 1. 1. ]\n", |
1239 | | - " [ 1. 1. ]\n", |
1240 | | - " [ 0.54509876 -0.36921401]\n", |
1241 | | - " [ 0.54509876 -0.36921401]\n", |
1242 | | - " [ 0.18642675 0.87708546]\n", |
1243 | | - " [ 0.18642675 0.87708546]\n", |
1244 | | - " [-0.45060393 -0.208063 ]\n", |
1245 | | - " [-0.45060393 -0.208063 ]\n", |
1246 | | - " [-0.45060393 -0.208063 ]\n", |
1247 | | - " [-0.39841465 -0.21105872]\n", |
1248 | | - " [-0.39841465 -0.21105872]]\n", |
1249 | | - "S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n", |
1250 | | - " 0.79130456 0.81487288 0.24000221 0.23988634 0.22655169 0.19592429\n", |
1251 | | - " 0.22903853]\n", |
1252 | | - "S.n_points: 1\n", |
1253 | | - "S.ocba_delta: 1\n", |
1254 | | - "X_shape_after: (13, 2)\n", |
1255 | | - "y_size_after: 13\n" |
1256 | | - ] |
1257 | | - }, |
1258 | | - { |
1259 | | - "name": "stderr", |
1260 | | - "output_type": "stream", |
1261 | | - "text": [ |
1262 | | - "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/budget/ocba.py:65: RuntimeWarning: invalid value encountered in cast\n", |
1263 | | - " add_budget = around(add_budget).astype(int)\n" |
1264 | | - ] |
1265 | | - } |
1266 | | - ], |
| 1210 | + "outputs": [], |
1267 | 1211 | "source": [ |
1268 | 1212 | "import numpy as np\n", |
1269 | 1213 | "from spotPython.fun.objectivefunctions import analytical\n", |
|
1315 | 1259 | "assert y_size_before + S.n_points * S.fun_repeats + S.ocba_delta == S.y.size" |
1316 | 1260 | ] |
1317 | 1261 | }, |
| 1262 | + { |
| 1263 | + "cell_type": "markdown", |
| 1264 | + "metadata": {}, |
| 1265 | + "source": [ |
| 1266 | + "## test get_new_X0()" |
| 1267 | + ] |
| 1268 | + }, |
| 1269 | + { |
| 1270 | + "cell_type": "code", |
| 1271 | + "execution_count": 5, |
| 1272 | + "metadata": {}, |
| 1273 | + "outputs": [ |
| 1274 | + { |
| 1275 | + "name": "stderr", |
| 1276 | + "output_type": "stream", |
| 1277 | + "text": [ |
| 1278 | + "Seed set to 123\n" |
| 1279 | + ] |
| 1280 | + }, |
| 1281 | + { |
| 1282 | + "name": "stdout", |
| 1283 | + "output_type": "stream", |
| 1284 | + "text": [ |
| 1285 | + "X0: [[-0.44554771693476863 -0.20875926895365957]\n", |
| 1286 | + " [-0.4455476317976184 -0.20875838912156508]\n", |
| 1287 | + " [-0.4455472980490104 -0.20875947484788715]\n", |
| 1288 | + " [-0.4455471641839423 -0.2087604928416959 ]\n", |
| 1289 | + " [-0.44554694753382335 -0.20875893366783724]\n", |
| 1290 | + " [-0.44554687059608705 -0.2087595746675645 ]\n", |
| 1291 | + " [-0.44554687053984077 -0.20875957380178856]\n", |
| 1292 | + " [-0.44554635573095563 -0.2087586627530108 ]\n", |
| 1293 | + " [-0.44554621783755394 -0.2087589280357544 ]\n", |
| 1294 | + " [-0.4455458637237855 -0.20875981800729412]]\n" |
| 1295 | + ] |
| 1296 | + } |
| 1297 | + ], |
| 1298 | + "source": [ |
| 1299 | + "import numpy as np\n", |
| 1300 | + "from spotPython.fun.objectivefunctions import analytical\n", |
| 1301 | + "from spotPython.spot import spot\n", |
| 1302 | + "from spotPython.utils.init import fun_control_init\n", |
| 1303 | + "# number of initial points:\n", |
| 1304 | + "ni = 3\n", |
| 1305 | + "X_start = np.array([[0, 1], [1, 0], [1, 1], [1, 1]])\n", |
| 1306 | + "\n", |
| 1307 | + "fun = analytical().fun_sphere\n", |
| 1308 | + "fun_control = fun_control_init(\n", |
| 1309 | + " sigma=0.0,\n", |
| 1310 | + " seed=123,)\n", |
| 1311 | + "lower = np.array([-1, -1])\n", |
| 1312 | + "upper = np.array([1, 1])\n", |
| 1313 | + "design_control={\"init_size\": ni,\n", |
| 1314 | + " \"repeats\": 1}\n", |
| 1315 | + "\n", |
| 1316 | + "S = spot.Spot(fun=fun,\n", |
| 1317 | + " noise=False,\n", |
| 1318 | + " fun_repeats=1,\n", |
| 1319 | + " n_points=10,\n", |
| 1320 | + " ocba_delta=0,\n", |
| 1321 | + " lower = lower,\n", |
| 1322 | + " upper= upper,\n", |
| 1323 | + " show_progress=True,\n", |
| 1324 | + " design_control=design_control,\n", |
| 1325 | + " fun_control=fun_control\n", |
| 1326 | + ")\n", |
| 1327 | + "S.initialize_design(X_start=X_start)\n", |
| 1328 | + "S.update_stats()\n", |
| 1329 | + "S.fit_surrogate()\n", |
| 1330 | + "X_ocba = None\n", |
| 1331 | + "X0 = S.get_new_X0()\n", |
| 1332 | + "assert X0.shape[0] == S.n_points\n", |
| 1333 | + "assert X0.shape[1] == S.lower.size\n", |
| 1334 | + "# assert new points are in the interval [lower, upper]\n", |
| 1335 | + "assert np.all(X0 >= S.lower)\n", |
| 1336 | + "assert np.all(X0 <= S.upper)\n", |
| 1337 | + "# print using 20 digits precision\n", |
| 1338 | + "np.set_printoptions(precision=20)\n", |
| 1339 | + "print(f\"X0: {X0}\")\n", |
| 1340 | + "\n" |
| 1341 | + ] |
| 1342 | + }, |
| 1343 | + { |
| 1344 | + "cell_type": "markdown", |
| 1345 | + "metadata": {}, |
| 1346 | + "source": [ |
| 1347 | + "## test selectNew()" |
| 1348 | + ] |
| 1349 | + }, |
| 1350 | + { |
| 1351 | + "cell_type": "code", |
| 1352 | + "execution_count": 19, |
| 1353 | + "metadata": {}, |
| 1354 | + "outputs": [], |
| 1355 | + "source": [ |
| 1356 | + "from spotPython.utils.compare import selectNew\n", |
| 1357 | + "import numpy as np\n", |
| 1358 | + "A = np.array([[1,2,3],[4,5,6]])\n", |
| 1359 | + "X = np.array([[1,2,3],[4,5,6]])\n", |
| 1360 | + "B, ind = selectNew(A, X)\n", |
| 1361 | + "assert B.shape[0] == 0\n", |
| 1362 | + "assert np.equal(ind, np.array([False, False])).all()" |
| 1363 | + ] |
| 1364 | + }, |
| 1365 | + { |
| 1366 | + "cell_type": "code", |
| 1367 | + "execution_count": 21, |
| 1368 | + "metadata": {}, |
| 1369 | + "outputs": [], |
| 1370 | + "source": [ |
| 1371 | + "A = np.array([[1,2,3],[4,5,7]])\n", |
| 1372 | + "X = np.array([[1,2,3],[4,5,6]])\n", |
| 1373 | + "B, ind = selectNew(A, X)\n", |
| 1374 | + "assert B.shape[0] == 1\n", |
| 1375 | + "assert np.equal(ind, np.array([False, True])).all()" |
| 1376 | + ] |
| 1377 | + }, |
1318 | 1378 | { |
1319 | 1379 | "cell_type": "code", |
1320 | 1380 | "execution_count": null, |
|
0 commit comments