Skip to content

Commit 0a9b540

Browse files
0.34.9
1 parent 6b0942a commit 0a9b540

25 files changed

Lines changed: 1550 additions & 4 deletions

notebooks/00_spotPython_tests.ipynb

Lines changed: 427 additions & 1 deletion
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.34.8"
10+
version = "0.34.9"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/spot/spot.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,14 @@ class Spot:
133133
Examples:
134134
>>> import numpy as np
135135
from math import inf
136+
from scipy.optimize import differential_evolution
136137
from spotpython.spot import spot
137138
from spotpython.utils.init import (
138139
fun_control_init,
139140
design_control_init,
140141
surrogate_control_init,
141142
optimizer_control_init)
142-
def objective_function(X, fun_control=None):
143+
def objective_function(X, **kwargs):
143144
if not isinstance(X, np.ndarray):
144145
X = np.array(X)
145146
if X.shape[1] != 2:
@@ -303,6 +304,22 @@ def _set_fun(self, fun):
303304
Exception: No objective function specified.
304305
Exception: Objective function is not callable
305306
307+
Examples:
308+
>>> import numpy as np
309+
from spotpython.spot import spot
310+
from spotpython.utils.init import fun_control_init
311+
312+
def objective_function(X):
313+
return np.sum(X, axis=1)
314+
315+
fun_control = fun_control_init(
316+
lower=np.array([0, 0]),
317+
upper=np.array([5, 5]),
318+
seed=12345
319+
)
320+
321+
S = spot.Spot(fun=objective_function, fun_control=fun_control)
322+
print(S.fun(np.array([[1, 2], [3, 4]])))
306323
"""
307324
self.fun = fun
308325
if self.fun is None:
@@ -318,6 +335,21 @@ def _set_bounds_and_dim(self) -> None:
318335
Returns:
319336
(NoneType): None
320337
338+
Examples:
339+
>>> import numpy as np
340+
from spotpython.spot import spot
341+
from spotpython.utils.init import fun_control_init
342+
343+
fun_control = fun_control_init(
344+
lower=np.array([-1, -1]),
345+
upper=np.array([1, 1])
346+
)
347+
348+
S = spot.Spot(fun=lambda x: x, fun_control=fun_control)
349+
print("Lower bounds:", S.lower)
350+
print("Upper bounds:", S.upper)
351+
print("Number of dimensions (k):", S.k)
352+
321353
"""
322354
# lower attribute updates:
323355
# if lower is in the fun_control dictionary, use the value of the key "lower" as the lower bound
@@ -336,6 +368,20 @@ def _set_var_type(self) -> None:
336368
Set the variable types based on the fun_control dictionary.
337369
If the variable types are not specified,
338370
all variable types are forced to 'num'.
371+
372+
Examples:
373+
>>> import numpy as np
374+
from spotpython.spot import spot
375+
from spotpython.utils.init import fun_control_init
376+
377+
fun_control = fun_control_init(
378+
lower=np.array([0, 0]),
379+
upper=np.array([10, 10]),
380+
var_type=["num", "cat"]
381+
)
382+
383+
S = spot.Spot(fun=lambda x: x, fun_control=fun_control)
384+
print("Variable types:", S.var_type)
339385
"""
340386
self.var_type = self.fun_control["var_type"]
341387
# Force numeric type as default in every dim:
@@ -345,11 +391,33 @@ def _set_var_type(self) -> None:
345391
self.var_type = self.var_type * self.k
346392
logger.warning("All variable types forced to 'num'.")
347393

394+
# --- check for allowed variable types ---
395+
allowed_types = {"num", "int", "float", "factor"}
396+
for vt in self.var_type:
397+
if vt not in allowed_types:
398+
raise ValueError(f"Invalid var_type '{vt}'. Allowed types are: {allowed_types}.")
399+
# "num" is the superset of "int" and "float"
400+
# (no further action needed, just a check)
401+
348402
def _set_var_name(self) -> None:
349403
"""
350404
Set the variable names based on the fun_control dictionary.
351405
If the variable names are not specified,
352406
all variable names are set to x0, x1, x2, ...
407+
408+
Examples:
409+
>>> import numpy as np
410+
from spotpython.spot import spot
411+
from spotpython.utils.init import fun_control_init
412+
413+
fun_control = fun_control_init(
414+
lower=np.array([0, 0]),
415+
upper=np.array([10, 10]),
416+
var_name=["length", "width"]
417+
)
418+
419+
S = spot.Spot(fun=lambda x: x, fun_control=fun_control)
420+
print("Variable names:", S.var_name)
353421
"""
354422
self.var_name = self.fun_control["var_name"]
355423
# use x0, x1, ... as default variable names:
@@ -523,6 +591,7 @@ def _design_setup(self, design) -> None:
523591
def _optimizer_setup(self, optimizer) -> None:
524592
"""
525593
Optimizer setup. If no optimizer is specified, use Differential Evolution.
594+
The current implementation of _optimizer_setup always overwrites self.optimizer with the default if optimizer is None, even if self.optimizer was already set.
526595
"""
527596
self.optimizer = optimizer
528597
if self.optimizer is None:
@@ -536,7 +605,7 @@ def _surrogate_control_setup(self) -> None:
536605
# to the value of the key "method" from the fun_control dictionary.
537606
# If the value is set (i.e., not None), it is not updated.
538607
if self.surrogate_control["method"] is None:
539-
self.surrogate_control.update({"method": self.fun_control.method})
608+
self.surrogate_control.update({"method": self.fun_control["method"]})
540609
if self.surrogate_control["model_fun_evals"] is None:
541610
self.surrogate_control.update({"model_fun_evals": self.optimizer_control["max_iter"]})
542611
# self.optimizer is not None here. If 1) the key "model_optimizer"

test/test_spot_design_setup.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
from spotpython.spot.spot import Spot
3+
from spotpython.utils.init import fun_control_init, design_control_init
4+
5+
def dummy_fun(X, **kwargs):
6+
return np.sum(X, axis=1)
7+
8+
def test_design_setup_default():
9+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
10+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
11+
# _design_setup is called in __init__, but let's call it again to check idempotency
12+
spot._design_setup(None)
13+
# Should use SpaceFilling by default
14+
from spotpython.design.spacefilling import SpaceFilling
15+
assert isinstance(spot.design, SpaceFilling)
16+
assert spot.design.k == 2
17+
18+
def test_design_setup_custom_object():
19+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
20+
class DummyDesign:
21+
def __init__(self, k):
22+
self.k = k
23+
custom_design = DummyDesign(k=2)
24+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
25+
spot._design_setup(custom_design)
26+
assert spot.design is custom_design
27+
assert spot.design.k == 2
28+
29+
def test_design_setup_with_different_k():
30+
fun_control = fun_control_init(lower=np.array([0, 0, 0]), upper=np.array([1, 1, 1]))
31+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
32+
spot._design_setup(None)
33+
assert spot.design.k == 3

test/test_spot_full_run.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import numpy as np
2+
import pytest
3+
from spotpython.spot.spot import Spot
4+
from spotpython.utils.init import fun_control_init, design_control_init
5+
6+
def dummy_fun(X, fun_control=None):
7+
X = np.atleast_2d(X)
8+
return np.sum(X**2, axis=1)
9+
10+
def test_run_basic_execution():
11+
lower = np.array([-1, -1])
12+
upper = np.array([1, 1])
13+
fun_control = fun_control_init(lower=lower, upper=upper, fun_evals=5)
14+
design_control = design_control_init(init_size=3)
15+
spot = Spot(fun=dummy_fun, fun_control=fun_control, design_control=design_control)
16+
result = spot.run()
17+
assert isinstance(result, Spot)
18+
# After run, X and y should be set and have at least fun_evals rows
19+
assert spot.X is not None
20+
assert spot.y is not None
21+
assert spot.X.shape[0] >= fun_control["fun_evals"]
22+
assert spot.y.shape[0] >= fun_control["fun_evals"]
23+
# min_y and min_X should be set
24+
assert hasattr(spot, "min_y")
25+
assert hasattr(spot, "min_X")
26+
27+
def test_run_with_initial_design():
28+
lower = np.array([0, 0])
29+
upper = np.array([1, 1])
30+
fun_control = fun_control_init(lower=lower, upper=upper, fun_evals=4)
31+
design_control = design_control_init(init_size=2)
32+
spot = Spot(fun=dummy_fun, fun_control=fun_control, design_control=design_control)
33+
X_start = np.array([[0, 0], [1, 1]])
34+
result = spot.run(X_start=X_start)
35+
assert isinstance(result, Spot)
36+
# X should contain the initial design
37+
assert np.any(np.all(spot.X == [0, 0], axis=1))
38+
assert np.any(np.all(spot.X == [1, 1], axis=1))
39+
assert spot.y.shape[0] >= fun_control["fun_evals"]
40+
41+
def test_run_sets_min_y_and_min_X():
42+
lower = np.array([-2, -2])
43+
upper = np.array([2, 2])
44+
fun_control = fun_control_init(lower=lower, upper=upper, fun_evals=6)
45+
design_control = design_control_init(init_size=3)
46+
spot = Spot(fun=dummy_fun, fun_control=fun_control, design_control=design_control)
47+
spot.run()
48+
# min_y should be the minimum of y
49+
assert np.isclose(spot.min_y, np.min(spot.y))
50+
# min_X should be the row in X corresponding to min_y
51+
idx = np.argmin(spot.y)
52+
np.testing.assert_array_equal(spot.min_X, spot.X[idx])

test/test_spot_get_counter.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
import pytest
3+
from spotpython.spot.spot import Spot
4+
from spotpython.utils.init import fun_control_init
5+
6+
def dummy_fun(X, **kwargs):
7+
return np.sum(X, axis=1)
8+
9+
def test_get_counter_initial():
10+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
11+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
12+
assert spot._get_counter() == 0
13+
14+
def test_get_counter_after_set():
15+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
16+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
17+
spot._set_counter(5)
18+
assert spot._get_counter() == 5
19+
20+
def test_get_counter_after_increment():
21+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
22+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
23+
spot._increment_counter(3)
24+
assert spot._get_counter() == 3
25+
26+
def test_get_counter_none_value():
27+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
28+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
29+
spot.counter = None
30+
assert spot._get_counter() == 0
31+
32+
def test_get_counter_negative_value():
33+
fun_control = fun_control_init(lower=np.array([0, 0]), upper=np.array([1, 1]))
34+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
35+
spot.counter = -7
36+
# _get_counter should just return the value, even if negative (no check in _get_counter)
37+
assert spot._get_counter() == -7

test/test_spot_get_new_X0.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import numpy as np
2+
import pytest
3+
from spotpython.spot.spot import Spot
4+
from spotpython.utils.init import fun_control_init, design_control_init
5+
6+
def dummy_fun(X, fun_control=None):
7+
X = np.atleast_2d(X)
8+
return np.sum(X, axis=1)
9+
10+
def test_get_new_X0_returns_valid_shape():
11+
lower = np.array([0, 0])
12+
upper = np.array([1, 1])
13+
var_type = ['float', 'float']
14+
var_name = ['x1', 'x2']
15+
fun_control = fun_control_init(lower=lower, upper=upper, var_type=var_type, var_name=var_name, n_points=2)
16+
design_control = design_control_init(init_size=3)
17+
spot = Spot(fun=dummy_fun, fun_control=fun_control, design_control=design_control)
18+
spot.X = np.array([[0.1, 0.2], [0.3, 0.4]])
19+
spot.y = dummy_fun(spot.X)
20+
spot.fit_surrogate()
21+
X0 = spot.get_new_X0()
22+
assert isinstance(X0, np.ndarray)
23+
assert X0.shape[1] == spot.k
24+
assert X0.shape[0] % spot.fun_repeats == 0
25+
assert np.all(X0 >= spot.lower)
26+
assert np.all(X0 <= spot.upper)
27+
28+
def test_get_new_X0_fallback(monkeypatch):
29+
lower = np.array([0, 0])
30+
upper = np.array([1, 1])
31+
var_type = ['float', 'float']
32+
var_name = ['x1', 'x2']
33+
fun_control = fun_control_init(lower=lower, upper=upper, var_type=var_type, var_name=var_name, n_points=2)
34+
design_control = design_control_init(init_size=3)
35+
spot = Spot(fun=dummy_fun, fun_control=fun_control, design_control=design_control)
36+
spot.X = np.array([[0.1, 0.2], [0.3, 0.4]])
37+
spot.y = dummy_fun(spot.X)
38+
spot.fit_surrogate()
39+
40+
# Monkeypatch suggest_new_X to return empty array to trigger fallback
41+
def suggest_new_X_empty():
42+
return np.empty((0, spot.k))
43+
spot.suggest_new_X = suggest_new_X_empty
44+
45+
X0 = spot.get_new_X0()
46+
assert isinstance(X0, np.ndarray)
47+
assert X0.shape[1] == spot.k
48+
assert np.all(X0 >= spot.lower)
49+
assert np.all(X0 <= spot.upper)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import pytest
3+
import pandas as pd
4+
5+
from spotpython.spot.spot import Spot
6+
from spotpython.utils.init import fun_control_init, design_control_init
7+
8+
def dummy_fun(X, fun_control=None):
9+
X = np.atleast_2d(X)
10+
return np.sum(X**2, axis=1)
11+
12+
def test_get_spot_attributes_as_df_basic():
13+
fun_control = fun_control_init(lower=np.array([-1, -1]), upper=np.array([1, 1]), fun_evals=5)
14+
design_control = design_control_init(init_size=3)
15+
spot = Spot(fun=dummy_fun, fun_control=fun_control, design_control=design_control)
16+
df = spot.get_spot_attributes_as_df()
17+
assert isinstance(df, pd.DataFrame)
18+
# Check that some expected attributes are present
19+
assert "fun_control" in df["Attribute Name"].values
20+
assert "design_control" in df["Attribute Name"].values
21+
assert "surrogate" in df["Attribute Name"].values
22+
# Check that the number of rows matches the number of attributes
23+
assert len(df) == len([attr for attr in dir(spot) if not callable(getattr(spot, attr)) and not attr.startswith("__")])
24+
25+
def test_get_spot_attributes_as_df_contains_values():
26+
fun_control = fun_control_init(lower=np.array([-2, -2]), upper=np.array([2, 2]), fun_evals=3)
27+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
28+
df = spot.get_spot_attributes_as_df()
29+
# Check that lower and upper bounds are present and correct
30+
lower_row = df[df["Attribute Name"] == "lower"]
31+
upper_row = df[df["Attribute Name"] == "upper"]
32+
assert np.allclose(lower_row.iloc[0]["Attribute Value"], np.array([-2, -2]))
33+
assert np.allclose(upper_row.iloc[0]["Attribute Value"], np.array([2, 2]))
34+
35+
def test_get_spot_attributes_as_df_dataframe_content():
36+
fun_control = fun_control_init(lower=np.array([0]), upper=np.array([1]), fun_evals=2)
37+
spot = Spot(fun=dummy_fun, fun_control=fun_control)
38+
df = spot.get_spot_attributes_as_df()
39+
# DataFrame should have columns 'Attribute Name' and 'Attribute Value'
40+
assert set(df.columns) == {"Attribute Name", "Attribute Value"}
41+
# There should be at least one attribute
42+
assert len(df) > 0

0 commit comments

Comments
 (0)