Skip to content

Commit d41b899

Browse files
wip
1 parent 5d027b4 commit d41b899

3 files changed

Lines changed: 56 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.13.7"
10+
version = "0.13.8"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/hyperparameters/values.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,3 +1366,40 @@ def get_tuned_hyperparameters(spot_tuner, fun_control=None) -> dict:
13661366
'initialization': 1.0}
13671367
"""
13681368
return spot_tuner.get_tuned_hyperparameters(fun_control=fun_control)
1369+
1370+
1371+
def update_fun_control(fun_control, new_control) -> dict:
1372+
for i, (key, value) in enumerate(new_control.items()):
1373+
if new_control[key]["type"] == "int":
1374+
set_control_hyperparameter_value(
1375+
fun_control,
1376+
key,
1377+
[
1378+
int(new_control[key]["lower"]),
1379+
int(new_control[key]["upper"]),
1380+
],
1381+
)
1382+
if (new_control[key]["type"] == "factor") and (new_control[key]["core_model_parameter_type"] == "bool"):
1383+
set_control_hyperparameter_value(
1384+
fun_control,
1385+
key,
1386+
[
1387+
int(new_control[key]["lower"]),
1388+
int(new_control[key]["upper"]),
1389+
],
1390+
)
1391+
if new_control[key]["type"] == "float":
1392+
set_control_hyperparameter_value(
1393+
fun_control,
1394+
key,
1395+
[
1396+
float(new_control[key]["lower"]),
1397+
float(new_control[key]["upper"]),
1398+
],
1399+
)
1400+
if new_control[key]["type"] == "factor" and new_control[key]["core_model_parameter_type"] != "bool":
1401+
fle = new_control[key]["levels"]
1402+
# convert the string to a list of strings
1403+
fle = fle.split()
1404+
set_control_hyperparameter_value(fun_control, key, fle)
1405+
fun_control["core_model_hyper_new_control"][key].update({"upper": len(fle) - 1})

src/spotPython/utils/init.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def fun_control_init(
3838
max_time=1,
3939
max_surrogate_points=30,
4040
metric_sklearn=None,
41+
metric_sklearn_name=None,
4142
noise=False,
4243
n_points=1,
4344
n_samples=None,
@@ -47,9 +48,11 @@ def fun_control_init(
4748
oml_grace_period=None,
4849
optimizer=None,
4950
prep_model=None,
51+
prep_model_name=None,
5052
seed=123,
5153
show_models=False,
5254
show_progress=True,
55+
shuffle=None,
5356
sigma=0.0,
5457
surrogate=None,
5558
target_column=None,
@@ -66,6 +69,7 @@ def fun_control_init(
6669
verbosity=0,
6770
weights=1.0,
6871
weight_coeff=0.0,
72+
weights_entry=None,
6973
):
7074
"""Initialize fun_control dictionary.
7175
@@ -138,6 +142,8 @@ def fun_control_init(
138142
The maximum number of points in the surrogate model. Default is inf.
139143
metric_sklearn (object):
140144
The metric object from the scikit-learn library. Default is None.
145+
metric_sklearn_name (str):
146+
The name of the metric object from the scikit-learn library. Default is None.
141147
noise (bool):
142148
Whether the objective function is noiy or not. Default is False.
143149
Affects the repeat of the function evaluations.
@@ -161,6 +167,8 @@ def fun_control_init(
161167
that us an instance of a SummaryWriter(), is created. Default is None.
162168
prep_model (object):
163169
The preprocessing model object. Used for river. Default is None.
170+
prep_model_name (str):
171+
The name of the preprocessing model. Default is None.
164172
seed (int):
165173
The seed to use for the random number generator. Default is 123.
166174
sigma (float):
@@ -170,6 +178,8 @@ def fun_control_init(
170178
show_models (bool):
171179
Plot model each generation.
172180
Currently only 1-dim functions are supported. Default is `False`.
181+
shuffle (bool):
182+
Whether the data were shuffled or not. Default is None.
173183
surrogate (object):
174184
The surrogate model object. Default is None.
175185
target_column (str):
@@ -210,6 +220,8 @@ def fun_control_init(
210220
Can be an array, so that different weights can be used for different (multiple) objectives.
211221
weight_coeff (float):
212222
Determines how to weight older measures. Default is 1.0. Used in the OML algorithm eval_oml.py.
223+
weights_entry (str):
224+
The weights entry used in the GUI. Default is None.
213225
214226
Returns:
215227
fun_control (dict):
@@ -243,6 +255,7 @@ def fun_control_init(
243255
'max_surrogate_points': 100,
244256
'metric_river': None,
245257
'metric_sklearn': None,
258+
'metric_sklearn_name': None,
246259
'metric_torch': None,
247260
'metric_params': {},
248261
'model_dict': {},
@@ -255,6 +268,7 @@ def fun_control_init(
255268
'optimizer': None,
256269
'path': None,
257270
'prep_model': None,
271+
prep_model_name': None,
258272
'save_model': False,
259273
'seed': 1234,
260274
'show_batch_interval': 1000000,
@@ -344,6 +358,7 @@ def fun_control_init(
344358
"max_surrogate_points": max_surrogate_points,
345359
"metric_river": None,
346360
"metric_sklearn": metric_sklearn,
361+
"metric_sklearn_name": metric_sklearn_name,
347362
"metric_torch": None,
348363
"metric_params": {},
349364
"model_dict": {},
@@ -357,12 +372,13 @@ def fun_control_init(
357372
"optimizer": optimizer,
358373
"path": None,
359374
"prep_model": prep_model,
375+
"prep_model_name": prep_model_name,
360376
"save_model": False,
361377
"seed": seed,
362378
"show_batch_interval": 1_000_000,
363379
"show_models": show_models,
364380
"show_progress": show_progress,
365-
"shuffle": None,
381+
"shuffle": shuffle,
366382
"sigma": sigma,
367383
"spot_tensorboard_path": spot_tensorboard_path,
368384
"spot_writer": spot_writer,
@@ -380,6 +396,7 @@ def fun_control_init(
380396
"verbosity": verbosity,
381397
"weights": weights,
382398
"weight_coeff": weight_coeff,
399+
"weights_entry": weights_entry
383400
}
384401
# lower = X_reshape(lower)
385402
# fun_control.update({"lower": lower})

0 commit comments

Comments
 (0)