11import lightning as L
22import torch
3- from torch .utils .data import DataLoader , random_split
3+ from torch .utils .data import DataLoader , random_split , TensorDataset
44from typing import Optional
55
66
7+
78class LightDataModule (L .LightningDataModule ):
89 """
910 A LightningDataModule for handling data.
@@ -25,6 +26,8 @@ class LightDataModule(L.LightningDataModule):
2526 The test seed. Defaults to 42.
2627 num_workers (int):
2728 The number of workers. Defaults to 0.
29+ scaler (object):
30+ The spot scaler object (e.g. TorchStandardScaler). Defaults to None.
2831
2932 Attributes:
3033 batch_size (int): The batch size.
@@ -79,13 +82,16 @@ def __init__(
7982 test_size : float ,
8083 test_seed : int = 42 ,
8184 num_workers : int = 0 ,
85+ scaler : Optional [object ] = None ,
8286 ):
8387 super ().__init__ ()
8488 self .batch_size = batch_size
8589 self .data_full = dataset
8690 self .test_size = test_size
8791 self .test_seed = test_seed
8892 self .num_workers = num_workers
93+ self .scaler = scaler
94+
8995
9096 def prepare_data (self ) -> None :
9197 """Prepares the data for use."""
@@ -98,6 +104,7 @@ def setup(self, stage: Optional[str] = None) -> None:
98104 Uses torch.utils.data.random_split() to split the data.
99105 Splitting is based on the test_size and test_seed.
100106 The test_size can be a float or an int.
107+ If a spotPython scaler object is defined, the data will be scaled.
101108
102109 Args:
103110 stage (Optional[str]):
@@ -140,14 +147,32 @@ def setup(self, stage: Optional[str] = None) -> None:
140147 self .data_train , self .data_val , _ = random_split (
141148 self .data_full , [train_size , val_size , test_size ], generator = generator_fit
142149 )
143-
150+ if self .scaler is not None :
151+ # Fit the scaler on training data and transform both train and val data
152+ train_val_data = torch .cat ([self .data_train [i ][0 ] for i in range (len (self .data_train ))])
153+ self .scaler .fit (train_val_data )
154+ self .data_train = [(self .scaler .transform (data ), target ) for data , target in self .data_train ]
155+ data_tensors_train = [torch .tensor (data , dtype = torch .float32 ) for data , target in self .data_train ]
156+ target_tensors_train = [torch .tensor (target , dtype = torch .float32 ) for data , target in self .data_train ]
157+ self .data_train = TensorDataset (torch .stack (data_tensors_train ), torch .stack (target_tensors_train ))
158+ #print(self.data_train)
159+ self .data_val = [(self .scaler .transform (data ), target ) for data , target in self .data_val ]
160+ data_tensors_val = [torch .tensor (data , dtype = torch .float32 ) for data , target in self .data_val ]
161+ target_tensors_val = [torch .tensor (target , dtype = torch .float32 ) for data , target in self .data_val ]
162+ self .data_val = TensorDataset (torch .stack (data_tensors_val ), torch .stack (target_tensors_val ))
163+
144164 # Assign test dataset for use in dataloader(s)
145165 if stage == "test" or stage is None :
146166 print (f"test_size: { test_size } used for test dataset." )
147167 # get test data set as test_abs percent of the full dataset
148168 generator_test = torch .Generator ().manual_seed (self .test_seed )
149169 self .data_test , _ = random_split (self .data_full , [test_size , full_train_size ], generator = generator_test )
150-
170+ if self .scaler is not None :
171+ self .data_test = [(self .scaler .transform (data ), target ) for data , target in self .data_test ]
172+ data_tensors_test = [torch .tensor (data , dtype = torch .float32 ) for data , target in self .data_test ]
173+ target_tensors_test = [torch .tensor (target , dtype = torch .float32 ) for data , target in self .data_test ]
174+ self .data_test = TensorDataset (torch .stack (data_tensors_test ), torch .stack (target_tensors_test ))
175+
151176 # if stage == "predict" or stage is None:
152177 # print(f"test_size, full_train_size: {test_size}, {full_train_size}")
153178 # generator_predict = torch.Generator().manual_seed(self.test_seed)
@@ -165,6 +190,11 @@ def setup(self, stage: Optional[str] = None) -> None:
165190 self .data_predict , _ = random_split (
166191 self .data_full , [test_size , full_train_size ], generator = generator_predict
167192 )
193+ if self .scaler is not None :
194+ self .data_predict = [(self .scaler .transform (data ), target ) for data , target in self .data_predict ]
195+ data_tensors_predict = [torch .tensor (data , dtype = torch .float32 ) for data , target in self .data_predict ]
196+ target_tensors_predict = [torch .tensor (target , dtype = torch .float32 ) for data , target in self .data_predict ]
197+ self .data_predict = TensorDataset (torch .stack (data_tensors_predict ), torch .stack (target_tensors_predict ))
168198
169199 def train_dataloader (self ) -> DataLoader :
170200 """
@@ -265,3 +295,5 @@ def predict_dataloader(self) -> DataLoader:
265295 # apply fit_transform to the val data
266296 return DataLoader (self .data_test , batch_size = self .batch_size , num_workers = self .num_workers )
267297 return DataLoader (self .data_predict , batch_size = len (self .data_predict ), num_workers = self .num_workers )
298+
299+
0 commit comments