Skip to content

Commit c9a30ee

Browse files
0.22.0
save + load updated
1 parent 6482831 commit c9a30ee

7 files changed

Lines changed: 82 additions & 72 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.21.7"
10+
version = "0.22.0"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/testmodel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
110110
trainer.fit(model=model, datamodule=dm)
111111

112112
# Load the last checkpoint
113-
test_result = trainer.test(datamodule=dm, ckpt_path="last")
113+
# test_result = trainer.test(datamodule=dm, ckpt_path="last")
114+
test_result = trainer.test(datamodule=dm, ckpt_path="best")
114115
test_result = test_result[0]
115116
print(f"test_model result: {test_result}")
116117
return test_result["val_loss"], test_result["hp_metric"]

src/spotpython/spot/spot.py

Lines changed: 37 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -316,32 +316,6 @@ def __init__(
316316
if self.surrogate_control["n_theta"] > 1:
317317
surrogate_control.update({"n_theta": self.k})
318318

319-
# # If no surrogate model is specified, use the internal
320-
# # spotpython kriging surrogate:
321-
# if self.surrogate is None:
322-
# # Call kriging with surrogate_control parameters:
323-
# self.surrogate = Kriging(
324-
# name="kriging",
325-
# noise=self.surrogate_control["noise"],
326-
# model_optimizer=self.surrogate_control["model_optimizer"],
327-
# model_fun_evals=self.surrogate_control["model_fun_evals"],
328-
# seed=self.surrogate_control["seed"],
329-
# log_level=self.log_level,
330-
# min_theta=self.surrogate_control["min_theta"],
331-
# max_theta=self.surrogate_control["max_theta"],
332-
# metric_factorial=self.surrogate_control["metric_factorial"],
333-
# n_theta=self.surrogate_control["n_theta"],
334-
# theta_init_zero=self.surrogate_control["theta_init_zero"],
335-
# p_val=self.surrogate_control["p_val"],
336-
# n_p=self.surrogate_control["n_p"],
337-
# optim_p=self.surrogate_control["optim_p"],
338-
# min_Lambda=self.surrogate_control["min_Lambda"],
339-
# max_Lambda=self.surrogate_control["max_Lambda"],
340-
# var_type=self.surrogate_control["var_type"],
341-
# spot_writer=self.spot_writer,
342-
# counter=self.design_control["init_size"] * self.design_control["repeats"] - 1,
343-
# )
344-
345319
# Internal attributes:
346320
self.X = None
347321
self.y = None
@@ -355,6 +329,40 @@ def __init__(
355329
self.mean_y = None
356330
self.var_y = None
357331

332+
# save experiment must be called before the spot_writer is initialized
333+
if self.fun_control.get("save_experiment"):
334+
filename = self.fun_control.get("PREFIX") + "_exp.pkl"
335+
self.save_experiment(filename=filename, verbosity=self.verbosity)
336+
337+
# Tensorboard must be initialized before the surrogate model:
338+
self.init_spot_writer()
339+
340+
# If no surrogate model is specified, use the internal
341+
# spotpython kriging surrogate:
342+
if self.surrogate is None:
343+
# Call kriging with surrogate_control parameters:
344+
self.surrogate = Kriging(
345+
name="kriging",
346+
noise=self.surrogate_control["noise"],
347+
model_optimizer=self.surrogate_control["model_optimizer"],
348+
model_fun_evals=self.surrogate_control["model_fun_evals"],
349+
seed=self.surrogate_control["seed"],
350+
log_level=self.log_level,
351+
min_theta=self.surrogate_control["min_theta"],
352+
max_theta=self.surrogate_control["max_theta"],
353+
metric_factorial=self.surrogate_control["metric_factorial"],
354+
n_theta=self.surrogate_control["n_theta"],
355+
theta_init_zero=self.surrogate_control["theta_init_zero"],
356+
p_val=self.surrogate_control["p_val"],
357+
n_p=self.surrogate_control["n_p"],
358+
optim_p=self.surrogate_control["optim_p"],
359+
min_Lambda=self.surrogate_control["min_Lambda"],
360+
max_Lambda=self.surrogate_control["max_Lambda"],
361+
var_type=self.surrogate_control["var_type"],
362+
spot_writer=self.spot_writer,
363+
counter=self.design_control["init_size"] * self.design_control["repeats"] - 1,
364+
)
365+
358366
logger.setLevel(self.log_level)
359367
logger.info(f"Starting the logger at level {self.log_level} for module {__name__}:")
360368
logger.debug("In Spot() init(): fun_control: %s", self.fun_control)
@@ -791,35 +799,6 @@ def run(self, X_start: np.ndarray = None) -> Spot:
791799
3.7179535332164810e-04])
792800
793801
"""
794-
# Tensorboard:
795-
self.init_spot_writer()
796-
797-
# If no surrogate model is specified, use the internal
798-
# spotpython kriging surrogate:
799-
if self.surrogate is None:
800-
# Call kriging with surrogate_control parameters:
801-
self.surrogate = Kriging(
802-
name="kriging",
803-
noise=self.surrogate_control["noise"],
804-
model_optimizer=self.surrogate_control["model_optimizer"],
805-
model_fun_evals=self.surrogate_control["model_fun_evals"],
806-
seed=self.surrogate_control["seed"],
807-
log_level=self.log_level,
808-
min_theta=self.surrogate_control["min_theta"],
809-
max_theta=self.surrogate_control["max_theta"],
810-
metric_factorial=self.surrogate_control["metric_factorial"],
811-
n_theta=self.surrogate_control["n_theta"],
812-
theta_init_zero=self.surrogate_control["theta_init_zero"],
813-
p_val=self.surrogate_control["p_val"],
814-
n_p=self.surrogate_control["n_p"],
815-
optim_p=self.surrogate_control["optim_p"],
816-
min_Lambda=self.surrogate_control["min_Lambda"],
817-
max_Lambda=self.surrogate_control["max_Lambda"],
818-
var_type=self.surrogate_control["var_type"],
819-
spot_writer=self.spot_writer,
820-
counter=self.design_control["init_size"] * self.design_control["repeats"] - 1,
821-
)
822-
823802
self.initialize_design(X_start)
824803
self.update_stats()
825804
self.fit_surrogate()
@@ -837,8 +816,9 @@ def run(self, X_start: np.ndarray = None) -> Spot:
837816
if self.fun_control.get("db_dict_name") is not None:
838817
self._write_db_dict()
839818

840-
if self.fun_control.get("save_experiment"):
841-
self.save_experiment()
819+
if self.fun_control.get("save_result"):
820+
filename = self.fun_control.get("PREFIX") + "_res.pkl"
821+
self.save_experiment(filename=filename, verbosity=self.verbosity)
842822
return self
843823

844824
def initialize_design(self, X_start=None) -> None:

src/spotpython/utils/file.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,28 @@ def get_experiment_filename(PREFIX) -> str:
9595
return filename
9696

9797

98-
def load_experiment(PKL_NAME=None, PREFIX=None):
98+
def load_result(PREFIX) -> None:
99+
"""Loads the result from a pickle file with the name
100+
PREFIX + "_res.pkl".
101+
This is the standard filename for the result file,
102+
when it is saved by the spot tuner using `save_result()`, i.e.,
103+
when fun_control["save_result"] is set to True.
104+
105+
Args:
106+
PREFIX (str): Prefix of the experiment.
107+
108+
Examples:
109+
>>> from spotpython.utils.file import load_result
110+
>>> load_result("branin")
111+
112+
"""
113+
if PREFIX is None:
114+
raise ValueError("No PREFIX provided.")
115+
PKL_NAME = PREFIX + "_res.pkl"
116+
load_experiment(PKL_NAME)
117+
118+
119+
def load_experiment(PREFIX=None, PKL_NAME=None):
99120
"""
100121
Loads the experiment from a pickle file.
101122
If PKL_NAME is None and PREFIX is not None, the experiment is loaded based on the PREFIX
@@ -105,8 +126,8 @@ def load_experiment(PKL_NAME=None, PREFIX=None):
105126
and `None` is assigned to the corresponding variables.
106127
107128
Args:
108-
PKL_NAME (str): Name of the pickle file. Defaults to None.
109129
PREFIX (str, optional): Prefix of the experiment. Defaults to None.
130+
PKL_NAME (str): Name of the pickle file. Defaults to None.
110131
111132
Returns:
112133
spot_tuner (object): The spot tuner object.
@@ -120,7 +141,7 @@ def load_experiment(PKL_NAME=None, PREFIX=None):
120141
121142
Examples:
122143
>>> from spotpython.utils.file import load_experiment
123-
>>> spot_tuner, fun_control, design_control, _, _ = load_experiment("RUN_0.pkl")
144+
>>> spot_tuner, fun_control, design_control, _, _ = load_experiment(PKL_NAME="RUN_0.pkl")
124145
125146
"""
126147
if PKL_NAME is None and PREFIX is not None:

src/spotpython/utils/init.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def fun_control_init(
7070
prep_model=None,
7171
prep_model_name=None,
7272
progress_file=None,
73-
save_experiment=False,
73+
save_experiment=True,
74+
save_result=True,
7475
scaler=None,
7576
scaler_name=None,
7677
scenario=None,
@@ -234,7 +235,9 @@ def fun_control_init(
234235
progress_file (str):
235236
The name of the progress file. Default is None.
236237
save_experiment (bool):
237-
Whether to save the experiment or not. Default is False.
238+
Whether to save the experiment before the run is started or not. Default is False.
239+
save_result (bool):
240+
Whether to save the result after the experiment is done or not. Default is False.
238241
scaler (object):
239242
The scaler object, e.g., the TorchStandard scaler from spot.utils.scaler.py.
240243
Default is None.
@@ -450,6 +453,7 @@ def fun_control_init(
450453
"prep_model_name": prep_model_name,
451454
"progress_file": progress_file,
452455
"save_experiment": save_experiment,
456+
"save_result": save_result,
453457
"save_model": False,
454458
"scaler": scaler,
455459
"scaler_name": scaler_name,

test/test_generate_random_point.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22
import numpy as np
33
from unittest.mock import Mock
4-
from spotpython.spot import spot
4+
from spotpython.spot import Spot
5+
from spotpython.fun import Analytical
56
from spotpython.utils.init import fun_control_init
67

78

@@ -13,10 +14,13 @@ def setup_spot():
1314
)
1415
return fun_control
1516

17+
def fun_nan(X, fun_control):
18+
return np.array([np.nan])
19+
1620

1721
def test_generate_random_point(setup_spot):
18-
fun = Mock(return_value=np.array([0])) # Replace with valid function
19-
S = spot.Spot(fun=fun, fun_control=setup_spot)
22+
fun = Analytical().fun_sphere
23+
S = Spot(fun=fun, fun_control=setup_spot)
2024
X0, y0 = S.generate_random_point()
2125

2226
print(f"X0: {X0}")
@@ -30,8 +34,8 @@ def test_generate_random_point(setup_spot):
3034

3135

3236
def test_generate_random_point_with_nan(setup_spot):
33-
fun = Mock(return_value=np.array([np.nan])) # Function that returns NaN
34-
S = spot.Spot(fun=fun, fun_control=setup_spot)
37+
fun = fun_nan
38+
S = Spot(fun=fun, fun_control=setup_spot)
3539
X0, y0 = S.generate_random_point()
3640

3741
print(f"X0 with NaN: {X0}")

test/test_to_all_dim.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import numpy as np
3-
from spotpython.fun.objectivefunctions import Analytical
4-
from spotpython.spot.spot import Spot
3+
from spotpython.fun import Analytical
4+
from spotpython.spot import Spot
55
from spotpython.utils.init import fun_control_init, design_control_init, optimizer_control_init, surrogate_control_init
66

77
def test_to_all_dim():
@@ -14,7 +14,7 @@ def test_to_all_dim():
1414
surrogate_control = surrogate_control_init()
1515

1616
spot_instance = Spot(
17-
fun=lambda x: x, # Dummy function
17+
fun = Analytical().fun_sphere, # Dummy function
1818
fun_control=fun_control,
1919
design_control=design_control,
2020
optimizer_control=optimizer_control,

0 commit comments

Comments
 (0)