@@ -9,11 +9,18 @@ class LightDataModule(L.LightningDataModule):
99 A LightningDataModule for handling data.
1010
1111 Args:
12- batch_size (int): The batch size.
13- dataset (Dataset): The dataset.
14- test_size (float): The test size. Defaults to 0.6.
15- test_seed (int): The test seed. Defaults to 42.
16- num_workers (int): The number of workers. Defaults to 0.
12+ batch_size (int):
13+ The batch size. Required.
14+ dataset (torch.utils.data.Dataset):
15+ The dataset from the torch.utils.data Dataset class.
16+ It must implement three functions: __init__, __len__, and __getitem__.
17+ Required.
18+ test_size (float):
19+ The test size. Defaults to 0.6.
20+ test_seed (int):
21+ The test seed. Defaults to 42.
22+ num_workers (int):
23+ The number of workers. Defaults to 0.
1724
1825 Attributes:
1926 batch_size (int): The batch size.
@@ -25,6 +32,18 @@ class LightDataModule(L.LightningDataModule):
2532 test_seed (int): The test seed.
2633 test_size (float): The test size.
2734
35+ Methods:
36+ prepare_data(self):
37+ Usually used for downloading the data. Here: Does nothing, i.e., pass.
38+ setup(self, stage: Optional[str] = None):
39+ Performs the training, validation, and test split.
40+ train_dataloader():
41+ Returns a DataLoader instance for the training set.
42+ val_dataloader():
43+ Returns a DataLoader instance for the validation set.
44+ test_dataloader():
45+ Returns a DataLoader instance for the test set.
46+
2847 Examples:
2948 >>> from spotPython.data.lightdatamodule import LightDataModule
3049 from spotPython.data.csvdataset import CSVDataset
@@ -55,10 +74,15 @@ def prepare_data(self) -> None:
5574
5675 def setup (self , stage : Optional [str ] = None ) -> None :
5776 """
58- Sets up the data for use.
77+ Splits the data for use in training, validation, and testing.
78+ Uses torch.utils.data.random_split() to split the data.
79+ Splitting is based on the test_size and test_seed.
80+ The test_size can be a float or an int.
5981
6082 Args:
61- stage (Optional[str]): The current stage. Defaults to None.
83+ stage (Optional[str]):
84+ The current stage. Can be "fit" (for training and validation), "test" (testing),
85+ or None (for all three stages). Defaults to None.
6286
6387 Examples:
6488 >>> from spotPython.data.lightdatamodule import LightDataModule
@@ -101,7 +125,8 @@ def setup(self, stage: Optional[str] = None) -> None:
101125
102126 def train_dataloader (self ) -> DataLoader :
103127 """
104- Returns the training dataloader.
128+ Returns the training dataloader, i.e., a pytorch DataLoader instance
129+ using the training dataset.
105130
106131 Returns:
107132 DataLoader: The training dataloader.
@@ -125,7 +150,8 @@ def train_dataloader(self) -> DataLoader:
125150
126151 def val_dataloader (self ) -> DataLoader :
127152 """
128- Returns the validation dataloader.
153+ Returns the validation dataloader, i.e., a pytorch DataLoader instance
154+ using the validation dataset.
129155
130156 Returns:
131157 DataLoader: The validation dataloader.
@@ -146,7 +172,8 @@ def val_dataloader(self) -> DataLoader:
146172
147173 def test_dataloader (self ) -> DataLoader :
148174 """
149- Returns the test dataloader.
175+ Returns the test dataloader, i.e., a pytorch DataLoader instance
176+ using the test dataset.
150177
151178 Returns:
152179 DataLoader: The test dataloader.
0 commit comments