Skip to content

Commit 3f5b232

Browse files
committed
scaler test
1 parent 54489cf commit 3f5b232

2 files changed

Lines changed: 83 additions & 0 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4064,6 +4064,56 @@
40644064
"print(f\"S.y: {S.y}\")"
40654065
]
40664066
},
4067+
{
4068+
"cell_type": "code",
4069+
"execution_count": 25,
4070+
"metadata": {},
4071+
"outputs": [
4072+
{
4073+
"name": "stdout",
4074+
"output_type": "stream",
4075+
"text": [
4076+
"LightDataModule.setup(): stage: None\n",
4077+
"train_size: 0.25, val_size: 0.25 used for train & val data.\n",
4078+
"test_size: 0.5 used for test dataset.\n",
4079+
"test_size: 0.5 used for predict dataset.\n",
4080+
"LightDataModule.train_dataloader(). data_train size: 5160\n"
4081+
]
4082+
}
4083+
],
4084+
"source": [
4085+
"import torch\n",
4086+
"from torch.utils.data import DataLoader\n",
4087+
"from spotPython.data.lightdatamodule import LightDataModule\n",
4088+
"from spotPython.data.csvdataset import CSVDataset\n",
4089+
"from spotPython.utils.scaler import TorchStandardScaler\n",
4090+
"from spotPython.data.california_housing import CaliforniaHousing\n",
4091+
"\n",
4092+
"\n",
4093+
"dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)\n",
4094+
"scaler = TorchStandardScaler()\n",
4095+
"data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)\n",
4096+
"data_module.setup()\n",
4097+
"\n",
4098+
"loader = data_module.train_dataloader\n",
4099+
"\n",
4100+
"total_sum = None\n",
4101+
"total_count = 0\n",
4102+
"\n",
4103+
"# Iterate over batches in the DataLoader\n",
4104+
"for batch in loader():\n",
4105+
" inputs, targets = batch\n",
4106+
" if total_sum is None:\n",
4107+
" total_sum = inputs.sum(dim=0)\n",
4108+
" else:\n",
4109+
" total_sum += inputs.sum(dim=0)\n",
4110+
" total_count += inputs.shape[0]\n",
4111+
"\n",
4112+
"# Calculate the mean over all inputs\n",
4113+
"mean_inputs = total_sum / total_count\n",
4114+
"assert mean_inputs.mean() < 0.00001"
4115+
]
4116+
},
40674117
{
40684118
"cell_type": "code",
40694119
"execution_count": null,

test/test_scaler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from spotPython.data.lightdatamodule import LightDataModule
3+
from spotPython.data.csvdataset import CSVDataset
4+
from spotPython.utils.scaler import TorchStandardScaler
5+
from spotPython.data.california_housing import CaliforniaHousing
6+
7+
def test_scaler():
8+
dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)
9+
scaler = TorchStandardScaler()
10+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)
11+
data_module.setup()
12+
13+
loader = data_module.train_dataloader
14+
15+
total_sum = None
16+
total_count = 0
17+
18+
# Iterate over batches in the DataLoader
19+
for batch in loader():
20+
inputs, targets = batch
21+
if total_sum is None:
22+
total_sum = inputs.sum(dim=0)
23+
else:
24+
total_sum += inputs.sum(dim=0)
25+
total_count += inputs.shape[0]
26+
27+
# Calculate the mean over all inputs
28+
mean_inputs = total_sum / total_count
29+
overall_mean = mean_inputs.mean()
30+
#assert that overall mean goes against zero
31+
assert overall_mean < 0.00001
32+
33+

0 commit comments

Comments
 (0)