Skip to content

Commit 24c0a68

Browse files
v0.0.41
1 parent fa740c6 commit 24c0a68

6 files changed

Lines changed: 235 additions & 27 deletions

File tree

notebooks/11_spot_hpt_torch.ipynb

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
{
3333
"data": {
3434
"text/plain": [
35-
"'10-sklearn_maans05_1min_10init_2023-04-26_09-52-27'"
35+
"'10-sklearn_maans05_1min_10init_2023-04-26_12-27-51'"
3636
]
3737
},
3838
"execution_count": 2,
@@ -112,7 +112,19 @@
112112
"cell_type": "code",
113113
"execution_count": 5,
114114
"metadata": {},
115-
"outputs": [],
115+
"outputs": [
116+
{
117+
"ename": "ModuleNotFoundError",
118+
"evalue": "No module named 'spotPython.torch'",
119+
"output_type": "error",
120+
"traceback": [
121+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
122+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
123+
"Cell \u001b[0;32mIn[5], line 53\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mspotPython\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mconvert\u001b[39;00m \u001b[39mimport\u001b[39;00m get_Xy_from_df\n\u001b[1;32m 52\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mspotPython\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mplot\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvalidation\u001b[39;00m \u001b[39mimport\u001b[39;00m plot_cv_predictions, plot_roc, plot_confusion_matrix\n\u001b[0;32m---> 53\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mspotPython\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mtorch\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mnet\u001b[39;00m \u001b[39mimport\u001b[39;00m Net_CIFAR10\n\u001b[1;32m 55\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpreprocessing\u001b[39;00m \u001b[39mimport\u001b[39;00m OneHotEncoder , MinMaxScaler, StandardScaler\n\u001b[1;32m 56\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpreprocessing\u001b[39;00m \u001b[39mimport\u001b[39;00m OrdinalEncoder\n",
124+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'spotPython.torch'"
125+
]
126+
}
127+
],
116128
"source": [
117129
"from tabulate import tabulate\n",
118130
"import copy\n",
@@ -200,7 +212,7 @@
200212
},
201213
{
202214
"cell_type": "code",
203-
"execution_count": 8,
215+
"execution_count": null,
204216
"metadata": {},
205217
"outputs": [],
206218
"source": [
@@ -225,7 +237,7 @@
225237
},
226238
{
227239
"cell_type": "code",
228-
"execution_count": 7,
240+
"execution_count": null,
229241
"metadata": {},
230242
"outputs": [],
231243
"source": [
@@ -253,7 +265,7 @@
253265
},
254266
{
255267
"cell_type": "code",
256-
"execution_count": 8,
268+
"execution_count": null,
257269
"metadata": {},
258270
"outputs": [],
259271
"source": [
@@ -274,7 +286,7 @@
274286
},
275287
{
276288
"cell_type": "code",
277-
"execution_count": 9,
289+
"execution_count": null,
278290
"metadata": {},
279291
"outputs": [],
280292
"source": [
@@ -310,7 +322,7 @@
310322
},
311323
{
312324
"cell_type": "code",
313-
"execution_count": 4,
325+
"execution_count": null,
314326
"metadata": {},
315327
"outputs": [],
316328
"source": [
@@ -328,7 +340,7 @@
328340
},
329341
{
330342
"cell_type": "code",
331-
"execution_count": 5,
343+
"execution_count": null,
332344
"metadata": {},
333345
"outputs": [],
334346
"source": [
@@ -349,7 +361,7 @@
349361
},
350362
{
351363
"cell_type": "code",
352-
"execution_count": 6,
364+
"execution_count": null,
353365
"metadata": {},
354366
"outputs": [
355367
{
@@ -381,7 +393,7 @@
381393
},
382394
{
383395
"cell_type": "code",
384-
"execution_count": 13,
396+
"execution_count": null,
385397
"metadata": {},
386398
"outputs": [
387399
{
@@ -410,7 +422,7 @@
410422
},
411423
{
412424
"cell_type": "code",
413-
"execution_count": 14,
425+
"execution_count": null,
414426
"metadata": {},
415427
"outputs": [],
416428
"source": [
@@ -433,7 +445,7 @@
433445
},
434446
{
435447
"cell_type": "code",
436-
"execution_count": 11,
448+
"execution_count": null,
437449
"metadata": {},
438450
"outputs": [],
439451
"source": [
@@ -459,18 +471,51 @@
459471
},
460472
{
461473
"cell_type": "code",
462-
"execution_count": 22,
474+
"execution_count": 23,
463475
"metadata": {},
464-
"outputs": [],
476+
"outputs": [
477+
{
478+
"ename": "NameError",
479+
"evalue": "name 'RidgeCV' is not defined",
480+
"output_type": "error",
481+
"traceback": [
482+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
483+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
484+
"Cell \u001b[0;32mIn[23], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m core_model \u001b[39m=\u001b[39m RidgeCV\n\u001b[1;32m 2\u001b[0m \u001b[39m#core_model = Net_CIFAR10\u001b[39;00m\n\u001b[1;32m 3\u001b[0m fun_control \u001b[39m=\u001b[39m add_core_model_to_fun_control(core_model\u001b[39m=\u001b[39mcore_model,\n\u001b[1;32m 4\u001b[0m fun_control\u001b[39m=\u001b[39mfun_control,\n\u001b[1;32m 5\u001b[0m hyper_dict\u001b[39m=\u001b[39mSklearnHyperDict,\n\u001b[1;32m 6\u001b[0m filename\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m)\n",
485+
"\u001b[0;31mNameError\u001b[0m: name 'RidgeCV' is not defined"
486+
]
487+
}
488+
],
465489
"source": [
466-
"# core_model = RidgeCV\n",
467-
"core_model = Net_CIFAR10\n",
490+
"core_model = RidgeCV\n",
491+
"#core_model = Net_CIFAR10\n",
468492
"fun_control = add_core_model_to_fun_control(core_model=core_model,\n",
469493
" fun_control=fun_control,\n",
470494
" hyper_dict=SklearnHyperDict,\n",
471495
" filename=None)"
472496
]
473497
},
498+
{
499+
"cell_type": "code",
500+
"execution_count": 24,
501+
"metadata": {},
502+
"outputs": [
503+
{
504+
"ename": "NameError",
505+
"evalue": "name 'core_model' is not defined",
506+
"output_type": "error",
507+
"traceback": [
508+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
509+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
510+
"Cell \u001b[0;32mIn[24], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m core_model\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\n",
511+
"\u001b[0;31mNameError\u001b[0m: name 'core_model' is not defined"
512+
]
513+
}
514+
],
515+
"source": [
516+
"core_model.__name__"
517+
]
518+
},
474519
{
475520
"attachments": {},
476521
"cell_type": "markdown",
@@ -549,6 +594,8 @@
549594
"weight_coeff = 1.0\n",
550595
"\n",
551596
"fun_control.update({\n",
597+
" \"data_dir\": None,\n",
598+
" \"checkpoint_dir\": None,\n",
552599
" \"horizon\": horizon,\n",
553600
" \"oml_grace_period\": oml_grace_period,\n",
554601
" \"weights\": weights,\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.0.40"
10+
version = "0.0.41"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
{
2+
"Net_CIFAR10":
3+
{
4+
"lr": {
5+
"type": "float",
6+
"default": 1e-03,
7+
"transform": "None",
8+
"lower": 1e-04,
9+
"upper": 1e-01},
10+
"l1": {
11+
"type": "int",
12+
"default": 5,
13+
"transform": "transform_power_2_int",
14+
"lower": 2,
15+
"upper": 9},
16+
"l2": {
17+
"type": "int",
18+
"default": 5,
19+
"transform": "transform_power_2_int",
20+
"lower": 2,
21+
"upper": 9},
22+
"batch_size": {
23+
"type": "int",
24+
"default": 4,
25+
"transform": "transform_power_2_int",
26+
"lower": 1,
27+
"upper": 4}
28+
},
29+
"Template":
30+
{
31+
"integer_hyperparameter": {
32+
"type": "int",
33+
"default": 200,
34+
"transform": "None",
35+
"lower": 10,
36+
"upper": 1000},
37+
"integer_hyperparameter_with_transformation": {
38+
"type": "int",
39+
"default": 20,
40+
"transform": "transform_power_2_int",
41+
"lower": 2,
42+
"upper": 20},
43+
"float_hyperparameter": {
44+
"type": "float",
45+
"default": 1e-07,
46+
"transform": "None",
47+
"lower": 1e-08,
48+
"upper": 1e-06},
49+
"factor_hyperparameter": {
50+
"levels": ["mc", "nb", "nba"],
51+
"type": "factor",
52+
"default": "nba",
53+
"transform": "None",
54+
"core_model_parameter_type": "str",
55+
"lower": 0,
56+
"upper": 2},
57+
"bool_hyperparameter": {
58+
"levels": [0, 1],
59+
"type": "factor",
60+
"default": 0, "transform": "None",
61+
"core_model_parameter_type": "bool",
62+
"lower": 0,
63+
"upper": 1}
64+
}
65+
}

src/spotPython/fun/hypertorch.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
from numpy import array
44
from sklearn.pipeline import make_pipeline
55
from spotPython.utils.convert import get_Xy_from_df
6+
from spotPython.utils.data import load_data
7+
import torch.nn as nn
8+
import torch.optim as optim
9+
import torch
10+
import os
11+
from torch.utils.data import random_split
612

713

814
from spotPython.hyperparameters.values import assign_values
@@ -64,23 +70,100 @@ def check_X_shape(self, X):
6470
raise Exception
6571

6672
def evaluate_model(self, model, fun_control):
73+
# TODO: config anpassen
6774
try:
75+
lr = fun_control["lr"]
76+
checkpoint_dir = fun_control["checkpoint_dir"]
77+
data_dir = fun_control["data_dir"]
78+
6879
X_train, y_train = get_Xy_from_df(fun_control["train"], fun_control["target_column"])
6980
X_test, y_test = get_Xy_from_df(fun_control["test"], fun_control["target_column"])
7081
model.fit(X_train, y_train)
71-
72-
73-
74-
75-
7682
df_preds = model.predict(X_test)
7783
df_eval = fun_control["metric_sklearn"](y_test, df_preds)
84+
#
85+
device = "cpu"
86+
# if torch.cuda.is_available():
87+
# device = "cuda:0"
88+
# if torch.cuda.device_count() > 1:
89+
# net = nn.DataParallel(net)
90+
model.to(device)
91+
92+
criterion = nn.CrossEntropyLoss()
93+
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
94+
95+
if checkpoint_dir:
96+
model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
97+
model.load_state_dict(model_state)
98+
optimizer.load_state_dict(optimizer_state)
99+
100+
trainset, testset = load_data(data_dir)
101+
102+
test_abs = int(len(trainset) * 0.8)
103+
train_subset, val_subset = random_split(trainset, [test_abs, len(trainset) - test_abs])
104+
105+
trainloader = torch.utils.data.DataLoader(
106+
train_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
107+
)
108+
valloader = torch.utils.data.DataLoader(
109+
val_subset, batch_size=int(config["batch_size"]), shuffle=True, num_workers=8
110+
)
111+
112+
for epoch in range(10): # loop over the dataset multiple times
113+
running_loss = 0.0
114+
epoch_steps = 0
115+
for i, data in enumerate(trainloader, 0):
116+
# get the inputs; data is a list of [inputs, labels]
117+
inputs, labels = data
118+
inputs, labels = inputs.to(device), labels.to(device)
119+
120+
# zero the parameter gradients
121+
optimizer.zero_grad()
122+
123+
# forward + backward + optimize
124+
outputs = model(inputs)
125+
loss = criterion(outputs, labels)
126+
loss.backward()
127+
optimizer.step()
128+
129+
# print statistics
130+
running_loss += loss.item()
131+
epoch_steps += 1
132+
if i % 2000 == 1999: # print every 2000 mini-batches
133+
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps))
134+
running_loss = 0.0
135+
136+
# Validation loss
137+
val_loss = 0.0
138+
val_steps = 0
139+
total = 0
140+
correct = 0
141+
for i, data in enumerate(valloader, 0):
142+
with torch.no_grad():
143+
inputs, labels = data
144+
inputs, labels = inputs.to(device), labels.to(device)
145+
146+
outputs = model(inputs)
147+
_, predicted = torch.max(outputs.data, 1)
148+
total += labels.size(0)
149+
correct += (predicted == labels).sum().item()
150+
151+
loss = criterion(outputs, labels)
152+
val_loss += loss.cpu().numpy()
153+
val_steps += 1
154+
155+
# TODO:
156+
# with tune.checkpoint_dir(epoch) as checkpoint_dir:
157+
path = os.path.join(checkpoint_dir, "checkpoint")
158+
torch.save((model.state_dict(), optimizer.state_dict()), path)
159+
df_eval = val_loss / val_steps
160+
df_preds = np.nan
161+
# accuracy = correct / total
78162
except Exception as err:
79163
print(f"Error in fun_sklearn(). Call to evaluate_model failed. {err=}, {type(err)=}")
80164
df_eval = np.nan
81-
df_eval = np.nan
165+
df_preds = np.nan
82166
return df_eval, df_preds
83-
84167

85168
def get_sklearn_df_eval_preds(self, model):
86169
try:
@@ -92,7 +175,7 @@ def get_sklearn_df_eval_preds(self, model):
92175
df_preds = np.nan
93176
return df_eval, df_preds
94177

95-
def fun_sklearn(self, X, fun_control=None):
178+
def fun_torch(self, X, fun_control=None):
96179
z_res = np.array([], dtype=float)
97180
self.fun_control.update(fun_control)
98181
self.check_X_shape(X)

0 commit comments

Comments
 (0)