|
4064 | 4064 | "print(f\"S.y: {S.y}\")" |
4065 | 4065 | ] |
4066 | 4066 | }, |
| 4067 | + { |
| 4068 | + "cell_type": "code", |
| 4069 | + "execution_count": 25, |
| 4070 | + "metadata": {}, |
| 4071 | + "outputs": [ |
| 4072 | + { |
| 4073 | + "name": "stdout", |
| 4074 | + "output_type": "stream", |
| 4075 | + "text": [ |
| 4076 | + "LightDataModule.setup(): stage: None\n", |
| 4077 | + "train_size: 0.25, val_size: 0.25 used for train & val data.\n", |
| 4078 | + "test_size: 0.5 used for test dataset.\n", |
| 4079 | + "test_size: 0.5 used for predict dataset.\n", |
| 4080 | + "LightDataModule.train_dataloader(). data_train size: 5160\n" |
| 4081 | + ] |
| 4082 | + } |
| 4083 | + ], |
| 4084 | + "source": [ |
| 4085 | + "import torch\n", |
| 4086 | + "from torch.utils.data import DataLoader\n", |
| 4087 | + "from spotPython.data.lightdatamodule import LightDataModule\n", |
| 4088 | + "from spotPython.data.csvdataset import CSVDataset\n", |
| 4089 | + "from spotPython.utils.scaler import TorchStandardScaler\n", |
| 4090 | + "from spotPython.data.california_housing import CaliforniaHousing\n", |
| 4091 | + "\n", |
| 4092 | + "\n", |
| 4093 | + "dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)\n", |
| 4094 | + "scaler = TorchStandardScaler()\n", |
| 4095 | + "data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)\n", |
| 4096 | + "data_module.setup()\n", |
| 4097 | + "\n", |
| 4098 | + "loader = data_module.train_dataloader\n", |
| 4099 | + "\n", |
| 4100 | + "total_sum = None\n", |
| 4101 | + "total_count = 0\n", |
| 4102 | + "\n", |
| 4103 | + "# Iterate over batches in the DataLoader\n", |
| 4104 | + "for batch in loader():\n", |
| 4105 | + " inputs, targets = batch\n", |
| 4106 | + " if total_sum is None:\n", |
| 4107 | + " total_sum = inputs.sum(dim=0)\n", |
| 4108 | + " else:\n", |
| 4109 | + " total_sum += inputs.sum(dim=0)\n", |
| 4110 | + " total_count += inputs.shape[0]\n", |
| 4111 | + "\n", |
| 4112 | + "# Calculate the mean over all inputs\n", |
| 4113 | + "mean_inputs = total_sum / total_count\n", |
| 4114 | + "assert mean_inputs.mean() < 0.00001" |
| 4115 | + ] |
| 4116 | + }, |
4067 | 4117 | { |
4068 | 4118 | "cell_type": "code", |
4069 | 4119 | "execution_count": null, |
|
0 commit comments