Skip to content

Commit cc30f2d

Browse files
committed
fix scaler
1 parent 4e1cb2f commit cc30f2d

1 file changed

Lines changed: 13 additions & 11 deletions

File tree

src/spotPython/data/lightdatamodule.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -149,18 +149,19 @@ def setup(self, stage: Optional[str] = None) -> None:
149149
)
150150
if self.scaler is not None:
151151
# Fit the scaler on training data and transform both train and val data
152-
train_val_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))])
153-
# train_val_data = self.data_train[:,0]
154-
self.scaler.fit(train_val_data)
152+
scaler_train_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))]).squeeze(1)
153+
#train_val_data = self.data_train[:,0]
154+
print(scaler_train_data.shape)
155+
self.scaler.fit(scaler_train_data)
155156
self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
156-
data_tensors_train = [data.clone().detach().requires_grad_(True) for data, target in self.data_train]
157+
data_tensors_train = [data.clone().detach() for data, target in self.data_train]
157158
target_tensors_train = [target.clone().detach() for data, target in self.data_train]
158-
self.data_train = TensorDataset(torch.stack(data_tensors_train), torch.stack(target_tensors_train))
159+
self.data_train = TensorDataset(torch.stack(data_tensors_train).squeeze(1), torch.stack(target_tensors_train))
159160
# print(self.data_train)
160161
self.data_val = [(self.scaler.transform(data), target) for data, target in self.data_val]
161-
data_tensors_val = [data.clone().detach().requires_grad_(True) for data, target in self.data_val]
162+
data_tensors_val = [data.clone().detach() for data, target in self.data_val]
162163
target_tensors_val = [target.clone().detach() for data, target in self.data_val]
163-
self.data_val = TensorDataset(torch.stack(data_tensors_val), torch.stack(target_tensors_val))
164+
self.data_val = TensorDataset(torch.stack(data_tensors_val).squeeze(1), torch.stack(target_tensors_val))
164165

165166
# Assign test dataset for use in dataloader(s)
166167
if stage == "test" or stage is None:
@@ -170,9 +171,9 @@ def setup(self, stage: Optional[str] = None) -> None:
170171
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
171172
if self.scaler is not None:
172173
self.data_test = [(self.scaler.transform(data), target) for data, target in self.data_test]
173-
data_tensors_test = [data.clone().detach().requires_grad_(True) for data, target in self.data_test]
174+
data_tensors_test = [data.clone().detach() for data, target in self.data_test]
174175
target_tensors_test = [target.clone().detach() for data, target in self.data_test]
175-
self.data_test = TensorDataset(torch.stack(data_tensors_test), torch.stack(target_tensors_test))
176+
self.data_test = TensorDataset(torch.stack(data_tensors_test).squeeze(1), torch.stack(target_tensors_test))
176177

177178
# if stage == "predict" or stage is None:
178179
# print(f"test_size, full_train_size: {test_size}, {full_train_size}")
@@ -194,11 +195,11 @@ def setup(self, stage: Optional[str] = None) -> None:
194195
if self.scaler is not None:
195196
self.data_predict = [(self.scaler.transform(data), target) for data, target in self.data_predict]
196197
data_tensors_predict = [
197-
data.clone().detach().requires_grad_(True) for data, target in self.data_predict
198+
data.clone().detach() for data, target in self.data_predict
198199
]
199200
target_tensors_predict = [target.clone().detach() for data, target in self.data_predict]
200201
self.data_predict = TensorDataset(
201-
torch.stack(data_tensors_predict), torch.stack(target_tensors_predict)
202+
torch.stack(data_tensors_predict).squeeze(1), torch.stack(target_tensors_predict)
202203
)
203204

204205
def train_dataloader(self) -> DataLoader:
@@ -298,4 +299,5 @@ def predict_dataloader(self) -> DataLoader:
298299
# print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
299300
# print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
300301
# apply fit_transform to the val data
302+
301303
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)

0 commit comments

Comments
 (0)