Skip to content

Commit bb8f7f0

Browse files
0.14.36
Added examples and tests to scaler.py
1 parent 9348ce4 commit bb8f7f0

5 files changed

Lines changed: 183 additions & 27 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.35"
10+
version = "0.14.36"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/utils/scaler.py

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,58 @@
44
class TorchStandardScaler:
55
"""
66
A class for scaling data using standardization with torch tensors.
7+
This scaler computes the mean and standard deviation on a dataset so that
8+
it can later be used to scale the data using the computed mean and standard deviation.
9+
10+
Attributes:
11+
mean (torch.Tensor): The mean value computed over the fitted data.
12+
std (torch.Tensor): The standard deviation computed over the fitted data.
13+
14+
Examples:
15+
>>> import torch
16+
>>> from spotPython.utils.scaler import TorchStandardScaler
17+
# Create a sample tensor
18+
>>> tensor = torch.rand((10, 3)) # Random tensor with shape (10, 3)
19+
>>> scaler = TorchStandardScaler()
20+
# Fit the scaler to the data
21+
>>> scaler.fit(tensor)
22+
# Transform the data using the fitted scaler
23+
>>> transformed_tensor = scaler.transform(tensor)
24+
>>> print(transformed_tensor)
25+
# Using fit_transform method to fit and transform in one step
26+
>>> another_tensor = torch.rand((10, 3))
27+
>>> scaled_tensor = scaler.fit_transform(another_tensor)
28+
>>> print(scaled_tensor)
729
"""
830

9-
def fit(self, x):
31+
def __init__(self):
32+
"""
33+
Initializes the TorchStandardScaler class without any pre-defined mean and std.
34+
"""
35+
self.mean = None
36+
self.std = None
37+
38+
def fit(self, x: torch.Tensor) -> None:
1039
"""
1140
Compute the mean and standard deviation of the input tensor.
1241
1342
Args:
14-
x (torch.Tensor): The input tensor.
43+
x (torch.Tensor): The input tensor, expected shape [n_samples, n_features]
1544
1645
Raises:
1746
TypeError: If the input is not a torch tensor.
1847
"""
1948
if not torch.is_tensor(x):
2049
raise TypeError("Input should be a torch tensor")
21-
self.mean = x.mean(0, keepdim=True)
22-
self.std = x.std(0, unbiased=False, keepdim=True)
50+
self.mean = x.mean(dim=0, keepdim=True)
51+
self.std = x.std(dim=0, unbiased=False, keepdim=True)
2352

24-
def transform(self, x):
53+
def transform(self, x: torch.Tensor) -> torch.Tensor:
2554
"""
2655
Scale the input tensor using the computed mean and standard deviation.
2756
2857
Args:
29-
x (torch.Tensor): The input tensor.
58+
x (torch.Tensor): The input tensor to be transformed, expected shape [n_samples, n_features]
3059
3160
Returns:
3261
torch.Tensor: The scaled tensor.
@@ -37,56 +66,77 @@ def transform(self, x):
3766
"""
3867
if not torch.is_tensor(x):
3968
raise TypeError("Input should be a torch tensor")
40-
if not hasattr(self, "mean") or not hasattr(self, "std"):
69+
if self.mean is None or self.std is None:
4170
raise RuntimeError("Must fit scaler before transforming data")
4271
x = (x - self.mean) / (self.std + 1e-7)
4372
return x
4473

45-
def fit_transform(self, x):
74+
def fit_transform(self, x: torch.Tensor) -> torch.Tensor:
4675
"""
4776
Fit the scaler to the input tensor and then scale the tensor.
4877
4978
Args:
50-
x (torch.Tensor): The input tensor.
79+
x (torch.Tensor): The input tensor, expected shape [n_samples, n_features]
5180
5281
Returns:
5382
torch.Tensor: The scaled tensor.
54-
83+
5584
Raises:
5685
TypeError: If the input is not a torch tensor.
5786
"""
58-
if not torch.is_tensor(x):
59-
raise TypeError("Input should be a torch tensor")
6087
self.fit(x)
6188
return self.transform(x)
6289

6390

