|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": 1, |
| 15 | + "execution_count": null, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
23 | 23 | }, |
24 | 24 | { |
25 | 25 | "cell_type": "code", |
26 | | - "execution_count": 2, |
27 | | - "metadata": {}, |
28 | | - "outputs": [ |
29 | | - { |
30 | | - "data": { |
31 | | - "text/plain": [ |
32 | | - "'12-torch_p040025_1min_5init_2023-05-12_11-11-36'" |
33 | | - ] |
34 | | - }, |
35 | | - "execution_count": 2, |
36 | | - "metadata": {}, |
37 | | - "output_type": "execute_result" |
38 | | - } |
39 | | - ], |
| 26 | + "execution_count": null, |
| 27 | + "metadata": {}, |
| 28 | + "outputs": [], |
40 | 29 | "source": [ |
41 | 30 | "import pickle\n", |
42 | 31 | "import socket\n", |
|
69 | 58 | }, |
70 | 59 | { |
71 | 60 | "cell_type": "code", |
72 | | - "execution_count": 3, |
73 | | - "metadata": {}, |
74 | | - "outputs": [ |
75 | | - { |
76 | | - "name": "stdout", |
77 | | - "output_type": "stream", |
78 | | - "text": [ |
79 | | - "spotPython 0.0.68\n", |
80 | | - "Note: you may need to restart the kernel to use updated packages.\n" |
81 | | - ] |
82 | | - } |
83 | | - ], |
| 61 | + "execution_count": null, |
| 62 | + "metadata": {}, |
| 63 | + "outputs": [], |
84 | 64 | "source": [ |
85 | 65 | "pip list | grep \"spot[RiverPython]\"" |
86 | 66 | ] |
87 | 67 | }, |
88 | 68 | { |
89 | 69 | "cell_type": "code", |
90 | | - "execution_count": 4, |
| 70 | + "execution_count": null, |
91 | 71 | "metadata": {}, |
92 | 72 | "outputs": [], |
93 | 73 | "source": [ |
|
98 | 78 | }, |
99 | 79 | { |
100 | 80 | "cell_type": "code", |
101 | | - "execution_count": 5, |
| 81 | + "execution_count": null, |
102 | 82 | "metadata": {}, |
103 | 83 | "outputs": [], |
104 | 84 | "source": [ |
|
175 | 155 | }, |
176 | 156 | { |
177 | 157 | "cell_type": "code", |
178 | | - "execution_count": 6, |
179 | | - "metadata": {}, |
180 | | - "outputs": [ |
181 | | - { |
182 | | - "name": "stdout", |
183 | | - "output_type": "stream", |
184 | | - "text": [ |
185 | | - "2.0.1\n", |
186 | | - "MPS device: mps\n" |
187 | | - ] |
188 | | - } |
189 | | - ], |
| 158 | + "execution_count": null, |
| 159 | + "metadata": {}, |
| 160 | + "outputs": [], |
190 | 161 | "source": [ |
191 | 162 | "print(torch.__version__)\n", |
192 | 163 | "# Check that MPS is available\n", |
|
213 | 184 | }, |
214 | 185 | { |
215 | 186 | "cell_type": "code", |
216 | | - "execution_count": 7, |
| 187 | + "execution_count": null, |
217 | 188 | "metadata": {}, |
218 | 189 | "outputs": [], |
219 | 190 | "source": [ |
|
230 | 201 | }, |
231 | 202 | { |
232 | 203 | "cell_type": "code", |
233 | | - "execution_count": 8, |
| 204 | + "execution_count": null, |
234 | 205 | "metadata": {}, |
235 | 206 | "outputs": [], |
236 | 207 | "source": [ |
|
251 | 222 | }, |
252 | 223 | { |
253 | 224 | "cell_type": "code", |
254 | | - "execution_count": 9, |
255 | | - "metadata": {}, |
256 | | - "outputs": [ |
257 | | - { |
258 | | - "name": "stdout", |
259 | | - "output_type": "stream", |
260 | | - "text": [ |
261 | | - "Files already downloaded and verified\n", |
262 | | - "Files already downloaded and verified\n" |
263 | | - ] |
264 | | - }, |
265 | | - { |
266 | | - "data": { |
267 | | - "text/plain": [ |
268 | | - "((50000, 32, 32, 3), (10000, 32, 32, 3))" |
269 | | - ] |
270 | | - }, |
271 | | - "execution_count": 9, |
272 | | - "metadata": {}, |
273 | | - "output_type": "execute_result" |
274 | | - } |
275 | | - ], |
| 225 | + "execution_count": null, |
| 226 | + "metadata": {}, |
| 227 | + "outputs": [], |
276 | 228 | "source": [ |
277 | 229 | "train, test = load_data()\n", |
278 | 230 | "train.data.shape, test.data.shape" |
279 | 231 | ] |
280 | 232 | }, |
281 | 233 | { |
282 | 234 | "cell_type": "code", |
283 | | - "execution_count": 10, |
| 235 | + "execution_count": null, |
284 | 236 | "metadata": {}, |
285 | 237 | "outputs": [], |
286 | 238 | "source": [ |
|
303 | 255 | }, |
304 | 256 | { |
305 | 257 | "cell_type": "code", |
306 | | - "execution_count": 11, |
| 258 | + "execution_count": null, |
307 | 259 | "metadata": {}, |
308 | 260 | "outputs": [], |
309 | 261 | "source": [ |
|
346 | 298 | }, |
347 | 299 | { |
348 | 300 | "cell_type": "code", |
349 | | - "execution_count": 12, |
| 301 | + "execution_count": null, |
350 | 302 | "metadata": {}, |
351 | 303 | "outputs": [], |
352 | 304 | "source": [ |
|
476 | 428 | }, |
477 | 429 | { |
478 | 430 | "cell_type": "code", |
479 | | - "execution_count": 13, |
| 431 | + "execution_count": null, |
480 | 432 | "metadata": {}, |
481 | 433 | "outputs": [], |
482 | 434 | "source": [ |
|
504 | 456 | }, |
505 | 457 | { |
506 | 458 | "cell_type": "code", |
507 | | - "execution_count": 14, |
| 459 | + "execution_count": null, |
508 | 460 | "metadata": {}, |
509 | 461 | "outputs": [], |
510 | 462 | "source": [ |
|
544 | 496 | }, |
545 | 497 | { |
546 | 498 | "cell_type": "code", |
547 | | - "execution_count": 15, |
| 499 | + "execution_count": null, |
548 | 500 | "metadata": {}, |
549 | 501 | "outputs": [], |
550 | 502 | "source": [ |
|
553 | 505 | "shuffle = True\n", |
554 | 506 | "eval = \"train_hold_out\"\n", |
555 | 507 | "device = DEVICE\n", |
| 508 | + "show_batch_interval = 100_000\n", |
556 | 509 | "\n", |
557 | 510 | "fun_control.update({\n", |
558 | 511 | " \"data_dir\": None,\n", |
|
568 | 521 | " \"shuffle\": shuffle,\n", |
569 | 522 | " \"eval\": eval,\n", |
570 | 523 | " \"device\": device,\n", |
| 524 | + " \"show_batch_interval\": show_batch_interval,\n", |
571 | 525 | " })" |
572 | 526 | ] |
573 | 527 | }, |
|
597 | 551 | }, |
598 | 552 | { |
599 | 553 | "cell_type": "code", |
600 | | - "execution_count": 16, |
| 554 | + "execution_count": null, |
601 | 555 | "metadata": {}, |
602 | 556 | "outputs": [], |
603 | 557 | "source": [ |
|
612 | 566 | }, |
613 | 567 | { |
614 | 568 | "cell_type": "code", |
615 | | - "execution_count": 17, |
616 | | - "metadata": {}, |
617 | | - "outputs": [ |
618 | | - { |
619 | | - "name": "stdout", |
620 | | - "output_type": "stream", |
621 | | - "text": [ |
622 | | - "| name | type | default | lower | upper | transform |\n", |
623 | | - "|------------|--------|-----------|---------|---------|-----------------------|\n", |
624 | | - "| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
625 | | - "| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n", |
626 | | - "| lr | float | 0.001 | 1e-05 | 0.01 | None |\n", |
627 | | - "| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n", |
628 | | - "| epochs | int | 3 | 3 | 4 | transform_power_2_int |\n", |
629 | | - "| k_folds | int | 2 | 2 | 2 | None |\n" |
630 | | - ] |
631 | | - } |
632 | | - ], |
| 569 | + "execution_count": null, |
| 570 | + "metadata": {}, |
| 571 | + "outputs": [], |
633 | 572 | "source": [ |
634 | 573 | "print(gen_design_table(fun_control))" |
635 | 574 | ] |
|
647 | 586 | }, |
648 | 587 | { |
649 | 588 | "cell_type": "code", |
650 | | - "execution_count": 18, |
651 | | - "metadata": {}, |
652 | | - "outputs": [ |
653 | | - { |
654 | | - "data": { |
655 | | - "text/plain": [ |
656 | | - "array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])" |
657 | | - ] |
658 | | - }, |
659 | | - "execution_count": 18, |
660 | | - "metadata": {}, |
661 | | - "output_type": "execute_result" |
662 | | - } |
663 | | - ], |
| 589 | + "execution_count": null, |
| 590 | + "metadata": {}, |
| 591 | + "outputs": [], |
664 | 592 | "source": [ |
665 | 593 | "from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n", |
666 | 594 | "hyper_dict=TorchHyperDict().load()\n", |
|
670 | 598 | }, |
671 | 599 | { |
672 | 600 | "cell_type": "code", |
673 | | - "execution_count": 19, |
674 | | - "metadata": {}, |
675 | | - "outputs": [ |
676 | | - { |
677 | | - "name": "stdout", |
678 | | - "output_type": "stream", |
679 | | - "text": [ |
680 | | - "Additional attributes: {'lr', 'batch_size', 'k_folds', 'epochs'}\n", |
681 | | - "Removed attributes: {'lr', 'batch_size', 'k_folds', 'epochs'}\n", |
682 | | - "Epoch: 1\n", |
683 | | - "Batch: 1000. Batch Size: 8. Training Loss (running): 2.126\n", |
684 | | - "Batch: 2000. Batch Size: 8. Training Loss (running): 0.898\n", |
685 | | - "Batch: 3000. Batch Size: 8. Training Loss (running): 0.552\n", |
686 | | - "Loss on hold-out set: 1.590382700252533\n", |
687 | | - "Accuracy on hold-out set: 0.4229\n", |
688 | | - "Epoch: 2\n", |
689 | | - "Batch: 1000. Batch Size: 8. Training Loss (running): 1.524\n", |
690 | | - "Batch: 2000. Batch Size: 8. Training Loss (running): 0.733\n", |
691 | | - "Batch: 3000. Batch Size: 8. Training Loss (running): 0.476\n", |
692 | | - "Loss on hold-out set: 1.4275218252420425\n", |
693 | | - "Accuracy on hold-out set: 0.4857\n", |
694 | | - "Epoch: 3\n", |
695 | | - "Batch: 1000. Batch Size: 8. Training Loss (running): 1.353\n", |
696 | | - "Batch: 2000. Batch Size: 8. Training Loss (running): 0.664\n", |
697 | | - "Batch: 3000. Batch Size: 8. Training Loss (running): 0.436\n", |
698 | | - "Loss on hold-out set: 1.3183471252799035\n", |
699 | | - "Accuracy on hold-out set: 0.5349\n", |
700 | | - "Epoch: 4\n", |
701 | | - "Batch: 1000. Batch Size: 8. Training Loss (running): 1.227\n", |
702 | | - "Batch: 2000. Batch Size: 8. Training Loss (running): 0.623\n", |
703 | | - "Batch: 3000. Batch Size: 8. Training Loss (running): 0.414\n", |
704 | | - "Loss on hold-out set: 1.3052782666921616\n", |
705 | | - "Accuracy on hold-out set: 0.5461\n", |
706 | | - "Epoch: 5\n", |
707 | | - "Batch: 1000. Batch Size: 8. Training Loss (running): 1.155\n", |
708 | | - "Batch: 2000. Batch Size: 8. Training Loss (running): 0.591\n", |
709 | | - "Batch: 3000. Batch Size: 8. Training Loss (running): 0.378\n", |
710 | | - "Loss on hold-out set: 1.2335647928357125\n", |
711 | | - "Accuracy on hold-out set: 0.57405\n", |
712 | | - "Epoch: 6\n" |
713 | | - ] |
714 | | - } |
715 | | - ], |
| 601 | + "execution_count": null, |
| 602 | + "metadata": {}, |
| 603 | + "outputs": [], |
716 | 604 | "source": [ |
717 | 605 | "spot_tuner = spot.Spot(fun=fun,\n", |
718 | 606 | " lower = lower,\n", |
|
967 | 855 | }, |
968 | 856 | { |
969 | 857 | "cell_type": "code", |
970 | | - "execution_count": null, |
| 858 | + "execution_count": 35, |
971 | 859 | "metadata": {}, |
972 | 860 | "outputs": [], |
973 | 861 | "source": [ |
|
0 commit comments