Skip to content

Commit 959dfbf

Browse files
v0.0.68 New torch classes
1 parent cb9b2f1 commit 959dfbf

11 files changed

Lines changed: 4188 additions & 737 deletions

notebooks/11_spot_hpt_torch_fashion_mnist.ipynb

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"MAX_TIME = 600\n",
20-
"INIT_SIZE = 20"
19+
"MAX_TIME = 1\n",
20+
"INIT_SIZE = 5\n",
21+
"DEVICE = \"cpu\" # \"cuda:0\""
2122
]
2223
},
2324
{
@@ -43,7 +44,7 @@
4344
"metadata": {},
4445
"source": [
4546
"# Chapter 11: Sequential Parameter Optimization\n",
46-
"## Hyperparameter Tuning: pytorch with fashionMNIST Data "
47+
"## Hyperparameter Tuning: pytorch with fashionMNIST Data Using Hold-out Data Sets"
4748
]
4849
},
4950
{
@@ -133,6 +134,7 @@
133134
" iterate_dict_values,\n",
134135
")\n",
135136
"\n",
137+
"from spotPython.torch.traintest import evaluate_cv, evaluate_hold_out\n",
136138
"from spotPython.utils.convert import class_for_name\n",
137139
"from spotPython.utils.eda import (\n",
138140
" get_stars,\n",
@@ -149,7 +151,7 @@
149151
"warnings.filterwarnings(\"ignore\")\n",
150152
"\n",
151153
"# Neural Net specific imports:\n",
152-
"from spotPython.torch.netcvfashionMNIST import Net_CV_fashionMNIST"
154+
"from spotPython.torch.netfashionMNIST import Net_fashionMNIST"
153155
]
154156
},
155157
{
@@ -281,13 +283,29 @@
281283
"## 3. Select `algorithm` and `core_model_hyper_dict`"
282284
]
283285
},
286+
{
287+
"attachments": {},
288+
"cell_type": "markdown",
289+
"metadata": {},
290+
"source": [
291+
"`spotPython` implements a class which is similar to the class described in the PyTorch tutorial. The class is called `Net_fashionMNIST` and is implemented in the file `netcifar10.py`. The class is imported here.\n",
292+
"\n",
293+
"Note: In addition to the class Net from the PyTorch tutorial, the class Net_CIFAR10 has additional attributes, namely:\n",
294+
"\n",
295+
"* learning rate (`lr`),\n",
296+
"* batchsize (`batch_size`),\n",
297+
"* epochs (`epochs`), and\n",
298+
"* k_folds (`k_folds`).\n",
299+
"\n",
300+
"Further attributes can be easily added to the class, e.g., `optimizer` or `loss_function`."
301+
]
302+
},
284303
{
285304
"cell_type": "code",
286305
"execution_count": null,
287306
"metadata": {},
288307
"outputs": [],
289308
"source": [
290-
"# core_model = RidgeCV\n",
291309
"core_model = Net_CV_fashionMNIST\n",
292310
"fun_control = add_core_model_to_fun_control(core_model=core_model,\n",
293311
" fun_control=fun_control,\n",
@@ -381,6 +399,7 @@
381399
"weights = 1.0\n",
382400
"shuffle = True\n",
383401
"eval = \"train_hold_out\"\n",
402+
"device = DEVICE\n",
384403
"\n",
385404
"fun_control.update({\n",
386405
" \"data_dir\": None,\n",
@@ -690,7 +709,7 @@
690709
"metadata": {},
691710
"outputs": [],
692711
"source": [
693-
"model_default.evaluate_hold_out(dataset = testset, shuffle=False)"
712+
"evaluate_hold_out(model_default, dataset = testset, shuffle=False)"
694713
]
695714
},
696715
{
@@ -699,7 +718,7 @@
699718
"metadata": {},
700719
"outputs": [],
701720
"source": [
702-
"model_spot.evaluate_hold_out(dataset = testset, shuffle=False)"
721+
"evaluate_hold_out(model_spot, dataset = testset, shuffle=False)"
703722
]
704723
},
705724
{

0 commit comments

Comments
 (0)