Skip to content

Commit 9028ee1

Browse files
committed
test min max scaler
1 parent 3f5b232 commit 9028ee1

2 files changed

Lines changed: 36 additions & 13 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4066,7 +4066,7 @@
40664066
},
40674067
{
40684068
"cell_type": "code",
4069-
"execution_count": 25,
4069+
"execution_count": 28,
40704070
"metadata": {},
40714071
"outputs": [
40724072
{
@@ -4079,19 +4079,30 @@
40794079
"test_size: 0.5 used for predict dataset.\n",
40804080
"LightDataModule.train_dataloader(). data_train size: 5160\n"
40814081
]
4082+
},
4083+
{
4084+
"data": {
4085+
"text/plain": [
4086+
"tensor([ 23.0493, 27.5234, 23.3288, 22.5529, 275.2078, 22.8845, 28.7669,\n",
4087+
" 0.8448], grad_fn=<AddBackward0>)"
4088+
]
4089+
},
4090+
"execution_count": 28,
4091+
"metadata": {},
4092+
"output_type": "execute_result"
40824093
}
40834094
],
40844095
"source": [
40854096
"import torch\n",
40864097
"from torch.utils.data import DataLoader\n",
40874098
"from spotPython.data.lightdatamodule import LightDataModule\n",
40884099
"from spotPython.data.csvdataset import CSVDataset\n",
4089-
"from spotPython.utils.scaler import TorchStandardScaler\n",
4100+
"from spotPython.utils.scaler import TorchStandardScaler, TorchMinMaxScaler\n",
40904101
"from spotPython.data.california_housing import CaliforniaHousing\n",
40914102
"\n",
40924103
"\n",
40934104
"dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)\n",
4094-
"scaler = TorchStandardScaler()\n",
4105+
"scaler = TorchMinMaxScaler()\n",
40954106
"data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)\n",
40964107
"data_module.setup()\n",
40974108
"\n",
@@ -4103,15 +4114,9 @@
41034114
"# Iterate over batches in the DataLoader\n",
41044115
"for batch in loader():\n",
41054116
" inputs, targets = batch\n",
4106-
" if total_sum is None:\n",
4107-
" total_sum = inputs.sum(dim=0)\n",
4108-
" else:\n",
4109-
" total_sum += inputs.sum(dim=0)\n",
4110-
" total_count += inputs.shape[0]\n",
4117+
" \n",
41114118
"\n",
4112-
"# Calculate the mean over all inputs\n",
4113-
"mean_inputs = total_sum / total_count\n",
4114-
"assert mean_inputs.mean() < 0.00001"
4119+
"total_sum\n"
41154120
]
41164121
},
41174122
{

test/test_scaler.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import torch
22
from spotPython.data.lightdatamodule import LightDataModule
33
from spotPython.data.csvdataset import CSVDataset
4-
from spotPython.utils.scaler import TorchStandardScaler
4+
from spotPython.utils.scaler import TorchStandardScaler, TorchMinMaxScaler
55
from spotPython.data.california_housing import CaliforniaHousing
66

7-
def test_scaler():
7+
def test_standard_scaler():
8+
"""
9+
Test if TorchStandardScaler scales data around 0.
10+
"""
811
dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)
912
scaler = TorchStandardScaler()
1013
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)
@@ -30,4 +33,19 @@ def test_scaler():
3033
#assert that overall mean goes against zero
3134
assert overall_mean < 0.00001
3235

36+
def test_min_max_scaler():
37+
"""
38+
Test if TorchMinMaxScaler scales data between 0 and 1.
39+
"""
40+
dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)
41+
scaler = TorchMinMaxScaler()
42+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)
43+
data_module.setup()
44+
45+
loader = data_module.train_dataloader
46+
47+
# Iterate over batches in the DataLoader
48+
for batch in loader():
49+
inputs, targets = batch
50+
assert torch.all(inputs >= 0) and torch.all(inputs <= 1), "Inputs are not scaled between 0 and 1"
3351

0 commit comments

Comments
 (0)