Skip to content

Commit 21f661a

Browse files
0.15.26
Cleanup lightdatamodule
1 parent 4375a36 commit 21f661a

5 files changed

Lines changed: 159 additions & 60 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.25"
10+
version = "0.15.26"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/lightdatamodule.py

Lines changed: 23 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
from torch.utils.data import DataLoader, random_split, TensorDataset
44
from typing import Optional
5+
from spotpython.utils.split import calculate_data_split
56

67

78
class LightDataModule(L.LightningDataModule):
@@ -95,6 +96,13 @@ def __init__(
9596
self.scaler = scaler
9697
self.verbosity = verbosity
9798

99+
def transform_dataset(self, dataset):
100+
"""Applies the scaler transformation to the dataset."""
101+
transformed_data = [(self.scaler.transform(data), target) for data, target in dataset]
102+
data_tensors = [data.clone().detach() for data, target in transformed_data]
103+
target_tensors = [target.clone().detach() for data, target in transformed_data]
104+
return TensorDataset(torch.stack(data_tensors).squeeze(1), torch.stack(target_tensors))
105+
98106
def prepare_data(self) -> None:
99107
"""Prepares the data for use."""
100108
# download
@@ -124,25 +132,12 @@ def setup(self, stage: Optional[str] = None) -> None:
124132
Training set size: 3
125133
126134
"""
127-
# if test_size is float, then train_size is 1 - test_size
128-
test_size = self.test_size
129-
if isinstance(self.test_size, float):
130-
full_train_size = round(1.0 - test_size, 2)
131-
val_size = round(full_train_size * test_size, 2)
132-
train_size = round(full_train_size - val_size, 2)
133-
else:
134-
# if test_size is int, then train_size is len(data_full) - test_size
135-
full_train_size = len(self.data_full) - test_size
136-
val_size = int(full_train_size * test_size / len(self.data_full))
137-
train_size = full_train_size - val_size
138-
139-
if self.verbosity > 0:
140-
print(f"LightDataModule.setup(): stage: {stage}")
141-
if self.verbosity > 1:
142-
print(f"LightDataModule setup(): full_train_size: {full_train_size}")
143-
print(f"LightDataModule setup(): val_size: {val_size}")
144-
print(f"LightDataModule setup(): train_size: {train_size}")
145-
print(f"LightDataModule setup(): test_size: {test_size}")
135+
full_train_size, val_size, train_size, test_size = calculate_data_split(
136+
test_size=self.test_size,
137+
full_size=len(self.data_full),
138+
verbosity=self.verbosity,
139+
stage=stage,
140+
)
146141

147142
# Assign train/val datasets for use in dataloaders
148143
if stage == "fit" or stage is None:
@@ -153,64 +148,37 @@ def setup(self, stage: Optional[str] = None) -> None:
153148
self.data_full, [train_size, val_size, test_size], generator=generator_fit
154149
)
155150
if self.scaler is not None:
156-
# Fit the scaler on training data and transform both train and val data
151+
# Fit the scaler on training data
157152
scaler_train_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))]).squeeze(1)
158-
# train_val_data = self.data_train[:,0]
159153
if self.verbosity > 0:
160154
print(scaler_train_data.shape)
161155
self.scaler.fit(scaler_train_data)
162-
self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
163-
data_tensors_train = [data.clone().detach() for data, target in self.data_train]
164-
target_tensors_train = [target.clone().detach() for data, target in self.data_train]
165-
self.data_train = TensorDataset(
166-
torch.stack(data_tensors_train).squeeze(1), torch.stack(target_tensors_train)
167-
)
168-
# print(self.data_train)
169-
self.data_val = [(self.scaler.transform(data), target) for data, target in self.data_val]
170-
data_tensors_val = [data.clone().detach() for data, target in self.data_val]
171-
target_tensors_val = [target.clone().detach() for data, target in self.data_val]
172-
self.data_val = TensorDataset(torch.stack(data_tensors_val).squeeze(1), torch.stack(target_tensors_val))
156+
# Transform the training data
157+
self.data_train = self.transform_dataset(self.data_train)
158+
# Transform the validation data
159+
self.data_val = self.transform_dataset(self.data_val)
173160

174161
# Assign test dataset for use in dataloader(s)
175162
if stage == "test" or stage is None:
176163
if self.verbosity > 0:
177164
print(f"test_size: {test_size} used for test dataset.")
178-
# get test data set as test_abs percent of the full dataset
179165
generator_test = torch.Generator().manual_seed(self.test_seed)
180166
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
181167
if self.scaler is not None:
182-
self.data_test = [(self.scaler.transform(data), target) for data, target in self.data_test]
183-
data_tensors_test = [data.clone().detach() for data, target in self.data_test]
184-
target_tensors_test = [target.clone().detach() for data, target in self.data_test]
185-
self.data_test = TensorDataset(
186-
torch.stack(data_tensors_test).squeeze(1), torch.stack(target_tensors_test)
187-
)
188-
189-
# if stage == "predict" or stage is None:
190-
# print(f"test_size, full_train_size: {test_size}, {full_train_size}")
191-
# generator_predict = torch.Generator().manual_seed(self.test_seed)
192-
# full_data_predict, _ = random_split(
193-
# self.data_full, [test_size, full_train_size], generator=generator_predict
194-
# )
195-
# # Only keep the features for prediction
196-
# self.data_predict = [x for x, _ in full_data_predict]
168+
# Transform the test data
169+
self.data_test = self.transform_dataset(self.data_test)
197170

198171
# Assign pred dataset for use in dataloader(s)
199172
if stage == "predict" or stage is None:
200173
if self.verbosity > 0:
201174
print(f"test_size: {test_size} used for predict dataset.")
202-
# get test data set as test_abs percent of the full dataset
203175
generator_predict = torch.Generator().manual_seed(self.test_seed)
204176
self.data_predict, _ = random_split(
205177
self.data_full, [test_size, full_train_size], generator=generator_predict
206178
)
207179
if self.scaler is not None:
208-
self.data_predict = [(self.scaler.transform(data), target) for data, target in self.data_predict]
209-
data_tensors_predict = [data.clone().detach() for data, target in self.data_predict]
210-
target_tensors_predict = [target.clone().detach() for data, target in self.data_predict]
211-
self.data_predict = TensorDataset(
212-
torch.stack(data_tensors_predict).squeeze(1), torch.stack(target_tensors_predict)
213-
)
180+
# Transform the predict data
181+
self.data_predict = self.transform_dataset(self.data_predict)
214182

215183
def train_dataloader(self) -> DataLoader:
216184
"""
@@ -235,7 +203,6 @@ def train_dataloader(self) -> DataLoader:
235203
print(f"LightDataModule.train_dataloader(). data_train size: {len(self.data_train)}")
236204
# print(f"LightDataModule: train_dataloader(). batch_size: {self.batch_size}")
237205
# print(f"LightDataModule: train_dataloader(). num_workers: {self.num_workers}")
238-
# apply fit_transform to the training data
239206
return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=self.num_workers)
240207

