Skip to content

Commit ebf33dd

Browse files
levels management
1 parent 8657188 commit ebf33dd

4 files changed

Lines changed: 49 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.0.14"
10+
version = "0.0.15"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
def modify_hyper_parameter_levels(fun_control, hyperparameter, levels):
2+
"""
3+
4+
Args:
5+
fun_control (dict): fun_control dictionary
6+
hyperparameter (str): hyperparameter name
7+
levels (list): list of levels
8+
9+
Returns:
10+
fun_control (dict): updated fun_control
11+
Example:
12+
>>> fun_control = {}
13+
core_model = HoeffdingTreeRegressor
14+
fun_control.update({"core_model": core_model})
15+
fun_control.update({"core_model_hyper_dict": river_hyper_dict[core_model.__name__]})
16+
levels = ["mean", "model"]
17+
fun_control = modify_hyper_parameter_levels(fun_control, "leaf_prediction", levels)
18+
"""
19+
fun_control["core_model_hyper_dict"][hyperparameter].update({"levels": levels})
20+
fun_control["core_model_hyper_dict"][hyperparameter].update({"lower": 0})
21+
fun_control["core_model_hyper_dict"][hyperparameter].update({"upper": len(levels) - 1})
22+
return fun_control

src/spotPython/spot/spot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from numpy import ravel
1919
from numpy import array
2020
from numpy import append
21-
from spotPython.utils.compare import selectNew
21+
from spotPython.utils.compare import selectNew, find_equal_in_lists
2222
from spotPython.utils.aggregate import aggregate_mean_var
2323
from spotPython.utils.repair import remove_nan
2424
from spotPython.budget.ocba import get_ocba_X
@@ -251,7 +251,7 @@ def __init__(
251251
def to_red_dim(self):
252252
self.all_lower = self.lower
253253
self.all_upper = self.upper
254-
self.ident = (self.upper - self.lower) == 0
254+
self.ident = find_equal_in_lists(a=self.lower, b=self.upper)
255255
self.lower = self.lower[~self.ident]
256256
self.upper = self.upper[~self.ident]
257257
self.red_dim = self.ident.any()

src/spotPython/utils/compare.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,27 @@ def selectNew(A, X, tolerance=0):
1818
B = np.abs(A - X[i, :])
1919
ind = ind + np.all(B <= tolerance, axis=1)
2020
return A[~ind], ~ind
21+
22+
23+
def find_equal_in_lists(a, b):
24+
"""Find equal values in two lists.
25+
26+
Args:
27+
a (list): list with a values
28+
b (list): list with b values
29+
30+
Returns:
31+
list: list with 1 if equal, otherwise 0
32+
Example:
33+
>>> a = [1, 2, 3, 4, 5]
34+
>>> b = [1, 2, 3, 4, 5]
35+
>>> find_equal_in_lists(a, b)
36+
[1, 1, 1, 1, 1]
37+
"""
38+
equal = []
39+
for i in range(len(a)):
40+
if a[i] == b[i]:
41+
equal.append(1)
42+
else:
43+
equal.append(0)
44+
return equal

0 commit comments

Comments
 (0)