Skip to content

Commit 8a9fec7

Browse files
0.26.5 crude reset
1 parent 8c39001 commit 8a9fec7

3 files changed

Lines changed: 125 additions & 267 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 79 additions & 249 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.26.4"
10+
version = "0.26.5"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/gp/gp_sep.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,44 @@
1616
from spotpython.utils.aggregate import select_distant_points
1717

1818

19+
def crude_reset(theta, tmin, tmax, m):
20+
"""
21+
Check whether any elements of the parameter vector ``theta`` lie below the
22+
corresponding elements of the lower bound ``tmin``. If so, reset ``theta``
23+
to a new vector based on the weighted average of ``tmin`` and ``tmax``,
24+
leaving bounds unmodified except for cases where ``tmax`` is negative.
25+
26+
Args:
27+
theta (np.ndarray): The current parameter values.
28+
tmin (np.ndarray): The lower bounds for the parameters.
29+
tmax (np.ndarray): The upper bounds for the parameters (may be adjusted if negative).
30+
m (int): The dimensionality or number of parameters (used to adjust negative ``tmax`` entries).
31+
32+
Returns:
33+
dict or None: A dictionary containing:
34+
- "theta" (np.ndarray): The reset parameter values.
35+
- "its" (int): Number of iterations (0, indicating immediate reset).
36+
- "msg" (str): Reason for the reset.
37+
- "conv" (int): Reset code (102).
38+
Returns None if no reset is needed.
39+
"""
40+
if np.any(theta < tmin):
41+
print("resetting due to init on lower boundary")
42+
print(f"theta: {theta}")
43+
print(f"tmin: {tmin}")
44+
for i in range(len(tmax)):
45+
if tmax[i] < 0:
46+
tmax[i] = np.sqrt(m)
47+
theta_new = 0.9 * np.maximum(tmin, 0) + 0.1 * np.array(tmax)
48+
return {
49+
"theta": theta_new,
50+
"its": 0,
51+
"msg": "reset due to init on lower boundary",
52+
"conv": 102,
53+
}
54+
return None
55+
56+
1957
def getDs(X: np.ndarray, p: float = 0.1, samp_size: int = 1000) -> dict:
2058
"""
2159
Calculate a rough starting, minimum, and maximum length-scale from the data X.
@@ -505,23 +543,13 @@ def fit(self, X: np.ndarray, y: np.ndarray, d=None, g=None, dK: bool = True, aut
505543
# Possibly reset parameters
506544
theta = np.concatenate((self.get_d(), [self.get_g()]))
507545
# Check if theta is on the boundary. If not on the boundary,
508-
# build the model and return the current parameters.
509-
if np.any(theta < tmin):
510-
print("resetting due to init on lower boundary")
511-
print(f"theta: {theta}")
512-
print(f"tmin: {tmin}")
513-
for i in range(len(tmax)):
514-
if tmax[i] < 0:
515-
tmax[i] = np.sqrt(m)
516-
theta_new = 0.9 * np.maximum(tmin, 0) + 0.1 * np.array(tmax)
517-
self.set_new_params(theta_new[:m], theta_new[m])
518-
self.build()
519-
return {
520-
"theta": theta_new,
521-
"its": 0,
522-
"msg": "reset due to init on lower boundary",
523-
"conv": 102,
524-
}
546+
# reset the current parameters.
547+
theta_new = crude_reset(theta, tmin, tmax, m)
548+
if theta_new is not None:
549+
theta = theta_new["theta"]
550+
# isuue a warning if the parameters are reset
551+
warnings.warn(f"resetting due to init on lower boundary: {theta_new['msg']}", RuntimeWarning)
552+
525553
# Convert ab to numpy array if it is a list
526554
if not isinstance(ab, np.ndarray):
527555
ab = np.array(ab, dtype=float)

0 commit comments

Comments
 (0)