Skip to content

Commit de4fe2c

Browse files
0.10.59
save_experiment() closes tensorflow file writer before saving to pkl
1 parent b6c104b commit de4fe2c

5 files changed

Lines changed: 79 additions & 12 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.58"
10+
version = "0.10.59"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/hyperparameters/values.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def return_conf_list_from_var_dict(
7575

7676
def iterate_dict_values(var_dict: Dict[str, np.ndarray]) -> Generator[Dict[str, Union[int, float]], None, None]:
7777
"""Iterate over the values of a dictionary of variables.
78-
7978
This function takes a dictionary of variables as input arguments and returns a generator that
8079
yields dictionaries with the values from the arrays in the input dictionary.
8180
@@ -100,7 +99,6 @@ def iterate_dict_values(var_dict: Dict[str, np.ndarray]) -> Generator[Dict[str,
10099

101100
def convert_keys(d: Dict[str, Union[int, float, str]], var_type: List[str]) -> Dict[str, Union[int, float]]:
102101
"""Convert values in a dictionary to integers based on a list of variable types.
103-
104102
This function takes a dictionary `d` and a list of variable types `var_type` as arguments.
105103
For each key in the dictionary,
106104
if the corresponding entry in `var_type` is not equal to `"num"`,
@@ -131,7 +129,6 @@ def convert_keys(d: Dict[str, Union[int, float, str]], var_type: List[str]) -> D
131129

132130
def get_dict_with_levels_and_types(fun_control: Dict[str, Any], v: Dict[str, Any]) -> Dict[str, Any]:
133131
"""Get dictionary with levels and types.
134-
135132
The function maps the numerical output of the hyperparameter optimization to the corresponding levels
136133
of the hyperparameter needed by the core model, i.e., the tuned algorithm.
137134
The function takes the dictionaries fun_control and v and returns a new dictionary with the same keys as v
@@ -300,6 +297,8 @@ def modify_hyper_parameter_levels(fun_control, hyperparameter, levels) -> None:
300297

301298
def modify_hyper_parameter_bounds(fun_control, hyperparameter, bounds) -> None:
302299
"""
300+
Modify the bounds of a hyperparameter in the fun_control dictionary.
301+
303302
Args:
304303
fun_control (dict):
305304
fun_control dictionary
@@ -545,7 +544,6 @@ def get_var_name(fun_control) -> list:
545544

546545
def get_bound_values(fun_control: dict, bound: str, as_list: bool = False) -> Union[List, np.ndarray]:
547546
"""Generate a list or array from a dictionary.
548-
549547
This function takes the values from the keys "bound" in the
550548
fun_control["core_model_hyper_dict"] dictionary and returns a list or array of the values
551549
in the same order as the keys in the dictionary.
@@ -744,6 +742,7 @@ def add_core_model_to_fun_control(fun_control, core_model, hyper_dict=None, file
744742

745743
def get_one_core_model_from_X(X, fun_control=None):
746744
"""Get one core model from X.
745+
747746
Args:
748747
X (np.array):
749748
The array with the hyper parameter values.
@@ -874,6 +873,7 @@ def get_one_river_model_from_X(X, fun_control=None):
874873

875874
def get_default_hyperparameters_as_array(fun_control) -> np.array:
876875
"""Get the default hyper parameters as array.
876+
877877
Args:
878878
fun_control (dict):
879879
The function control dictionary.
@@ -906,6 +906,7 @@ def get_default_hyperparameters_as_array(fun_control) -> np.array:
906906

907907
def get_default_hyperparameters_for_core_model(fun_control) -> dict:
908908
"""Get the default hyper parameters for the core model.
909+
909910
Args:
910911
fun_control (dict):
911912
The function control dictionary.

src/spotPython/utils/file.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,71 @@ def save_experiment(
8585
8686
Returns:
8787
PKL_NAME (str):
88-
Name of the pickle file. Build as "spot_" + PREFIX + "experiment.pickle".
88+
Name of the pickle file. Build as "spot_" + PREFIX + "_experiment.pickle".
8989
90+
Examples:
91+
>>> import os
92+
from spotPython.utils.file import save_experiment, load_experiment
93+
import numpy as np
94+
from math import inf
95+
from spotPython.spot import spot
96+
from spotPython.utils.init import (
97+
fun_control_init,
98+
design_control_init,
99+
surrogate_control_init,
100+
optimizer_control_init)
101+
from spotPython.fun.objectivefunctions import analytical
102+
fun = analytical().fun_branin
103+
fun_control = fun_control_init(
104+
PREFIX="branin",
105+
SUMMARY_WRITER=False,
106+
lower = np.array([0, 0]),
107+
upper = np.array([10, 10]),
108+
fun_evals=8,
109+
fun_repeats=1,
110+
max_time=inf,
111+
noise=False,
112+
tolerance_x=0,
113+
ocba_delta=0,
114+
var_type=["num", "num"],
115+
infill_criterion="ei",
116+
n_points=1,
117+
seed=123,
118+
log_level=20,
119+
show_models=False,
120+
show_progress=True)
121+
design_control = design_control_init(
122+
init_size=5,
123+
repeats=1)
124+
surrogate_control = surrogate_control_init(
125+
model_fun_evals=10000,
126+
min_theta=-3,
127+
max_theta=3,
128+
n_theta=2,
129+
theta_init_zero=True,
130+
n_p=1,
131+
optim_p=False,
132+
var_type=["num", "num"],
133+
seed=124)
134+
optimizer_control = optimizer_control_init(
135+
max_iter=1000,
136+
seed=125)
137+
spot_tuner = spot.Spot(fun=fun,
138+
fun_control=fun_control,
139+
design_control=design_control,
140+
surrogate_control=surrogate_control,
141+
optimizer_control=optimizer_control)
142+
# Call the save_experiment function
143+
pkl_name = save_experiment(
144+
spot_tuner=spot_tuner,
145+
fun_control=fun_control,
146+
design_control=None,
147+
surrogate_control=None,
148+
optimizer_control=None
149+
)
150+
# Call the load_experiment function
151+
(spot_tuner_1, fun_control_1, design_control_1,
152+
surrogate_control_1, optimizer_control_1) = load_experiment(pkl_name)
90153
"""
91154
if design_control is None:
92155
design_control = design_control_init()
@@ -106,8 +169,11 @@ def save_experiment(
106169
"surrogate_control": surrogate_control,
107170
"optimizer_control": optimizer_control,
108171
}
172+
# check if the key "spot_writer" is in the fun_control dictionary
173+
if "spot_writer" in fun_control and fun_control["spot_writer"] is not None:
174+
fun_control["spot_writer"].close()
109175
PREFIX = fun_control["PREFIX"]
110-
PKL_NAME = "spot_" + PREFIX + "experiment.pickle"
176+
PKL_NAME = "spot_" + PREFIX + "_experiment.pickle"
111177
with open(PKL_NAME, "wb") as handle:
112178
pickle.dump(experiment, handle, protocol=pickle.HIGHEST_PROTOCOL)
113179
print(f"Experiment saved as {PKL_NAME}")
@@ -136,6 +202,4 @@ def load_experiment(PKL_NAME):
136202
design_control = experiment["design_control"]
137203
surrogate_control = experiment["surrogate_control"]
138204
optimizer_control = experiment["optimizer_control"]
139-
# TODO: Add the key "spot_writer" to the fun_control dictionary,
140-
# because it was not saved in the pickle file.
141205
return spot_tuner, fun_control, design_control, surrogate_control, optimizer_control

src/spotPython/utils/init.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def fun_control_init(
2323
device=None,
2424
devices=1,
2525
enable_progress_bar=False,
26+
EXPERIMENT_NAME=None,
2627
fun_evals=15,
2728
fun_repeats=1,
2829
infill_criterion="y",
@@ -38,6 +39,7 @@ def fun_control_init(
3839
show_models=False,
3940
show_progress=True,
4041
sigma=0.0,
42+
surrogate=None,
4143
task=None,
4244
test_seed=1234,
4345
test_size=0.4,
@@ -48,12 +50,13 @@ def fun_control_init(
4850
verbosity=0,
4951
):
5052
"""Initialize fun_control dictionary.
53+
5154
Args:
5255
_L_in (int):
5356
The number of input features.
5457
_L_out (int):
5558
The number of output features.
56-
acceleration (str):
59+
accelerator (str):
5760
The accelerator to be used by the Lighting Trainer.
5861
It can be either "auto", "dp", "ddp", "ddp2", "ddp_spawn", "ddp_cpu", "gpu", "tpu".
5962
Default is "auto".
@@ -448,7 +451,6 @@ def surrogate_control_init(
448451
for many situations. For example, for k=2 these are 30 000 iterations.
449452
Therefore we set this value to 1000.
450453
451-
452454
"""
453455
surrogate_control = {
454456
"log_level": log_level,

test/test_file_save_load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_file_save_load():
6868
spot_tuner_1, fun_control_1, design_control_1, surrogate_control_1, optimizer_control_1 = load_experiment(pkl_name)
6969

7070
# Verify the name of the pickle file
71-
assert pkl_name == f"spot_{fun_control['PREFIX']}experiment.pickle"
71+
assert pkl_name == f"spot_{fun_control['PREFIX']}_experiment.pickle"
7272

7373
# Clean up the temporary directory
7474
os.remove(pkl_name)

0 commit comments

Comments
 (0)