Skip to content

Commit c585668

Browse files
committed
scaler options
1 parent 9dda9a0 commit c585668

3 files changed

Lines changed: 136 additions & 3 deletions

File tree

src/spotPython/data/lightdatamodule.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import lightning as L
22
import torch
3-
from torch.utils.data import DataLoader, random_split
3+
from torch.utils.data import DataLoader, random_split, TensorDataset
44
from typing import Optional
55

66

7+
78
class LightDataModule(L.LightningDataModule):
89
"""
910
A LightningDataModule for handling data.
@@ -25,6 +26,8 @@ class LightDataModule(L.LightningDataModule):
2526
The test seed. Defaults to 42.
2627
num_workers (int):
2728
The number of workers. Defaults to 0.
29+
scaler (object):
30+
The spot scaler object (e.g. TorchStandardScaler). Defaults to None.
2831
2932
Attributes:
3033
batch_size (int): The batch size.
@@ -79,13 +82,16 @@ def __init__(
7982
test_size: float,
8083
test_seed: int = 42,
8184
num_workers: int = 0,
85+
scaler: Optional[object] = None,
8286
):
8387
super().__init__()
8488
self.batch_size = batch_size
8589
self.data_full = dataset
8690
self.test_size = test_size
8791
self.test_seed = test_seed
8892
self.num_workers = num_workers
93+
self.scaler = scaler
94+
8995

9096
def prepare_data(self) -> None:
9197
"""Prepares the data for use."""
@@ -98,6 +104,7 @@ def setup(self, stage: Optional[str] = None) -> None:
98104
Uses torch.utils.data.random_split() to split the data.
99105
Splitting is based on the test_size and test_seed.
100106
The test_size can be a float or an int.
107+
If a spotPython scaler object is defined, the data will be scaled.
101108
102109
Args:
103110
stage (Optional[str]):
@@ -140,14 +147,32 @@ def setup(self, stage: Optional[str] = None) -> None:
140147
self.data_train, self.data_val, _ = random_split(
141148
self.data_full, [train_size, val_size, test_size], generator=generator_fit
142149
)
143-
150+
if self.scaler is not None:
151+
# Fit the scaler on training data and transform both train and val data
152+
train_val_data = torch.cat([self.data_train[i][0] for i in range(len(self.data_train))])
153+
self.scaler.fit(train_val_data)
154+
self.data_train = [(self.scaler.transform(data), target) for data, target in self.data_train]
155+
data_tensors_train = [torch.tensor(data, dtype=torch.float32) for data, target in self.data_train]
156+
target_tensors_train = [torch.tensor(target, dtype=torch.float32) for data, target in self.data_train]
157+
self.data_train = TensorDataset(torch.stack(data_tensors_train), torch.stack(target_tensors_train))
158+
#print(self.data_train)
159+
self.data_val = [(self.scaler.transform(data), target) for data, target in self.data_val]
160+
data_tensors_val = [torch.tensor(data, dtype=torch.float32) for data, target in self.data_val]
161+
target_tensors_val = [torch.tensor(target, dtype=torch.float32) for data, target in self.data_val]
162+
self.data_val = TensorDataset(torch.stack(data_tensors_val), torch.stack(target_tensors_val))
163+
144164
# Assign test dataset for use in dataloader(s)
145165
if stage == "test" or stage is None:
146166
print(f"test_size: {test_size} used for test dataset.")
147167
# get test data set as test_abs percent of the full dataset
148168
generator_test = torch.Generator().manual_seed(self.test_seed)
149169
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
150-
170+
if self.scaler is not None:
171+
self.data_test = [(self.scaler.transform(data), target) for data, target in self.data_test]
172+
data_tensors_test = [torch.tensor(data, dtype=torch.float32) for data, target in self.data_test]
173+
target_tensors_test = [torch.tensor(target, dtype=torch.float32) for data, target in self.data_test]
174+
self.data_test = TensorDataset(torch.stack(data_tensors_test), torch.stack(target_tensors_test))
175+
151176
# if stage == "predict" or stage is None:
152177
# print(f"test_size, full_train_size: {test_size}, {full_train_size}")
153178
# generator_predict = torch.Generator().manual_seed(self.test_seed)
@@ -165,6 +190,11 @@ def setup(self, stage: Optional[str] = None) -> None:
165190
self.data_predict, _ = random_split(
166191
self.data_full, [test_size, full_train_size], generator=generator_predict
167192
)
193+
if self.scaler is not None:
194+
self.data_predict = [(self.scaler.transform(data), target) for data, target in self.data_predict]
195+
data_tensors_predict= [torch.tensor(data, dtype=torch.float32) for data, target in self.data_predict]
196+
target_tensors_predict = [torch.tensor(target, dtype=torch.float32) for data, target in self.data_predict]
197+
self.data_predict = TensorDataset(torch.stack(data_tensors_predict), torch.stack(target_tensors_predict))
168198

169199
def train_dataloader(self) -> DataLoader:
170200
"""
@@ -265,3 +295,5 @@ def predict_dataloader(self) -> DataLoader:
265295
# apply fit_transform to the val data
266296
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)
267297
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)
298+
299+

