@@ -149,18 +149,19 @@ def setup(self, stage: Optional[str] = None) -> None:
149149 )
150150 if self .scaler is not None :
151151 # Fit the scaler on training data and transform both train and val data
152- train_val_data = torch .stack ([self .data_train [i ][0 ] for i in range (len (self .data_train ))])
153- # train_val_data = self.data_train[:,0]
154- self .scaler .fit (train_val_data )
152+ scaler_train_data = torch .stack ([self .data_train [i ][0 ] for i in range (len (self .data_train ))]).squeeze (1 )
153+ #train_val_data = self.data_train[:,0]
154+ print (scaler_train_data .shape )
155+ self .scaler .fit (scaler_train_data )
155156 self .data_train = [(self .scaler .transform (data ), target ) for data , target in self .data_train ]
156- data_tensors_train = [data .clone ().detach (). requires_grad_ ( True ) for data , target in self .data_train ]
157+ data_tensors_train = [data .clone ().detach () for data , target in self .data_train ]
157158 target_tensors_train = [target .clone ().detach () for data , target in self .data_train ]
158- self .data_train = TensorDataset (torch .stack (data_tensors_train ), torch .stack (target_tensors_train ))
159+ self .data_train = TensorDataset (torch .stack (data_tensors_train ). squeeze ( 1 ) , torch .stack (target_tensors_train ))
159160 # print(self.data_train)
160161 self .data_val = [(self .scaler .transform (data ), target ) for data , target in self .data_val ]
161- data_tensors_val = [data .clone ().detach (). requires_grad_ ( True ) for data , target in self .data_val ]
162+ data_tensors_val = [data .clone ().detach () for data , target in self .data_val ]
162163 target_tensors_val = [target .clone ().detach () for data , target in self .data_val ]
163- self .data_val = TensorDataset (torch .stack (data_tensors_val ), torch .stack (target_tensors_val ))
164+ self .data_val = TensorDataset (torch .stack (data_tensors_val ). squeeze ( 1 ) , torch .stack (target_tensors_val ))
164165
165166 # Assign test dataset for use in dataloader(s)
166167 if stage == "test" or stage is None :
@@ -170,9 +171,9 @@ def setup(self, stage: Optional[str] = None) -> None:
170171 self .data_test , _ = random_split (self .data_full , [test_size , full_train_size ], generator = generator_test )
171172 if self .scaler is not None :
172173 self .data_test = [(self .scaler .transform (data ), target ) for data , target in self .data_test ]
173- data_tensors_test = [data .clone ().detach (). requires_grad_ ( True ) for data , target in self .data_test ]
174+ data_tensors_test = [data .clone ().detach () for data , target in self .data_test ]
174175 target_tensors_test = [target .clone ().detach () for data , target in self .data_test ]
175- self .data_test = TensorDataset (torch .stack (data_tensors_test ), torch .stack (target_tensors_test ))
176+ self .data_test = TensorDataset (torch .stack (data_tensors_test ). squeeze ( 1 ) , torch .stack (target_tensors_test ))
176177
177178 # if stage == "predict" or stage is None:
178179 # print(f"test_size, full_train_size: {test_size}, {full_train_size}")
@@ -194,11 +195,11 @@ def setup(self, stage: Optional[str] = None) -> None:
194195 if self .scaler is not None :
195196 self .data_predict = [(self .scaler .transform (data ), target ) for data , target in self .data_predict ]
196197 data_tensors_predict = [
197- data .clone ().detach (). requires_grad_ ( True ) for data , target in self .data_predict
198+ data .clone ().detach () for data , target in self .data_predict
198199 ]
199200 target_tensors_predict = [target .clone ().detach () for data , target in self .data_predict ]
200201 self .data_predict = TensorDataset (
201- torch .stack (data_tensors_predict ), torch .stack (target_tensors_predict )
202+ torch .stack (data_tensors_predict ). squeeze ( 1 ) , torch .stack (target_tensors_predict )
202203 )
203204
204205 def train_dataloader (self ) -> DataLoader :
@@ -298,4 +299,5 @@ def predict_dataloader(self) -> DataLoader:
298299 # print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
299300 # print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
300301 # apply fit_transform to the val data
302+
301303 return DataLoader (self .data_predict , batch_size = len (self .data_predict ), num_workers = self .num_workers )
0 commit comments