@@ -16,7 +16,11 @@ class LightDataModule(L.LightningDataModule):
1616 It must implement three functions: __init__, __len__, and __getitem__.
1717 Required.
1818 test_size (float):
19- The test size. Required.
19+ The test size. if test_size is float, then train_size is 1 - test_size.
20+ If test_size is int, then train_size is len(data_full) - test_size.
21+ Train size will be split into train and validation sets.
22+ So if test size is 0.7, the 0.7 train size will be split into 0.7 * 0.7 = 0.49 train set
23+ amd 0.7 * 0.3 = 0.21 validation set.
2024 test_seed (int):
2125 The test seed. Defaults to 42.
2226 num_workers (int):
@@ -47,13 +51,21 @@ class LightDataModule(L.LightningDataModule):
4751 Examples:
4852 >>> from spotPython.data.lightdatamodule import LightDataModule
4953 from spotPython.data.csvdataset import CSVDataset
50- from spotPython.data.pkldataset import PKLDataset
5154 import torch
55+ # data.csv is simple csv file with 11 samples
5256 dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
5357 data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
5458 data_module.setup()
5559 print(f"Training set size: {len(data_module.data_train)}")
60+ print(f"Validation set size: {len(data_module.data_val)}")
61+ print(f"Test set size: {len(data_module.data_test)}")
62+ full_train_size: 0.5
63+ val_size: 0.25
64+ train_size: 0.25
65+ test_size: 0.5
5666 Training set size: 3
67+ Validation set size: 3
68+ Test set size: 6
5769
5870 References:
5971 See https://lightning.ai/docs/pytorch/stable/data/datamodule.html
@@ -109,17 +121,20 @@ def setup(self, stage: Optional[str] = None) -> None:
109121 val_size = int (full_train_size * test_size / len (self .data_full ))
110122 train_size = full_train_size - val_size
111123
112- print (f"full_train_size: { full_train_size } " )
113- print (f"val_size: { val_size } " )
114- print (f"train_size: { train_size } " )
115- print (f"test_size: { test_size } " )
124+ print (f"LightDataModule: setup(). stage: { stage } " )
125+ print (f"LightDataModule setup(): full_train_size: { full_train_size } " )
126+ print (f"LightDataModule setup(): val_size: { val_size } " )
127+ print (f"LightDataModule setup(): train_size: { train_size } " )
128+ print (f"LightDataModule setup(): test_size: { test_size } " )
116129
117130 # Assign train/val datasets for use in dataloaders
118131 if stage == "fit" or stage is None :
132+ print ("LightDataModule: setup(). stage: fit" )
119133 self .data_train , self .data_val , _ = random_split (self .data_full , [train_size , val_size , test_size ])
120134
121135 # Assign test dataset for use in dataloader(s)
122136 if stage == "test" or stage is None :
137+ print ("LightDataModule: setup(). stage: test" )
123138 # get test data aset as test_abs percent of the full dataset
124139 generator_test = torch .Generator ().manual_seed (self .test_seed )
125140 self .data_test , _ = random_split (self .data_full , [test_size , full_train_size ], generator = generator_test )
@@ -135,6 +150,7 @@ def setup(self, stage: Optional[str] = None) -> None:
135150
136151 # Assign pred dataset for use in dataloader(s)
137152 if stage == "predict" or stage is None :
153+ print ("LightDataModule: setup(). stage: predict" )
138154 # get test data aset as test_abs percent of the full dataset
139155 generator_predict = torch .Generator ().manual_seed (self .test_seed )
140156 self .data_predict , _ = random_split (
@@ -152,7 +168,6 @@ def train_dataloader(self) -> DataLoader:
152168 Examples:
153169 >>> from spotPython.data.lightdatamodule import LightDataModule
154170 from spotPython.data.csvdataset import CSVDataset
155- from spotPython.data.pkldataset import PKLDataset
156171 import torch
157172 dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
158173 data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
@@ -177,15 +192,16 @@ def val_dataloader(self) -> DataLoader:
177192 Examples:
178193 >>> from spotPython.data.lightdatamodule import LightDataModule
179194 from spotPython.data.csvdataset import CSVDataset
180- from spotPython.data.pkldataset import PKLDataset
181195 import torch
182196 dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
183197 data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
184198 data_module.setup()
185199 print(f"Training set size: {len(data_module.data_val)}")
186200 Training set size: 3
187-
188201 """
202+ print (f"LightDataModule: val_dataloader(). Training set size: { len (self .data_val )} " )
203+ print (f"LightDataModule: val_dataloader(). batch_size: { self .batch_size } " )
204+ print (f"LightDataModule: val_dataloader(). num_workers: { self .num_workers } " )
189205 return DataLoader (self .data_val , batch_size = self .batch_size , num_workers = self .num_workers )
190206
191207 def test_dataloader (self ) -> DataLoader :
@@ -199,7 +215,6 @@ def test_dataloader(self) -> DataLoader:
199215 Examples:
200216 >>> from spotPython.data.lightdatamodule import LightDataModule
201217 from spotPython.data.csvdataset import CSVDataset
202- from spotPython.data.pkldataset import PKLDataset
203218 import torch
204219 dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
205220 data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
@@ -214,6 +229,24 @@ def test_dataloader(self) -> DataLoader:
214229 return DataLoader (self .data_test , batch_size = self .batch_size , num_workers = self .num_workers )
215230
216231 def predict_dataloader (self ) -> DataLoader :
232+ """
233+ Returns the predict dataloader, i.e., a pytorch DataLoader instance
234+ using the predict dataset.
235+
236+ Returns:
237+ DataLoader: The predict dataloader.
238+
239+ Examples:
240+ >>> from spotPython.data.lightdatamodule import LightDataModule
241+ from spotPython.data.csvdataset import CSVDataset
242+ import torch
243+ dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
244+ data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
245+ data_module.setup()
246+ print(f"Predict set size: {len(data_module.data_predict)}")
247+ Predict set size: 6
248+
249+ """
217250 print (f"LightDataModule: predict_dataloader(). Predict set size: { len (self .data_predict )} " )
218251 print (f"LightDataModule: predict_dataloader(). batch_size: { self .batch_size } " )
219252 print (f"LightDataModule: predict_dataloader(). num_workers: { self .num_workers } " )
0 commit comments