Skip to content

Commit 3ab8d36

Browse files
0.14.4
1 parent e48635f commit 3ab8d36

7 files changed

Lines changed: 143 additions & 17 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.14.3"
10+
version = "0.14.4"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/transformer/encoderblock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44

55
class EncoderBlock(nn.Module):
6-
def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0):
6+
def __init__(self, input_dim, num_heads, dim_feedforward, dropout=0.0) -> None:
77
"""
88
Initializes the EncoderBlock object.
99

src/spotPython/spot/spot.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
from __future__ import annotations
22

3+
import pprint
4+
import os
5+
import copy
6+
import json
37
from numpy.random import default_rng
48
from spotPython.design.spacefilling import spacefilling
59
from spotPython.build.kriging import Kriging
@@ -35,6 +39,7 @@
3539
)
3640
import plotly.graph_objects as go
3741
from typing import Callable
42+
from spotPython.utils.numpy2json import NumpyEncoder
3843

3944

4045
logger = logging.getLogger(__name__)
@@ -612,6 +617,88 @@ def get_new_X0(self) -> np.array:
612617
logger.warning("No new XO found on surrogate. Generate new solution %s", X0)
613618
return X0
614619

620+
def write_db_dict(self) -> None:
621+
"""Writes a dictionary with the experiment parameters to the json file spotPython_db.json.
622+
623+
Args:
624+
self (object): Spot object
625+
626+
Returns:
627+
(NoneType): None
628+
629+
"""
630+
# get the time in seconds from 1.1.1970 and convert the time to a string
631+
t_str = str(time.time())
632+
ident = str(self.fun_control["PREFIX"]) + "_" + t_str
633+
634+
spot_tuner = copy.deepcopy(self)
635+
spot_tuner_control = vars(spot_tuner)
636+
637+
fun_control = copy.deepcopy(spot_tuner_control["fun_control"])
638+
design_control = copy.deepcopy(spot_tuner_control["design_control"])
639+
optimizer_control = copy.deepcopy(spot_tuner_control["optimizer_control"])
640+
surrogate_control = copy.deepcopy(spot_tuner_control["surrogate_control"])
641+
642+
# remove keys from the dictionaries:
643+
spot_tuner_control.pop("fun_control", None)
644+
spot_tuner_control.pop("design_control", None)
645+
spot_tuner_control.pop("optimizer_control", None)
646+
spot_tuner_control.pop("surrogate_control", None)
647+
spot_tuner_control.pop("spot_writer", None)
648+
spot_tuner_control.pop("design", None)
649+
spot_tuner_control.pop("fun", None)
650+
spot_tuner_control.pop("optimizer", None)
651+
spot_tuner_control.pop("rng", None)
652+
spot_tuner_control.pop("surrogate", None)
653+
654+
fun_control.pop("core_model", None)
655+
fun_control.pop("metric_river", None)
656+
fun_control.pop("metric_sklearn", None)
657+
fun_control.pop("metric_torch", None)
658+
fun_control.pop("prep_model", None)
659+
fun_control.pop("spot_writer", None)
660+
fun_control.pop("test", None)
661+
fun_control.pop("train", None)
662+
663+
surrogate_control.pop("model_optimizer", None)
664+
surrogate_control.pop("surrogate", None)
665+
666+
print("\n**********************")
667+
print("The following dictionaries are written to the json file spotPython_db.json:")
668+
print("fun_control:")
669+
pprint.pprint(fun_control)
670+
print("design_control:")
671+
pprint.pprint(design_control)
672+
print("optimizer_control:")
673+
pprint.pprint(optimizer_control)
674+
print("surrogate_control:")
675+
pprint.pprint(surrogate_control)
676+
print("spot_tuner_control:")
677+
pprint.pprint(spot_tuner_control)
678+
db_dict = {
679+
str(ident): {
680+
"fun_control": fun_control,
681+
"design_control": design_control,
682+
"surrogate_control": surrogate_control,
683+
"optimizer_control": optimizer_control,
684+
"spot_tuner_control": spot_tuner_control,
685+
}
686+
}
687+
688+
# check if the directory "db_dicts" exists.
689+
if not os.path.exists("db_dicts"):
690+
try:
691+
os.makedirs("db_dicts")
692+
except OSError as e:
693+
raise Exception(f"Error creating directory: {e}")
694+
if os.path.exists("db_dicts"):
695+
try:
696+
with open("db_dicts/" + self.fun_control["db_dict_name"], "a") as f:
697+
json.dump(db_dict, f, indent=4, cls=NumpyEncoder)
698+
f.close()
699+
except OSError as e:
700+
raise Exception(f"Error writing to file: {e}")
701+
615702
def run(self, X_start=None) -> Spot:
616703
self.initialize_design(X_start)
617704
# New: self.update_stats() moved here:
@@ -640,6 +727,9 @@ def run(self, X_start=None) -> Spot:
640727
if self.spot_writer is not None:
641728
writer = self.spot_writer
642729
writer.close()
730+
pprint.pprint(self.fun_control)
731+
if self.fun_control["db_dict_name"] is not None:
732+
self.write_db_dict()
643733
return self
644734

645735
def initialize_design(self, X_start=None) -> None:

src/spotPython/torch/cosinewarmupcheduler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@ class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler):
1010
warmup (int): The number of warmup steps.
1111
max_iters (int): The number of maximum iterations the model is trained for.
1212
13-
Returns:
14-
torch.optim.Optimizer: The optimizer with the learning rate updated.
15-
1613
Example:
1714
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
1815
>>> scheduler = CosineWarmupScheduler(optimizer, warmup=10, max_iters=100)

src/spotPython/utils/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def map_to_True_False(value):
141141
return False
142142

143143

144-
def sort_by_kth_and_return_indices(array, k):
144+
def sort_by_kth_and_return_indices(array, k) -> list:
145145
"""Sorts an array of arrays based on the k-th values in descending order and returns
146146
the indices of the original array entries.
147147

