|
1 | | -# TODO: Setup test for PKLDataset that can use the pkl file form data folder |
2 | | -# import pytest |
3 | | -# import torch |
4 | | -# from torch.utils.data import DataLoader |
5 | | -# from spotPython.data.pkldataset import PKLDataset |
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from torch.utils.data import DataLoader |
| 4 | +from spotPython.data.pkldataset import PKLDataset |
6 | 5 |
|
7 | 6 |
|
8 | | -# def test_pkl_dataset(): |
9 | | -# # Create an instance of PKLDataset for testing |
10 | | -# dataset = PKLDataset(target_column='prognosis') |
| 7 | +def test_pkl_dataset(): |
| 8 | + # Create an instance of PKLDataset for testing |
| 9 | + dataset = PKLDataset(target_column='prognosis') |
11 | 10 |
|
12 | | -# # Test the length of the dataset |
13 | | -# assert len(dataset) > 0 |
| 11 | + # Test the length of the dataset |
| 12 | + assert len(dataset) > 0 |
14 | 13 |
|
15 | | -# # Test __getitem__ |
16 | | -# idx = 0 |
17 | | -# sample = dataset[idx] |
18 | | -# assert isinstance(sample, tuple) |
19 | | -# assert len(sample) == 2 |
20 | | -# feature, target = sample |
21 | | -# assert isinstance(feature, torch.Tensor) |
22 | | -# assert isinstance(target, torch.Tensor) |
| 14 | + # Test __getitem__ |
| 15 | + idx = 0 |
| 16 | + sample = dataset[idx] |
| 17 | + assert isinstance(sample, tuple) |
| 18 | + assert len(sample) == 2 |
| 19 | + feature, target = sample |
| 20 | + assert isinstance(feature, torch.Tensor) |
| 21 | + assert isinstance(target, torch.Tensor) |
23 | 22 |
|
24 | | -# # Test DataLoader |
25 | | -# batch_size = 3 |
26 | | -# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
27 | | -# for batch in dataloader: |
28 | | -# inputs, targets = batch |
29 | | -# assert inputs.size(0) == batch_size |
30 | | -# assert targets.size(0) == batch_size |
31 | | -# break |
| 23 | + # Test DataLoader |
| 24 | + batch_size = 3 |
| 25 | + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
| 26 | + for batch in dataloader: |
| 27 | + inputs, targets = batch |
| 28 | + assert inputs.size(0) == batch_size |
| 29 | + assert targets.size(0) == batch_size |
| 30 | + break |
32 | 31 |
|
33 | 32 |
|
34 | | -# if __name__ == "__main__": |
35 | | -# pytest.main(["-v", __file__]) |
| 33 | +if __name__ == "__main__": |
| 34 | + pytest.main(["-v", __file__]) |
0 commit comments