Skip to content

Commit c1d49f1

Browse files
Create cifar10datamodule.py
1 parent a54941c commit c1d49f1

1 file changed

Lines changed: 111 additions & 0 deletions

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import lightning as L
2+
from torch.utils.data import DataLoader, random_split
3+
from spotPython.light.csvdataset import CSVDataset
4+
from typing import Optional
5+
6+
7+
class CSVDataModule(L.LightningDataModule):
8+
"""
9+
A LightningDataModule for handling CSV data.
10+
11+
Args:
12+
batch_size (int): The size of the batch.
13+
DATASET_PATH (str): The path to the dataset. Defaults to "./data".
14+
num_workers (int): The number of workers for data loading. Defaults to 0.
15+
16+
Attributes:
17+
data_train (Dataset): The training dataset.
18+
data_val (Dataset): The validation dataset.
19+
data_test (Dataset): The test dataset.
20+
"""
21+
22+
def __init__(self, batch_size: int, DATASET_PATH: str = "./data", num_workers: int = 0):
23+
super().__init__()
24+
self.batch_size = batch_size
25+
self.num_workers = num_workers
26+
27+
def prepare_data(self) -> None:
28+
"""Prepares the data for use."""
29+
# download
30+
pass
31+
32+
def setup(self, stage: Optional[str] = None) -> None:
33+
"""
34+
Sets up the data for use.
35+
36+
Args:
37+
stage (Optional[str]): The current stage. Defaults to None.
38+
Examples:
39+
>>> from spotPython.light import CSVDataModule
40+
>>> data_module = CSVDataModule(batch_size=64)
41+
>>> data_module.setup()
42+
>>> print(f"Training set size: {len(data_module.data_train)}")
43+
Training set size: 45000
44+
>>> print(f"Validation set size: {len(data_module.data_val)}")
45+
Validation set size: 5000
46+
>>> print(f"Test set size: {len(data_module.data_test)}")
47+
Test set size: 10000
48+
49+
"""
50+
# Assign train/val datasets for use in dataloaders
51+
if stage == "fit" or stage is None:
52+
data_full = CSVDataset(csv_file="./data/VBDP/train.csv", train=True)
53+
test_abs = int(len(data_full) * 0.6)
54+
self.data_train, self.data_val = random_split(data_full, [test_abs, len(data_full) - test_abs])
55+
56+
# Assign test dataset for use in dataloader(s)
57+
# TODO: Adapt this to the VBDP Situation
58+
if stage == "test" or stage is None:
59+
self.data_test = CSVDataset(csv_file="./data/VBDP/train.csv", train=True)
60+
61+
def train_dataloader(self) -> DataLoader:
62+
"""
63+
Returns the training dataloader.
64+
65+
Returns:
66+
DataLoader: The training dataloader.
67+
Examples:
68+
>>> from spotPython.light import CSVDataModule
69+
>>> data_module = CSVDataModule(batch_size=64)
70+
>>> data_module.setup()
71+
>>> train_dataloader = data_module.train_dataloader()
72+
>>> print(f"Training dataloader size: {len(train_dataloader)}")
73+
Training dataloader size: 704
74+
75+
"""
76+
return DataLoader(self.data_train, batch_size=self.batch_size, num_workers=self.num_workers)
77+
78+
def val_dataloader(self) -> DataLoader:
79+
"""
80+
Returns the validation dataloader.
81+
82+
Returns:
83+
DataLoader: The validation dataloader.
84+
Examples:
85+
>>> from spotPython.light import CSVDataModule
86+
>>> data_module = CSVDataModule(batch_size=64)
87+
>>> data_module.setup()
88+
>>> val_dataloader = data_module.val_dataloader()
89+
>>> print(f"Validation dataloader size: {len(val_dataloader)}")
90+
Validation dataloader size: 79
91+
92+
"""
93+
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)
94+
95+
def test_dataloader(self) -> DataLoader:
96+
"""
97+
Returns the test dataloader.
98+
99+
Returns:
100+
DataLoader: The test dataloader.
101+
102+
Examples:
103+
>>> from spotPython.light import CSVDataModule
104+
>>> data_module = CSVDataModule(batch_size=64)
105+
>>> data_module.setup()
106+
>>> test_dataloader = data_module.test_dataloader()
107+
>>> print(f"Test dataloader size: {len(test_dataloader)}")
108+
Test dataloader size: 704
109+
110+
"""
111+
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)

0 commit comments

Comments
 (0)