src/spotPython/utils/init.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def fun_control_init(
407407
"prep_model_name": prep_model_name,
408408
"progress_file": progress_file,
409409
"save_model": False,
410+
"scaler":scaler,
410411
"scenario": scenario,
411412
"seed": seed,
412413
"show_batch_interval": 1_000_000,

src/spotPython/utils/scaler.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import torch
2+
3+
class TorchStandardScaler:
4+
"""
5+
A class for scaling data using standardization with torch tensors.
6+
"""
7+
def fit(self, x):
8+
"""
9+
Compute the mean and standard deviation of the input tensor.
10+
Args:
11+
x (torch.Tensor): The input tensor.
12+
Raises:
13+
TypeError: If the input is not a torch tensor.
14+
"""
15+
if not torch.is_tensor(x):
16+
raise TypeError("Input should be a torch tensor")
17+
self.mean = x.mean(0, keepdim=True)
18+
self.std = x.std(0, unbiased=False, keepdim=True)
19+
20+
def transform(self, x):
21+
"""
22+
Scale the input tensor using the computed mean and standard deviation.
23+
Args:
24+
x (torch.Tensor): The input tensor.
25+
Returns:
26+
torch.Tensor: The scaled tensor.
27+
Raises:
28+
TypeError: If the input is not a torch tensor.
29+
RuntimeError: If the scaler has not been fitted before transforming data.
30+
"""
31+
if not torch.is_tensor(x):
32+
raise TypeError("Input should be a torch tensor")
33+
if not hasattr(self, 'mean') or not hasattr(self, 'std'):
34+
raise RuntimeError("Must fit scaler before transforming data")
35+
x = (x - self.mean) / (self.std + 1e-7)
36+
return x
37+
38+
def fit_transform(self, x):
39+
"""
40+
Fit the scaler to the input tensor and then scale the tensor.
41+
Args:
42+
x (torch.Tensor): The input tensor.
43+
Returns:
44+
torch.Tensor: The scaled tensor.
45+
"""
46+
self.fit(x)
47+
return self.transform(x)
48+
49+
50+
51+
class TorchMinMaxScaler:
52+
"""
53+
A class for scaling data using min-max normalization with PyTorch tensors.
54+
"""
55+
def fit(self, x):
56+
"""
57+
Fit the scaler to the input data.
58+
Parameters:
59+
- x: torch.Tensor
60+
The input data to fit the scaler to.
61+
Raises:
62+
- TypeError: If the input is not a torch tensor.
63+
"""
64+
if not torch.is_tensor(x):
65+
raise TypeError("Input should be a torch tensor")
66+
self.min = x.min(0, keepdim=True).values
67+
self.max = x.max(0, keepdim=True).values
68+
69+
def transform(self, x):
70+
"""
71+
Transform the input data using the fitted scaler.
72+
Parameters:
73+
- x: torch.Tensor
74+
The input data to transform.
75+
Returns:
76+
- torch.Tensor: The transformed data.
77+
Raises:
78+
- TypeError: If the input is not a torch tensor.
79+
- RuntimeError: If the scaler has not been fitted before transforming data.
80+
"""
81+
if not torch.is_tensor(x):
82+
raise TypeError("Input should be a torch tensor")
83+
if not hasattr(self, 'min') or not hasattr(self, 'max'):
84+
raise RuntimeError("Must fit scaler before transforming data")
85+
x = (x - self.min) / (self.max - self.min + 1e-7)
86+
return x
87+
88+
def fit_transform(self, x):
89+
"""
90+
Fit the scaler to the input data and transform it.
91+
Parameters:
92+
- x: torch.Tensor
93+
The input data to fit and transform.
94+
Returns:
95+
- torch.Tensor: The transformed data.
96+
Raises:
97+
- TypeError: If the input is not a torch tensor.
98+
"""
99+
self.fit(x)
100+
return self.transform(x)

0 commit comments

Comments
 (0)