44class TorchStandardScaler :
55 """
66 A class for scaling data using standardization with torch tensors.
7+ This scaler computes the mean and standard deviation on a dataset so that
8+ it can later be used to scale the data using the computed mean and standard deviation.
9+
10+ Attributes:
11+ mean (torch.Tensor): The mean value computed over the fitted data.
12+ std (torch.Tensor): The standard deviation computed over the fitted data.
13+
14+ Examples:
15+ >>> import torch
16+ >>> from spotPython.utils.scaler import TorchStandardScaler
17+ # Create a sample tensor
18+ >>> tensor = torch.rand((10, 3)) # Random tensor with shape (10, 3)
19+ >>> scaler = TorchStandardScaler()
20+ # Fit the scaler to the data
21+ >>> scaler.fit(tensor)
22+ # Transform the data using the fitted scaler
23+ >>> transformed_tensor = scaler.transform(tensor)
24+ >>> print(transformed_tensor)
25+ # Using fit_transform method to fit and transform in one step
26+ >>> another_tensor = torch.rand((10, 3))
27+ >>> scaled_tensor = scaler.fit_transform(another_tensor)
28+ >>> print(scaled_tensor)
729 """
830
9- def fit (self , x ):
31+ def __init__ (self ):
32+ """
33+ Initializes the TorchStandardScaler class without any pre-defined mean and std.
34+ """
35+ self .mean = None
36+ self .std = None
37+
38+ def fit (self , x : torch .Tensor ) -> None :
1039 """
1140 Compute the mean and standard deviation of the input tensor.
1241
1342 Args:
14- x (torch.Tensor): The input tensor.
43+ x (torch.Tensor): The input tensor, expected shape [n_samples, n_features]
1544
1645 Raises:
1746 TypeError: If the input is not a torch tensor.
1847 """
1948 if not torch .is_tensor (x ):
2049 raise TypeError ("Input should be a torch tensor" )
21- self .mean = x .mean (0 , keepdim = True )
22- self .std = x .std (0 , unbiased = False , keepdim = True )
50+ self .mean = x .mean (dim = 0 , keepdim = True )
51+ self .std = x .std (dim = 0 , unbiased = False , keepdim = True )
2352
24- def transform (self , x ) :
53+ def transform (self , x : torch . Tensor ) -> torch . Tensor :
2554 """
2655 Scale the input tensor using the computed mean and standard deviation.
2756
2857 Args:
29- x (torch.Tensor): The input tensor.
58+ x (torch.Tensor): The input tensor to be transformed, expected shape [n_samples, n_features]
3059
3160 Returns:
3261 torch.Tensor: The scaled tensor.
@@ -37,56 +66,77 @@ def transform(self, x):
3766 """
3867 if not torch .is_tensor (x ):
3968 raise TypeError ("Input should be a torch tensor" )
40- if not hasattr ( self , " mean" ) or not hasattr ( self , " std" ) :
69+ if self . mean is None or self . std is None :
4170 raise RuntimeError ("Must fit scaler before transforming data" )
4271 x = (x - self .mean ) / (self .std + 1e-7 )
4372 return x
4473
45- def fit_transform (self , x ) :
74+ def fit_transform (self , x : torch . Tensor ) -> torch . Tensor :
4675 """
4776 Fit the scaler to the input tensor and then scale the tensor.
4877
4978 Args:
50- x (torch.Tensor): The input tensor.
79+ x (torch.Tensor): The input tensor, expected shape [n_samples, n_features]
5180
5281 Returns:
5382 torch.Tensor: The scaled tensor.
54-
83+
5584 Raises:
5685 TypeError: If the input is not a torch tensor.
5786 """
58- if not torch .is_tensor (x ):
59- raise TypeError ("Input should be a torch tensor" )
6087 self .fit (x )
6188 return self .transform (x )
6289
6390
6491class TorchMinMaxScaler :
6592 """
6693 A class for scaling data using min-max normalization with PyTorch tensors.
94+ This scaler calculates the minimum and maximum values in the dataset to scale the data within a given range.
95+
96+ Attributes:
97+ min (torch.Tensor): The minimum values computed over the fitted data.
98+ max (torch.Tensor): The maximum values computed over the fitted data.
99+
100+ Examples:
101+ >>> import torch
102+ >>> from spotPython.utils.scaler import TorchMinMaxScaler
103+ >>> scaler = TorchMinMaxScaler()
104+ # Given a tensor
105+ >>> tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
106+ # Fit and transform the tensor using the scaler
107+ >>> scaled_tensor = scaler.fit_transform(tensor)
108+ >>> print(scaled_tensor)
109+ # The output will be a tensor with values scaled between 0 and 1.
67110 """
68111
69- def fit (self , x ):
112+ def __init__ (self ):
113+ """
114+ Initializes the TorchMinMaxScaler class without any predefined min and max.
115+ """
116+ self .min = None
117+ self .max = None
118+
119+ def fit (self , x : torch .Tensor ) -> None :
70120 """
71121 Compute the minimum and maximum value of the input tensor.
72122
73- Parameters :
123+ Args :
74124 x (torch.Tensor): The input tensor.
75125
76126 Raises:
77127 TypeError: If the input is not a torch tensor.
78128 """
79129 if not torch .is_tensor (x ):
80130 raise TypeError ("Input should be a torch tensor" )
81- self .min = x .min (0 , keepdim = True ).values
82- self .max = x .max (0 , keepdim = True ).values
131+ self .min = x .min (dim = 0 , keepdim = True ).values
132+ self .max = x .max (dim = 0 , keepdim = True ).values
83133
84- def transform (self , x ) :
134+ def transform (self , x : torch . Tensor ) -> torch . Tensor :
85135 """
86136 Scale the input tensor using the computed minimum and maximum values.
87137
88138 Args:
89- x (torch.Tensor): The input tensor.
139+ x (torch.Tensor): The input tensor to be scaled .
90140
91141 Returns:
92142 torch.Tensor: The scaled tensor.
@@ -97,12 +147,12 @@ def transform(self, x):
97147 """
98148 if not torch .is_tensor (x ):
99149 raise TypeError ("Input should be a torch tensor" )
100- if not hasattr ( self , " min" ) or not hasattr ( self , " max" ) :
150+ if self . min is None or self . max is None :
101151 raise RuntimeError ("Must fit scaler before transforming data" )
102152 x = (x - self .min ) / (self .max - self .min + 1e-7 )
103153 return x
104154
105- def fit_transform (self , x ) :
155+ def fit_transform (self , x : torch . Tensor ) -> torch . Tensor :
106156 """
107157 Fit the scaler to the input tensor and then scale the tensor.
108158
@@ -115,7 +165,5 @@ def fit_transform(self, x):
115165 Raises:
116166 TypeError: If the input is not a torch tensor.
117167 """
118- if not torch .is_tensor (x ):
119- raise TypeError ("Input should be a torch tensor" )
120168 self .fit (x )
121169 return self .transform (x )
0 commit comments