@@ -12,45 +12,23 @@ class LightDataModule(L.LightningDataModule):
1212 Args:
1313 batch_size (int):
1414 The batch size. Required.
15- dataset (torch.utils.data.Dataset):
15+ dataset (torch.utils.data.Dataset, optional ):
1616 The dataset from the torch.utils.data Dataset class.
17- It must implement three functions: __init__, __len__, and __getitem__.
18- Required.
19- test_size (float):
20- The test size. if test_size is float, then train_size is 1 - test_size.
17+ It must implement three functions: __init__, __len__, and __getitem__.
18+ data_full_train (torch.utils.data.Dataset, optional):
19+ The full training dataset from which training and validation sets will be derived.
20+ data_test (torch.utils.data.Dataset, optional):
21+ The separate test dataset that will be used for testing.
22+ test_size (float, optional):
23+ The test size. If test_size is float, then train_size is 1 - test_size.
2124 If test_size is int, then train_size is len(data_full) - test_size.
22- Train size will be split into train and validation sets.
23- So if test size is 0.7, the 0.7 train size will be split into 0.7 * 0.7 = 0.49 train set
24- amd 0.7 * 0.3 = 0.21 validation set.
2525 test_seed (int):
2626 The test seed. Defaults to 42.
2727 num_workers (int):
2828 The number of workers. Defaults to 0.
29- scaler (object):
29+ scaler (object, optional ):
3030 The spot scaler object (e.g. TorchStandardScaler). Defaults to None.
3131
32- Attributes:
33- batch_size (int): The batch size.
34- data_full (Dataset): The full dataset.
35- data_test (Dataset): The test dataset.
36- data_train (Dataset): The training dataset.
37- data_val (Dataset): The validation dataset.
38- num_workers (int): The number of workers.
39- test_seed (int): The test seed.
40- test_size (float): The test size.
41-
42- Methods:
43- prepare_data(self):
44- Usually used for downloading the data. Here: Does nothing, i.e., pass.
45- setup(self, stage: Optional[str] = None):
46- Performs the training, validation, and test split.
47- train_dataloader():
48- Returns a DataLoader instance for the training set.
49- val_dataloader():
50- Returns a DataLoader instance for the validation set.
51- test_dataloader():
52- Returns a DataLoader instance for the test set.
53-
5432 Examples:
5533 >>> from spotpython.data.lightdatamodule import LightDataModule
5634 from spotpython.data.csvdataset import CSVDataset
@@ -80,8 +58,10 @@ class LightDataModule(L.LightningDataModule):
8058 def __init__ (
8159 self ,
8260 batch_size : int ,
83- dataset : object ,
84- test_size : float ,
61+ dataset : Optional [object ] = None ,
62+ data_full_train : Optional [object ] = None ,
63+ data_test : Optional [object ] = None ,
64+ test_size : Optional [float ] = None ,
8565 test_seed : int = 42 ,
8666 num_workers : int = 0 ,
8767 scaler : Optional [object ] = None ,
@@ -90,6 +70,8 @@ def __init__(
9070 super ().__init__ ()
9171 self .batch_size = batch_size
9272 self .data_full = dataset
73+ self .data_full_train = data_full_train
74+ self .data_test = data_test
9375 self .test_size = test_size
9476 self .test_seed = test_seed
9577 self .num_workers = num_workers
@@ -166,49 +148,92 @@ def setup(self, stage: Optional[str] = None) -> None:
166148 Training set size: 3
167149
168150 """
169- full_size = len (self .data_full )
170- test_size = self .test_size
171-
172- # consider the case when test_size is a float
173- if isinstance (self .test_size , float ):
174- full_train_size = 1.0 - self .test_size
175- val_size = full_train_size * self .test_size
176- train_size = full_train_size - val_size
151+ if self .data_full is not None :
152+ full_size = len (self .data_full )
153+ test_size = self .test_size
154+
155+ # consider the case when test_size is a float
156+ if isinstance (self .test_size , float ):
157+ full_train_size = 1.0 - self .test_size
158+ val_size = full_train_size * self .test_size
159+ train_size = full_train_size - val_size
160+ else :
161+ # test_size is an int, training size calculation directly based on it
162+ full_train_size = full_size - self .test_size
163+ val_size = floor (full_train_size * self .test_size / full_size )
164+ train_size = full_size - val_size - test_size
165+
166+ # Assign train/val datasets for use in dataloaders
167+ if stage == "fit" or stage is None :
168+ if self .verbosity > 0 :
169+ print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
170+ generator_fit = torch .Generator ().manual_seed (self .test_seed )
171+ self .data_train , self .data_val , _ = random_split (self .data_full , [train_size , val_size , test_size ], generator = generator_fit )
172+ # Handle scaling and transformation if scaler is provided
173+ if self .scaler is not None :
174+ self .handle_scaling_and_transform ()
175+
176+ # Assign test dataset for use in dataloader(s)
177+ if stage == "test" or stage is None :
178+ if self .verbosity > 0 :
179+ print (f"test_size: { test_size } used for test dataset." )
180+ generator_test = torch .Generator ().manual_seed (self .test_seed )
181+ self .data_test , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_test )
182+ if self .scaler is not None :
183+ # Transform the test data
184+ self .data_test = self .transform_dataset (self .data_test )
185+
186+ # Assign pred dataset for use in dataloader(s)
187+ if stage == "predict" or stage is None :
188+ if self .verbosity > 0 :
189+ print (f"test_size: { test_size } used for predict dataset." )
190+ generator_predict = torch .Generator ().manual_seed (self .test_seed )
191+ self .data_predict , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_predict )
192+ if self .scaler is not None :
193+ # Transform the predict data
194+ self .data_predict = self .transform_dataset (self .data_predict )
177195 else :
178- # test_size is an int, training size calculation directly based on it
179- full_train_size = full_size - self .test_size
180- val_size = floor (full_train_size * self .test_size / full_size )
181- train_size = full_size - val_size - test_size
182-
183- # Assign train/val datasets for use in dataloaders
184- if stage == "fit" or stage is None :
185- if self .verbosity > 0 :
186- print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
187- generator_fit = torch .Generator ().manual_seed (self .test_seed )
188- self .data_train , self .data_val , _ = random_split (self .data_full , [train_size , val_size , test_size ], generator = generator_fit )
189- # Handle scaling and transformation if scaler is provided
190- if self .scaler is not None :
191- self .handle_scaling_and_transform ()
192-
193- # Assign test dataset for use in dataloader(s)
194- if stage == "test" or stage is None :
195- if self .verbosity > 0 :
196- print (f"test_size: { test_size } used for test dataset." )
197- generator_test = torch .Generator ().manual_seed (self .test_seed )
198- self .data_test , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_test )
199- if self .scaler is not None :
200- # Transform the test data
201- self .data_test = self .transform_dataset (self .data_test )
202-
203- # Assign pred dataset for use in dataloader(s)
204- if stage == "predict" or stage is None :
205- if self .verbosity > 0 :
206- print (f"test_size: { test_size } used for predict dataset." )
207- generator_predict = torch .Generator ().manual_seed (self .test_seed )
208- self .data_predict , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_predict )
209- if self .scaler is not None :
210- # Transform the predict data
211- self .data_predict = self .transform_dataset (self .data_predict )
196+ # New functionality with separate full_train and test datasets. Use these datasets directly.
197+ full_train_size = len (self .data_full_train )
198+ test_size = self .test_size
199+ # consider the case when test_size is a float
200+ if isinstance (self .test_size , float ):
201+ val_size = self .test_size
202+ train_size = 1 - self .test_size
203+ else :
204+ # test_size is an int, training size calculation directly based on it
205+ full_size = len (self .data_full_train ) + len (self .data_test )
206+ full_train_size = len (self .data_full_train )
207+ val_size = floor (full_train_size * self .test_size / full_size )
208+ train_size = full_train_size - val_size
209+
210+ # Assign train/val datasets for use in dataloaders
211+ if stage == "fit" or stage is None :
212+ if self .verbosity > 0 :
213+ print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
214+ generator_fit = torch .Generator ().manual_seed (self .test_seed )
215+ self .data_train , self .data_val = random_split (self .data_full_train , [train_size , val_size ], generator = generator_fit )
216+ # Handle scaling and transformation if scaler is provided
217+ if self .scaler is not None :
218+ self .handle_scaling_and_transform ()
219+
220+ # Assign test dataset for use in dataloader(s)
221+ if stage == "test" or stage is None :
222+ if self .verbosity > 0 :
223+ print (f"test_size: { test_size } used for test dataset." )
224+ self .data_test = self .data_test
225+ if self .scaler is not None :
226+ # Transform the test data
227+ self .data_test = self .transform_dataset (self .data_test )
228+
229+ # Assign pred dataset for use in dataloader(s)
230+ if stage == "predict" or stage is None :
231+ if self .verbosity > 0 :
232+ print (f"test_size: { test_size } used for predict dataset." )
233+ self .data_predict = self .data_test
234+ if self .scaler is not None :
235+ # Transform the predict data
236+ self .data_predict = self .transform_dataset (self .data_predict )
212237
213238 def train_dataloader (self ) -> DataLoader :
214239 """
0 commit comments