Skip to content

Commit 4e1cb2f

Browse files
0.14.40
lightdatamodule.py returns correct predictor
1 parent 3a69d3f commit 4e1cb2f

5 files changed

Lines changed: 39 additions & 7 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,4 @@ notebooks/00_spotPython_tests_files/libs/quarto-html/quarto.js
320320
notebooks/00_spotPython_tests_files/libs/quarto-html/tippy.css
321321
notebooks/00_spotPython_tests_files/libs/quarto-html/tippy.umd.min.js
322322
notebooks/00_spotPython_tests.html
323+
data.csv

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.38"
10+
version = "0.14.40"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/lightdatamodule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,5 +298,4 @@ def predict_dataloader(self) -> DataLoader:
298298
# print(f"LightDataModule: predict_dataloader(). batch_size: {self.batch_size}")
299299
# print(f"LightDataModule: predict_dataloader(). num_workers: {self.num_workers}")
300300
# apply fit_transform to the val data
301-
return DataLoader(self.data_test, batch_size=self.batch_size, num_workers=self.num_workers)
302301
return DataLoader(self.data_predict, batch_size=len(self.data_predict), num_workers=self.num_workers)

test/test_lightdatamodule.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import torch
32
from spotPython.data.lightdatamodule import LightDataModule
43
from spotPython.data.csvdataset import CSVDataset
@@ -44,7 +43,3 @@ def test_light_data_module_sizes():
4443
assert len(data_module.data_train) == 3
4544
assert len(data_module.data_val) == 3
4645
assert len(data_module.data_test) == 6
47-
48-
49-
if __name__ == "__main__":
50-
pytest.main(["-v", __file__])

test/test_lightdatamodule_class.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
import torch
3+
from spotPython.data.lightdatamodule import LightDataModule
4+
from spotPython.data.csvdataset import CSVDataset
5+
6+
7+
class TestLightDataModule:
8+
@pytest.fixture
9+
def setup_data_module(self):
10+
# Setup the dataset and data module as per the provided code snippet.
11+
# Mock the data.csv file content
12+
csv_content = """feature1,feature2,prognosis
13+
1,2,0
14+
3,4,1
15+
5,6,0
16+
7,8,1
17+
9,10,0
18+
11,12,1
19+
13,14,0
20+
15,16,1
21+
17,18,0
22+
19,20,1
23+
21,22,0
24+
23,24,1"""
25+
26+
with open("data.csv", "w") as f:
27+
f.write(csv_content)
28+
29+
dataset = CSVDataset(csv_file="data.csv", target_column="prognosis", feature_type=torch.long)
30+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
31+
data_module.setup()
32+
return data_module
33+
34+
def test_predict_set_size(self, setup_data_module):
35+
data_module = setup_data_module
36+
assert len(data_module.data_predict) == 6, "Expected predict set size to be 6,"
37+
f"but got {len(data_module.data_predict)}"

0 commit comments

Comments
 (0)