@@ -72,11 +72,23 @@ class LightDataModule(L.LightningDataModule):
7272
7373 """
7474
75- def __init__ (self , batch_size : int , dataset : object , test_size : float , test_seed : int = 42 , num_workers : int = 0 ):
75+ def __init__ (
76+ self ,
77+ batch_size : int ,
78+ dataset : object ,
79+ test_size : float ,
80+ scaler : None = None ,
81+ test_seed : int = 42 ,
82+ num_workers : int = 0 ,
83+ ):
7684 super ().__init__ ()
7785 self .batch_size = batch_size
7886 self .data_full = dataset
7987 self .test_size = test_size
88+ if scaler is not None :
89+ self .scaler = scaler ()
90+ else :
91+ self .scaler = None
8092 self .test_seed = test_seed
8193 self .num_workers = num_workers
8294
@@ -182,6 +194,9 @@ def train_dataloader(self) -> DataLoader:
182194 print (f"LightDataModule.train_dataloader(). data_train size: { len (self .data_train )} " )
183195 # print(f"LightDataModule: train_dataloader(). batch_size: {self.batch_size}")
184196 # print(f"LightDataModule: train_dataloader(). num_workers: {self.num_workers}")
197+ # apply fit_transform to the training data
198+ if self .scaler is not None :
199+ self .data_train = self .scaler .fit_transform (self .data_train )
185200 return DataLoader (self .data_train , batch_size = self .batch_size , num_workers = self .num_workers )
186201
187202 def val_dataloader (self ) -> DataLoader :
@@ -205,6 +220,9 @@ def val_dataloader(self) -> DataLoader:
205220 print (f"LightDataModule.val_dataloader(). Val. set size: { len (self .data_val )} " )
206221 # print(f"LightDataModule: val_dataloader(). batch_size: {self.batch_size}")
207222 # print(f"LightDataModule: val_dataloader(). num_workers: {self.num_workers}")
223+ # apply fit_transform to the val data
224+ if self .scaler is not None :
225+ self .data_val = self .scaler .transform (self .data_val )
208226 return DataLoader (self .data_val , batch_size = self .batch_size , num_workers = self .num_workers )
209227
210228 def test_dataloader (self ) -> DataLoader :
@@ -229,6 +247,9 @@ def test_dataloader(self) -> DataLoader:
229247 print (f"LightDataModule.test_dataloader(). Test set size: { len (self .data_test )} " )
230248 # print(f"LightDataModule: test_dataloader(). batch_size: {self.batch_size}")
231249 # print(f"LightDataModule: test_dataloader(). num_workers: {self.num_workers}")
250+ # apply fit_transform to the val data
251+ if self .scaler is not None :
252+ self .data_test = self .scaler .transform (self .data_test )
232253 return DataLoader (self .data_test , batch_size = self .batch_size , num_workers = self .num_workers )
233254
234255 def predict_dataloader (self ) -> DataLoader :
@@ -253,4 +274,8 @@ def predict_dataloader(self) -> DataLoader:
253274 print (f"LightDataModule.predict_dataloader(). Predict set size: { len (self .data_predict )} " )
254275 # print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
255276 # print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
277+ # apply fit_transform to the val data
278+ if self .scaler is not None :
279+ self .data_test = self .scaler .transform (self .data_test )
280+ return DataLoader (self .data_test , batch_size = self .batch_size , num_workers = self .num_workers )
256281 return DataLoader (self .data_predict , batch_size = len (self .data_predict ), num_workers = self .num_workers )
0 commit comments