|
4119 | 4119 | "total_sum\n" |
4120 | 4120 | ] |
4121 | 4121 | }, |
| 4122 | + { |
| 4123 | + "cell_type": "code", |
| 4124 | + "execution_count": 1, |
| 4125 | + "metadata": {}, |
| 4126 | + "outputs": [ |
| 4127 | + { |
| 4128 | + "name": "stderr", |
| 4129 | + "output_type": "stream", |
| 4130 | + "text": [ |
| 4131 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`.\n", |
| 4132 | + "GPU available: True (mps), used: True\n", |
| 4133 | + "TPU available: False, using: 0 TPU cores\n", |
| 4134 | + "IPU available: False, using: 0 IPUs\n", |
| 4135 | + "HPU available: False, using: 0 HPUs\n", |
| 4136 | + "\n", |
| 4137 | + " | Name | Type | Params | In sizes | Out sizes\n", |
| 4138 | + "-------------------------------------------------------------\n", |
| 4139 | + "0 | layers | Sequential | 15.9 K | [8, 10] | [8, 1] \n", |
| 4140 | + "-------------------------------------------------------------\n", |
| 4141 | + "15.9 K Trainable params\n", |
| 4142 | + "0 Non-trainable params\n", |
| 4143 | + "15.9 K Total params\n", |
| 4144 | + "0.064 Total estimated model params size (MB)\n", |
| 4145 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n", |
| 4146 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n", |
| 4147 | + "`Trainer.fit` stopped: `max_epochs=2` reached.\n", |
| 4148 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.\n" |
| 4149 | + ] |
| 4150 | + }, |
| 4151 | + { |
| 4152 | + "data": { |
| 4153 | + "text/html": [ |
| 4154 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 4155 | + "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
| 4156 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 4157 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 29018.087890625 </span>│\n", |
| 4158 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 29018.087890625 </span>│\n", |
| 4159 | + "└───────────────────────────┴───────────────────────────┘\n", |
| 4160 | + "</pre>\n" |
| 4161 | + ], |
| 4162 | + "text/plain": [ |
| 4163 | + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 4164 | + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
| 4165 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 4166 | + "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 29018.087890625 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 4167 | + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 29018.087890625 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 4168 | + "└───────────────────────────┴───────────────────────────┘\n" |
| 4169 | + ] |
| 4170 | + }, |
| 4171 | + "metadata": {}, |
| 4172 | + "output_type": "display_data" |
| 4173 | + } |
| 4174 | + ], |
| 4175 | + "source": [ |
| 4176 | + "from torch.utils.data import DataLoader\n", |
| 4177 | + "from spotPython.data.diabetes import Diabetes\n", |
| 4178 | + "from spotPython.light.regression.netlightregression import NetLightRegression\n", |
| 4179 | + "from torch import nn\n", |
| 4180 | + "import lightning as L\n", |
| 4181 | + "\n", |
| 4182 | + "\n", |
| 4183 | + "def test_net_light_regression_class():\n", |
| 4184 | + " BATCH_SIZE = 8\n", |
| 4185 | + "\n", |
| 4186 | + " dataset = Diabetes()\n", |
| 4187 | + " train_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
| 4188 | + " test_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
| 4189 | + " val_loader = DataLoader(dataset, batch_size=BATCH_SIZE)\n", |
| 4190 | + "\n", |
| 4191 | + " net_light_regression = NetLightRegression(\n", |
| 4192 | + " l1=128,\n", |
| 4193 | + " epochs=10,\n", |
| 4194 | + " batch_size=BATCH_SIZE,\n", |
| 4195 | + " initialization=\"Default\",\n", |
| 4196 | + " act_fn=nn.ReLU(),\n", |
| 4197 | + " optimizer=\"Adam\",\n", |
| 4198 | + " dropout_prob=0.1,\n", |
| 4199 | + " lr_mult=0.1,\n", |
| 4200 | + " patience=5,\n", |
| 4201 | + " _L_in=10,\n", |
| 4202 | + " _L_out=1,\n", |
| 4203 | + " _torchmetric=\"mean_squared_error\",\n", |
| 4204 | + " )\n", |
| 4205 | + " trainer = L.Trainer(\n", |
| 4206 | + " max_epochs=2,\n", |
| 4207 | + " enable_progress_bar=False,\n", |
| 4208 | + " )\n", |
| 4209 | + " trainer.fit(net_light_regression, train_loader, val_loader)\n", |
| 4210 | + " res = trainer.test(net_light_regression, test_loader)\n", |
| 4211 | + " # test if the entry 'hp_metric' is in the res dict\n", |
| 4212 | + " assert \"hp_metric\" in res[0].keys()\n", |
| 4213 | + "\n", |
| 4214 | + "test_net_light_regression_class()" |
| 4215 | + ] |
| 4216 | + }, |
4122 | 4217 | { |
4123 | 4218 | "cell_type": "code", |
4124 | 4219 | "execution_count": null, |
|
4143 | 4238 | "name": "python", |
4144 | 4239 | "nbconvert_exporter": "python", |
4145 | 4240 | "pygments_lexer": "ipython3", |
4146 | | - "version": "3.11.8" |
| 4241 | + "version": "3.11.7" |
4147 | 4242 | } |
4148 | 4243 | }, |
4149 | 4244 | "nbformat": 4, |
|
0 commit comments