Skip to content

Commit f6ba61d

Browse files
0.18.10
tests added
1 parent 070f678 commit f6ba61d

3 files changed

Lines changed: 63 additions & 7 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.9"
10+
version = "0.18.10"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/lightdatamodule.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,30 +143,33 @@ def _setup_full_data_provided(self, stage) -> None:
143143

144144
# Assign train/val datasets for use in dataloaders
145145
if stage == "fit" or stage is None:
146-
if self.verbosity > 0:
147-
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
148146
generator_fit = torch.Generator().manual_seed(self.test_seed)
149147
self.data_train, self.data_val, _ = random_split(self.data_full, [train_size, val_size, test_size], generator=generator_fit)
148+
if self.verbosity > 0:
149+
print(f"train_size: {train_size}, val_size: {val_size}, test_sie: {test_size} for splitting train & val data.")
150+
print(f"train samples: {len(self.data_train)}, val samples: {len(self.data_val)} generated for train & val data.")
150151
# Handle scaling and transformation if scaler is provided
151152
if self.scaler is not None:
152153
self.handle_scaling_and_transform()
153154

154155
# Assign test dataset for use in dataloader(s)
155156
if stage == "test" or stage is None:
156-
if self.verbosity > 0:
157-
print(f"test_size: {test_size} used for test dataset.")
158157
generator_test = torch.Generator().manual_seed(self.test_seed)
159158
self.data_test, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_test)
159+
if self.verbosity > 0:
160+
print(f"train_size: {train_size}, val_size: {val_size}, test_sie: {test_size} for splitting test data.")
161+
print(f"test samples: {len(self.data_test)} generated for test data.")
160162
if self.scaler is not None:
161163
# Transform the test data
162164
self.data_test = self.transform_dataset(self.data_test)
163165

164166
# Assign pred dataset for use in dataloader(s)
165167
if stage == "predict" or stage is None:
166-
if self.verbosity > 0:
167-
print(f"test_size: {test_size} used for predict dataset.")
168168
generator_predict = torch.Generator().manual_seed(self.test_seed)
169169
self.data_predict, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_predict)
170+
if self.verbosity > 0:
171+
print(f"train_size: {train_size}, val_size: {val_size}, test_size (= predict_size): {test_size} for splitting predict data.")
172+
print(f"predict samples: {len(self.data_predict)} generated for train & val data.")
170173
if self.scaler is not None:
171174
# Transform the predict data
172175
self.data_predict = self.transform_dataset(self.data_predict)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import torch
3+
from torch.utils.data import TensorDataset
4+
from lightning import seed_everything
5+
from spotpython.data.lightdatamodule import LightDataModule
6+
7+
# Assuming LightDataModule is already imported from the provided code.
8+
9+
# Define a mock scaler for testing purpose.
10+
class MockScaler:
11+
def fit(self, data):
12+
pass
13+
14+
def transform(self, data):
15+
return data
16+
17+
# Define a simple dataset for testing.
18+
def create_mock_dataset(size=12):
19+
data = torch.arange(size).float().view(-1, 1)
20+
target = torch.arange(size).long()
21+
return TensorDataset(data, target)
22+
23+
# Test initialization and data splits
24+
@pytest.mark.parametrize("test_size, expected_train_size, expected_val_size, expected_test_size", [
25+
(0.5, 3, 3, 6), # Split 12 items into 3 train, 3 val, 6 test
26+
(0.4, 5, 3, 5), # Split 12 items into 5 train, 3 val, 5 test
27+
])
28+
def test_data_splitting(test_size, expected_train_size, expected_val_size, expected_test_size):
29+
dataset = create_mock_dataset()
30+
data_module = LightDataModule(dataset=dataset, batch_size=2, test_size=test_size, verbosity=1)
31+
data_module.setup()
32+
33+
assert len(data_module.data_train) == expected_train_size
34+
assert len(data_module.data_val) == expected_val_size
35+
assert len(data_module.data_test) == expected_test_size
36+
37+
38+
# Test DataLoader
39+
def test_dataloader():
40+
dataset = create_mock_dataset()
41+
data_module = LightDataModule(dataset=dataset, batch_size=2, test_size=0.5, verbosity=1)
42+
data_module.setup()
43+
44+
train_loader = data_module.train_dataloader()
45+
val_loader = data_module.val_dataloader()
46+
test_loader = data_module.test_dataloader()
47+
48+
assert len(train_loader.dataset) == len(data_module.data_train)
49+
assert len(val_loader.dataset) == len(data_module.data_val)
50+
assert len(test_loader.dataset) == len(data_module.data_test)
51+
52+
if __name__ == "__main__":
53+
pytest.main([__file__])

0 commit comments

Comments
 (0)