Skip to content

Commit 046b1e4

Browse files
0.14.26
1 parent e1bc738 commit 046b1e4

6 files changed

Lines changed: 35 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.25"
10+
version = "0.14.26"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/lightdatamodule.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,23 @@ class LightDataModule(L.LightningDataModule):
7272
7373
"""
7474

75-
def __init__(self, batch_size: int, dataset: object, test_size: float, test_seed: int = 42, num_workers: int = 0):
75+
def __init__(
76+
self,
77+
batch_size: int,
78+
dataset: object,
79+
test_size: float,
80+
scaler: None = None,
81+
test_seed: int = 42,
82+
num_workers: int = 0,
83+
):
7684
super().__init__()
7785
self.batch_size = batch_size
7886
self.data_full = dataset
7987
self.test_size = test_size
88+
if scaler is not None:
89+
self.scaler = scaler()
90+
else:
91+
self.scaler = None
8092
self.test_seed = test_seed
8193
self.num_workers = num_workers
8294

@@ -182,6 +194,9 @@ def train_dataloader(self) -> DataLoader:
182194
print(f"LightDataModule.train_dataloader(). data_train size: {len(self.data_train)}")
183195
# print(f"LightDataModule: train_dataloader(). batch_size: {self.batch_size}")
184196
# print(f"LightDataModule: train_dataloader(). num_workers: {self.num_workers}")
197+
# apply fit_transform to the training data
198+
if self.scaler is not None:
199+
self.data_train = self.scaler.fit_transform(self.data_train)
185200
return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=self.num_workers)
186201

187202
def val_dataloader(self) -> DataLoader:
@@ -205,6 +220,9 @@ def val_dataloader(self) -> DataLoader:
205220
print(f"LightDataModule.val_dataloader(). Val. set size: {len(self.data_val)}")
206221
# print(f"LightDataModule: val_dataloader(). batch_size: {self.batch_size}")
207222
# print(f"LightDataModule: val_dataloader(). num_workers: {self.num_workers}")
223+
# apply fit_transform to the val data
224+
if self.scaler is not None:
225+
self.data_val = self.scaler.transform(self.data_val)
208226
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)
209227

210228
def test_dataloader(self) -> DataLoader:
@@ -229,6 +247,9 @@ def test_dataloader(self) -> DataLoader:
229247
print(f"LightDataModule.test_dataloader(). Test set size: {len(self.data_test)}")
230248
# print(f"LightDataModule: test_dataloader(). batch_size: {self.batch_size}")
231249
# print(f"LightDataModule: test_dataloader(). num_workers: {self.num_workers}")
250+
# apply fit_transform to the val data
251+
if self.scaler is not None:
252+
self.data_test = self.scaler.transform(self.data_test)
232253
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)
233254

234255
def predict_dataloader(self) -> DataLoader:
@@ -253,4 +274,8 @@ def predict_dataloader(self) -> DataLoader:
253274
print(f"LightDataModule.predict_dataloader(). Predict set size: {len(self.data_predict)}")
254275
# print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
255276
# print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
277+
# apply fit_transform to the val data
278+
if self.scaler is not None:
279+
self.data_test = self.scaler.transform(self.data_test)
280+
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)
256281
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)

src/spotPython/light/predictmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7070
num_workers=fun_control["num_workers"],
7171
test_size=fun_control["test_size"],
7272
test_seed=fun_control["test_seed"],
73+
scaler=fun_control["scaler"],
7374
)
7475
# TODO: Check if this is necessary:
7576
# dm.setup(stage="train")

src/spotPython/light/testmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7171
num_workers=fun_control["num_workers"],
7272
test_size=fun_control["test_size"],
7373
test_seed=fun_control["test_seed"],
74+
scaler=fun_control["scaler"],
7475
)
7576
# TODO: Check if this is necessary:
7677
# dm.setup()

src/spotPython/light/trainmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
102102
num_workers=fun_control["num_workers"],
103103
test_size=fun_control["test_size"],
104104
test_seed=fun_control["test_seed"],
105+
scaler=fun_control["scaler"],
105106
)
106107
# TODO: Check if this is necessary:
107108
# dm.setup()

src/spotPython/utils/init.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def fun_control_init(
5858
prep_model=None,
5959
prep_model_name=None,
6060
progress_file=None,
61+
scaler=None,
6162
scenario=None,
6263
seed=123,
6364
show_models=False,
@@ -186,6 +187,9 @@ def fun_control_init(
186187
The name of the preprocessing model. Default is None.
187188
progress_file (str):
188189
The name of the progress file. Default is None.
190+
scaler (object):
191+
The scaler object, e.g., StandardScaler from sklearn via "from sklearn.preprocessing import StandardScaler".
192+
Default is None.
189193
scenario (str):
190194
The scenario to use. Default is None. Can be "river", "sklearn", or "lightning".
191195
seed (int):
@@ -403,6 +407,7 @@ def fun_control_init(
403407
"prep_model_name": prep_model_name,
404408
"progress_file": progress_file,
405409
"save_model": False,
410+
"scaler": scaler,
406411
"scenario": scenario,
407412
"seed": seed,
408413
"show_batch_interval": 1_000_000,

0 commit comments

Comments
 (0)