@@ -9,8 +9,10 @@ class TorchStandardScaler:
99 def fit (self , x ):
1010 """
1111 Compute the mean and standard deviation of the input tensor.
12+
1213 Args:
1314 x (torch.Tensor): The input tensor.
15+
1416 Raises:
1517 TypeError: If the input is not a torch tensor.
1618 """
@@ -22,10 +24,13 @@ def fit(self, x):
2224 def transform (self , x ):
2325 """
2426 Scale the input tensor using the computed mean and standard deviation.
27+
2528 Args:
2629 x (torch.Tensor): The input tensor.
30+
2731 Returns:
2832 torch.Tensor: The scaled tensor.
33+
2934 Raises:
3035 TypeError: If the input is not a torch tensor.
3136 RuntimeError: If the scaler has not been fitted before transforming data.
@@ -40,11 +45,18 @@ def transform(self, x):
4045 def fit_transform (self , x ):
4146 """
4247 Fit the scaler to the input tensor and then scale the tensor.
48+
4349 Args:
4450 x (torch.Tensor): The input tensor.
51+
4552 Returns:
4653 torch.Tensor: The scaled tensor.
54+
55+ Raises:
56+ TypeError: If the input is not a torch tensor.
4757 """
58+ if not torch .is_tensor (x ):
59+ raise TypeError ("Input should be a torch tensor" )
4860 self .fit (x )
4961 return self .transform (x )
5062
@@ -56,12 +68,13 @@ class TorchMinMaxScaler:
5668
5769 def fit (self , x ):
5870 """
59- Fit the scaler to the input data.
71+ Compute the minimum and maximum value of the input tensor.
72+
6073 Parameters:
61- - x: torch.Tensor
62- The input data to fit the scaler to.
74+ x ( torch.Tensor): The input tensor.
75+
6376 Raises:
64- - TypeError: If the input is not a torch tensor.
77+ TypeError: If the input is not a torch tensor.
6578 """
6679 if not torch .is_tensor (x ):
6780 raise TypeError ("Input should be a torch tensor" )
@@ -70,15 +83,17 @@ def fit(self, x):
7083
7184 def transform (self , x ):
7285 """
73- Transform the input data using the fitted scaler.
74- Parameters:
75- - x: torch.Tensor
76- The input data to transform.
86+ Scale the input tensor using the computed minimum and maximum values.
87+
88+ Args:
89+ x (torch.Tensor): The input tensor.
90+
7791 Returns:
78- - torch.Tensor: The transformed data.
92+ torch.Tensor: The scaled tensor.
93+
7994 Raises:
80- - TypeError: If the input is not a torch tensor.
81- - RuntimeError: If the scaler has not been fitted before transforming data.
95+ TypeError: If the input is not a torch tensor.
96+ RuntimeError: If the scaler has not been fitted before transforming data.
8297 """
8398 if not torch .is_tensor (x ):
8499 raise TypeError ("Input should be a torch tensor" )
@@ -89,14 +104,18 @@ def transform(self, x):
89104
90105 def fit_transform (self , x ):
91106 """
92- Fit the scaler to the input data and transform it.
93- Parameters:
94- - x: torch.Tensor
95- The input data to fit and transform.
107+ Fit the scaler to the input tensor and then scale the tensor.
108+
109+ Args:
110+ x (torch.Tensor): The input tensor.
111+
96112 Returns:
97- - torch.Tensor: The transformed data.
113+ torch.Tensor: The scaled tensor.
114+
98115 Raises:
99- - TypeError: If the input is not a torch tensor.
116+ TypeError: If the input is not a torch tensor.
100117 """
118+ if not torch .is_tensor (x ):
119+ raise TypeError ("Input should be a torch tensor" )
101120 self .fit (x )
102121 return self .transform (x )
0 commit comments