Skip to content

Commit 27db166

Browse files
0.24.20
manydataset
1 parent d2c4ab6 commit 27db166

2 files changed

Lines changed: 107 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.24.19"
10+
version = "0.24.20"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/manydataset.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from torch.utils.data import Dataset
2+
import torch
3+
import pandas as pd
4+
from typing import List, Optional, Union
5+
6+
7+
class ManyToManyDataset(Dataset):
8+
"""
9+
A PyTorch Dataset for many-to-many data.
10+
11+
Args:
12+
df_list (List[pd.DataFrame]): List of pandas DataFrames.
13+
target (str): The target column name.
14+
drop (Optional[Union[str, List[str]]]): Column(s) to drop from the DataFrames. Default is None.
15+
dtype (torch.dtype): Data type for the tensors. Default is torch.float32.
16+
17+
Attributes:
18+
data (List[pd.DataFrame]): List of pandas DataFrames with specified columns dropped.
19+
target (List[torch.Tensor]): List of target tensors.
20+
features (List[torch.Tensor]): List of feature tensors.
21+
22+
Examples:
23+
>>> import pandas as pd
24+
>>> from spotpython.data.manydataset import ManyToManyDataset
25+
>>> df1 = pd.DataFrame({'feature1': [1, 2], 'feature2': [3, 4], 'target': [5, 6]})
26+
>>> df2 = pd.DataFrame({'feature1': [7, 8], 'feature2': [9, 10], 'target': [11, 12]})
27+
>>> dataset = ManyToManyDataset([df1, df2], target='target', drop='feature2')
28+
>>> len(dataset)
29+
2
30+
>>> dataset[0]
31+
(tensor([[1.],
32+
[2.]]), tensor([5., 6.]))
33+
"""
34+
35+
def __init__(
36+
self,
37+
df_list: List[pd.DataFrame],
38+
target: str,
39+
drop: Optional[Union[str, List[str]]] = None,
40+
dtype: torch.dtype = torch.float32,
41+
):
42+
try:
43+
self.data = [df.drop(drop, axis=1) for df in df_list]
44+
except KeyError:
45+
self.data = df_list
46+
self.target = [torch.tensor(df[target].to_numpy(), dtype=dtype) for df in self.data]
47+
self.features = [torch.tensor(df.drop([target], axis=1).to_numpy(), dtype=dtype) for df in self.data]
48+
49+
def __getitem__(self, index: int):
50+
x = self.features[index]
51+
y = self.target[index]
52+
return x, y
53+
54+
def __len__(self) -> int:
55+
return len(self.data)
56+
57+
58+
class ManyToOneDataset(Dataset):
59+
"""
60+
A PyTorch Dataset for many-to-one data.
61+
62+
Args:
63+
df_list (List[pd.DataFrame]): List of pandas DataFrames.
64+
target (str): The target column name.
65+
drop (Optional[Union[str, List[str]]]): Column(s) to drop from the DataFrames. Default is None.
66+
dtype (torch.dtype): Data type for the tensors. Default is torch.float32.
67+
68+
Attributes:
69+
data (List[pd.DataFrame]): List of pandas DataFrames with specified columns dropped.
70+
target (List[torch.Tensor]): List of target tensors.
71+
features (List[torch.Tensor]): List of feature tensors.
72+
73+
Examples:
74+
>>> import pandas as pd
75+
>>> from spotpython.data.manydataset import ManyToOneDataset
76+
>>> df1 = pd.DataFrame({'feature1': [1, 2], 'feature2': [3, 4], 'target': [5, 6]})
77+
>>> df2 = pd.DataFrame({'feature1': [7, 8], 'feature2': [9, 10], 'target': [11, 12]})
78+
>>> dataset = ManyToOneDataset([df1, df2], target='target', drop='feature2')
79+
>>> len(dataset)
80+
2
81+
>>> dataset[0]
82+
(tensor([[1.],
83+
[2.]]), tensor(5.))
84+
"""
85+
86+
def __init__(
87+
self,
88+
df_list: List[pd.DataFrame],
89+
target: str,
90+
drop: Optional[Union[str, List[str]]] = None,
91+
dtype: torch.dtype = torch.float32,
92+
):
93+
try:
94+
self.data = [df.drop(drop, axis=1) for df in df_list]
95+
except KeyError:
96+
self.data = df_list
97+
self.target = [torch.tensor(df[target].to_numpy()[0], dtype=dtype) for df in self.data]
98+
self.features = [torch.tensor(df.drop([target], axis=1).to_numpy(), dtype=dtype) for df in self.data]
99+
100+
def __getitem__(self, index: int):
101+
x = self.features[index]
102+
y = self.target[index]
103+
return x, y
104+
105+
def __len__(self) -> int:
106+
return len(self.data)

0 commit comments

Comments
 (0)