Skip to content

Commit 563a01d

Browse files
traansformer
1 parent 4538c20 commit 563a01d

11 files changed

Lines changed: 708 additions & 63 deletions

notebooks/00_spotPython_tests.ipynb

Lines changed: 182 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,7 +3060,7 @@
30603060
},
30613061
{
30623062
"cell_type": "code",
3063-
"execution_count": 14,
3063+
"execution_count": null,
30643064
"metadata": {},
30653065
"outputs": [],
30663066
"source": [
@@ -3147,26 +3147,203 @@
31473147
},
31483148
{
31493149
"cell_type": "code",
3150-
"execution_count": 15,
3150+
"execution_count": null,
3151+
"metadata": {},
3152+
"outputs": [],
3153+
"source": [
3154+
"test_file_save_load()"
3155+
]
3156+
},
3157+
{
3158+
"cell_type": "markdown",
3159+
"metadata": {},
3160+
"source": [
3161+
"# Netlightregression2"
3162+
]
3163+
},
3164+
{
3165+
"cell_type": "code",
3166+
"execution_count": 6,
31513167
"metadata": {},
31523168
"outputs": [
31533169
{
31543170
"name": "stderr",
31553171
"output_type": "stream",
31563172
"text": [
3157-
"Seed set to 123\n"
3173+
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`.\n",
3174+
"GPU available: True (mps), used: True\n",
3175+
"TPU available: False, using: 0 TPU cores\n",
3176+
"IPU available: False, using: 0 IPUs\n",
3177+
"HPU available: False, using: 0 HPUs\n"
31583178
]
31593179
},
31603180
{
31613181
"name": "stdout",
31623182
"output_type": "stream",
31633183
"text": [
3164-
"Experiment saved as spot_braninexperiment.pickle\n"
3184+
"batch_x.shape: torch.Size([8, 10])\n",
3185+
"batch_y.shape: torch.Size([8])\n"
31653186
]
3187+
},
3188+
{
3189+
"name": "stderr",
3190+
"output_type": "stream",
3191+
"text": [
3192+
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.\n",
3193+
"\n",
3194+
" | Name | Type | Params | In sizes | Out sizes\n",
3195+
"-------------------------------------------------------------\n",
3196+
"0 | layers | Sequential | 25.6 K | [8, 10] | [8, 1] \n",
3197+
"-------------------------------------------------------------\n",
3198+
"25.6 K Trainable params\n",
3199+
"0 Non-trainable params\n",
3200+
"25.6 K Total params\n",
3201+
"0.102 Total estimated model params size (MB)\n",
3202+
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
3203+
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
3204+
"`Trainer.fit` stopped: `max_epochs=10` reached.\n",
3205+
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
3206+
]
3207+
},
3208+
{
3209+
"data": {
3210+
"text/html": [
3211+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
3212+
"┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
3213+
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
3214+
"│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28803.052734375 </span>│\n",
3215+
"│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28803.052734375 </span>│\n",
3216+
"└───────────────────────────┴───────────────────────────┘\n",
3217+
"</pre>\n"
3218+
],
3219+
"text/plain": [
3220+
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
3221+
"\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
3222+
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
3223+
"\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28803.052734375 \u001b[0m\u001b[35m \u001b[0m│\n",
3224+
"\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28803.052734375 \u001b[0m\u001b[35m \u001b[0m│\n",
3225+
"└───────────────────────────┴───────────────────────────┘\n"
3226+
]
3227+
},
3228+
"metadata": {},
3229+
"output_type": "display_data"
3230+
},
3231+
{
3232+
"name": "stderr",
3233+
"output_type": "stream",
3234+
"text": [
3235+
"/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
3236+
]
3237+
},
3238+
{
3239+
"data": {
3240+
"text/html": [
3241+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
3242+
"┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n",
3243+
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
3244+
"│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28280.533203125 </span>│\n",
3245+
"│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 28280.533203125 </span>│\n",
3246+
"└───────────────────────────┴───────────────────────────┘\n",
3247+
"</pre>\n"
3248+
],
3249+
"text/plain": [
3250+
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
3251+
"\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
3252+
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
3253+
"\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28280.533203125 \u001b[0m\u001b[35m \u001b[0m│\n",
3254+
"\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 28280.533203125 \u001b[0m\u001b[35m \u001b[0m│\n",
3255+
"└───────────────────────────┴───────────────────────────┘\n"
3256+
]
3257+
},
3258+
"metadata": {},
3259+
"output_type": "display_data"
3260+
},
3261+
{
3262+
"data": {
3263+
"text/plain": [
3264+
"[{'val_loss': 28280.533203125, 'hp_metric': 28280.533203125}]"
3265+
]
3266+
},
3267+
"execution_count": 6,
3268+
"metadata": {},
3269+
"output_type": "execute_result"
31663270
}
31673271
],
31683272
"source": [
3169-
"test_file_save_load()"
3273+
"from torch.utils.data import DataLoader\n",
3274+
"from spotPython.data.diabetes import Diabetes\n",
3275+
"from spotPython.light.regression.netlightregression2 import NetLightRegression2\n",
3276+
"from torch import nn\n",
3277+
"import lightning as L\n",
3278+
"import torch\n",
3279+
"BATCH_SIZE = 8\n",
3280+
"dataset = Diabetes()\n",
3281+
"train1_set, test_set = torch.utils.data.random_split(dataset, [0.6, 0.4])\n",
3282+
"train_set, val_set = torch.utils.data.random_split(train1_set, [0.6, 0.4])\n",
3283+
"train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, pin_memory=True)\n",
3284+
"test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)\n",
3285+
"val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)\n",
3286+
"batch_x, batch_y = next(iter(train_loader))\n",
3287+
"print(f\"batch_x.shape: {batch_x.shape}\")\n",
3288+
"print(f\"batch_y.shape: {batch_y.shape}\")\n",
3289+
"net_light_base = NetLightRegression2(l1=128,\n",
3290+
" epochs=10,\n",
3291+
" batch_size=BATCH_SIZE,\n",
3292+
" initialization='Default',\n",
3293+
" act_fn=nn.ReLU(),\n",
3294+
" optimizer='Adam',\n",
3295+
" dropout_prob=0.1,\n",
3296+
" lr_mult=0.1,\n",
3297+
" patience=5,\n",
3298+
" _L_in=10,\n",
3299+
" _L_out=1)\n",
3300+
"trainer = L.Trainer(max_epochs=10, enable_progress_bar=False)\n",
3301+
"trainer.fit(net_light_base, train_loader)\n",
3302+
"trainer.validate(net_light_base, val_loader)\n",
3303+
"trainer.test(net_light_base, test_loader)"
3304+
]
3305+
},
3306+
{
3307+
"cell_type": "markdown",
3308+
"metadata": {},
3309+
"source": [
3310+
"# LightDataModule"
3311+
]
3312+
},
3313+
{
3314+
"cell_type": "code",
3315+
"execution_count": 8,
3316+
"metadata": {},
3317+
"outputs": [
3318+
{
3319+
"name": "stdout",
3320+
"output_type": "stream",
3321+
"text": [
3322+
"LightDataModule: setup(). stage: None\n",
3323+
"LightDataModule setup(): full_train_size: 0.5\n",
3324+
"LightDataModule setup(): val_size: 0.25\n",
3325+
"LightDataModule setup(): train_size: 0.25\n",
3326+
"LightDataModule setup(): test_size: 0.5\n",
3327+
"LightDataModule: setup(). stage: fit\n",
3328+
"LightDataModule: setup(). stage: test\n",
3329+
"LightDataModule: setup(). stage: predict\n",
3330+
"Training set size: 3\n",
3331+
"Validation set size: 3\n",
3332+
"Test set size: 6\n"
3333+
]
3334+
}
3335+
],
3336+
"source": [
3337+
"from spotPython.data.lightdatamodule import LightDataModule\n",
3338+
"from spotPython.data.csvdataset import CSVDataset\n",
3339+
"import torch\n",
3340+
"# data.csv is simple csv file with 11 samples\n",
3341+
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n",
3342+
"data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)\n",
3343+
"data_module.setup()\n",
3344+
"print(f\"Training set size: {len(data_module.data_train)}\")\n",
3345+
"print(f\"Validation set size: {len(data_module.data_val)}\")\n",
3346+
"print(f\"Test set size: {len(data_module.data_test)}\")"
31703347
]
31713348
},
31723349
{

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.66"
10+
version = "0.10.67"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/lightdatamodule.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,15 @@ def setup(self, stage: Optional[str] = None) -> None:
130130
# Assign train/val datasets for use in dataloaders
131131
if stage == "fit" or stage is None:
132132
print("LightDataModule: setup(). stage: fit")
133-
self.data_train, self.data_val, _ = random_split(self.data_full, [train_size, val_size, test_size])
133+
generator_fit = torch.Generator().manual_seed(self.test_seed)
134+
self.data_train, self.data_val, _ = random_split(
135+
self.data_full, [train_size, val_size, test_size], generator=generator_fit
136+
)
134137

135138
# Assign test dataset for use in dataloader(s)
136139
if stage == "test" or stage is None:
137140
print("LightDataModule: setup(). stage: test")
138-
# get test data aset as test_abs percent of the full dataset
141+
# get test data set as test_abs percent of the full dataset
139142
generator_test = torch.Generator().manual_seed(self.test_seed)
140143
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
141144

@@ -151,7 +154,7 @@ def setup(self, stage: Optional[str] = None) -> None:
151154
# Assign pred dataset for use in dataloader(s)
152155
if stage == "predict" or stage is None:
153156
print("LightDataModule: setup(). stage: predict")
154-
# get test data aset as test_abs percent of the full dataset
157+
# get test data set as test_abs percent of the full dataset
155158
generator_predict = torch.Generator().manual_seed(self.test_seed)
156159
self.data_predict, _ = random_split(
157160
self.data_full, [test_size, full_train_size], generator=generator_predict
@@ -199,7 +202,7 @@ def val_dataloader(self) -> DataLoader:
199202
print(f"Training set size: {len(data_module.data_val)}")
200203
Training set size: 3
201204
"""
202-
print(f"LightDataModule: val_dataloader(). Training set size: {len(self.data_val)}")
205+
print(f"LightDataModule: val_dataloader(). Validation set size: {len(self.data_val)}")
203206
print(f"LightDataModule: val_dataloader(). batch_size: {self.batch_size}")
204207
print(f"LightDataModule: val_dataloader(). num_workers: {self.num_workers}")
205208
return DataLoader(self.data_val, batch_size=self.batch_size, num_workers=self.num_workers)

0 commit comments

Comments
 (0)