Skip to content

Commit 5d7d910

Browse files
0.35.1
1 parent 6a80322 commit 5d7d910

4 files changed

Lines changed: 27 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.35.0"
10+
version = "0.35.1"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/spot/spot.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,13 @@ def __init__(
212212
self.optimizer_control = optimizer_control
213213
self.surrogate_control = surrogate_control
214214

215+
# Kernel selection from fun_control (NEW)
216+
self.kernel = None
217+
self.kernel_params = None
218+
if fun_control is not None:
219+
self.kernel = fun_control.get("kernel", "gauss")
220+
self.kernel_params = fun_control.get("kernel_params", {})
221+
215222
self.counter = 0
216223
self.success_rate = 0.0
217224
self.success_counter = 0
@@ -710,6 +717,8 @@ def surrogate_setup(self, surrogate) -> None:
710717
use_nystrom=self.surrogate_control["use_nystrom"],
711718
nystrom_m=self.surrogate_control["nystrom_m"],
712719
nystrom_seed=self.surrogate_control["nystrom_seed"],
720+
kernel=self.kernel,
721+
kernel_params=self.kernel_params,
713722
)
714723

715724
def get_spot_attributes_as_df(self) -> pd.DataFrame:
@@ -1695,6 +1704,10 @@ def fit_surrogate(self) -> None:
16951704
logger.debug("In fit_surrogate(): self.y: %s", self.y)
16961705
logger.debug("In fit_surrogate(): self.X.shape: %s", self.X.shape)
16971706
logger.debug("In fit_surrogate(): self.y.shape: %s", self.y.shape)
1707+
# Pass kernel options to surrogate if Kriging is used
1708+
if hasattr(self.surrogate, "kernel"):
1709+
self.surrogate.kernel = self.kernel
1710+
self.surrogate.kernel_params = self.kernel_params
16981711
X_points = self.X.shape[0]
16991712
y_points = self.y.shape[0]
17001713
if X_points == y_points:

src/spotpython/utils/init.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def fun_control_init(
5757
horizon=None,
5858
hyperdict=None,
5959
infill_criterion="y",
60+
kernel="gauss",
61+
kernel_params={},
6062
log_every_n_steps=50,
6163
log_level=50,
6264
lower=None,
@@ -208,6 +210,11 @@ def fun_control_init(
208210
For example: `spotriver.hyperdict.river_hyper_dict import RiverHyperDict`
209211
infill_criterion (str):
210212
Can be `"y"`, `"s"`, `"ei"` (negative expected improvement), or `"all"`. Default is "y".
213+
kernel (str):
214+
The kernel to be used by the Kriging surrogate model.
215+
Can be either "gauss", "matern32", "matern52", or "exponential". Default is "gauss".
216+
kernel_params (dict):
217+
The parameters for the kernel function. Default is an empty dictionary.
211218
log_every_n_steps (int):
212219
Lightning: How often to log within steps. Default: 50.
213220
log_level (int):
@@ -392,6 +399,8 @@ def fun_control_init(
392399
'horizon': 7,
393400
'infill_criterion': 'y',
394401
'k_folds': None,
402+
'kernel': 'gauss',
403+
'kernel_params': {},
395404
'loss_function': None,
396405
'lower': None,
397406
'max_surrogate_points': 100,
@@ -483,6 +492,8 @@ def fun_control_init(
483492
"hyperdict": hyperdict,
484493
"infill_criterion": infill_criterion,
485494
"k_folds": 3,
495+
"kernel": kernel,
496+
"kernel_params": kernel_params,
486497
"log_every_n_steps": log_every_n_steps,
487498
"log_graph": False,
488499
"log_level": log_level,

test/test_get_spot_attributes_as_df.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def test_get_spot_attributes_as_df():
5151
'ident',
5252
'infill_criterion',
5353
'k',
54+
'kernel',
55+
'kernel_params',
5456
'log_level',
5557
'lower',
5658
'max_surrogate_points',

0 commit comments

Comments
 (0)