Skip to content

Commit 3a69d3f

Browse files
committed
fix scaling
1 parent 7e72f7d commit 3a69d3f

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/spotPython/data/lightdatamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def setup(self, stage: Optional[str] = None) -> None:
149149
)
150150
if self.scaler is not None:
151151
# Fit the scaler on training data and transform both train and val data
152-
train_val_data = torch.cat([self.data_train[i][0] for i in range(len(self.data_train))])
152+
train_val_data = torch.stack([self.data_train[i][0] for i in range(len(self.data_train))])
153+
# train_val_data = self.data_train[:,0]
153154
self.scaler.fit(train_val_data)
154155
self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
155156
data_tensors_train = [data.clone().detach().requires_grad_(True) for data, target in self.data_train]

0 commit comments

Comments
 (0)