|
986 | 986 | }, |
987 | 987 | { |
988 | 988 | "cell_type": "code", |
989 | | - "execution_count": 13, |
| 989 | + "execution_count": null, |
990 | 990 | "metadata": {}, |
991 | | - "outputs": [ |
992 | | - { |
993 | | - "name": "stdout", |
994 | | - "output_type": "stream", |
995 | | - "text": [ |
996 | | - "S.X: [[0 0]\n", |
997 | | - " [0 1]\n", |
998 | | - " [1 0]\n", |
999 | | - " [1 1]]\n", |
1000 | | - "S.y: [0. 1. 1. 2.]\n" |
1001 | | - ] |
1002 | | - } |
1003 | | - ], |
| 991 | + "outputs": [], |
1004 | 992 | "source": [ |
1005 | 993 | "import numpy as np\n", |
1006 | 994 | "from spotPython.fun.objectivefunctions import analytical\n", |
|
1038 | 1026 | }, |
1039 | 1027 | { |
1040 | 1028 | "cell_type": "code", |
1041 | | - "execution_count": 23, |
| 1029 | + "execution_count": null, |
1042 | 1030 | "metadata": {}, |
1043 | | - "outputs": [ |
1044 | | - { |
1045 | | - "name": "stdout", |
1046 | | - "output_type": "stream", |
1047 | | - "text": [ |
1048 | | - "S.X: [[0 0]\n", |
1049 | | - " [0 1]\n", |
1050 | | - " [1 0]\n", |
1051 | | - " [1 1]\n", |
1052 | | - " [1 1]]\n", |
1053 | | - "S.y: [0. 1. 1. 2. 2.]\n" |
1054 | | - ] |
1055 | | - } |
1056 | | - ], |
| 1031 | + "outputs": [], |
1057 | 1032 | "source": [ |
1058 | 1033 | "import numpy as np\n", |
1059 | 1034 | "from spotPython.fun.objectivefunctions import analytical\n", |
|
1093 | 1068 | }, |
1094 | 1069 | { |
1095 | 1070 | "cell_type": "code", |
1096 | | - "execution_count": 29, |
| 1071 | + "execution_count": null, |
1097 | 1072 | "metadata": {}, |
1098 | | - "outputs": [ |
1099 | | - { |
1100 | | - "name": "stdout", |
1101 | | - "output_type": "stream", |
1102 | | - "text": [ |
1103 | | - "S.X: [[0 0]\n", |
1104 | | - " [0 1]\n", |
1105 | | - " [1 0]\n", |
1106 | | - " [1 1]\n", |
1107 | | - " [1 1]]\n", |
1108 | | - "S.y: [0. 1. 1. 2. 2.]\n" |
1109 | | - ] |
1110 | | - } |
1111 | | - ], |
| 1073 | + "outputs": [], |
1112 | 1074 | "source": [ |
1113 | 1075 | "import numpy as np\n", |
1114 | 1076 | "from spotPython.fun.objectivefunctions import analytical\n", |
|
1159 | 1121 | }, |
1160 | 1122 | { |
1161 | 1123 | "cell_type": "code", |
1162 | | - "execution_count": 36, |
| 1124 | + "execution_count": null, |
| 1125 | + "metadata": {}, |
| 1126 | + "outputs": [], |
| 1127 | + "source": [ |
| 1128 | + "import numpy as np\n", |
| 1129 | + "from spotPython.fun.objectivefunctions import analytical\n", |
| 1130 | + "from spotPython.spot import spot\n", |
| 1131 | + "# number of initial points:\n", |
| 1132 | + "ni = 0\n", |
| 1133 | + "X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1], [1, 1]])\n", |
| 1134 | + "\n", |
| 1135 | + "fun = analytical().fun_sphere\n", |
| 1136 | + "lower = np.array([-1, -1])\n", |
| 1137 | + "upper = np.array([1, 1])\n", |
| 1138 | + "design_control={\"init_size\": ni}\n", |
| 1139 | + "\n", |
| 1140 | + "S = spot.Spot(fun=fun,\n", |
| 1141 | + " noise=False,\n", |
| 1142 | + " lower = lower,\n", |
| 1143 | + " upper= upper,\n", |
| 1144 | + " show_progress=True,\n", |
| 1145 | + " design_control=design_control,)\n", |
| 1146 | + "S.initialize_design(X_start=X_start)\n", |
| 1147 | + "print(f\"S.X: {S.X}\")\n", |
| 1148 | + "print(f\"S.y: {S.y}\")\n", |
| 1149 | + "S.update_stats()\n", |
| 1150 | + "S.fit_surrogate()\n", |
| 1151 | + "assert S.surrogate.Psi.shape[0] == S.X.shape[0]\n" |
| 1152 | + ] |
| 1153 | + }, |
| 1154 | + { |
| 1155 | + "cell_type": "markdown", |
| 1156 | + "metadata": {}, |
| 1157 | + "source": [ |
| 1158 | + "## test update_design()" |
| 1159 | + ] |
| 1160 | + }, |
| 1161 | + { |
| 1162 | + "cell_type": "code", |
| 1163 | + "execution_count": null, |
| 1164 | + "metadata": {}, |
| 1165 | + "outputs": [], |
| 1166 | + "source": [ |
| 1167 | + "import numpy as np\n", |
| 1168 | + "from spotPython.fun.objectivefunctions import analytical\n", |
| 1169 | + "from spotPython.spot import spot\n", |
| 1170 | + "# number of initial points:\n", |
| 1171 | + "ni = 0\n", |
| 1172 | + "X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1], [1, 1]])\n", |
| 1173 | + "\n", |
| 1174 | + "fun = analytical().fun_sphere\n", |
| 1175 | + "lower = np.array([-1, -1])\n", |
| 1176 | + "upper = np.array([1, 1])\n", |
| 1177 | + "design_control={\"init_size\": ni}\n", |
| 1178 | + "\n", |
| 1179 | + "S = spot.Spot(fun=fun,\n", |
| 1180 | + " noise=False,\n", |
| 1181 | + " lower = lower,\n", |
| 1182 | + " upper= upper,\n", |
| 1183 | + " show_progress=True,\n", |
| 1184 | + " design_control=design_control,)\n", |
| 1185 | + "S.initialize_design(X_start=X_start)\n", |
| 1186 | + "print(f\"S.X: {S.X}\")\n", |
| 1187 | + "print(f\"S.y: {S.y}\")\n", |
| 1188 | + "X_shape_before = S.X.shape\n", |
| 1189 | + "print(f\"X_shape_before: {X_shape_before}\")\n", |
| 1190 | + "print(f\"y_size_before: {S.y.size}\")\n", |
| 1191 | + "y_size_before = S.y.size\n", |
| 1192 | + "S.update_stats()\n", |
| 1193 | + "S.fit_surrogate()\n", |
| 1194 | + "S.update_design()\n", |
| 1195 | + "print(f\"S.X: {S.X}\")\n", |
| 1196 | + "print(f\"S.y: {S.y}\")\n", |
| 1197 | + "print(f\"S.n_points: {S.n_points}\")\n", |
| 1198 | + "print(f\"X_shape_after: {S.X.shape}\")\n", |
| 1199 | + "print(f\"y_size_after: {S.y.size}\")\n", |
| 1200 | + "# compare the shapes of the X and y values before and after the update_design method\n", |
| 1201 | + "assert X_shape_before[0] + S.n_points == S.X.shape[0]\n", |
| 1202 | + "assert X_shape_before[1] == S.X.shape[1]\n", |
| 1203 | + "assert y_size_before + S.n_points == S.y.size" |
| 1204 | + ] |
| 1205 | + }, |
| 1206 | + { |
| 1207 | + "cell_type": "code", |
| 1208 | + "execution_count": 8, |
1163 | 1209 | "metadata": {}, |
1164 | 1210 | "outputs": [ |
| 1211 | + { |
| 1212 | + "name": "stderr", |
| 1213 | + "output_type": "stream", |
| 1214 | + "text": [ |
| 1215 | + "Seed set to 42\n" |
| 1216 | + ] |
| 1217 | + }, |
1165 | 1218 | { |
1166 | 1219 | "name": "stdout", |
1167 | 1220 | "output_type": "stream", |
1168 | 1221 | "text": [ |
1169 | | - "S.X: [[0 0]\n", |
1170 | | - " [0 1]\n", |
1171 | | - " [1 0]\n", |
1172 | | - " [1 1]\n", |
1173 | | - " [1 1]]\n", |
1174 | | - "S.y: [0. 1. 1. 2. 2.]\n" |
| 1222 | + "S.X: [[ 0. 1. ]\n", |
| 1223 | + " [ 1. 0. ]\n", |
| 1224 | + " [ 1. 1. ]\n", |
| 1225 | + " [ 1. 1. ]\n", |
| 1226 | + " [ 0.54509876 -0.36921401]\n", |
| 1227 | + " [ 0.54509876 -0.36921401]\n", |
| 1228 | + " [ 0.18642675 0.87708546]\n", |
| 1229 | + " [ 0.18642675 0.87708546]\n", |
| 1230 | + " [-0.45060393 -0.208063 ]\n", |
| 1231 | + " [-0.45060393 -0.208063 ]]\n", |
| 1232 | + "S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n", |
| 1233 | + " 0.79130456 0.81487288 0.24000221 0.23988634]\n", |
| 1234 | + "X_shape_before: (10, 2)\n", |
| 1235 | + "y_size_before: 10\n", |
| 1236 | + "S.X: [[ 0. 1. ]\n", |
| 1237 | + " [ 1. 0. ]\n", |
| 1238 | + " [ 1. 1. ]\n", |
| 1239 | + " [ 1. 1. ]\n", |
| 1240 | + " [ 0.54509876 -0.36921401]\n", |
| 1241 | + " [ 0.54509876 -0.36921401]\n", |
| 1242 | + " [ 0.18642675 0.87708546]\n", |
| 1243 | + " [ 0.18642675 0.87708546]\n", |
| 1244 | + " [-0.45060393 -0.208063 ]\n", |
| 1245 | + " [-0.45060393 -0.208063 ]\n", |
| 1246 | + " [-0.45060393 -0.208063 ]\n", |
| 1247 | + " [-0.39841465 -0.21105872]\n", |
| 1248 | + " [-0.39841465 -0.21105872]]\n", |
| 1249 | + "S.y: [0.98021757 0.99264427 2.02575851 2.00387949 0.45185626 0.44499372\n", |
| 1250 | + " 0.79130456 0.81487288 0.24000221 0.23988634 0.22655169 0.19592429\n", |
| 1251 | + " 0.22903853]\n", |
| 1252 | + "S.n_points: 1\n", |
| 1253 | + "S.ocba_delta: 1\n", |
| 1254 | + "X_shape_after: (13, 2)\n", |
| 1255 | + "y_size_after: 13\n" |
| 1256 | + ] |
| 1257 | + }, |
| 1258 | + { |
| 1259 | + "name": "stderr", |
| 1260 | + "output_type": "stream", |
| 1261 | + "text": [ |
| 1262 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/budget/ocba.py:65: RuntimeWarning: invalid value encountered in cast\n", |
| 1263 | + " add_budget = around(add_budget).astype(int)\n" |
1175 | 1264 | ] |
1176 | 1265 | } |
1177 | 1266 | ], |
1178 | 1267 | "source": [ |
1179 | 1268 | "import numpy as np\n", |
1180 | 1269 | "from spotPython.fun.objectivefunctions import analytical\n", |
1181 | 1270 | "from spotPython.spot import spot\n", |
| 1271 | + "from spotPython.utils.init import fun_control_init\n", |
1182 | 1272 | "# number of initial points:\n", |
1183 | | - "ni = 0\n", |
1184 | | - "X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1], [1, 1]])\n", |
| 1273 | + "ni = 3\n", |
| 1274 | + "X_start = np.array([[0, 1], [1, 0], [1, 1], [1, 1]])\n", |
1185 | 1275 | "\n", |
1186 | 1276 | "fun = analytical().fun_sphere\n", |
| 1277 | + "fun_control = fun_control_init(\n", |
| 1278 | + " sigma=0.02,\n", |
| 1279 | + " seed=123,)\n", |
1187 | 1280 | "lower = np.array([-1, -1])\n", |
1188 | 1281 | "upper = np.array([1, 1])\n", |
1189 | | - "design_control={\"init_size\": ni}\n", |
| 1282 | + "design_control={\"init_size\": ni,\n", |
| 1283 | + " \"repeats\": 2}\n", |
1190 | 1284 | "\n", |
1191 | 1285 | "S = spot.Spot(fun=fun,\n", |
1192 | | - " noise=False,\n", |
| 1286 | + " noise=True,\n", |
| 1287 | + " fun_repeats=2,\n", |
| 1288 | + " n_points=1,\n", |
| 1289 | + " ocba_delta=1,\n", |
1193 | 1290 | " lower = lower,\n", |
1194 | 1291 | " upper= upper,\n", |
1195 | 1292 | " show_progress=True,\n", |
1196 | | - " design_control=design_control,)\n", |
| 1293 | + " design_control=design_control,\n", |
| 1294 | + " fun_control=fun_control\n", |
| 1295 | + ")\n", |
1197 | 1296 | "S.initialize_design(X_start=X_start)\n", |
1198 | 1297 | "print(f\"S.X: {S.X}\")\n", |
1199 | 1298 | "print(f\"S.y: {S.y}\")\n", |
| 1299 | + "X_shape_before = S.X.shape\n", |
| 1300 | + "print(f\"X_shape_before: {X_shape_before}\")\n", |
| 1301 | + "print(f\"y_size_before: {S.y.size}\")\n", |
| 1302 | + "y_size_before = S.y.size\n", |
1200 | 1303 | "S.update_stats()\n", |
1201 | 1304 | "S.fit_surrogate()\n", |
1202 | | - "assert S.surrogate.Psi.shape[0] == S.X.shape[0]\n" |
| 1305 | + "S.update_design()\n", |
| 1306 | + "print(f\"S.X: {S.X}\")\n", |
| 1307 | + "print(f\"S.y: {S.y}\")\n", |
| 1308 | + "print(f\"S.n_points: {S.n_points}\")\n", |
| 1309 | + "print(f\"S.ocba_delta: {S.ocba_delta}\")\n", |
| 1310 | + "print(f\"X_shape_after: {S.X.shape}\")\n", |
| 1311 | + "print(f\"y_size_after: {S.y.size}\")\n", |
| 1312 | + "# compare the shapes of the X and y values before and after the update_design method\n", |
| 1313 | + "assert X_shape_before[0] + S.n_points * S.fun_repeats + S.ocba_delta == S.X.shape[0]\n", |
| 1314 | + "assert X_shape_before[1] == S.X.shape[1]\n", |
| 1315 | + "assert y_size_before + S.n_points * S.fun_repeats + S.ocba_delta == S.y.size" |
1203 | 1316 | ] |
1204 | 1317 | }, |
1205 | 1318 | { |
|
0 commit comments