Skip to content

Commit b349faa

Browse files
Update spot.py
1 parent 98b114e commit b349faa

1 file changed

Lines changed: 80 additions & 76 deletions

File tree

src/spotPython/spot/spot.py

Lines changed: 80 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,86 +1326,90 @@ def get_tuned_hyperparameters(self, fun_control=None) -> dict:
13261326
"""Return the tuned hyperparameter values from the run.
13271327
If `noise == True`, the mean values are returned.
13281328
1329+
Args:
1330+
fun_control (dict, optional):
1331+
fun_control dictionary
1332+
13291333
Returns:
13301334
(dict): dictionary of tuned hyperparameters.
13311335
13321336
Examples:
1333-
>>> from spotPython.utils.device import getDevice
1334-
from math import inf
1335-
from spotPython.utils.init import fun_control_init
1336-
import numpy as np
1337-
from spotPython.hyperparameters.values import set_control_key_value
1338-
from spotPython.data.diabetes import Diabetes
1339-
MAX_TIME = 1
1340-
FUN_EVALS = 10
1341-
INIT_SIZE = 5
1342-
WORKERS = 0
1343-
PREFIX="037"
1344-
DEVICE = getDevice()
1345-
DEVICES = 1
1346-
TEST_SIZE = 0.4
1347-
TORCH_METRIC = "mean_squared_error"
1348-
dataset = Diabetes()
1349-
fun_control = fun_control_init(
1350-
_L_in=10,
1351-
_L_out=1,
1352-
_torchmetric=TORCH_METRIC,
1353-
PREFIX=PREFIX,
1354-
TENSORBOARD_CLEAN=True,
1355-
data_set=dataset,
1356-
device=DEVICE,
1357-
enable_progress_bar=False,
1358-
fun_evals=FUN_EVALS,
1359-
log_level=50,
1360-
max_time=MAX_TIME,
1361-
num_workers=WORKERS,
1362-
show_progress=True,
1363-
test_size=TEST_SIZE,
1364-
tolerance_x=np.sqrt(np.spacing(1)),
1365-
)
1366-
from spotPython.light.regression.netlightregression import NetLightRegression
1367-
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
1368-
from spotPython.hyperparameters.values import add_core_model_to_fun_control
1369-
add_core_model_to_fun_control(fun_control=fun_control,
1370-
core_model=NetLightRegression,
1371-
hyper_dict=LightHyperDict)
1372-
from spotPython.hyperparameters.values import set_control_hyperparameter_value
1373-
set_control_hyperparameter_value(fun_control, "l1", [7, 8])
1374-
set_control_hyperparameter_value(fun_control, "epochs", [3, 5])
1375-
set_control_hyperparameter_value(fun_control, "batch_size", [4, 5])
1376-
set_control_hyperparameter_value(fun_control, "optimizer", [
1377-
"Adam",
1378-
"RAdam",
1379-
])
1380-
set_control_hyperparameter_value(fun_control, "dropout_prob", [0.01, 0.1])
1381-
set_control_hyperparameter_value(fun_control, "lr_mult", [0.5, 5.0])
1382-
set_control_hyperparameter_value(fun_control, "patience", [2, 3])
1383-
set_control_hyperparameter_value(fun_control, "act_fn",[
1384-
"ReLU",
1385-
"LeakyReLU"
1386-
] )
1387-
from spotPython.utils.init import design_control_init, surrogate_control_init
1388-
design_control = design_control_init(init_size=INIT_SIZE)
1389-
surrogate_control = surrogate_control_init(noise=True,
1390-
n_theta=2)
1391-
from spotPython.fun.hyperlight import HyperLight
1392-
fun = HyperLight(log_level=50).fun
1393-
from spotPython.spot import spot
1394-
spot_tuner = spot.Spot(fun=fun,
1395-
fun_control=fun_control,
1396-
design_control=design_control,
1397-
surrogate_control=surrogate_control)
1398-
spot_tuner.run()
1399-
spot_tuner.get_tuned_hyperparameters()
1400-
{'l1': 7.0,
1401-
'epochs': 5.0,
1402-
'batch_size': 4.0,
1403-
'act_fn': 0.0,
1404-
'optimizer': 0.0,
1405-
'dropout_prob': 0.01,
1406-
'lr_mult': 5.0,
1407-
'patience': 3.0,
1408-
'initialization': 1.0}
1337+
>>> from spotPython.utils.device import getDevice
1338+
from math import inf
1339+
from spotPython.utils.init import fun_control_init
1340+
import numpy as np
1341+
from spotPython.hyperparameters.values import set_control_key_value
1342+
from spotPython.data.diabetes import Diabetes
1343+
MAX_TIME = 1
1344+
FUN_EVALS = 10
1345+
INIT_SIZE = 5
1346+
WORKERS = 0
1347+
PREFIX="037"
1348+
DEVICE = getDevice()
1349+
DEVICES = 1
1350+
TEST_SIZE = 0.4
1351+
TORCH_METRIC = "mean_squared_error"
1352+
dataset = Diabetes()
1353+
fun_control = fun_control_init(
1354+
_L_in=10,
1355+
_L_out=1,
1356+
_torchmetric=TORCH_METRIC,
1357+
PREFIX=PREFIX,
1358+
TENSORBOARD_CLEAN=True,
1359+
data_set=dataset,
1360+
device=DEVICE,
1361+
enable_progress_bar=False,
1362+
fun_evals=FUN_EVALS,
1363+
log_level=50,
1364+
max_time=MAX_TIME,
1365+
num_workers=WORKERS,
1366+
show_progress=True,
1367+
test_size=TEST_SIZE,
1368+
tolerance_x=np.sqrt(np.spacing(1)),
1369+
)
1370+
from spotPython.light.regression.netlightregression import NetLightRegression
1371+
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
1372+
from spotPython.hyperparameters.values import add_core_model_to_fun_control
1373+
add_core_model_to_fun_control(fun_control=fun_control,
1374+
core_model=NetLightRegression,
1375+
hyper_dict=LightHyperDict)
1376+
from spotPython.hyperparameters.values import set_control_hyperparameter_value
1377+
set_control_hyperparameter_value(fun_control, "l1", [7, 8])
1378+
set_control_hyperparameter_value(fun_control, "epochs", [3, 5])
1379+
set_control_hyperparameter_value(fun_control, "batch_size", [4, 5])
1380+
set_control_hyperparameter_value(fun_control, "optimizer", [
1381+
"Adam",
1382+
"RAdam",
1383+
])
1384+
set_control_hyperparameter_value(fun_control, "dropout_prob", [0.01, 0.1])
1385+
set_control_hyperparameter_value(fun_control, "lr_mult", [0.5, 5.0])
1386+
set_control_hyperparameter_value(fun_control, "patience", [2, 3])
1387+
set_control_hyperparameter_value(fun_control, "act_fn",[
1388+
"ReLU",
1389+
"LeakyReLU"
1390+
] )
1391+
from spotPython.utils.init import design_control_init, surrogate_control_init
1392+
design_control = design_control_init(init_size=INIT_SIZE)
1393+
surrogate_control = surrogate_control_init(noise=True,
1394+
n_theta=2)
1395+
from spotPython.fun.hyperlight import HyperLight
1396+
fun = HyperLight(log_level=50).fun
1397+
from spotPython.spot import spot
1398+
spot_tuner = spot.Spot(fun=fun,
1399+
fun_control=fun_control,
1400+
design_control=design_control,
1401+
surrogate_control=surrogate_control)
1402+
spot_tuner.run()
1403+
spot_tuner.get_tuned_hyperparameters()
1404+
{'l1': 7.0,
1405+
'epochs': 5.0,
1406+
'batch_size': 4.0,
1407+
'act_fn': 0.0,
1408+
'optimizer': 0.0,
1409+
'dropout_prob': 0.01,
1410+
'lr_mult': 5.0,
1411+
'patience': 3.0,
1412+
'initialization': 1.0}
14091413
14101414
"""
14111415
output = []

0 commit comments

Comments
 (0)