Skip to content

Commit 9348ce4

Browse files
committed
update scaler doc
1 parent 0cc4f54 commit 9348ce4

1 file changed

Lines changed: 36 additions & 17 deletions

File tree

src/spotPython/utils/scaler.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ class TorchStandardScaler:
99
def fit(self, x):
1010
"""
1111
Compute the mean and standard deviation of the input tensor.
12+
1213
Args:
1314
x (torch.Tensor): The input tensor.
15+
1416
Raises:
1517
TypeError: If the input is not a torch tensor.
1618
"""
@@ -22,10 +24,13 @@ def fit(self, x):
2224
def transform(self, x):
2325
"""
2426
Scale the input tensor using the computed mean and standard deviation.
27+
2528
Args:
2629
x (torch.Tensor): The input tensor.
30+
2731
Returns:
2832
torch.Tensor: The scaled tensor.
33+
2934
Raises:
3035
TypeError: If the input is not a torch tensor.
3136
RuntimeError: If the scaler has not been fitted before transforming data.
@@ -40,11 +45,18 @@ def transform(self, x):
4045
def fit_transform(self, x):
4146
"""
4247
Fit the scaler to the input tensor and then scale the tensor.
48+
4349
Args:
4450
x (torch.Tensor): The input tensor.
51+
4552
Returns:
4653
torch.Tensor: The scaled tensor.
54+
55+
Raises:
56+
TypeError: If the input is not a torch tensor.
4757
"""
58+
if not torch.is_tensor(x):
59+
raise TypeError("Input should be a torch tensor")
4860
self.fit(x)
4961
return self.transform(x)
5062

@@ -56,12 +68,13 @@ class TorchMinMaxScaler:
5668

5769
def fit(self, x):
5870
"""
59-
Fit the scaler to the input data.
71+
Compute the minimum and maximum value of the input tensor.
72+
6073
Parameters:
61-
- x: torch.Tensor
62-
The input data to fit the scaler to.
74+
x (torch.Tensor): The input tensor.
75+
6376
Raises:
64-
- TypeError: If the input is not a torch tensor.
77+
TypeError: If the input is not a torch tensor.
6578
"""
6679
if not torch.is_tensor(x):
6780
raise TypeError("Input should be a torch tensor")
@@ -70,15 +83,17 @@ def fit(self, x):
7083

7184
def transform(self, x):
7285
"""
73-
Transform the input data using the fitted scaler.
74-
Parameters:
75-
- x: torch.Tensor
76-
The input data to transform.
86+
Scale the input tensor using the computed minimum and maximum values.
87+
88+
Args:
89+
x (torch.Tensor): The input tensor.
90+
7791
Returns:
78-
- torch.Tensor: The transformed data.
92+
torch.Tensor: The scaled tensor.
93+
7994
Raises:
80-
- TypeError: If the input is not a torch tensor.
81-
- RuntimeError: If the scaler has not been fitted before transforming data.
95+
TypeError: If the input is not a torch tensor.
96+
RuntimeError: If the scaler has not been fitted before transforming data.
8297
"""
8398
if not torch.is_tensor(x):
8499
raise TypeError("Input should be a torch tensor")
@@ -89,14 +104,18 @@ def transform(self, x):
89104

90105
def fit_transform(self, x):
91106
"""
92-
Fit the scaler to the input data and transform it.
93-
Parameters:
94-
- x: torch.Tensor
95-
The input data to fit and transform.
107+
Fit the scaler to the input tensor and then scale the tensor.
108+
109+
Args:
110+
x (torch.Tensor): The input tensor.
111+
96112
Returns:
97-
- torch.Tensor: The transformed data.
113+
torch.Tensor: The scaled tensor.
114+
98115
Raises:
99-
- TypeError: If the input is not a torch tensor.
116+
TypeError: If the input is not a torch tensor.
100117
"""
118+
if not torch.is_tensor(x):
119+
raise TypeError("Input should be a torch tensor")
101120
self.fit(x)
102121
return self.transform(x)

0 commit comments

Comments
 (0)