6491
class TorchMinMaxScaler:
6592
"""
6693
A class for scaling data using min-max normalization with PyTorch tensors.
94+
This scaler calculates the minimum and maximum values in the dataset to scale the data within a given range.
95+
96+
Attributes:
97+
min (torch.Tensor): The minimum values computed over the fitted data.
98+
max (torch.Tensor): The maximum values computed over the fitted data.
99+
100+
Examples:
101+
>>> import torch
102+
>>> from spotPython.utils.scaler import TorchMinMaxScaler
103+
>>> scaler = TorchMinMaxScaler()
104+
# Given a tensor
105+
>>> tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
106+
# Fit and transform the tensor using the scaler
107+
>>> scaled_tensor = scaler.fit_transform(tensor)
108+
>>> print(scaled_tensor)
109+
# The output will be a tensor with values scaled between 0 and 1.
67110
"""
68111

69-
def fit(self, x):
112+
def __init__(self):
113+
"""
114+
Initializes the TorchMinMaxScaler class without any predefined min and max.
115+
"""
116+
self.min = None
117+
self.max = None
118+
119+
def fit(self, x: torch.Tensor) -> None:
70120
"""
71121
Compute the minimum and maximum value of the input tensor.
72122
73-
Parameters:
123+
Args:
74124
x (torch.Tensor): The input tensor.
75125
76126
Raises:
77127
TypeError: If the input is not a torch tensor.
78128
"""
79129
if not torch.is_tensor(x):
80130
raise TypeError("Input should be a torch tensor")
81-
self.min = x.min(0, keepdim=True).values
82-
self.max = x.max(0, keepdim=True).values
131+
self.min = x.min(dim=0, keepdim=True).values
132+
self.max = x.max(dim=0, keepdim=True).values
83133

84-
def transform(self, x):
134+
def transform(self, x: torch.Tensor) -> torch.Tensor:
85135
"""
86136
Scale the input tensor using the computed minimum and maximum values.
87137
88138
Args:
89-
x (torch.Tensor): The input tensor.
139+
x (torch.Tensor): The input tensor to be scaled.
90140
91141
Returns:
92142
torch.Tensor: The scaled tensor.
@@ -97,12 +147,12 @@ def transform(self, x):
97147
"""
98148
if not torch.is_tensor(x):
99149
raise TypeError("Input should be a torch tensor")
100-
if not hasattr(self, "min") or not hasattr(self, "max"):
150+
if self.min is None or self.max is None:
101151
raise RuntimeError("Must fit scaler before transforming data")
102152
x = (x - self.min) / (self.max - self.min + 1e-7)
103153
return x
104154

105-
def fit_transform(self, x):
155+
def fit_transform(self, x: torch.Tensor) -> torch.Tensor:
106156
"""
107157
Fit the scaler to the input tensor and then scale the tensor.
108158
@@ -115,7 +165,5 @@ def fit_transform(self, x):
115165
Raises:
116166
TypeError: If the input is not a torch tensor.
117167
"""
118-
if not torch.is_tensor(x):
119-
raise TypeError("Input should be a torch tensor")
120168
self.fit(x)
121169
return self.transform(x)

test/test_scaler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
from spotPython.data.lightdatamodule import LightDataModule
3-
from spotPython.data.csvdataset import CSVDataset
43
from spotPython.utils.scaler import TorchStandardScaler, TorchMinMaxScaler
54
from spotPython.data.california_housing import CaliforniaHousing
65

6+
77
def test_standard_scaler():
88
"""
99
Test if TorchStandardScaler scales data around 0.
@@ -30,9 +30,10 @@ def test_standard_scaler():
3030
# Calculate the mean over all inputs
3131
mean_inputs = total_sum / total_count
3232
overall_mean = mean_inputs.mean()
33-
#assert that overall mean goes against zero
33+
# assert that overall mean goes against zero
3434
assert overall_mean < 0.00001
35-
35+
36+
3637
def test_min_max_scaler():
3738
"""
3839
Test if TorchMinMaxScaler scales data between 0 and 1.
@@ -48,4 +49,3 @@ def test_min_max_scaler():
4849
for batch in loader():
4950
inputs, targets = batch
5051
assert torch.all(inputs >= 0) and torch.all(inputs <= 1), "Inputs are not scaled between 0 and 1"
51-

