Skip to content

Commit ad71e45

Browse files
test fixed
1 parent 83e5a7a commit ad71e45

4 files changed

Lines changed: 190 additions & 14 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -491,26 +491,26 @@
491491
},
492492
{
493493
"cell_type": "code",
494-
"execution_count": 2,
494+
"execution_count": 7,
495495
"metadata": {},
496496
"outputs": [],
497497
"source": [
498-
"data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)"
498+
"data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=7)"
499499
]
500500
},
501501
{
502502
"cell_type": "code",
503-
"execution_count": 3,
503+
"execution_count": 8,
504504
"metadata": {},
505505
"outputs": [
506506
{
507507
"name": "stdout",
508508
"output_type": "stream",
509509
"text": [
510-
"full_train_size: 0.5\n",
511-
"val_size: 0.25\n",
512-
"train_size: 0.25\n",
513-
"test_size: 0.5\n"
510+
"full_train_size: 4\n",
511+
"val_size: 2\n",
512+
"train_size: 2\n",
513+
"test_size: 7\n"
514514
]
515515
}
516516
],
@@ -520,14 +520,14 @@
520520
},
521521
{
522522
"cell_type": "code",
523-
"execution_count": 4,
523+
"execution_count": 9,
524524
"metadata": {},
525525
"outputs": [
526526
{
527527
"name": "stdout",
528528
"output_type": "stream",
529529
"text": [
530-
"Training set size: 3\n"
530+
"Training set size: 2\n"
531531
]
532532
}
533533
],
@@ -537,14 +537,14 @@
537537
},
538538
{
539539
"cell_type": "code",
540-
"execution_count": 5,
540+
"execution_count": 10,
541541
"metadata": {},
542542
"outputs": [
543543
{
544544
"name": "stdout",
545545
"output_type": "stream",
546546
"text": [
547-
"Validation set size: 3\n"
547+
"Validation set size: 2\n"
548548
]
549549
}
550550
],
@@ -554,14 +554,14 @@
554554
},
555555
{
556556
"cell_type": "code",
557-
"execution_count": 6,
557+
"execution_count": 11,
558558
"metadata": {},
559559
"outputs": [
560560
{
561561
"name": "stdout",
562562
"output_type": "stream",
563563
"text": [
564-
"Test set size: 6\n"
564+
"Test set size: 7\n"
565565
]
566566
}
567567
],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.43"
10+
version = "0.6.44"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import lightning as L
2+
import torch
3+
from torch.utils.data import DataLoader, random_split
4+
from typing import Optional
5+
6+
7+
class LightDataModule(L.LightningDataModule):
8+
"""
9+
A LightningDataModule for handling data.
10+
11+
Args:
12+
batch_size (int): The batch size.
13+
dataset (Dataset): The dataset.
14+
test_size (float): The test size. Defaults to 0.6.
15+
test_seed (int): The test seed. Defaults to 42.
16+
num_workers (int): The number of workers. Defaults to 0.
17+
18+
Attributes:
19+
batch_size (int): The batch size.
20+
data_full (Dataset): The full dataset.
21+
data_test (Dataset): The test dataset.
22+
data_train (Dataset): The training dataset.
23+
data_val (Dataset): The validation dataset.
24+
num_workers (int): The number of workers.
25+
test_seed (int): The test seed.
26+
test_size (float): The test size.
27+
28+
Examples:
29+
>>> from spotPython.data.lightdatamodule import LightDataModule
30+
from spotPython.data.csvdataset import CSVDataset
31+
from spotPython.data.pkldataset import PKLDataset
32+
import torch
33+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
34+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
35+
data_module.setup()
36+
print(f"Training set size: {len(data_module.data_train)}")
37+
Training set size: 3
38+
39+
"""
40+
41+
def __init__(
42+
self, batch_size: int, dataset=None, test_size: float = 0.6, test_seed: int = 42, num_workers: int = 0
43+
):
44+
super().__init__()
45+
self.batch_size = batch_size
46+
self.data_full = dataset
47+
self.test_size = test_size
48+
self.test_seed = test_seed
49+
self.num_workers = num_workers
50+
51+
def prepare_data(self) -> None:
52+
"""Prepares the data for use."""
53+
# download
54+
pass
55+
56+
def setup(self, stage: Optional[str] = None) -> None:
57+
"""
58+
Sets up the data for use.
59+
60+
Args:
61+
stage (Optional[str]): The current stage. Defaults to None.
62+
63+
Examples:
64+
>>> from spotPython.data.lightdatamodule import LightDataModule
65+
from spotPython.data.csvdataset import CSVDataset
66+
from spotPython.data.pkldataset import PKLDataset
67+
import torch
68+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
69+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
70+
data_module.setup()
71+
print(f"Training set size: {len(data_module.data_train)}")
72+
Training set size: 3
73+
74+
"""
75+
# if test_size is float, then train_size is 1 - test_size
76+
test_size = self.test_size
77+
if isinstance(self.test_size, float):
78+
full_train_size = round(1.0 - test_size, 2)
79+
val_size = round(full_train_size * test_size, 2)
80+
train_size = round(full_train_size - val_size, 2)
81+
else:
82+
# if test_size is int, then train_size is len(data_full) - test_size
83+
full_train_size = len(self.data_full) - test_size
84+
val_size = int(full_train_size * test_size / len(self.data_full))
85+
train_size = full_train_size - val_size
86+
87+
print(f"full_train_size: {full_train_size}")
88+
print(f"val_size: {val_size}")
89+
print(f"train_size: {train_size}")
90+
print(f"test_size: {test_size}")
91+
92+
# Assign train/val datasets for use in dataloaders
93+
if stage == "fit" or stage is None:
94+
self.data_train, self.data_val, _ = random_split(self.data_full, [train_size, val_size, test_size])
95+
96+
# Assign test dataset for use in dataloader(s)
97+
if stage == "test" or stage is None:
98+
# get test data aset as test_abs percent of the full dataset
99+
generator_test = torch.Generator().manual_seed(self.test_seed)
100+
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
101+
102+
def train_dataloader(self) -> DataLoader:
103+
"""
104+
Returns the training dataloader.
105+
106+
Returns:
107+
DataLoader: The training dataloader.
108+
109+
Examples:
110+
>>> from spotPython.data.lightdatamodule import LightDataModule
111+
from spotPython.data.csvdataset import CSVDataset
112+
from spotPython.data.pkldataset import PKLDataset
113+
import torch
114+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
115+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
116+
data_module.setup()
117+
print(f"Training set size: {len(data_module.data_train)}")
118+
Training set size: 3
119+
120+
"""
121+
return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=self.num_workers)
122+
123+
def val_dataloader(self) -> DataLoader:
124+
"""
125+
Returns the validation dataloader.
126+
127+
Returns:
128+
DataLoader: The validation dataloader.
129+
130+
Examples:
131+
>>> from spotPython.data.lightdatamodule import LightDataModule
132+
from spotPython.data.csvdataset import CSVDataset
133+
from spotPython.data.pkldataset import PKLDataset
134+
import torch
135+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
136+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
137+
data_module.setup()
138+
print(f"Training set size: {len(data_module.data_val)}")
139+
Training set size: 3
140+
141+
"""
142+
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)
143+
144+
def test_dataloader(self) -> DataLoader:
145+
"""
146+
Returns the test dataloader.
147+
148+
Returns:
149+
DataLoader: The test dataloader.
150+
151+
Examples:
152+
>>> from spotPython.data.lightdatamodule import LightDataModule
153+
from spotPython.data.csvdataset import CSVDataset
154+
from spotPython.data.pkldataset import PKLDataset
155+
import torch
156+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
157+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
158+
data_module.setup()
159+
print(f"Test set size: {len(data_module.data_test)}")
160+
Test set size: 6
161+
162+
"""
163+
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)

test/test_lightdatamodule.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@ def test_light_data_module():
1717
# Test the length of val and train: should be equal, because test_size=0.5
1818
assert len(data_module.data_train) == len(data_module.data_val)
1919

20+
def test_light_data_module_test_size():
21+
# Create an instance of CSVDataset for testing
22+
dataset = CSVDataset(target_column='prognosis', feature_type=torch.long)
23+
24+
# Test the length of the dataset
25+
assert len(dataset) > 0
26+
27+
# Now testing an absolute test_size
28+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=7)
29+
data_module.setup()
30+
31+
# Test the length of val and train: should be equal, because test_size=0.5
32+
assert len(data_module.data_test) == 7
2033

2134
if __name__ == "__main__":
2235
pytest.main(["-v", __file__])

0 commit comments

Comments
 (0)