Skip to content

Commit cc4910c

Browse files
0.18.7
train / test data separately
1 parent 1c3b712 commit cc4910c

9 files changed

Lines changed: 135 additions & 77 deletions

File tree

RELEASE_NOTES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
2+
spotpython 0.18.7:
3+
4+
Separate train and test data sets can be passed to Lightning DataModules
5+
16
spotpython 0.18.6:
27

38
- split.py:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.6"
10+
version = "0.18.7"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/lightcrossvalidationdatamodule.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ class LightCrossValidationDataModule(L.LightningDataModule):
1212
1313
Args:
1414
batch_size (int): The size of the batch. Defaults to 64.
15+
dataset (torch.utils.data.Dataset, optional):
16+
The dataset from the torch.utils.data Dataset class.
17+
It must implement three functions: __init__, __len__, and __getitem__.
18+
data_full_train (torch.utils.data.Dataset, optional):
19+
The full training dataset from which training and validation sets will be derived.
20+
data_test (torch.utils.data.Dataset, optional):
21+
The separate test dataset that will be used for testing.
1522
k (int): The fold number. Defaults to 1.
1623
split_seed (int): The random seed for splitting the data. Defaults to 42.
1724
num_splits (int): The number of splits for cross-validation. Defaults to 10.
@@ -38,7 +45,9 @@ class LightCrossValidationDataModule(L.LightningDataModule):
3845
def __init__(
3946
self,
4047
batch_size=64,
41-
dataset=None,
48+
dataset: Optional[object] = None,
49+
data_full_train: Optional[object] = None,
50+
data_test: Optional[object] = None,
4251
k: int = 1,
4352
split_seed: int = 42,
4453
num_splits: int = 10,
@@ -50,6 +59,8 @@ def __init__(
5059
super().__init__()
5160
self.batch_size = batch_size
5261
self.data_full = dataset
62+
self.data_full_train = data_full_train
63+
self.data_test = data_test
5364
self.data_dir = data_dir
5465
self.num_workers = num_workers
5566
self.k = k

src/spotpython/data/lightdatamodule.py

Lines changed: 100 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,23 @@ class LightDataModule(L.LightningDataModule):
1212
Args:
1313
batch_size (int):
1414
The batch size. Required.
15-
dataset (torch.utils.data.Dataset):
15+
dataset (torch.utils.data.Dataset, optional):
1616
The dataset from the torch.utils.data Dataset class.
17-
It must implement three functions: __init__, __len__, and __getitem__.
18-
Required.
19-
test_size (float):
20-
The test size. if test_size is float, then train_size is 1 - test_size.
17+
It must implement three functions: __init__, __len__, and __getitem__.
18+
data_full_train (torch.utils.data.Dataset, optional):
19+
The full training dataset from which training and validation sets will be derived.
20+
data_test (torch.utils.data.Dataset, optional):
21+
The separate test dataset that will be used for testing.
22+
test_size (float, optional):
23+
The test size. If test_size is float, then train_size is 1 - test_size.
2124
If test_size is int, then train_size is len(data_full) - test_size.
22-
Train size will be split into train and validation sets.
23-
So if test size is 0.7, the 0.7 train size will be split into 0.7 * 0.7 = 0.49 train set
24-
amd 0.7 * 0.3 = 0.21 validation set.
2525
test_seed (int):
2626
The test seed. Defaults to 42.
2727
num_workers (int):
2828
The number of workers. Defaults to 0.
29-
scaler (object):
29+
scaler (object, optional):
3030
The spot scaler object (e.g. TorchStandardScaler). Defaults to None.
3131
32-
Attributes:
33-
batch_size (int): The batch size.
34-
data_full (Dataset): The full dataset.
35-
data_test (Dataset): The test dataset.
36-
data_train (Dataset): The training dataset.
37-
data_val (Dataset): The validation dataset.
38-
num_workers (int): The number of workers.
39-
test_seed (int): The test seed.
40-
test_size (float): The test size.
41-
42-
Methods:
43-
prepare_data(self):
44-
Usually used for downloading the data. Here: Does nothing, i.e., pass.
45-
setup(self, stage: Optional[str] = None):
46-
Performs the training, validation, and test split.
47-
train_dataloader():
48-
Returns a DataLoader instance for the training set.
49-
val_dataloader():
50-
Returns a DataLoader instance for the validation set.
51-
test_dataloader():
52-
Returns a DataLoader instance for the test set.
53-
5432
Examples:
5533
>>> from spotpython.data.lightdatamodule import LightDataModule
5634
from spotpython.data.csvdataset import CSVDataset
@@ -80,8 +58,10 @@ class LightDataModule(L.LightningDataModule):
8058
def __init__(
8159
self,
8260
batch_size: int,
83-
dataset: object,
84-
test_size: float,
61+
dataset: Optional[object] = None,
62+
data_full_train: Optional[object] = None,
63+
data_test: Optional[object] = None,
64+
test_size: Optional[float] = None,
8565
test_seed: int = 42,
8666
num_workers: int = 0,
8767
scaler: Optional[object] = None,
@@ -90,6 +70,8 @@ def __init__(
9070
super().__init__()
9171
self.batch_size = batch_size
9272
self.data_full = dataset
73+
self.data_full_train = data_full_train
74+
self.data_test = data_test
9375
self.test_size = test_size
9476
self.test_seed = test_seed
9577
self.num_workers = num_workers
@@ -166,49 +148,92 @@ def setup(self, stage: Optional[str] = None) -> None:
166148
Training set size: 3
167149
168150
"""
169-
full_size = len(self.data_full)
170-
test_size = self.test_size
171-
172-
# consider the case when test_size is a float
173-
if isinstance(self.test_size, float):
174-
full_train_size = 1.0 - self.test_size
175-
val_size = full_train_size * self.test_size
176-
train_size = full_train_size - val_size
151+
if self.data_full is not None:
152+
full_size = len(self.data_full)
153+
test_size = self.test_size
154+
155+
# consider the case when test_size is a float
156+
if isinstance(self.test_size, float):
157+
full_train_size = 1.0 - self.test_size
158+
val_size = full_train_size * self.test_size
159+
train_size = full_train_size - val_size
160+
else:
161+
# test_size is an int, training size calculation directly based on it
162+
full_train_size = full_size - self.test_size
163+
val_size = floor(full_train_size * self.test_size / full_size)
164+
train_size = full_size - val_size - test_size
165+
166+
# Assign train/val datasets for use in dataloaders
167+
if stage == "fit" or stage is None:
168+
if self.verbosity > 0:
169+
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
170+
generator_fit = torch.Generator().manual_seed(self.test_seed)
171+
self.data_train, self.data_val, _ = random_split(self.data_full, [train_size, val_size, test_size], generator=generator_fit)
172+
# Handle scaling and transformation if scaler is provided
173+
if self.scaler is not None:
174+
self.handle_scaling_and_transform()
175+
176+
# Assign test dataset for use in dataloader(s)
177+
if stage == "test" or stage is None:
178+
if self.verbosity > 0:
179+
print(f"test_size: {test_size} used for test dataset.")
180+
generator_test = torch.Generator().manual_seed(self.test_seed)
181+
self.data_test, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_test)
182+
if self.scaler is not None:
183+
# Transform the test data
184+
self.data_test = self.transform_dataset(self.data_test)
185+
186+
# Assign pred dataset for use in dataloader(s)
187+
if stage == "predict" or stage is None:
188+
if self.verbosity > 0:
189+
print(f"test_size: {test_size} used for predict dataset.")
190+
generator_predict = torch.Generator().manual_seed(self.test_seed)
191+
self.data_predict, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_predict)
192+
if self.scaler is not None:
193+
# Transform the predict data
194+
self.data_predict = self.transform_dataset(self.data_predict)
177195
else:
178-
# test_size is an int, training size calculation directly based on it
179-
full_train_size = full_size - self.test_size
180-
val_size = floor(full_train_size * self.test_size / full_size)
181-
train_size = full_size - val_size - test_size
182-
183-
# Assign train/val datasets for use in dataloaders
184-
if stage == "fit" or stage is None:
185-
if self.verbosity > 0:
186-
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
187-
generator_fit = torch.Generator().manual_seed(self.test_seed)
188-
self.data_train, self.data_val, _ = random_split(self.data_full, [train_size, val_size, test_size], generator=generator_fit)
189-
# Handle scaling and transformation if scaler is provided
190-
if self.scaler is not None:
191-
self.handle_scaling_and_transform()
192-
193-
# Assign test dataset for use in dataloader(s)
194-
if stage == "test" or stage is None:
195-
if self.verbosity > 0:
196-
print(f"test_size: {test_size} used for test dataset.")
197-
generator_test = torch.Generator().manual_seed(self.test_seed)
198-
self.data_test, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_test)
199-
if self.scaler is not None:
200-
# Transform the test data
201-
self.data_test = self.transform_dataset(self.data_test)
202-
203-
# Assign pred dataset for use in dataloader(s)
204-
if stage == "predict" or stage is None:
205-
if self.verbosity > 0:
206-
print(f"test_size: {test_size} used for predict dataset.")
207-
generator_predict = torch.Generator().manual_seed(self.test_seed)
208-
self.data_predict, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_predict)
209-
if self.scaler is not None:
210-
# Transform the predict data
211-
self.data_predict = self.transform_dataset(self.data_predict)
196+
# New functionality with separate full_train and test datasets. Use these datasets directly.
197+
full_train_size = len(self.data_full_train)
198+
test_size = self.test_size
199+
# consider the case when test_size is a float
200+
if isinstance(self.test_size, float):
201+
val_size = self.test_size
202+
train_size = 1 - self.test_size
203+
else:
204+
# test_size is an int, training size calculation directly based on it
205+
full_size = len(self.data_full_train) + len(self.data_test)
206+
full_train_size = len(self.data_full_train)
207+
val_size = floor(full_train_size * self.test_size / full_size)
208+
train_size = full_train_size - val_size
209+
210+
# Assign train/val datasets for use in dataloaders
211+
if stage == "fit" or stage is None:
212+
if self.verbosity > 0:
213+
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
214+
generator_fit = torch.Generator().manual_seed(self.test_seed)
215+
self.data_train, self.data_val = random_split(self.data_full_train, [train_size, val_size], generator=generator_fit)
216+
# Handle scaling and transformation if scaler is provided
217+
if self.scaler is not None:
218+
self.handle_scaling_and_transform()
219+
220+
# Assign test dataset for use in dataloader(s)
221+
if stage == "test" or stage is None:
222+
if self.verbosity > 0:
223+
print(f"test_size: {test_size} used for test dataset.")
224+
self.data_test = self.data_test
225+
if self.scaler is not None:
226+
# Transform the test data
227+
self.data_test = self.transform_dataset(self.data_test)
228+
229+
# Assign pred dataset for use in dataloader(s)
230+
if stage == "predict" or stage is None:
231+
if self.verbosity > 0:
232+
print(f"test_size: {test_size} used for predict dataset.")
233+
self.data_predict = self.data_test
234+
if self.scaler is not None:
235+
# Transform the predict data
236+
self.data_predict = self.transform_dataset(self.data_predict)
212237

213238
def train_dataloader(self) -> DataLoader:
214239
"""

src/spotpython/light/cvmodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def cv_model(config: dict, fun_control: dict) -> float:
6161
num_splits=num_folds,
6262
split_seed=split_seed,
6363
dataset=fun_control["data_set"],
64+
data_full_train=fun_control["data_full_train"],
65+
data_test=fun_control["data_test"],
6466
num_workers=fun_control["num_workers"],
6567
batch_size=config["batch_size"],
6668
data_dir=fun_control["DATASET_PATH"],

src/spotpython/light/predictmodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
6666
config_id = generate_config_id(config, timestamp=False) + "_TEST"
6767
dm = LightDataModule(
6868
dataset=fun_control["data_set"],
69+
data_full_train=fun_control["data_full_train"],
70+
data_test=fun_control["data_test"],
6971
batch_size=config["batch_size"],
7072
num_workers=fun_control["num_workers"],
7173
test_size=fun_control["test_size"],

src/spotpython/light/testmodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
6767
config_id = generate_config_id(config, timestamp=False) + "_TEST"
6868
dm = LightDataModule(
6969
dataset=fun_control["data_set"],
70+
data_full_train=fun_control["data_full_train"],
71+
data_test=fun_control["data_test"],
7072
batch_size=config["batch_size"],
7173
num_workers=fun_control["num_workers"],
7274
test_size=fun_control["test_size"],

src/spotpython/light/trainmodel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
109109
model = build_model_instance(config, fun_control)
110110
dm = LightDataModule(
111111
dataset=fun_control["data_set"],
112+
data_full_train=fun_control["data_full_train"],
113+
data_test=fun_control["data_test"],
112114
batch_size=config["batch_size"],
113115
num_workers=fun_control["num_workers"],
114116
test_size=fun_control["test_size"],

src/spotpython/utils/init.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ def fun_control_init(
2929
core_model=None,
3030
core_model_name=None,
3131
data=None,
32+
data_full_train=None,
3233
data_dir="./data",
3334
data_module=None,
3435
data_set=None,
3536
data_set_name=None,
37+
data_test=None,
3638
db_dict_name=None,
3739
design=None,
3840
device=None,
@@ -125,12 +127,17 @@ def fun_control_init(
125127
The data object. Default is None.
126128
data_dir (str):
127129
The directory to save the data. Default is "./data".
130+
data_full_train (torch.utils.data.Dataset, optional):
131+
The full training dataset from which training and validation sets will be derived.
132+
Default is None.
128133
data_module (object):
129134
The data module object. Default is None.
130135
data_set (object):
131136
The data set object. Default is None.
132137
data_set_name (str):
133138
The name of the data set. Default is None.
139+
data_test (torch.utils.data.Dataset, optional):
140+
The separate test dataset that will be used for testing. Default is None.
134141
db_dict_name (str):
135142
The name of the database dictionary. Default is None.
136143
device (str):
@@ -387,9 +394,11 @@ def fun_control_init(
387394
"counter": 0,
388395
"data": data,
389396
"data_dir": data_dir,
397+
"data_full_train": data_full_train,
390398
"data_module": data_module,
391399
"data_set": data_set,
392400
"data_set_name": data_set_name,
401+
"data_test": data_test,
393402
"db_dict_name": db_dict_name,
394403
"design": design,
395404
"device": device,

0 commit comments

Comments
 (0)