Skip to content

Commit 9dda9a0

Browse files
scaler prepared
1 parent a71141f commit 9dda9a0

6 files changed

Lines changed: 11 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.33"
10+
version = "0.14.34"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/predictmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7070
num_workers=fun_control["num_workers"],
7171
test_size=fun_control["test_size"],
7272
test_seed=fun_control["test_seed"],
73+
scaler=fun_control["scaler"],
7374
)
7475
# TODO: Check if this is necessary:
7576
# dm.setup(stage="train")

src/spotPython/light/testmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7171
num_workers=fun_control["num_workers"],
7272
test_size=fun_control["test_size"],
7373
test_seed=fun_control["test_seed"],
74+
scaler=fun_control["scaler"],
7475
)
7576
# TODO: Check if this is necessary:
7677
# dm.setup()

src/spotPython/light/trainmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
102102
num_workers=fun_control["num_workers"],
103103
test_size=fun_control["test_size"],
104104
test_seed=fun_control["test_seed"],
105+
scaler=fun_control["scaler"],
105106
)
106107
# TODO: Check if this is necessary:
107108
# dm.setup()

src/spotPython/spot/spot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -787,9 +787,9 @@ def run(self, X_start=None) -> Spot:
787787
print(f"S.y: {S.y}")
788788
Seed set to 123
789789
Seed set to 123
790-
spotPython tuning: 0.0 [########--] 80.00%
791-
spotPython tuning: 0.0 [#########-] 86.67%
792-
spotPython tuning: 0.0 [#########-] 93.33%
790+
spotPython tuning: 0.0 [########--] 80.00%
791+
spotPython tuning: 0.0 [#########-] 86.67%
792+
spotPython tuning: 0.0 [#########-] 93.33%
793793
spotPython tuning: 0.0 [##########] 100.00% Done...
794794
795795
S.X: [[ 0.00000000e+00 0.00000000e+00]

src/spotPython/utils/init.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def fun_control_init(
5858
prep_model=None,
5959
prep_model_name=None,
6060
progress_file=None,
61+
scaler=None,
6162
scenario=None,
6263
seed=123,
6364
show_models=False,
@@ -186,6 +187,9 @@ def fun_control_init(
186187
The name of the preprocessing model. Default is None.
187188
progress_file (str):
188189
The name of the progress file. Default is None.
190+
scaler (object):
191+
The scaler object, e.g., the TorchStandard scaler from spot.utils.scaler.py.
192+
Default is None.
189193
scenario (str):
190194
The scenario to use. Default is None. Can be "river", "sklearn", or "lightning".
191195
seed (int):

0 commit comments

Comments
 (0)