Skip to content

Commit 740aca0

Browse files
0.10.20 documentation
Improved docs
1 parent 3267b96 commit 740aca0

2 files changed

Lines changed: 38 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.19"
10+
version = "0.10.20"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/lightdatamodule.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,18 @@ class LightDataModule(L.LightningDataModule):
99
A LightningDataModule for handling data.
1010
1111
Args:
12-
batch_size (int): The batch size.
13-
dataset (Dataset): The dataset.
14-
test_size (float): The test size. Defaults to 0.6.
15-
test_seed (int): The test seed. Defaults to 42.
16-
num_workers (int): The number of workers. Defaults to 0.
12+
batch_size (int):
13+
The batch size. Required.
14+
dataset (torch.utils.data.Dataset):
15+
The dataset from the torch.utils.data Dataset class.
16+
It must implement three functions: __init__, __len__, and __getitem__.
17+
Required.
18+
test_size (float):
19+
The test size. Defaults to 0.6.
20+
test_seed (int):
21+
The test seed. Defaults to 42.
22+
num_workers (int):
23+
The number of workers. Defaults to 0.
1724
1825
Attributes:
1926
batch_size (int): The batch size.
@@ -25,6 +32,18 @@ class LightDataModule(L.LightningDataModule):
2532
test_seed (int): The test seed.
2633
test_size (float): The test size.
2734
35+
Methods:
36+
prepare_data(self):
37+
Usually used for downloading the data. Here: Does nothing, i.e., pass.
38+
setup(self, stage: Optional[str] = None):
39+
Performs the training, validation, and test split.
40+
train_dataloader():
41+
Returns a DataLoader instance for the training set.
42+
val_dataloader():
43+
Returns a DataLoader instance for the validation set.
44+
test_dataloader():
45+
Returns a DataLoader instance for the test set.
46+
2847
Examples:
2948
>>> from spotPython.data.lightdatamodule import LightDataModule
3049
from spotPython.data.csvdataset import CSVDataset
@@ -55,10 +74,15 @@ def prepare_data(self) -> None:
5574

5675
def setup(self, stage: Optional[str] = None) -> None:
5776
"""
58-
Sets up the data for use.
77+
Splits the data for use in training, validation, and testing.
78+
Uses torch.utils.data.random_split() to split the data.
79+
Splitting is based on the test_size and test_seed.
80+
The test_size can be a float or an int.
5981
6082
Args:
61-
stage (Optional[str]): The current stage. Defaults to None.
83+
stage (Optional[str]):
84+
The current stage. Can be "fit" (for training and validation), "test" (testing),
85+
or None (for all three stages). Defaults to None.
6286
6387
Examples:
6488
>>> from spotPython.data.lightdatamodule import LightDataModule
@@ -101,7 +125,8 @@ def setup(self, stage: Optional[str] = None) -> None:
101125

102126
def train_dataloader(self) -> DataLoader:
103127
"""
104-
Returns the training dataloader.
128+
Returns the training dataloader, i.e., a pytorch DataLoader instance
129+
using the training dataset.
105130
106131
Returns:
107132
DataLoader: The training dataloader.
@@ -125,7 +150,8 @@ def train_dataloader(self) -> DataLoader:
125150

126151
def val_dataloader(self) -> DataLoader:
127152
"""
128-
Returns the validation dataloader.
153+
Returns the validation dataloader, i.e., a pytorch DataLoader instance
154+
using the validation dataset.
129155
130156
Returns:
131157
DataLoader: The validation dataloader.
@@ -146,7 +172,8 @@ def val_dataloader(self) -> DataLoader:
146172

147173
def test_dataloader(self) -> DataLoader:
148174
"""
149-
Returns the test dataloader.
175+
Returns the test dataloader, i.e., a pytorch DataLoader instance
176+
using the test dataset.
150177
151178
Returns:
152179
DataLoader: The test dataloader.

0 commit comments

Comments
 (0)