Skip to content

Commit 28f774f

Browse files
pkl data added
1 parent 49c9d35 commit 28f774f

2 files changed

Lines changed: 28 additions & 32 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.41"
10+
version = "0.6.42"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
@@ -64,9 +64,6 @@ include-package-data = true
6464
namespaces = true
6565
where = ["src"]
6666

67-
[tool.setuptools.package-data]
68-
spotPython = ["*.json", "*.csv", "*.pkl"]
69-
7067
[tool.black]
7168
line-length = 120
7269
target-version = ["py311"]

test/test_pkldataset.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
1-
# TODO: Setup test for PKLDataset that can use the pkl file form data folder
2-
# import pytest
3-
# import torch
4-
# from torch.utils.data import DataLoader
5-
# from spotPython.data.pkldataset import PKLDataset
1+
import pytest
2+
import torch
3+
from torch.utils.data import DataLoader
4+
from spotPython.data.pkldataset import PKLDataset
65

76

8-
# def test_pkl_dataset():
9-
# # Create an instance of PKLDataset for testing
10-
# dataset = PKLDataset(target_column='prognosis')
7+
def test_pkl_dataset():
8+
# Create an instance of PKLDataset for testing
9+
dataset = PKLDataset(target_column='prognosis')
1110

12-
# # Test the length of the dataset
13-
# assert len(dataset) > 0
11+
# Test the length of the dataset
12+
assert len(dataset) > 0
1413

15-
# # Test __getitem__
16-
# idx = 0
17-
# sample = dataset[idx]
18-
# assert isinstance(sample, tuple)
19-
# assert len(sample) == 2
20-
# feature, target = sample
21-
# assert isinstance(feature, torch.Tensor)
22-
# assert isinstance(target, torch.Tensor)
14+
# Test __getitem__
15+
idx = 0
16+
sample = dataset[idx]
17+
assert isinstance(sample, tuple)
18+
assert len(sample) == 2
19+
feature, target = sample
20+
assert isinstance(feature, torch.Tensor)
21+
assert isinstance(target, torch.Tensor)
2322

24-
# # Test DataLoader
25-
# batch_size = 3
26-
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
27-
# for batch in dataloader:
28-
# inputs, targets = batch
29-
# assert inputs.size(0) == batch_size
30-
# assert targets.size(0) == batch_size
31-
# break
23+
# Test DataLoader
24+
batch_size = 3
25+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
26+
for batch in dataloader:
27+
inputs, targets = batch
28+
assert inputs.size(0) == batch_size
29+
assert targets.size(0) == batch_size
30+
break
3231

3332

34-
# if __name__ == "__main__":
35-
# pytest.main(["-v", __file__])
33+
if __name__ == "__main__":
34+
pytest.main(["-v", __file__])

0 commit comments

Comments
 (0)