241208
def val_dataloader(self) -> DataLoader:
@@ -260,7 +227,6 @@ def val_dataloader(self) -> DataLoader:
260227
print(f"LightDataModule.val_dataloader(). Val. set size: {len(self.data_val)}")
261228
# print(f"LightDataModule: val_dataloader(). batch_size: {self.batch_size}")
262229
# print(f"LightDataModule: val_dataloader(). num_workers: {self.num_workers}")
263-
# apply fit_transform to the val data
264230
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)
265231

266232
def test_dataloader(self) -> DataLoader:
@@ -312,6 +278,4 @@ def predict_dataloader(self) -> DataLoader:
312278
print(f"LightDataModule.predict_dataloader(). Predict set size: {len(self.data_predict)}")
313279
# print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
314280
# print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
315-
# apply fit_transform to the val data
316-
317281
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)

src/spotpython/utils/split.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
def calculate_data_split(test_size, full_size, verbosity=0, stage=None) -> tuple:
2+
"""
3+
Calculates the split sizes for training, validation, and test datasets.
4+
5+
Args:
6+
test_size (float or int):
7+
The size of the test set.
8+
Can be a float for proportion or an int for absolute number of test samples.
9+
full_size (int):
10+
The size of the full dataset.
11+
verbosity (int, optional):
12+
The level of verbosity for debug output. Defaults to 0.
13+
stage (str, optional):
14+
The stage of setup, for debug output if needed.
15+
16+
Returns:
17+
tuple: A tuple containing the sizes (full_train_size, val_size, train_size, test_size).
18+
"""
19+
if isinstance(test_size, float):
20+
full_train_size = round(1.0 - test_size, 2)
21+
val_size = round(full_train_size * test_size, 2)
22+
train_size = round(full_train_size - val_size, 2)
23+
else:
24+
# test_size is considered an int, training size calculation directly based on it
25+
full_train_size = full_size - test_size
26+
val_size = int(full_train_size * test_size / full_size)
27+
train_size = full_train_size - val_size
28+
29+
if verbosity > 0:
30+
print(f"stage: {stage}")
31+
if verbosity > 1:
32+
print(f"full_sizefull_train_size: {full_train_size}")
33+
print(f"full_sizeval_size: {val_size}")
34+
print(f"full_sizetrain_size: {train_size}")
35+
print(f"full_sizetest_size: {test_size}")
36+
37+
return full_train_size, val_size, train_size, test_size