test/test_torch_minmax_scaler.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import torch
3+
from spotPython.utils.scaler import TorchMinMaxScaler
4+
5+
6+
def test_min_max_scaler_fit():
7+
"""Test the min and max values computed by the `fit` method."""
8+
tensor = torch.tensor([[2.0, 4.0], [1.0, 5.0], [3.0, 6.0]])
9+
expected_min = torch.tensor([[1.0, 4.0]])
10+
expected_max = torch.tensor([[3.0, 6.0]])
11+
12+
scaler = TorchMinMaxScaler()
13+
scaler.fit(tensor)
14+
15+
torch.testing.assert_allclose(scaler.min, expected_min)
16+
torch.testing.assert_allclose(scaler.max, expected_max)
17+
18+
19+
def test_min_max_scaler_transform():
20+
"""Test the output of the `transform` method."""
21+
tensor = torch.tensor([[2.0, 4.0], [1.0, 5.0], [3.0, 6.0]])
22+
scaler = TorchMinMaxScaler()
23+
scaler.fit(tensor)
24+
transformed = scaler.transform(tensor)
25+
26+
expected_transformed = torch.tensor([[0.5, 0.0], [0.0, 0.5], [1.0, 1.0]])
27+
28+
torch.testing.assert_allclose(transformed, expected_transformed)
29+
30+
31+
def test_min_max_scaler_fit_transform():
32+
"""Check that `fit_transform` method correctly fits and transforms the data."""
33+
tensor = torch.tensor([[2.0, 4.0], [1.0, 5.0], [3.0, 6.0]])
34+
scaler = TorchMinMaxScaler()
35+
transformed = scaler.fit_transform(tensor)
36+
37+
expected_transformed = torch.tensor([[0.5, 0.0], [0.0, 0.5], [1.0, 1.0]])
38+
39+
torch.testing.assert_allclose(transformed, expected_transformed)
40+
41+
42+
def test_input_validation():
43+
"""Ensure type error is raised with incorrect input type."""
44+
scaler = TorchMinMaxScaler()
45+
with pytest.raises(TypeError):
46+
scaler.fit([[1, 2], [3, 4]]) # Not a tensor, should raise error
47+
48+
49+
def test_transform_before_fit():
50+
"""Ensure appropriate error is raised when transform is called before fit."""
51+
scaler = TorchMinMaxScaler()
52+
with pytest.raises(RuntimeError):
53+
scaler.transform(torch.tensor([[2.0, 4.0], [1.0, 5.0]]))

test/test_torch_standard_scaler.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import torch
3+
from spotPython.utils.scaler import TorchStandardScaler
4+
5+
6+
def test_fit():
7+
"""Test the `fit` method for correct mean and std computation."""
8+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
9+
expected_mean = torch.tensor([[2.0, 3.0]])
10+
expected_std = torch.tensor([[1.0, 1.0]])
11+
12+
scaler = TorchStandardScaler()
13+
scaler.fit(tensor)
14+
15+
torch.testing.assert_allclose(scaler.mean, expected_mean)
16+
torch.testing.assert_allclose(scaler.std, expected_std, atol=1e-7, rtol=1e-7)
17+
18+
19+
def test_transform():
20+
"""Test the `transform` method for correct data scaling."""
21+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
22+
scaler = TorchStandardScaler()
23+
scaler.fit(tensor)
24+
transformed = scaler.transform(tensor)
25+
26+
expected_transformed = torch.tensor([[-1.0, -1.0], [1.0, 1.0]])
27+
28+
torch.testing.assert_allclose(transformed, expected_transformed)
29+
30+
31+
def test_fit_transform():
32+
"""Test the `fit_transform` method for combined fitting and transforming."""
33+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
34+
scaler = TorchStandardScaler()
35+
transformed = scaler.fit_transform(tensor)
36+
37+
expected_transformed = torch.tensor([[-1.0, -1.0], [1.0, 1.0]])
38+
39+
torch.testing.assert_allclose(transformed, expected_transformed)
40+
41+
42+
def test_input_not_tensor():
43+
"""Test that a TypeError is raised if the input data is not a tensor."""
44+
scaler = TorchStandardScaler()
45+
with pytest.raises(TypeError):
46+
scaler.fit([1.0, 2.0]) # Passing a list instead of a tensor
47+
48+
49+
def test_unfitted_transform():
50+
"""Test that a RuntimeError is raised if attempting to transform without fitting first."""
51+
tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
52+
scaler = TorchStandardScaler()
53+
54+
with pytest.raises(RuntimeError):
55+
scaler.transform(tensor)

0 commit comments

Comments
 (0)