|
4066 | 4066 | }, |
4067 | 4067 | { |
4068 | 4068 | "cell_type": "code", |
4069 | | - "execution_count": 25, |
| 4069 | + "execution_count": 28, |
4070 | 4070 | "metadata": {}, |
4071 | 4071 | "outputs": [ |
4072 | 4072 | { |
|
4079 | 4079 | "test_size: 0.5 used for predict dataset.\n", |
4080 | 4080 | "LightDataModule.train_dataloader(). data_train size: 5160\n" |
4081 | 4081 | ] |
| 4082 | + }, |
| 4083 | + { |
| 4084 | + "data": { |
| 4085 | + "text/plain": [ |
| 4086 | + "tensor([ 23.0493, 27.5234, 23.3288, 22.5529, 275.2078, 22.8845, 28.7669,\n", |
| 4087 | + " 0.8448], grad_fn=<AddBackward0>)" |
| 4088 | + ] |
| 4089 | + }, |
| 4090 | + "execution_count": 28, |
| 4091 | + "metadata": {}, |
| 4092 | + "output_type": "execute_result" |
4082 | 4093 | } |
4083 | 4094 | ], |
4084 | 4095 | "source": [ |
4085 | 4096 | "import torch\n", |
4086 | 4097 | "from torch.utils.data import DataLoader\n", |
4087 | 4098 | "from spotPython.data.lightdatamodule import LightDataModule\n", |
4088 | 4099 | "from spotPython.data.csvdataset import CSVDataset\n", |
4089 | | - "from spotPython.utils.scaler import TorchStandardScaler\n", |
| 4100 | + "from spotPython.utils.scaler import TorchStandardScaler, TorchMinMaxScaler\n", |
4090 | 4101 | "from spotPython.data.california_housing import CaliforniaHousing\n", |
4091 | 4102 | "\n", |
4092 | 4103 | "\n", |
4093 | 4104 | "dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)\n", |
4094 | | - "scaler = TorchStandardScaler()\n", |
| 4105 | + "scaler = TorchMinMaxScaler()\n", |
4095 | 4106 | "data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)\n", |
4096 | 4107 | "data_module.setup()\n", |
4097 | 4108 | "\n", |
|
4103 | 4114 | "# Iterate over batches in the DataLoader\n", |
4104 | 4115 | "for batch in loader():\n", |
4105 | 4116 | " inputs, targets = batch\n", |
4106 | | - " if total_sum is None:\n", |
4107 | | - " total_sum = inputs.sum(dim=0)\n", |
4108 | | - " else:\n", |
4109 | | - " total_sum += inputs.sum(dim=0)\n", |
4110 | | - " total_count += inputs.shape[0]\n", |
| 4117 | + " \n", |
4111 | 4118 | "\n", |
4112 | | - "# Calculate the mean over all inputs\n", |
4113 | | - "mean_inputs = total_sum / total_count\n", |
4114 | | - "assert mean_inputs.mean() < 0.00001" |
| 4119 | + "total_sum\n" |
4115 | 4120 | ] |
4116 | 4121 | }, |
4117 | 4122 | { |
|
0 commit comments