Skip to content

Commit 4e2fa45

Browse files
v0.8.18
docs
1 parent 280fa7b commit 4e2fa45

7 files changed

Lines changed: 753 additions & 166 deletions

File tree

notebooks/testKriging.ipynb

Lines changed: 272 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,36 @@
6565
"assert np.allclose(mean_prediction[training_indices], y[training_indices], atol=1e-6)\n"
6666
]
6767
},
68+
{
69+
"cell_type": "markdown",
70+
"metadata": {},
71+
"source": [
72+
"## Kriging predict"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"metadata": {},
79+
"outputs": [],
80+
"source": [
81+
"from spotPython.build.kriging import Kriging\n",
82+
"import numpy as np\n",
83+
"import matplotlib.pyplot as plt\n",
84+
"from numpy import linspace, arange\n",
85+
"rng = np.random.RandomState(1)\n",
86+
"X = linspace(start=0, stop=10, num=1_0).reshape(-1, 1)\n",
87+
"y = np.squeeze(X * np.sin(X))\n",
88+
"training_indices = rng.choice(arange(y.size), size=6, replace=False)\n",
89+
"X_train, y_train = X[training_indices], y[training_indices]\n",
90+
"S = Kriging(name='kriging', seed=124)\n",
91+
"S.fit(X_train, y_train)\n",
92+
"mean_prediction, std_prediction, s_ei = S.predict(X, return_val=\"all\")\n",
93+
"print(f\"mean_prediction: {mean_prediction}\")\n",
94+
"print(f\"std_prediction: {std_prediction}\")\n",
95+
"print(f\"s_ei: {s_ei}\")"
96+
]
97+
},
6898
{
6999
"cell_type": "markdown",
70100
"metadata": {},
@@ -270,7 +300,9 @@
270300
"assert S.nat_X.all() == nat_X.all()\n",
271301
"assert S.nat_y.all() == nat_y.all()\n",
272302
"assert S.nat_X.shape == (2, 2)\n",
273-
"assert S.nat_y.shape == (2,)"
303+
"assert S.nat_y.shape == (2,)\n",
304+
"print(f\"S.nat_X: {S.nat_X}\")\n",
305+
"print(f\"S.nat_y: {S.nat_y}\")\n"
274306
]
275307
},
276308
{
@@ -407,6 +439,34 @@
407439
"assert S.Psi.shape == (n, n)\n"
408440
]
409441
},
442+
{
443+
"cell_type": "markdown",
444+
"metadata": {},
445+
"source": [
446+
"## test is_any()"
447+
]
448+
},
449+
{
450+
"cell_type": "code",
451+
"execution_count": 13,
452+
"metadata": {},
453+
"outputs": [],
454+
"source": [
455+
"from spotPython.build.kriging import Kriging\n",
456+
"from numpy import power\n",
457+
"import numpy as np\n",
458+
"nat_X = np.array([[0], [1]])\n",
459+
"nat_y = np.array([0, 1])\n",
460+
"n=1\n",
461+
"p=1\n",
462+
"S=Kriging(name='kriging', seed=124, n_theta=n, n_p=p, optim_p=True, noise=False)\n",
463+
"S.initialize_variables(nat_X, nat_y)\n",
464+
"S.set_variable_types()\n",
465+
"S.set_theta_values()\n",
466+
"assert np.equal(S.__is_any__(power(10.0, S.theta), 0), False)\n",
467+
"assert np.equal(S.__is_any__(S.theta, 0), True)"
468+
]
469+
},
410470
{
411471
"cell_type": "markdown",
412472
"metadata": {},
@@ -573,20 +633,17 @@
573633
"source": [
574634
"import numpy as np\n",
575635
"from spotPython.build.kriging import Kriging\n",
576-
"\n",
577636
"X_train = np.array([[1., 2.],\n",
578637
" [2., 4.],\n",
579638
" [3., 6.]])\n",
580639
"y_train = np.array([1., 2., 3.])\n",
581-
"\n",
582640
"S = Kriging(name='kriging',\n",
583641
" seed=123,\n",
584642
" log_level=50,\n",
585643
" n_theta=1,\n",
586644
" noise=False,\n",
587645
" cod_type=\"norm\")\n",
588646
"S.fit(X_train, y_train)\n",
589-
"\n",
590647
"# force theta to simple values:\n",
591648
"S.theta = np.array([0.0])\n",
592649
"nat_X = np.array([1., 0.])\n",
@@ -595,10 +652,9 @@
595652
"res = np.array([[np.exp(-4)],\n",
596653
" [np.exp(-17)],\n",
597654
" [np.exp(-40)]])\n",
598-
"\n",
599-
"# assert np.array_equal(S.psi, res)\n",
600-
"\n",
601-
"S.psi, res\n"
655+
"assert np.array_equal(S.psi, res)\n",
656+
"print(f\"S.psi:\\n {S.psi}\")\n",
657+
"print(f\"Control value res:\\n {res}\")"
602658
]
603659
},
604660
{
@@ -1370,7 +1426,136 @@
13701426
"# number of initial points:\n",
13711427
"ni = 7\n",
13721428
"# number of points\n",
1373-
"fun_evals = 100\n",
1429+
"fun_evals = 10\n",
1430+
"\n",
1431+
"fun = analytical().fun_sphere\n",
1432+
"lower = np.array([-1, -1])\n",
1433+
"upper = np.array([1, 1])\n",
1434+
"design_control={\"init_size\": ni}\n",
1435+
"surrogate_control={\"n_theta\": 3}\n",
1436+
"S = spot.Spot(fun=fun,\n",
1437+
" lower = lower,\n",
1438+
" upper= upper,\n",
1439+
" log_level = 50,\n",
1440+
" fun_evals = fun_evals,\n",
1441+
" tolerance_x = np.sqrt(np.spacing(1)),\n",
1442+
" show_progress=True,\n",
1443+
" design_control=design_control,\n",
1444+
" surrogate_control=surrogate_control,)\n",
1445+
"S.run()\n",
1446+
"S.plot_progress(log_y=True)\n"
1447+
]
1448+
},
1449+
{
1450+
"cell_type": "code",
1451+
"execution_count": null,
1452+
"metadata": {},
1453+
"outputs": [],
1454+
"source": [
1455+
"import numpy as np\n",
1456+
"from spotPython.fun.objectivefunctions import analytical\n",
1457+
"from spotPython.spot import spot\n",
1458+
"# number of initial points:\n",
1459+
"ni = 5\n",
1460+
"# number of points\n",
1461+
"fun_evals = 10\n",
1462+
"\n",
1463+
"fun = analytical().fun_sphere\n",
1464+
"lower = np.array([-1, -1, -1])\n",
1465+
"upper = np.array([1, 1, 1])\n",
1466+
"design_control={\"init_size\": ni}\n",
1467+
"surrogate_control={\"n_theta\": 3}\n",
1468+
"S = spot.Spot(fun=fun,\n",
1469+
" lower = lower,\n",
1470+
" upper= upper,\n",
1471+
" log_level = 50,\n",
1472+
" fun_evals = fun_evals,\n",
1473+
" tolerance_x = np.sqrt(np.spacing(1)),\n",
1474+
" show_progress=True,\n",
1475+
" design_control=design_control,\n",
1476+
" surrogate_control=surrogate_control,)\n",
1477+
"S.run()\n",
1478+
"S.plot_important_hyperparameter_contour()"
1479+
]
1480+
},
1481+
{
1482+
"cell_type": "code",
1483+
"execution_count": null,
1484+
"metadata": {},
1485+
"outputs": [],
1486+
"source": [
1487+
"import numpy as np\n",
1488+
"from spotPython.fun.objectivefunctions import analytical\n",
1489+
"from spotPython.spot import spot\n",
1490+
"# number of initial points:\n",
1491+
"ni = 5\n",
1492+
"# number of points\n",
1493+
"fun_evals = 10\n",
1494+
"\n",
1495+
"fun = analytical().fun_sphere\n",
1496+
"lower = np.array([-1, -1, -1])\n",
1497+
"upper = np.array([1, 1, 1])\n",
1498+
"design_control={\"init_size\": ni}\n",
1499+
"surrogate_control={\"n_theta\": 3}\n",
1500+
"S = spot.Spot(fun=fun,\n",
1501+
" lower = lower,\n",
1502+
" upper= upper,\n",
1503+
" log_level = 50,\n",
1504+
" fun_evals = fun_evals,\n",
1505+
" tolerance_x = np.sqrt(np.spacing(1)),\n",
1506+
" show_progress=True,\n",
1507+
" design_control=design_control,\n",
1508+
" surrogate_control=surrogate_control,)\n",
1509+
"S.run()\n",
1510+
"S.plot_contour()\n",
1511+
"S.plot_contour(i=1, j=2)"
1512+
]
1513+
},
1514+
{
1515+
"cell_type": "code",
1516+
"execution_count": null,
1517+
"metadata": {},
1518+
"outputs": [],
1519+
"source": [
1520+
"import numpy as np\n",
1521+
"from spotPython.fun.objectivefunctions import analytical\n",
1522+
"from spotPython.spot import spot\n",
1523+
"# number of initial points:\n",
1524+
"ni = 3\n",
1525+
"# number of points\n",
1526+
"fun_evals = 7\n",
1527+
"\n",
1528+
"fun = analytical().fun_sphere\n",
1529+
"lower = np.array([-1])\n",
1530+
"upper = np.array([1])\n",
1531+
"design_control={\"init_size\": ni}\n",
1532+
"surrogate_control={\"n_theta\": 1}\n",
1533+
"S = spot.Spot(fun=fun,\n",
1534+
" lower = lower,\n",
1535+
" upper= upper,\n",
1536+
" log_level = 50,\n",
1537+
" fun_evals = fun_evals,\n",
1538+
" tolerance_x = np.sqrt(np.spacing(1)),\n",
1539+
" show_progress=True,\n",
1540+
" design_control=design_control,\n",
1541+
" surrogate_control=surrogate_control,)\n",
1542+
"S.run()\n",
1543+
"S.plot_model()\n"
1544+
]
1545+
},
1546+
{
1547+
"cell_type": "code",
1548+
"execution_count": null,
1549+
"metadata": {},
1550+
"outputs": [],
1551+
"source": [
1552+
"import numpy as np\n",
1553+
"from spotPython.fun.objectivefunctions import analytical\n",
1554+
"from spotPython.spot import spot\n",
1555+
"# number of initial points:\n",
1556+
"ni = 5\n",
1557+
"# number of points\n",
1558+
"fun_evals = 10\n",
13741559
"\n",
13751560
"fun = analytical().fun_sphere\n",
13761561
"lower = np.array([-1, -1, -1])\n",
@@ -1387,9 +1572,46 @@
13871572
" design_control=design_control,\n",
13881573
" surrogate_control=surrogate_control,)\n",
13891574
"S.run()\n",
1390-
"S.plot_important_hyperparameter_contour()\n",
1391-
"S.min_X\n",
1392-
"S.min_y"
1575+
"S.parallel_plot()"
1576+
]
1577+
},
1578+
{
1579+
"cell_type": "code",
1580+
"execution_count": null,
1581+
"metadata": {},
1582+
"outputs": [],
1583+
"source": [
1584+
"import numpy as np\n",
1585+
"from spotPython.fun.objectivefunctions import analytical\n",
1586+
"from spotPython.spot import spot\n",
1587+
"# 1-dimensional example\n",
1588+
"fun = analytical().fun_sphere\n",
1589+
"lower = np.array([-1])\n",
1590+
"upper = np.array([1])\n",
1591+
"design_control={\"init_size\": 10}\n",
1592+
"S = spot.Spot(fun=fun,\n",
1593+
" noise=False,\n",
1594+
" lower = lower,\n",
1595+
" upper= upper,\n",
1596+
" design_control=design_control,)\n",
1597+
"S.initialize_design()\n",
1598+
"S.update_stats()\n",
1599+
"S.fit_surrogate()\n",
1600+
"S.surrogate.plot()\n",
1601+
"# 2-dimensional example\n",
1602+
"fun = analytical().fun_sphere\n",
1603+
"lower = np.array([-1, -1])\n",
1604+
"upper = np.array([1, 1])\n",
1605+
"design_control={\"init_size\": 10}\n",
1606+
"S = spot.Spot(fun=fun,\n",
1607+
" noise=False,\n",
1608+
" lower = lower,\n",
1609+
" upper= upper,\n",
1610+
" design_control=design_control,)\n",
1611+
"S.initialize_design()\n",
1612+
"S.update_stats()\n",
1613+
"S.fit_surrogate()\n",
1614+
"S.surrogate.plot()"
13931615
]
13941616
},
13951617
{
@@ -1508,7 +1730,7 @@
15081730
"metadata": {},
15091731
"outputs": [],
15101732
"source": [
1511-
"from spotPython.budget.ocba import get_ocba, get_ocba_X\n",
1733+
"from spotPython.budget.ocba import get_ocba_X\n",
15121734
"from spotPython.utils.aggregate import aggregate_mean_var\n",
15131735
"import numpy as np\n",
15141736
"X = np.array([[1,2,3],\n",
@@ -1529,36 +1751,15 @@
15291751
"print(f\"mean_y: {mean_y}\")\n",
15301752
"print(f\"var_y: {var_y}\")\n",
15311753
"delta = 5\n",
1532-
"# get_ocba(means, vars, delta,verbose=True)\n",
15331754
"X_new = get_ocba_X(X=mean_X, means=mean_y, vars=var_y, delta=delta,verbose=True)\n",
15341755
"X_new\n"
15351756
]
15361757
},
15371758
{
15381759
"cell_type": "code",
1539-
"execution_count": 2,
1760+
"execution_count": null,
15401761
"metadata": {},
1541-
"outputs": [
1542-
{
1543-
"name": "stdout",
1544-
"output_type": "stream",
1545-
"text": [
1546-
"X: [[1 2 3]\n",
1547-
" [1 2 3]\n",
1548-
" [4 5 6]\n",
1549-
" [4 5 6]\n",
1550-
" [4 5 6]\n",
1551-
" [4 5 6]\n",
1552-
" [4 5 6]]\n",
1553-
"mean_X.shape: (2, 3)\n",
1554-
"y: [ 1 2 30 40 40 500 600]\n",
1555-
"mean_X: [[1. 2. 3.]\n",
1556-
" [4. 5. 6.]]\n",
1557-
"mean_y: [ 1.5 242. ]\n",
1558-
"var_y: [5.000e-01 8.032e+04]\n"
1559-
]
1560-
}
1561-
],
1762+
"outputs": [],
15621763
"source": [
15631764
"from spotPython.budget.ocba import get_ocba, get_ocba_X\n",
15641765
"from spotPython.utils.aggregate import aggregate_mean_var\n",
@@ -1587,6 +1788,41 @@
15871788
"assert X_new is None\n"
15881789
]
15891790
},
1791+
{
1792+
"cell_type": "code",
1793+
"execution_count": 21,
1794+
"metadata": {},
1795+
"outputs": [
1796+
{
1797+
"name": "stdout",
1798+
"output_type": "stream",
1799+
"text": [
1800+
"Before: [1, 2, 3]\n",
1801+
"After: [4, 2, 5]\n"
1802+
]
1803+
}
1804+
],
1805+
"source": [
1806+
"import numpy as np\n",
1807+
"from spotPython.fun.objectivefunctions import analytical\n",
1808+
"from spotPython.spot import spot\n",
1809+
"fun = analytical().fun_sphere\n",
1810+
"lower = np.array([-1, -1])\n",
1811+
"upper = np.array([1, 1])\n",
1812+
"S = spot.Spot(fun=fun,\n",
1813+
" lower = lower,\n",
1814+
" upper= upper,\n",
1815+
")\n",
1816+
"z0 = [1, 2, 3]\n",
1817+
"print(f\"Before: {z0}\")\n",
1818+
"new_val_1 = 4\n",
1819+
"new_val_2 = 5\n",
1820+
"index_1 = 0\n",
1821+
"index_2 = 2\n",
1822+
"S.chg(x=new_val_1, y=new_val_2, z0=z0, i=index_1, j=index_2)\n",
1823+
"print(f\"After: {z0}\")"
1824+
]
1825+
},
15901826
{
15911827
"cell_type": "code",
15921828
"execution_count": null,

0 commit comments

Comments
 (0)