22import torch
33from torch .utils .data import DataLoader , random_split , TensorDataset
44from typing import Optional
5+ from spotpython .utils .split import calculate_data_split
56
67
78class LightDataModule (L .LightningDataModule ):
@@ -95,6 +96,13 @@ def __init__(
9596 self .scaler = scaler
9697 self .verbosity = verbosity
9798
99+ def transform_dataset (self , dataset ):
100+ """Applies the scaler transformation to the dataset."""
101+ transformed_data = [(self .scaler .transform (data ), target ) for data , target in dataset ]
102+ data_tensors = [data .clone ().detach () for data , target in transformed_data ]
103+ target_tensors = [target .clone ().detach () for data , target in transformed_data ]
104+ return TensorDataset (torch .stack (data_tensors ).squeeze (1 ), torch .stack (target_tensors ))
105+
98106 def prepare_data (self ) -> None :
99107 """Prepares the data for use."""
100108 # download
@@ -124,25 +132,12 @@ def setup(self, stage: Optional[str] = None) -> None:
124132 Training set size: 3
125133
126134 """
127- # if test_size is float, then train_size is 1 - test_size
128- test_size = self .test_size
129- if isinstance (self .test_size , float ):
130- full_train_size = round (1.0 - test_size , 2 )
131- val_size = round (full_train_size * test_size , 2 )
132- train_size = round (full_train_size - val_size , 2 )
133- else :
134- # if test_size is int, then train_size is len(data_full) - test_size
135- full_train_size = len (self .data_full ) - test_size
136- val_size = int (full_train_size * test_size / len (self .data_full ))
137- train_size = full_train_size - val_size
138-
139- if self .verbosity > 0 :
140- print (f"LightDataModule.setup(): stage: { stage } " )
141- if self .verbosity > 1 :
142- print (f"LightDataModule setup(): full_train_size: { full_train_size } " )
143- print (f"LightDataModule setup(): val_size: { val_size } " )
144- print (f"LightDataModule setup(): train_size: { train_size } " )
145- print (f"LightDataModule setup(): test_size: { test_size } " )
135+ full_train_size , val_size , train_size , test_size = calculate_data_split (
136+ test_size = self .test_size ,
137+ full_size = len (self .data_full ),
138+ verbosity = self .verbosity ,
139+ stage = stage ,
140+ )
146141
147142 # Assign train/val datasets for use in dataloaders
148143 if stage == "fit" or stage is None :
@@ -153,64 +148,37 @@ def setup(self, stage: Optional[str] = None) -> None:
153148 self .data_full , [train_size , val_size , test_size ], generator = generator_fit
154149 )
155150 if self .scaler is not None :
156- # Fit the scaler on training data and transform both train and val data
151+ # Fit the scaler on training data
157152 scaler_train_data = torch .stack ([self .data_train [i ][0 ] for i in range (len (self .data_train ))]).squeeze (1 )
158- # train_val_data = self.data_train[:,0]
159153 if self .verbosity > 0 :
160154 print (scaler_train_data .shape )
161155 self .scaler .fit (scaler_train_data )
162- self .data_train = [(self .scaler .transform (data ), target ) for data , target in self .data_train ]
163- data_tensors_train = [data .clone ().detach () for data , target in self .data_train ]
164- target_tensors_train = [target .clone ().detach () for data , target in self .data_train ]
165- self .data_train = TensorDataset (
166- torch .stack (data_tensors_train ).squeeze (1 ), torch .stack (target_tensors_train )
167- )
168- # print(self.data_train)
169- self .data_val = [(self .scaler .transform (data ), target ) for data , target in self .data_val ]
170- data_tensors_val = [data .clone ().detach () for data , target in self .data_val ]
171- target_tensors_val = [target .clone ().detach () for data , target in self .data_val ]
172- self .data_val = TensorDataset (torch .stack (data_tensors_val ).squeeze (1 ), torch .stack (target_tensors_val ))
156+ # Transform the training data
157+ self .data_train = self .transform_dataset (self .data_train )
158+ # Transform the validation data
159+ self .data_val = self .transform_dataset (self .data_val )
173160
174161 # Assign test dataset for use in dataloader(s)
175162 if stage == "test" or stage is None :
176163 if self .verbosity > 0 :
177164 print (f"test_size: { test_size } used for test dataset." )
178- # get test data set as test_abs percent of the full dataset
179165 generator_test = torch .Generator ().manual_seed (self .test_seed )
180166 self .data_test , _ = random_split (self .data_full , [test_size , full_train_size ], generator = generator_test )
181167 if self .scaler is not None :
182- self .data_test = [(self .scaler .transform (data ), target ) for data , target in self .data_test ]
183- data_tensors_test = [data .clone ().detach () for data , target in self .data_test ]
184- target_tensors_test = [target .clone ().detach () for data , target in self .data_test ]
185- self .data_test = TensorDataset (
186- torch .stack (data_tensors_test ).squeeze (1 ), torch .stack (target_tensors_test )
187- )
188-
189- # if stage == "predict" or stage is None:
190- # print(f"test_size, full_train_size: {test_size}, {full_train_size}")
191- # generator_predict = torch.Generator().manual_seed(self.test_seed)
192- # full_data_predict, _ = random_split(
193- # self.data_full, [test_size, full_train_size], generator=generator_predict
194- # )
195- # # Only keep the features for prediction
196- # self.data_predict = [x for x, _ in full_data_predict]
168+ # Transform the test data
169+ self .data_test = self .transform_dataset (self .data_test )
197170
198171 # Assign pred dataset for use in dataloader(s)
199172 if stage == "predict" or stage is None :
200173 if self .verbosity > 0 :
201174 print (f"test_size: { test_size } used for predict dataset." )
202- # get test data set as test_abs percent of the full dataset
203175 generator_predict = torch .Generator ().manual_seed (self .test_seed )
204176 self .data_predict , _ = random_split (
205177 self .data_full , [test_size , full_train_size ], generator = generator_predict
206178 )
207179 if self .scaler is not None :
208- self .data_predict = [(self .scaler .transform (data ), target ) for data , target in self .data_predict ]
209- data_tensors_predict = [data .clone ().detach () for data , target in self .data_predict ]
210- target_tensors_predict = [target .clone ().detach () for data , target in self .data_predict ]
211- self .data_predict = TensorDataset (
212- torch .stack (data_tensors_predict ).squeeze (1 ), torch .stack (target_tensors_predict )
213- )
180+ # Transform the predict data
181+ self .data_predict = self .transform_dataset (self .data_predict )
214182
215183 def train_dataloader (self ) -> DataLoader :
216184 """
@@ -235,7 +203,6 @@ def train_dataloader(self) -> DataLoader:
235203 print (f"LightDataModule.train_dataloader(). data_train size: { len (self .data_train )} " )
236204 # print(f"LightDataModule: train_dataloader(). batch_size: {self.batch_size}")
237205 # print(f"LightDataModule: train_dataloader(). num_workers: {self.num_workers}")
238- # apply fit_transform to the training data
239206 return DataLoader (self .data_train , batch_size = self .batch_size , num_workers = self .num_workers )
240207
241208 def val_dataloader (self ) -> DataLoader :
@@ -260,7 +227,6 @@ def val_dataloader(self) -> DataLoader:
260227 print (f"LightDataModule.val_dataloader(). Val. set size: { len (self .data_val )} " )
261228 # print(f"LightDataModule: val_dataloader(). batch_size: {self.batch_size}")
262229 # print(f"LightDataModule: val_dataloader(). num_workers: {self.num_workers}")
263- # apply fit_transform to the val data
264230 return DataLoader (self .data_val , batch_size = self .batch_size , num_workers = self .num_workers )
265231
266232 def test_dataloader (self ) -> DataLoader :
@@ -312,6 +278,4 @@ def predict_dataloader(self) -> DataLoader:
312278 print (f"LightDataModule.predict_dataloader(). Predict set size: { len (self .data_predict )} " )
313279 # print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
314280 # print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
315- # apply fit_transform to the val data
316-
317281 return DataLoader (self .data_predict , batch_size = len (self .data_predict ), num_workers = self .num_workers )
0 commit comments