|
1392 | 1392 | "S.min_y" |
1393 | 1393 | ] |
1394 | 1394 | }, |
| 1395 | + { |
| 1396 | + "cell_type": "markdown", |
| 1397 | + "metadata": {}, |
| 1398 | + "source": [ |
| 1399 | + "## test ocba" |
| 1400 | + ] |
| 1401 | + }, |
1395 | 1402 | { |
1396 | 1403 | "cell_type": "code", |
1397 | 1404 | "execution_count": null, |
1398 | 1405 | "metadata": {}, |
1399 | 1406 | "outputs": [], |
1400 | | - "source": [] |
| 1407 | + "source": [ |
| 1408 | + "\n", |
| 1409 | + "import copy\n", |
| 1410 | + "import numpy as np\n", |
| 1411 | + "from spotPython.fun.objectivefunctions import analytical\n", |
| 1412 | + "from spotPython.spot import spot\n", |
| 1413 | + "from spotPython.budget.ocba import get_ocba\n", |
| 1414 | + "\n", |
| 1415 | + "# Test based on the example from the book:\n", |
| 1416 | + "# Chun-Hung Chen and Loo Hay Lee:\n", |
| 1417 | + "# Stochastic Simulation Optimization: An Optimal Computer Budget Allocation,\n", |
| 1418 | + "# pp. 49 and pp. 215\n", |
| 1419 | + "# p. 49:\n", |
| 1420 | + "# mean_y = np.array([1,2,3,4,5])\n", |
| 1421 | + "# var_y = np.array([1,1,9,9,4])\n", |
| 1422 | + "# get_ocba(mean_y, var_y, 50)\n", |
| 1423 | + "# [11 9 19 9 2]\n", |
| 1424 | + "\n", |
| 1425 | + "fun = analytical().fun_linear\n", |
| 1426 | + "fun_control = {\"sigma\": 0.001,\n", |
| 1427 | + " \"seed\": 123}\n", |
| 1428 | + "spot_1_noisy = spot.Spot(fun=fun,\n", |
| 1429 | + " lower = np.array([-1]),\n", |
| 1430 | + " upper = np.array([1]),\n", |
| 1431 | + " fun_evals = 20,\n", |
| 1432 | + " fun_repeats = 2,\n", |
| 1433 | + " noise = True,\n", |
| 1434 | + " ocba_delta=1,\n", |
| 1435 | + " seed=123,\n", |
| 1436 | + " show_models=False,\n", |
| 1437 | + " fun_control = fun_control,\n", |
| 1438 | + " design_control={\"init_size\": 3,\n", |
| 1439 | + " \"repeats\": 2},\n", |
| 1440 | + " surrogate_control={\"noise\": True})\n", |
| 1441 | + "spot_1_noisy.run()\n", |
| 1442 | + "spot_2 = copy.deepcopy(spot_1_noisy)\n", |
| 1443 | + "spot_2.mean_y = np.array([1,2,3,4,5])\n", |
| 1444 | + "spot_2.var_y = np.array([1,1,9,9,4])\n", |
| 1445 | + "n = 50\n", |
| 1446 | + "o = get_ocba(spot_2.mean_y, spot_2.var_y, n)\n", |
| 1447 | + "assert sum(o) == 50\n", |
| 1448 | + "assert (o == np.array([[11, 9, 19, 9, 2]])).all()\n", |
| 1449 | + "o" |
| 1450 | + ] |
| 1451 | + }, |
| 1452 | + { |
| 1453 | + "cell_type": "code", |
| 1454 | + "execution_count": 1, |
| 1455 | + "metadata": {}, |
| 1456 | + "outputs": [ |
| 1457 | + { |
| 1458 | + "name": "stderr", |
| 1459 | + "output_type": "stream", |
| 1460 | + "text": [ |
| 1461 | + "Seed set to 123\n" |
| 1462 | + ] |
| 1463 | + }, |
| 1464 | + { |
| 1465 | + "name": "stdout", |
| 1466 | + "output_type": "stream", |
| 1467 | + "text": [ |
| 1468 | + "S.X: [[ 0. 1. ]\n", |
| 1469 | + " [ 1. 0. ]\n", |
| 1470 | + " [ 1. 1. ]\n", |
| 1471 | + " [ 1. 1. ]\n", |
| 1472 | + " [ 0.54509876 -0.36921401]\n", |
| 1473 | + " [ 0.54509876 -0.36921401]\n", |
| 1474 | + " [ 0.18642675 0.87708546]\n", |
| 1475 | + " [ 0.18642675 0.87708546]\n", |
| 1476 | + " [-0.45060393 -0.208063 ]\n", |
| 1477 | + " [-0.45060393 -0.208063 ]]\n", |
| 1478 | + "S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n", |
| 1479 | + " 0.79130456 0.81487288 0.24000221 0.23988634]\n", |
| 1480 | + "X_shape_before: (10, 2)\n", |
| 1481 | + "y_size_before: 10\n", |
| 1482 | + "S.X: [[ 0. 1. ]\n", |
| 1483 | + " [ 1. 0. ]\n", |
| 1484 | + " [ 1. 1. ]\n", |
| 1485 | + " [ 1. 1. ]\n", |
| 1486 | + " [ 0.54509876 -0.36921401]\n", |
| 1487 | + " [ 0.54509876 -0.36921401]\n", |
| 1488 | + " [ 0.18642675 0.87708546]\n", |
| 1489 | + " [ 0.18642675 0.87708546]\n", |
| 1490 | + " [-0.45060393 -0.208063 ]\n", |
| 1491 | + " [-0.45060393 -0.208063 ]\n", |
| 1492 | + " [-0.39841465 -0.21105872]\n", |
| 1493 | + " [-0.39841465 -0.21105872]]\n", |
| 1494 | + "S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n", |
| 1495 | + " 0.79130456 0.81487288 0.24000221 0.23988634 0.18349759 0.19592429]\n", |
| 1496 | + "S.n_points: 1\n", |
| 1497 | + "S.ocba_delta: 1\n", |
| 1498 | + "X_shape_after: (12, 2)\n", |
| 1499 | + "y_size_after: 12\n" |
| 1500 | + ] |
| 1501 | + } |
| 1502 | + ], |
| 1503 | + "source": [ |
| 1504 | + "import numpy as np\n", |
| 1505 | + "from spotPython.fun.objectivefunctions import analytical\n", |
| 1506 | + "from spotPython.spot import spot\n", |
| 1507 | + "from spotPython.utils.init import fun_control_init\n", |
| 1508 | + "# number of initial points:\n", |
| 1509 | + "ni = 3\n", |
| 1510 | + "X_start = np.array([[0, 1], [1, 0], [1, 1], [1, 1]])\n", |
| 1511 | + "\n", |
| 1512 | + "fun = analytical().fun_sphere\n", |
| 1513 | + "fun_control = fun_control_init(\n", |
| 1514 | + " sigma=0.02,\n", |
| 1515 | + " seed=123,)\n", |
| 1516 | + "lower = np.array([-1, -1])\n", |
| 1517 | + "upper = np.array([1, 1])\n", |
| 1518 | + "design_control={\"init_size\": ni,\n", |
| 1519 | + " \"repeats\": 2}\n", |
| 1520 | + "\n", |
| 1521 | + "S = spot.Spot(fun=fun,\n", |
| 1522 | + " noise=True,\n", |
| 1523 | + " fun_repeats=2,\n", |
| 1524 | + " n_points=1,\n", |
| 1525 | + " ocba_delta=1,\n", |
| 1526 | + " log_level = 10,\n", |
| 1527 | + " lower = lower,\n", |
| 1528 | + " upper= upper,\n", |
| 1529 | + " show_progress=False,\n", |
| 1530 | + " design_control=design_control,\n", |
| 1531 | + " fun_control=fun_control\n", |
| 1532 | + ")\n", |
| 1533 | + "S.initialize_design(X_start=X_start)\n", |
| 1534 | + "print(f\"S.X: {S.X}\")\n", |
| 1535 | + "print(f\"S.y: {S.y}\")\n", |
| 1536 | + "X_shape_before = S.X.shape\n", |
| 1537 | + "print(f\"X_shape_before: {X_shape_before}\")\n", |
| 1538 | + "print(f\"y_size_before: {S.y.size}\")\n", |
| 1539 | + "y_size_before = S.y.size\n", |
| 1540 | + "S.update_stats()\n", |
| 1541 | + "S.fit_surrogate()\n", |
| 1542 | + "S.update_design()\n", |
| 1543 | + "print(f\"S.X: {S.X}\")\n", |
| 1544 | + "print(f\"S.y: {S.y}\")\n", |
| 1545 | + "print(f\"S.n_points: {S.n_points}\")\n", |
| 1546 | + "print(f\"S.ocba_delta: {S.ocba_delta}\")\n", |
| 1547 | + "print(f\"X_shape_after: {S.X.shape}\")\n", |
| 1548 | + "print(f\"y_size_after: {S.y.size}\")\n" |
| 1549 | + ] |
1401 | 1550 | }, |
1402 | 1551 | { |
1403 | 1552 | "cell_type": "code", |
|
0 commit comments