Skip to content

Commit 755446b

Browse files
v0.6.7
DATASET_PATH -> data_dir
1 parent 5516035 commit 755446b

4 files changed

Lines changed: 9 additions & 9 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.6"
10+
version = "0.6.7"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/crossvalidationdatamodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class CrossValidationDataModule(L.LightningDataModule):
1515
k (int): The fold number. Defaults to 1.
1616
split_seed (int): The random seed for splitting the data. Defaults to 42.
1717
num_splits (int): The number of splits for cross-validation. Defaults to 10.
18-
DATASET_PATH (str): The path to the dataset. Defaults to "./data".
18+
data_dir (str): The path to the dataset. Defaults to "./data".
1919
num_workers (int): The number of workers for data loading. Defaults to 0.
2020
pin_memory (bool): Whether to pin memory for data loading. Defaults to False.
2121
@@ -40,13 +40,13 @@ def __init__(
4040
k: int = 1,
4141
split_seed: int = 42,
4242
num_splits: int = 10,
43-
DATASET_PATH: str = "./data",
43+
data_dir: str = "./data",
4444
num_workers: int = 0,
4545
pin_memory: bool = False,
4646
):
4747
super().__init__()
4848
self.batch_size = batch_size
49-
self.DATASET_PATH = DATASET_PATH
49+
self.data_dir = data_dir
5050
self.num_workers = num_workers
5151
self.k = k
5252
self.split_seed = split_seed

src/spotPython/light/csvdatamodule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class CSVDataModule(L.LightningDataModule):
1010
1111
Args:
1212
batch_size (int): The size of the batch.
13-
DATASET_PATH (str): The path to the dataset. Defaults to "./data".
13+
data_dir (str): The path to the dataset. Defaults to "./data".
1414
num_workers (int): The number of workers for data loading. Defaults to 0.
1515
1616
Attributes:
@@ -19,7 +19,7 @@ class CSVDataModule(L.LightningDataModule):
1919
data_test (Dataset): The test dataset.
2020
"""
2121

22-
def __init__(self, batch_size: int, DATASET_PATH: str = "./data", num_workers: int = 0):
22+
def __init__(self, batch_size: int, data_dir: str = "./data", num_workers: int = 0):
2323
super().__init__()
2424
self.batch_size = batch_size
2525
self.num_workers = num_workers

src/spotPython/light/traintest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def train_model(config: dict, fun_control: dict) -> float:
6262
dm = CSVDataModule(
6363
batch_size=config["batch_size"],
6464
num_workers=fun_control["num_workers"],
65-
DATASET_PATH=fun_control["DATASET_PATH"],
65+
data_dir=fun_control["DATASET_PATH"],
6666
)
6767

6868
# Init trainer
@@ -130,7 +130,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
130130
dm = CSVDataModule(
131131
batch_size=config["batch_size"],
132132
num_workers=fun_control["num_workers"],
133-
DATASET_PATH=fun_control["DATASET_PATH"],
133+
data_dir=fun_control["DATASET_PATH"],
134134
)
135135
# Init model from datamodule's attributes
136136
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
@@ -224,7 +224,7 @@ def cv_model(config: dict, fun_control: dict) -> float:
224224
num_splits=num_folds,
225225
split_seed=split_seed,
226226
batch_size=config["batch_size"],
227-
DATASET_PATH=fun_control["DATASET_PATH"],
227+
data_dir=fun_control["DATASET_PATH"],
228228
)
229229
dm.prepare_data()
230230
dm.setup()

0 commit comments

Comments
 (0)