Skip to content

Commit dcee62e

Browse files
v0.9.6
n_theta = k if n_theta > 1
1 parent bd278eb commit dcee62e

3 files changed

Lines changed: 23 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.9.5"
10+
version = "0.9.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/spot/spot.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,13 @@ def __init__(
332332
"seed": 124,
333333
}
334334
self.surrogate_control.update(surrogate_control)
335+
336+
# If self.surrogate_control["n_theta"] > 1, use k theta values:
337+
if self.surrogate_control["n_theta"] > 1:
338+
surrogate_control.update({"n_theta": self.k})
339+
else:
340+
surrogate_control.update({"n_theta": 1})
341+
335342
# If no surrogate model is specified, use the internal
336343
# spotPython kriging surrogate:
337344
if self.surrogate is None:

src/spotPython/utils/init.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,28 @@
11
import os
22
import lightning as L
33
import datetime
4+
from math import inf
45

56
# PyTorch TensorBoard support
67
from torch.utils.tensorboard import SummaryWriter
78

89

910
def fun_control_init(
10-
task="classification",
1111
_L_in=None,
1212
_L_out=None,
13-
enable_progress_bar=False,
14-
spot_tensorboard_path=None,
1513
TENSORBOARD_CLEAN=False,
16-
num_workers=0,
1714
device=None,
15+
enable_progress_bar=False,
16+
fun_evals=inf,
17+
log_level=10,
18+
max_time=1,
19+
num_workers=0,
1820
seed=1234,
1921
sigma=0.0,
22+
show_progress=False,
23+
spot_tensorboard_path=None,
24+
task="classification",
25+
tolerance_x=0,
2026
):
2127
"""Initialize fun_control dictionary.
2228
Args:
@@ -127,9 +133,11 @@ def fun_control_init(
127133
"device": device,
128134
"enable_progress_bar": enable_progress_bar,
129135
"eval": None,
130-
"fun_evals": 15,
136+
"fun_evals": fun_evals,
131137
"k_folds": 3,
138+
"log_level": log_level,
132139
"loss_function": None,
140+
"max_time": max_time,
133141
"metric_river": None,
134142
"metric_sklearn": None,
135143
"metric_torch": None,
@@ -143,9 +151,11 @@ def fun_control_init(
143151
"save_model": False,
144152
"seed": seed,
145153
"show_batch_interval": 1_000_000,
154+
"show_progress": show_progress,
146155
"shuffle": None,
147156
"sigma": sigma,
148157
"target_column": None,
158+
"tolerance_x": tolerance_x,
149159
"train": None,
150160
"test": None,
151161
"task": task,

0 commit comments

Comments
 (0)