|
3060 | 3060 | }, |
3061 | 3061 | { |
3062 | 3062 | "cell_type": "code", |
3063 | | - "execution_count": 14, |
| 3063 | + "execution_count": null, |
3064 | 3064 | "metadata": {}, |
3065 | 3065 | "outputs": [], |
3066 | 3066 | "source": [ |
|
3147 | 3147 | }, |
3148 | 3148 | { |
3149 | 3149 | "cell_type": "code", |
3150 | | - "execution_count": 15, |
| 3150 | + "execution_count": null, |
| 3151 | + "metadata": {}, |
| 3152 | + "outputs": [], |
| 3153 | + "source": [ |
| 3154 | + "test_file_save_load()" |
| 3155 | + ] |
| 3156 | + }, |
| 3157 | + { |
| 3158 | + "cell_type": "markdown", |
| 3159 | + "metadata": {}, |
| 3160 | + "source": [ |
| 3161 | + "# Netlightregression2" |
| 3162 | + ] |
| 3163 | + }, |
| 3164 | + { |
| 3165 | + "cell_type": "code", |
| 3166 | + "execution_count": 6, |
3151 | 3167 | "metadata": {}, |
3152 | 3168 | "outputs": [ |
3153 | 3169 | { |
3154 | 3170 | "name": "stderr", |
3155 | 3171 | "output_type": "stream", |
3156 | 3172 | "text": [ |
3157 | | - "Seed set to 123\n" |
| 3173 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`.\n", |
| 3174 | + "GPU available: True (mps), used: True\n", |
| 3175 | + "TPU available: False, using: 0 TPU cores\n", |
| 3176 | + "IPU available: False, using: 0 IPUs\n", |
| 3177 | + "HPU available: False, using: 0 HPUs\n" |
3158 | 3178 | ] |
3159 | 3179 | }, |
3160 | 3180 | { |
3161 | 3181 | "name": "stdout", |
3162 | 3182 | "output_type": "stream", |
3163 | 3183 | "text": [ |
3164 | | - "Experiment saved as spot_braninexperiment.pickle\n" |
| 3184 | + "batch_x.shape: torch.Size([8, 10])\n", |
| 3185 | + "batch_y.shape: torch.Size([8])\n" |
3165 | 3186 | ] |
| 3187 | + }, |
| 3188 | + { |
| 3189 | + "name": "stderr", |
| 3190 | + "output_type": "stream", |
| 3191 | + "text": [ |
| 3192 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n", |
| 3193 | + "\n", |
| 3194 | + " | Name | Type | Params | In sizes | Out sizes\n", |
| 3195 | + "-------------------------------------------------------------\n", |
| 3196 | + "0 | layers | Sequential | 25.6 K | [8, 10] | [8, 1] \n", |
| 3197 | + "-------------------------------------------------------------\n", |
| 3198 | + "25.6 K Trainable params\n", |
| 3199 | + "0 Non-trainable params\n", |
| 3200 | + "25.6 K Total params\n", |
| 3201 | + "0.102 Total estimated model params size (MB)\n", |
| 3202 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", |
| 3203 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", |
| 3204 | + "`Trainer.fit` stopped: `max_epochs=10` reached.\n", |
| 3205 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" |
| 3206 | + ] |
| 3207 | + }, |
| 3208 | + { |
| 3209 | + "data": { |
| 3210 | + "text/html": [ |
| 3211 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 3212 | + "┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
| 3213 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 3214 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28803.052734375 </span>│\n", |
| 3215 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28803.052734375 </span>│\n", |
| 3216 | + "└───────────────────────────┴───────────────────────────┘\n", |
| 3217 | + "</pre>\n" |
| 3218 | + ], |
| 3219 | + "text/plain": [ |
| 3220 | + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 3221 | + "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
| 3222 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 3223 | + "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28803.052734375 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 3224 | + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28803.052734375 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 3225 | + "└───────────────────────────┴───────────────────────────┘\n" |
| 3226 | + ] |
| 3227 | + }, |
| 3228 | + "metadata": {}, |
| 3229 | + "output_type": "display_data" |
| 3230 | + }, |
| 3231 | + { |
| 3232 | + "name": "stderr", |
| 3233 | + "output_type": "stream", |
| 3234 | + "text": [ |
| 3235 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n" |
| 3236 | + ] |
| 3237 | + }, |
| 3238 | + { |
| 3239 | + "data": { |
| 3240 | + "text/html": [ |
| 3241 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 3242 | + "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
| 3243 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 3244 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28280.533203125 </span>│\n", |
| 3245 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28280.533203125 </span>│\n", |
| 3246 | + "└───────────────────────────┴───────────────────────────┘\n", |
| 3247 | + "</pre>\n" |
| 3248 | + ], |
| 3249 | + "text/plain": [ |
| 3250 | + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 3251 | + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
| 3252 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 3253 | + "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28280.533203125 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 3254 | + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28280.533203125 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 3255 | + "└───────────────────────────┴───────────────────────────┘\n" |
| 3256 | + ] |
| 3257 | + }, |
| 3258 | + "metadata": {}, |
| 3259 | + "output_type": "display_data" |
| 3260 | + }, |
| 3261 | + { |
| 3262 | + "data": { |
| 3263 | + "text/plain": [ |
| 3264 | + "[{'val_loss': 28280.533203125, 'hp_metric': 28280.533203125}]" |
| 3265 | + ] |
| 3266 | + }, |
| 3267 | + "execution_count": 6, |
| 3268 | + "metadata": {}, |
| 3269 | + "output_type": "execute_result" |
3166 | 3270 | } |
3167 | 3271 | ], |
3168 | 3272 | "source": [ |
3169 | | - "test_file_save_load()" |
| 3273 | + "from torch.utils.data import DataLoader\n", |
| 3274 | + "from spotPython.data.diabetes import Diabetes\n", |
| 3275 | + "from spotPython.light.regression.netlightregression2 import NetLightRegression2\n", |
| 3276 | + "from torch import nn\n", |
| 3277 | + "import lightning as L\n", |
| 3278 | + "import torch\n", |
| 3279 | + "BATCH_SIZE = 8\n", |
| 3280 | + "dataset = Diabetes()\n", |
| 3281 | + "train1_set, test_set = torch.utils.data.random_split(dataset, [0.6, 0.4])\n", |
| 3282 | + "train_set, val_set = torch.utils.data.random_split(train1_set, [0.6, 0.4])\n", |
| 3283 | + "train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)\n", |
| 3284 | + "test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)\n", |
| 3285 | + "val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)\n", |
| 3286 | + "batch_x, batch_y = next(iter(train_loader))\n", |
| 3287 | + "print(f\"batch_x.shape: {batch_x.shape}\")\n", |
| 3288 | + "print(f\"batch_y.shape: {batch_y.shape}\")\n", |
| 3289 | + "net_light_base = NetLightRegression2(l1=128,\n", |
| 3290 | + " epochs=10,\n", |
| 3291 | + " batch_size=BATCH_SIZE,\n", |
| 3292 | + " initialization='Default',\n", |
| 3293 | + " act_fn=nn.ReLU(),\n", |
| 3294 | + " optimizer='Adam',\n", |
| 3295 | + " dropout_prob=0.1,\n", |
| 3296 | + " lr_mult=0.1,\n", |
| 3297 | + " patience=5,\n", |
| 3298 | + " _L_in=10,\n", |
| 3299 | + " _L_out=1)\n", |
| 3300 | + "trainer = L.Trainer(max_epochs=10, enable_progress_bar=False)\n", |
| 3301 | + "trainer.fit(net_light_base, train_loader)\n", |
| 3302 | + "trainer.validate(net_light_base, val_loader)\n", |
| 3303 | + "trainer.test(net_light_base, test_loader)" |
| 3304 | + ] |
| 3305 | + }, |
| 3306 | + { |
| 3307 | + "cell_type": "markdown", |
| 3308 | + "metadata": {}, |
| 3309 | + "source": [ |
| 3310 | + "# LightDataModule" |
| 3311 | + ] |
| 3312 | + }, |
| 3313 | + { |
| 3314 | + "cell_type": "code", |
| 3315 | + "execution_count": 8, |
| 3316 | + "metadata": {}, |
| 3317 | + "outputs": [ |
| 3318 | + { |
| 3319 | + "name": "stdout", |
| 3320 | + "output_type": "stream", |
| 3321 | + "text": [ |
| 3322 | + "LightDataModule: setup(). stage: None\n", |
| 3323 | + "LightDataModule setup(): full_train_size: 0.5\n", |
| 3324 | + "LightDataModule setup(): val_size: 0.25\n", |
| 3325 | + "LightDataModule setup(): train_size: 0.25\n", |
| 3326 | + "LightDataModule setup(): test_size: 0.5\n", |
| 3327 | + "LightDataModule: setup(). stage: fit\n", |
| 3328 | + "LightDataModule: setup(). stage: test\n", |
| 3329 | + "LightDataModule: setup(). stage: predict\n", |
| 3330 | + "Training set size: 3\n", |
| 3331 | + "Validation set size: 3\n", |
| 3332 | + "Test set size: 6\n" |
| 3333 | + ] |
| 3334 | + } |
| 3335 | + ], |
| 3336 | + "source": [ |
| 3337 | + "from spotPython.data.lightdatamodule import LightDataModule\n", |
| 3338 | + "from spotPython.data.csvdataset import CSVDataset\n", |
| 3339 | + "import torch\n", |
| 3340 | + "# data.csv is simple csv file with 11 samples\n", |
| 3341 | + "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n", |
| 3342 | + "data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)\n", |
| 3343 | + "data_module.setup()\n", |
| 3344 | + "print(f\"Training set size: {len(data_module.data_train)}\")\n", |
| 3345 | + "print(f\"Validation set size: {len(data_module.data_val)}\")\n", |
| 3346 | + "print(f\"Test set size: {len(data_module.data_test)}\")" |
3170 | 3347 | ] |
3171 | 3348 | }, |
3172 | 3349 | { |
|
0 commit comments