Skip to content

Commit a571d86

Browse files
0.21.5
save experiment improved
1 parent cb31fec commit a571d86

3 files changed

Lines changed: 53 additions & 18 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.21.4"
10+
version = "0.21.5"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/spot/spot.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,14 +1052,23 @@ def write_tensorboard_log(self) -> None:
10521052
self.spot_writer.add_hparams(config, {"hp_metric": y_j})
10531053
self.spot_writer.flush()
10541054

1055-
def save_experiment(self, filename=None, path=None, overwrite=True) -> None:
1055+
def save_experiment(self, filename=None, path=None, overwrite=True, verbosity=0) -> None:
10561056
"""
10571057
Save the experiment to a file.
10581058
10591059
Args:
1060-
filename (str): The filename of the experiment file.
1061-
path (str): The path to the experiment file.
1062-
overwrite (bool): If `True`, the file will be overwritten if it already exists. Default is `True`.
1060+
filename (str):
1061+
The filename of the experiment file. If not provided,
1062+
the filename is generated based on the PREFIX using the
1063+
`get_experiment_filename()` function. Default is `None`.
1064+
path (str):
1065+
The path to the experiment file. If not provided, the file
1066+
is saved in the current working directory. Default is `None`.
1067+
overwrite (bool):
1068+
If `True`, the file will be overwritten if it already exists.
1069+
Default is `True`.
1070+
verbosity (int):
1071+
The level of verbosity. Default is 0.
10631072
10641073
Returns:
10651074
None
@@ -1079,14 +1088,14 @@ def save_experiment(self, filename=None, path=None, overwrite=True) -> None:
10791088
"design_control": design_control,
10801089
"fun_control": fun_control,
10811090
"optimizer_control": optimizer_control,
1082-
"spot_tuner": self._get_pickle_safe_spot_tuner(),
1091+
"spot_tuner": self._get_pickle_safe_spot_tuner(verbosity=verbosity),
10831092
"surrogate_control": surrogate_control,
10841093
}
10851094

10861095
# Determine the filename based on PREFIX if not provided
10871096
PREFIX = fun_control.get("PREFIX", "experiment")
10881097
if filename is None:
1089-
filename = self.get_experiment_filename(PREFIX)
1098+
filename = get_experiment_filename(PREFIX)
10901099

10911100
if path is not None:
10921101
filename = os.path.join(path, filename)
@@ -1126,21 +1135,45 @@ def _close_and_del_spot_writer(self) -> None:
11261135
self.spot_writer.close()
11271136
del self.spot_writer
11281137

1129-
def _get_pickle_safe_spot_tuner(self):
1138+
def _get_pickle_safe_spot_tuner(self, verbosity=0) -> Spot:
11301139
"""
11311140
Create a copy of self excluding unpickleable components for safe pickling.
11321141
This ensures no unpicklable components are passed to pickle.dump().
11331142
1143+
Args:
1144+
verbosity (int):
1145+
The level of verbosity. Default is 0.
1146+
11341147
Returns:
11351148
Spot: A copy of the Spot instance with unpickleable components removed.
11361149
"""
1137-
# Make a deepcopy and manually remove unpickleable components
1138-
spot_tuner = copy.deepcopy(self)
1139-
unpickleable_attrs = ["spot_writer", "logger", "fun", "optimizer", "surrogate"]
1140-
for attr in unpickleable_attrs:
1141-
if hasattr(spot_tuner, attr):
1142-
delattr(spot_tuner, attr)
1143-
return spot_tuner
1150+
# List of attributes that can't be pickled
1151+
unpickleable_attrs = ["spot_writer", "logger", "fun", "optimizer", "surrogate", "data_set", "scaler", "rng", "design"]
1152+
1153+
# Prepare a dictionary to store picklable state
1154+
picklable_state = {}
1155+
1156+
# Copy picklable attributes to the dictionary
1157+
for key, value in self.__dict__.items():
1158+
if key not in unpickleable_attrs:
1159+
try:
1160+
# Test if the attribute can be pickled
1161+
copy.deepcopy(value)
1162+
picklable_state[key] = value
1163+
if verbosity > 1:
1164+
print(f"Attribute {key} is picklable and will be included in the experiment file.")
1165+
except Exception:
1166+
if verbosity > 0:
1167+
print(f"Attribute {key} is not picklable and will be excluded from the experiment file.")
1168+
continue
1169+
1170+
# Use the dictionary to create a new instance
1171+
picklable_instance = self.__class__.__new__(self.__class__)
1172+
picklable_instance.__dict__.update(picklable_state)
1173+
if verbosity > 1:
1174+
print(f"Picklable instance created: {picklable_instance.__dict__}")
1175+
1176+
return picklable_instance
11441177

11451178
def init_spot_writer(self) -> None:
11461179
"""

src/spotpython/utils/file.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ def load_pickle(filename: str):
7070
return obj
7171

7272

73-
def get_experiment_filename(PREFIX):
73+
def get_experiment_filename(PREFIX) -> str:
7474
"""Returns the name of the experiment.
75+
This is the PREFIX with the suffix ".pkl".
76+
It is none, if PREFIX is None.
7577
7678
Args:
7779
PREFIX (str): Prefix of the experiment.
@@ -89,7 +91,7 @@ def get_experiment_filename(PREFIX):
8991
if PREFIX is None:
9092
return None
9193
else:
92-
filename = "spot_" + PREFIX + "_experiment.pickle"
94+
filename = PREFIX + ".pkl"
9395
return filename
9496

9597

@@ -118,7 +120,7 @@ def load_experiment(PKL_NAME=None, PREFIX=None):
118120
119121
Examples:
120122
>>> from spotpython.utils.file import load_experiment
121-
>>> spot_tuner, fun_control, design_control, _, _ = load_experiment("spot_0_experiment.pickle")
123+
>>> spot_tuner, fun_control, design_control, _, _ = load_experiment("RUN_0.pkl")
122124
123125
"""
124126
if PKL_NAME is None and PREFIX is not None:

0 commit comments

Comments
 (0)