test/test_calculate_split.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from spotpython.utils.split import calculate_data_split
2+
import pytest
3+
4+
def test_calculate_data_split_float():
5+
full_size = 100
6+
test_size = 0.2
7+
expected_full_train_size = 0.8
8+
expected_val_size = 0.16
9+
expected_train_size = 0.64
10+
11+
result = calculate_data_split(test_size, full_size)
12+
13+
assert result == (expected_full_train_size, expected_val_size, expected_train_size, test_size), \
14+
f"Result was {result}, expected {(expected_full_train_size, expected_val_size, expected_train_size, test_size)}"
15+
16+
def test_calculate_data_split_int():
17+
full_size = 100
18+
test_size = 20
19+
expected_full_train_size = 80
20+
expected_val_size = 16 # Calculated as 80 * 20 / 100
21+
expected_train_size = 64 # 80 - 16
22+
23+
result = calculate_data_split(test_size, full_size)
24+
25+
assert result == (expected_full_train_size, expected_val_size, expected_train_size, test_size), \
26+
f"Result was {result}, expected {(expected_full_train_size, expected_val_size, expected_train_size, test_size)}"
27+
28+
def test_calculate_data_split_verbosity():
29+
full_size = 100
30+
test_size = 0.2
31+
32+
# Ideally, we'd capture the output here as well
33+
# For now, we just confirm it runs without error
34+
result = calculate_data_split(test_size, full_size, verbosity=2, stage='test')
35+
36+
expected_full_train_size = 0.8
37+
expected_val_size = 0.16
38+
expected_train_size = 0.64
39+
40+
assert result == (expected_full_train_size, expected_val_size, expected_train_size, test_size)
41+
42+
if __name__ == "__main__":
43+
pytest.main()

test/test_transform_dataset.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import torch
3+
from torch.utils.data import TensorDataset
4+
from unittest.mock import MagicMock
5+
6+
# Assuming the class containing transform_dataset is named MyDataModule
7+
class MyDataModule:
8+
def __init__(self, scaler):
9+
self.scaler = scaler
10+
11+
def transform_dataset(self, dataset):
12+
transformed_data = [(self.scaler.transform(data), target) for data, target in dataset]
13+
data_tensors = [data.clone().detach() for data, target in transformed_data]
14+
target_tensors = [target.clone().detach() for data, target in transformed_data]
15+
return TensorDataset(torch.stack(data_tensors).squeeze(1), torch.stack(target_tensors))
16+
17+
# Test function for transform_dataset
18+
@pytest.fixture
19+
def setup_data():
20+
# Mock dataset
21+
input_data = torch.randn(3, 4) # Mock input data
22+
target_data = torch.tensor([0, 1, 2]) # Mock target data
23+
24+
dataset = [(input_data[i], target_data[i]) for i in range(len(target_data))]
25+
26+
# Mock scaler with a simple transform logic
27+
mock_scaler = MagicMock()
28+
mock_scaler.transform = lambda x: 2 * x # Example transformation: multiply by 2
29+
30+
return mock_scaler, dataset
31+
32+
def test_transform_dataset(setup_data):
33+
mock_scaler, dataset = setup_data
34+
data_module = MyDataModule(mock_scaler)
35+
36+
transformed_dataset = data_module.transform_dataset(dataset)
37+
38+
# Check that transform_dataset returns a TensorDataset
39+
assert isinstance(transformed_dataset, TensorDataset)
40+
41+
# Extract transformed data and targets
42+
transformed_data, transformed_targets = transformed_dataset.tensors
43+
44+
# Verify the shape
45+
assert transformed_data.shape == torch.Size([3, 4])
46+
assert transformed_targets.shape == torch.Size([3])
47+
48+
# Verify that the data was transformed correctly (i.e., multiplied by 2)
49+
expected_data = torch.stack([mock_scaler.transform(d[0]) for d in dataset]).squeeze(1)
50+
for td, ed in zip(transformed_data, expected_data):
51+
assert torch.allclose(td, ed)
52+
53+
# Verify that the targets were unchanged
54+
expected_targets = torch.tensor([d[1] for d in dataset])
55+
assert torch.equal(transformed_targets, expected_targets)

0 commit comments

Comments
 (0)