Skip to content

Commit 6e73687

Browse files
csv and pkl datasets moved to data
1 parent 8226edc commit 6e73687

8 files changed

Lines changed: 267 additions & 46 deletions

File tree

MANIFEST.in

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
include src/spotPython/data/*.csv
22
include src/spotPython/data/*.json
3-
4-
3+
include src/spotPython/data/*.pkl

notebooks/00_spotPython_tests.ipynb

Lines changed: 99 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -137,30 +137,71 @@
137137
},
138138
{
139139
"cell_type": "code",
140-
"execution_count": null,
141-
"metadata": {},
142-
"outputs": [],
140+
"execution_count": 1,
141+
"metadata": {},
142+
"outputs": [
143+
{
144+
"name": "stdout",
145+
"output_type": "stream",
146+
"text": [
147+
"Loading data from /Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/data/data.csv\n",
148+
"torch.Size([11, 64])\n",
149+
"torch.Size([11])\n"
150+
]
151+
}
152+
],
143153
"source": [
144-
"from spotPython.light.csvdataset import CSVDataset\n",
145-
"dataset = CSVDataset(csv_file='./data/spotPython/data.csv', target_column='prognosis')\n",
154+
"from spotPython.data.csvdataset import CSVDataset\n",
155+
"# dataset = CSVDataset(csv_file='./data/spotPython/data.csv', target_column='prognosis')\n",
156+
"dataset = CSVDataset(target_column='prognosis')\n",
146157
"print(dataset.data.shape)\n",
147158
"print(dataset.targets.shape) "
148159
]
149160
},
150161
{
151162
"cell_type": "code",
152-
"execution_count": null,
153-
"metadata": {},
154-
"outputs": [],
163+
"execution_count": 5,
164+
"metadata": {},
165+
"outputs": [
166+
{
167+
"data": {
168+
"text/plain": [
169+
"'Split: Train'"
170+
]
171+
},
172+
"execution_count": 5,
173+
"metadata": {},
174+
"output_type": "execute_result"
175+
}
176+
],
155177
"source": [
156178
"dataset.extra_repr()"
157179
]
158180
},
159181
{
160182
"cell_type": "code",
161-
"execution_count": null,
162-
"metadata": {},
163-
"outputs": [],
183+
"execution_count": 6,
184+
"metadata": {},
185+
"outputs": [
186+
{
187+
"name": "stdout",
188+
"output_type": "stream",
189+
"text": [
190+
"Batch Size: 3\n",
191+
"---------------\n",
192+
"Inputs: tensor([[1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
193+
" 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,\n",
194+
" 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],\n",
195+
" [0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,\n",
196+
" 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1,\n",
197+
" 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
198+
" [1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1,\n",
199+
" 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0,\n",
200+
" 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1]])\n",
201+
"Targets: tensor([6, 8, 3])\n"
202+
]
203+
}
204+
],
164205
"source": [
165206
"from torch.utils.data import DataLoader\n",
166207
"# Set batch size for DataLoader\n",
@@ -203,9 +244,9 @@
203244
"metadata": {},
204245
"outputs": [],
205246
"source": [
206-
"from spotPython.light.csvdataset import CSVDataset\n",
247+
"from spotPython.data.csvdataset import CSVDataset\n",
207248
"import torch\n",
208-
"dataset = CSVDataset(csv_file='./data/spotPython/data.csv', target_column='prognosis', feature_type=torch.long)"
249+
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)"
209250
]
210251
},
211252
{
@@ -353,35 +394,60 @@
353394
},
354395
{
355396
"cell_type": "code",
356-
"execution_count": null,
397+
"execution_count": 4,
357398
"metadata": {},
358399
"outputs": [],
359400
"source": [
360-
"# from spotPython.light.pkldataset import PKLDataset\n",
361-
"# import torch\n",
362-
"# dataset = PKLDataset(pkl_file='./data/spotPython/data.pkl', target_column='prognosis', feature_type=torch.long)"
401+
"from spotPython.data.pkldataset import PKLDataset\n",
402+
"import torch\n",
403+
"dataset = PKLDataset(target_column='prognosis', feature_type=torch.long)"
363404
]
364405
},
365406
{
366407
"cell_type": "code",
367-
"execution_count": null,
368-
"metadata": {},
369-
"outputs": [],
408+
"execution_count": 3,
409+
"metadata": {},
410+
"outputs": [
411+
{
412+
"name": "stdout",
413+
"output_type": "stream",
414+
"text": [
415+
"Batch Size: 5\n",
416+
"---------------\n",
417+
"Inputs: tensor([[1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
418+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
419+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
420+
" [1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,\n",
421+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,\n",
422+
" 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
423+
" [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
424+
" 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,\n",
425+
" 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],\n",
426+
" [1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n",
427+
" 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
428+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
429+
" [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,\n",
430+
" 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
431+
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
432+
"Targets: tensor([ 0, 1, 6, 9, 10])\n"
433+
]
434+
}
435+
],
370436
"source": [
371-
"# from torch.utils.data import DataLoader\n",
372-
"# # Set batch size for DataLoader\n",
373-
"# batch_size = 5\n",
374-
"# # Create DataLoader\n",
375-
"# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
437+
"from torch.utils.data import DataLoader\n",
438+
"# Set batch size for DataLoader\n",
439+
"batch_size = 5\n",
440+
"# Create DataLoader\n",
441+
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
376442
"\n",
377-
"# # Iterate over the data in the DataLoader\n",
378-
"# for batch in dataloader:\n",
379-
"# inputs, targets = batch\n",
380-
"# print(f\"Batch Size: {inputs.size(0)}\")\n",
381-
"# print(\"---------------\")\n",
382-
"# print(f\"Inputs: {inputs}\")\n",
383-
"# print(f\"Targets: {targets}\")\n",
384-
"# break"
443+
"# Iterate over the data in the DataLoader\n",
444+
"for batch in dataloader:\n",
445+
" inputs, targets = batch\n",
446+
" print(f\"Batch Size: {inputs.size(0)}\")\n",
447+
" print(\"---------------\")\n",
448+
" print(f\"Inputs: {inputs}\")\n",
449+
" print(f\"Targets: {targets}\")\n",
450+
" break"
385451
]
386452
},
387453
{

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.37"
10+
version = "0.6.38"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,82 @@
22
import pandas as pd
33
from torch.utils.data import Dataset
44
from sklearn.preprocessing import LabelEncoder
5+
import pathlib
56

67

78
class CSVDataset(Dataset):
89
"""
910
A PyTorch Dataset for handling CSV data.
1011
1112
Args:
12-
csv_file (str): The path to the CSV file. Defaults to "./data/spotPython/data.csv".
13+
filename (str): The path to the CSV file. Defaults to "data.csv".
14+
directory (str): The path to the directory where the CSV file is stored. Defaults to None.
15+
feature_type (torch.dtype): The data type of the features. Defaults to torch.float.
16+
target_column (str): The name of the target column. Defaults to "y".
17+
target_type (torch.dtype): The data type of the targets. Defaults to torch.long.
1318
train (bool): Whether the dataset is for training or not. Defaults to True.
19+
rmNA (bool): Whether to remove rows with NA values or not. Defaults to True.
20+
**desc: Additional keyword arguments.
1421
1522
Attributes:
1623
data (Tensor): The data features.
1724
targets (Tensor): The data targets.
25+
26+
Examples:
27+
>>> from torch.utils.data import DataLoader
28+
from spotPython.data.csvdataset import CSVDataset
29+
import torch
30+
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
31+
# Set batch size for DataLoader
32+
batch_size = 5
33+
# Create DataLoader
34+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
35+
36+
# Iterate over the data in the DataLoader
37+
for batch in dataloader:
38+
inputs, targets = batch
39+
print(f"Batch Size: {inputs.size(0)}")
40+
print("---------------")
41+
print(f"Inputs: {inputs}")
42+
print(f"Targets: {targets}")
1843
"""
1944

2045
def __init__(
2146
self,
22-
csv_file: str = "./data/spotPython/data.csv",
47+
filename: str = "data.csv",
48+
directory: None = None,
2349
feature_type: torch.dtype = torch.float,
2450
target_column: str = "y",
2551
target_type: torch.dtype = torch.long,
2652
train: bool = True,
2753
rmNA=True,
54+
**desc,
2855
) -> None:
2956
super().__init__()
30-
self.csv_file = csv_file
57+
self.filename = filename
58+
self.directory = directory
3159
self.feature_type = feature_type
3260
self.target_type = target_type
3361
self.target_column = target_column
3462
self.train = train
3563
self.rmNA = rmNA
3664
self.data, self.targets = self._load_data()
3765

66+
@property
67+
def path(self):
68+
if self.directory:
69+
return pathlib.Path(self.directory).joinpath(self.filename)
70+
return pathlib.Path(__file__).parent.joinpath(self.filename)
71+
72+
@property
73+
def _repr_content(self):
74+
content = super()._repr_content
75+
content["Path"] = str(self.path)
76+
return content
77+
3878
def _load_data(self) -> tuple:
39-
df = pd.read_csv(self.csv_file, index_col=False)
79+
print(f"Loading data from {self.path}")
80+
df = pd.read_csv(self.path, index_col=False)
4081
# rm rows with NA
4182
if self.rmNA:
4283
df = df.dropna()
@@ -66,7 +107,7 @@ def __getitem__(self, idx: int) -> tuple:
66107
67108
Examples:
68109
>>> from spotPython.light.csvdataset import CSVDataset
69-
dataset = CSVDataset(csv_file='./data/spotPython/data.csv', target_column='prognosis')
110+
dataset = CSVDataset(filename='./data/spotPython/data.csv', target_column='prognosis')
70111
print(dataset.data.shape)
71112
print(dataset.targets.shape)
72113
torch.Size([11, 65])

src/spotPython/data/data.csv

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
sudden_fever,headache,mouth_bleed,nose_bleed,muscle_pain,joint_pain,vomiting,rash,diarrhea,hypotension,pleural_effusion,ascites,gastro_bleeding,swelling,nausea,chills,myalgia,digestion_trouble,fatigue,skin_lesions,stomach_pain,orbital_pain,neck_pain,weakness,back_pain,weight_loss,gum_bleed,jaundice,coma,diziness,inflammation,red_eyes,loss_of_appetite,urination_loss,slow_heart_rate,abdominal_pain,light_sensitivity,yellow_skin,yellow_eyes,facial_distortion,microcephaly,rigor,bitter_tongue,convulsion,anemia,cocacola_urine,hypoglycemia,prostraction,hyperpyrexia,stiff_neck,irritability,confusion,tremor,paralysis,lymph_swells,breathing_restriction,toe_inflammation,finger_inflammation,lips_irritation,itchiness,ulcers,toenail_loss,speech_problem,bullseye_rash,prognosis
2+
1,0,0,0,1,1,1,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,Chikungunya
3+
1,0,0,0,1,1,1,1,0,1,0,1,1,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,1,0,0,Dengue
4+
1,1,1,1,0,1,0,1,0,1,0,0,1,1,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,1,0,0,0,0,0,Rift Valley fever
5+
1,1,0,1,1,0,0,0,0,0,0,0,1,0,1,0,0,0,1,0,1,0,1,0,1,0,1,1,0,0,1,1,1,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,Yellow Fever
6+
0,0,1,0,0,1,0,0,0,0,1,1,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,1,0,0,0,1,1,1,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,Zika
7+
1,1,0,1,1,0,0,1,1,1,1,1,1,0,0,0,0,0,0,1,0,1,0,1,0,1,0,0,0,0,1,1,0,1,1,0,1,1,0,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,Malaria
8+
0,0,0,1,1,1,1,0,0,0,0,0,1,0,1,1,1,0,1,0,1,0,0,1,0,0,1,0,1,1,1,0,0,0,0,1,1,0,1,0,0,1,0,0,0,1,0,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,0,0,Japanese encephalitis
9+
0,1,0,0,0,1,1,0,0,0,0,0,0,0,1,1,1,0,0,1,1,1,0,0,1,0,0,0,0,0,0,0,1,1,0,0,1,0,0,1,0,0,0,1,1,0,0,1,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,West Nile fever
10+
1,1,1,1,1,0,1,1,0,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,0,1,0,1,1,0,1,1,1,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,Plague
11+
0,1,0,0,0,0,0,1,1,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,1,1,1,0,0,Tungiasis
12+
1,1,1,0,0,1,1,1,0,0,1,1,0,0,1,1,1,1,1,0,1,1,1,1,1,0,1,1,1,0,1,0,1,0,1,1,0,1,1,0,0,0,1,0,1,1,1,0,1,1,1,1,1,1,0,1,1,0,0,1,1,0,1,1,Lyme disease
Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,76 @@
22
import pandas as pd
33
from torch.utils.data import Dataset
44
from sklearn.preprocessing import LabelEncoder
5+
import pathlib
56

67

78
class PKLDataset(Dataset):
89
"""
910
A PyTorch Dataset for handling pickle (*.pkl) data.
1011
1112
Args:
12-
pkl_file (str): The path to the pkl file. Defaults to "./data/spotPython/data.pkl".
13+
filename (str): The path to the pkl file. Defaults to "data.pkl".
1314
train (bool): Whether the dataset is for training or not. Defaults to True.
1415
1516
Attributes:
1617
data (Tensor): The data features.
1718
targets (Tensor): The data targets.
19+
20+
Examples:
21+
>>> from spotPython.data.pkldataset import PKLDataset
22+
import torch
23+
from torch.utils.data import DataLoader
24+
dataset = PKLDataset(target_column='prognosis', feature_type=torch.long)
25+
# Set batch size for DataLoader
26+
batch_size = 5
27+
# Create DataLoader
28+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
29+
30+
# Iterate over the data in the DataLoader
31+
for batch in dataloader:
32+
inputs, targets = batch
33+
print(f"Batch Size: {inputs.size(0)}")
34+
print("---------------")
35+
print(f"Inputs: {inputs}")
36+
print(f"Targets: {targets}")
37+
break
1838
"""
1939

2040
def __init__(
2141
self,
22-
pkl_file: str = "./data/spotPython/data.pkl",
42+
filename: str = "data.pkl",
43+
directory: None = None,
2344
feature_type: torch.dtype = torch.float,
2445
target_column: str = "y",
2546
target_type: torch.dtype = torch.long,
2647
train: bool = True,
2748
rmNA=True,
49+
**desc,
2850
) -> None:
2951
super().__init__()
30-
self.pkl_file = pkl_file
52+
self.filename = filename
53+
self.directory = directory
3154
self.feature_type = feature_type
3255
self.target_type = target_type
3356
self.target_column = target_column
3457
self.train = train
3558
self.rmNA = rmNA
3659
self.data, self.targets = self._load_data()
3760

61+
@property
62+
def path(self):
63+
if self.directory:
64+
return pathlib.Path(self.directory).joinpath(self.filename)
65+
return pathlib.Path(__file__).parent.joinpath(self.filename)
66+
67+
@property
68+
def _repr_content(self):
69+
content = super()._repr_content
70+
content["Path"] = str(self.path)
71+
return content
72+
3873
def _load_data(self) -> tuple:
39-
with open(self.pkl_file, "rb") as f:
74+
with open(self.path, "rb") as f:
4075
df = pd.read_pickle(f)
4176
# rm rows with NA
4277
if self.rmNA:
@@ -67,7 +102,7 @@ def __getitem__(self, idx: int) -> tuple:
67102
68103
Examples:
69104
>>> from spotPython.light.pkldataset import pklDataset
70-
dataset = pklDataset(pkl_file='./data/spotPython/data.pkl', target_column='prognosis')
105+
dataset = pklDataset(filename='./data/spotPython/data.pkl', target_column='prognosis')
71106
print(dataset.data.shape)
72107
print(dataset.targets.shape)
73108
torch.Size([11, 65])

0 commit comments

Comments
 (0)