Skip to content

Commit 325f3b6

Browse files
v0.0.69
1 parent 959dfbf commit 325f3b6

14 files changed

Lines changed: 2226 additions & 2363 deletions

notebooks/11_spot_hpt_torch_fashion_mnist.ipynb

Lines changed: 1675 additions & 108 deletions
Large diffs are not rendered by default.

notebooks/12_spot_hpt_torch_cifar10.ipynb

Lines changed: 36 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -23,20 +23,9 @@
2323
},
2424
{
2525
"cell_type": "code",
26-
"execution_count": 2,
27-
"metadata": {},
28-
"outputs": [
29-
{
30-
"data": {
31-
"text/plain": [
32-
"'12-torch_p040025_1min_5init_2023-05-12_11-11-36'"
33-
]
34-
},
35-
"execution_count": 2,
36-
"metadata": {},
37-
"output_type": "execute_result"
38-
}
39-
],
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
4029
"source": [
4130
"import pickle\n",
4231
"import socket\n",
@@ -69,25 +58,16 @@
6958
},
7059
{
7160
"cell_type": "code",
72-
"execution_count": 3,
73-
"metadata": {},
74-
"outputs": [
75-
{
76-
"name": "stdout",
77-
"output_type": "stream",
78-
"text": [
79-
"spotPython 0.0.68\n",
80-
"Note: you may need to restart the kernel to use updated packages.\n"
81-
]
82-
}
83-
],
61+
"execution_count": null,
62+
"metadata": {},
63+
"outputs": [],
8464
"source": [
8565
"pip list | grep \"spot[RiverPython]\""
8666
]
8767
},
8868
{
8969
"cell_type": "code",
90-
"execution_count": 4,
70+
"execution_count": null,
9171
"metadata": {},
9272
"outputs": [],
9373
"source": [
@@ -98,7 +78,7 @@
9878
},
9979
{
10080
"cell_type": "code",
101-
"execution_count": 5,
81+
"execution_count": null,
10282
"metadata": {},
10383
"outputs": [],
10484
"source": [
@@ -175,18 +155,9 @@
175155
},
176156
{
177157
"cell_type": "code",
178-
"execution_count": 6,
179-
"metadata": {},
180-
"outputs": [
181-
{
182-
"name": "stdout",
183-
"output_type": "stream",
184-
"text": [
185-
"2.0.1\n",
186-
"MPS device: mps\n"
187-
]
188-
}
189-
],
158+
"execution_count": null,
159+
"metadata": {},
160+
"outputs": [],
190161
"source": [
191162
"print(torch.__version__)\n",
192163
"# Check that MPS is available\n",
@@ -213,7 +184,7 @@
213184
},
214185
{
215186
"cell_type": "code",
216-
"execution_count": 7,
187+
"execution_count": null,
217188
"metadata": {},
218189
"outputs": [],
219190
"source": [
@@ -230,7 +201,7 @@
230201
},
231202
{
232203
"cell_type": "code",
233-
"execution_count": 8,
204+
"execution_count": null,
234205
"metadata": {},
235206
"outputs": [],
236207
"source": [
@@ -251,36 +222,17 @@
251222
},
252223
{
253224
"cell_type": "code",
254-
"execution_count": 9,
255-
"metadata": {},
256-
"outputs": [
257-
{
258-
"name": "stdout",
259-
"output_type": "stream",
260-
"text": [
261-
"Files already downloaded and verified\n",
262-
"Files already downloaded and verified\n"
263-
]
264-
},
265-
{
266-
"data": {
267-
"text/plain": [
268-
"((50000, 32, 32, 3), (10000, 32, 32, 3))"
269-
]
270-
},
271-
"execution_count": 9,
272-
"metadata": {},
273-
"output_type": "execute_result"
274-
}
275-
],
225+
"execution_count": null,
226+
"metadata": {},
227+
"outputs": [],
276228
"source": [
277229
"train, test = load_data()\n",
278230
"train.data.shape, test.data.shape"
279231
]
280232
},
281233
{
282234
"cell_type": "code",
283-
"execution_count": 10,
235+
"execution_count": null,
284236
"metadata": {},
285237
"outputs": [],
286238
"source": [
@@ -303,7 +255,7 @@
303255
},
304256
{
305257
"cell_type": "code",
306-
"execution_count": 11,
258+
"execution_count": null,
307259
"metadata": {},
308260
"outputs": [],
309261
"source": [
@@ -346,7 +298,7 @@
346298
},
347299
{
348300
"cell_type": "code",
349-
"execution_count": 12,
301+
"execution_count": null,
350302
"metadata": {},
351303
"outputs": [],
352304
"source": [
@@ -476,7 +428,7 @@
476428
},
477429
{
478430
"cell_type": "code",
479-
"execution_count": 13,
431+
"execution_count": null,
480432
"metadata": {},
481433
"outputs": [],
482434
"source": [
@@ -504,7 +456,7 @@
504456
},
505457
{
506458
"cell_type": "code",
507-
"execution_count": 14,
459+
"execution_count": null,
508460
"metadata": {},
509461
"outputs": [],
510462
"source": [
@@ -544,7 +496,7 @@
544496
},
545497
{
546498
"cell_type": "code",
547-
"execution_count": 15,
499+
"execution_count": null,
548500
"metadata": {},
549501
"outputs": [],
550502
"source": [
@@ -553,6 +505,7 @@
553505
"shuffle = True\n",
554506
"eval = \"train_hold_out\"\n",
555507
"device = DEVICE\n",
508+
"show_batch_interval = 100_000\n",
556509
"\n",
557510
"fun_control.update({\n",
558511
" \"data_dir\": None,\n",
@@ -568,6 +521,7 @@
568521
" \"shuffle\": shuffle,\n",
569522
" \"eval\": eval,\n",
570523
" \"device\": device,\n",
524+
" \"show_batch_interval\": show_batch_interval,\n",
571525
" })"
572526
]
573527
},
@@ -597,7 +551,7 @@
597551
},
598552
{
599553
"cell_type": "code",
600-
"execution_count": 16,
554+
"execution_count": null,
601555
"metadata": {},
602556
"outputs": [],
603557
"source": [
@@ -612,24 +566,9 @@
612566
},
613567
{
614568
"cell_type": "code",
615-
"execution_count": 17,
616-
"metadata": {},
617-
"outputs": [
618-
{
619-
"name": "stdout",
620-
"output_type": "stream",
621-
"text": [
622-
"| name | type | default | lower | upper | transform |\n",
623-
"|------------|--------|-----------|---------|---------|-----------------------|\n",
624-
"| l1 | int | 5 | 2 | 9 | transform_power_2_int |\n",
625-
"| l2 | int | 5 | 2 | 9 | transform_power_2_int |\n",
626-
"| lr | float | 0.001 | 1e-05 | 0.01 | None |\n",
627-
"| batch_size | int | 4 | 1 | 4 | transform_power_2_int |\n",
628-
"| epochs | int | 3 | 3 | 4 | transform_power_2_int |\n",
629-
"| k_folds | int | 2 | 2 | 2 | None |\n"
630-
]
631-
}
632-
],
569+
"execution_count": null,
570+
"metadata": {},
571+
"outputs": [],
633572
"source": [
634573
"print(gen_design_table(fun_control))"
635574
]
@@ -647,20 +586,9 @@
647586
},
648587
{
649588
"cell_type": "code",
650-
"execution_count": 18,
651-
"metadata": {},
652-
"outputs": [
653-
{
654-
"data": {
655-
"text/plain": [
656-
"array([[5.e+00, 5.e+00, 1.e-03, 4.e+00, 3.e+00, 2.e+00]])"
657-
]
658-
},
659-
"execution_count": 18,
660-
"metadata": {},
661-
"output_type": "execute_result"
662-
}
663-
],
589+
"execution_count": null,
590+
"metadata": {},
591+
"outputs": [],
664592
"source": [
665593
"from spotPython.hyperparameters.values import get_default_hyperparameters_as_array\n",
666594
"hyper_dict=TorchHyperDict().load()\n",
@@ -670,49 +598,9 @@
670598
},
671599
{
672600
"cell_type": "code",
673-
"execution_count": 19,
674-
"metadata": {},
675-
"outputs": [
676-
{
677-
"name": "stdout",
678-
"output_type": "stream",
679-
"text": [
680-
"Additional attributes: {'lr', 'batch_size', 'k_folds', 'epochs'}\n",
681-
"Removed attributes: {'lr', 'batch_size', 'k_folds', 'epochs'}\n",
682-
"Epoch: 1\n",
683-
"Batch: 1000. Batch Size: 8. Training Loss (running): 2.126\n",
684-
"Batch: 2000. Batch Size: 8. Training Loss (running): 0.898\n",
685-
"Batch: 3000. Batch Size: 8. Training Loss (running): 0.552\n",
686-
"Loss on hold-out set: 1.590382700252533\n",
687-
"Accuracy on hold-out set: 0.4229\n",
688-
"Epoch: 2\n",
689-
"Batch: 1000. Batch Size: 8. Training Loss (running): 1.524\n",
690-
"Batch: 2000. Batch Size: 8. Training Loss (running): 0.733\n",
691-
"Batch: 3000. Batch Size: 8. Training Loss (running): 0.476\n",
692-
"Loss on hold-out set: 1.4275218252420425\n",
693-
"Accuracy on hold-out set: 0.4857\n",
694-
"Epoch: 3\n",
695-
"Batch: 1000. Batch Size: 8. Training Loss (running): 1.353\n",
696-
"Batch: 2000. Batch Size: 8. Training Loss (running): 0.664\n",
697-
"Batch: 3000. Batch Size: 8. Training Loss (running): 0.436\n",
698-
"Loss on hold-out set: 1.3183471252799035\n",
699-
"Accuracy on hold-out set: 0.5349\n",
700-
"Epoch: 4\n",
701-
"Batch: 1000. Batch Size: 8. Training Loss (running): 1.227\n",
702-
"Batch: 2000. Batch Size: 8. Training Loss (running): 0.623\n",
703-
"Batch: 3000. Batch Size: 8. Training Loss (running): 0.414\n",
704-
"Loss on hold-out set: 1.3052782666921616\n",
705-
"Accuracy on hold-out set: 0.5461\n",
706-
"Epoch: 5\n",
707-
"Batch: 1000. Batch Size: 8. Training Loss (running): 1.155\n",
708-
"Batch: 2000. Batch Size: 8. Training Loss (running): 0.591\n",
709-
"Batch: 3000. Batch Size: 8. Training Loss (running): 0.378\n",
710-
"Loss on hold-out set: 1.2335647928357125\n",
711-
"Accuracy on hold-out set: 0.57405\n",
712-
"Epoch: 6\n"
713-
]
714-
}
715-
],
601+
"execution_count": null,
602+
"metadata": {},
603+
"outputs": [],
716604
"source": [
717605
"spot_tuner = spot.Spot(fun=fun,\n",
718606
" lower = lower,\n",
@@ -967,7 +855,7 @@
967855
},
968856
{
969857
"cell_type": "code",
970-
"execution_count": null,
858+
"execution_count": 35,
971859
"metadata": {},
972860
"outputs": [],
973861
"source": [

0 commit comments

Comments
 (0)