src/spotPython/utils/file.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,18 +74,51 @@ def load_pickle(filename: str):
7474
return obj
7575

7676

77+
def get_experiment_filename(PREFIX):
78+
"""Returns the name of the experiment.
79+
80+
Args:
81+
PREFIX (str): Prefix of the experiment.
82+
83+
Returns:
84+
filename (str): Name of the experiment.
85+
86+
Examples:
87+
>>> from spotPython.utils.file import get_experiment_name
88+
>>> from spotPython.utils.init import fun_control_init
89+
>>> fun_control = fun_control_init(PREFIX="branin")
90+
>>> PREFIX = fun_control["PREFIX"]
91+
>>> filename = get_experiment_filename(PREFIX)
92+
"""
93+
filename = "spot_" + PREFIX + "_experiment.pickle"
94+
return filename
95+
96+
7797
def save_experiment(
78-
spot_tuner, fun_control, design_control=None, surrogate_control=None, optimizer_control=None
98+
spot_tuner,
99+
fun_control,
100+
design_control=None,
101+
surrogate_control=None,
102+
optimizer_control=None,
103+
filename=None,
79104
) -> str:
80105
"""
81106
Saves the experiment as a pickle file.
82107
83108
Args:
84-
spot_tuner (object): The spot tuner object.
85-
fun_control (dict): The function control dictionary.
86-
design_control (dict, optional): The design control dictionary. Defaults to None.
87-
surrogate_control (dict, optional): The surrogate control dictionary. Defaults to None.
88-
optimizer_control (dict, optional): The optimizer control dictionary. Defaults to None.
109+
spot_tuner (object):
110+
The spot tuner object.
111+
fun_control (dict):
112+
The function control dictionary.
113+
design_control (dict, optional):
114+
The design control dictionary. Defaults to None.
115+
surrogate_control (dict, optional):
116+
The surrogate control dictionary. Defaults to None.
117+
optimizer_control (dict, optional):
118+
The optimizer control dictionary. Defaults to None.
119+
filename (str, optional):
120+
Name of the pickle file. Defaults to None.
121+
If None, the name is built as "spot_" + PREFIX + "_experiment.pickle".
89122
90123
Returns:
91124
PKL_NAME (str):
@@ -176,12 +209,13 @@ def save_experiment(
176209
# check if the key "spot_writer" is in the fun_control dictionary
177210
if "spot_writer" in fun_control and fun_control["spot_writer"] is not None:
178211
fun_control["spot_writer"].close()
179-
PREFIX = fun_control["PREFIX"]
180-
PKL_NAME = "spot_" + PREFIX + "_experiment.pickle"
181-
with open(PKL_NAME, "wb") as handle:
212+
if filename is None:
213+
PREFIX = fun_control["PREFIX"]
214+
filename = get_experiment_filename(PREFIX)
215+
with open(filename, "wb") as handle:
182216
pickle.dump(experiment, handle, protocol=pickle.HIGHEST_PROTOCOL)
183-
print(f"Experiment saved as {PKL_NAME}")
184-
return PKL_NAME
217+
print(f"Experiment saved as {filename}")
218+
return filename
185219

186220

187221
def load_experiment(PKL_NAME):

src/spotPython/utils/init.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def fun_control_init(
2424
data_module=None,
2525
data_set=None,
2626
data_set_name=None,
27+
db_dict_name=None,
2728
design=None,
2829
device=None,
2930
devices=1,
@@ -102,6 +103,8 @@ def fun_control_init(
102103
The data set object. Default is None.
103104
data_set_name (str):
104105
The name of the data set. Default is None.
106+
db_dict_name (str):
107+
The name of the database dictionary. Default is None.
105108
device (str):
106109
The device to use for the training. It can be either "cpu", "mps", or "cuda".
107110
devices (str or int):
@@ -243,6 +246,7 @@ def fun_control_init(
243246
'core_model_name': None,
244247
'data': None,
245248
'data_dir': './data',
249+
'db_dict_name': None,
246250
'device': None,
247251
'devices': "auto",
248252
'enable_progress_bar': False,
@@ -340,6 +344,7 @@ def fun_control_init(
340344
"data_module": data_module,
341345
"data_set": data_set,
342346
"data_set_name": data_set_name,
347+
"db_dict_name": db_dict_name,
343348
"design": design,
344349
"device": device,
345350
"devices": devices,

0 commit comments

Comments
 (0)