88class CSVDataset (Dataset ):
99 """
1010 A PyTorch Dataset for handling CSV data.
11-
12- Args:
13- filename (str): The path to the CSV file. Defaults to "data.csv".
14- directory (str): The path to the directory where the CSV file is stored. Defaults to None.
15- feature_type (torch.dtype): The data type of the features. Defaults to torch.float.
16- target_column (str): The name of the target column. Defaults to "y".
17- target_type (torch.dtype): The data type of the targets. Defaults to torch.long.
18- train (bool): Whether the dataset is for training or not. Defaults to True.
19- rmNA (bool): Whether to remove rows with NA values or not. Defaults to True.
20- dropId (bool): Whether to drop the "id" column or not. Defaults to False.
21- **desc (Any): Additional keyword arguments.
22-
23- Attributes:
24- data (Tensor): The data features.
25- targets (Tensor): The data targets.
26-
27- Examples:
28- >>> from torch.utils.data import DataLoader
29- from spotPython.data.csvdataset import CSVDataset
30- import torch
31- dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
32- # Set batch size for DataLoader
33- batch_size = 5
34- # Create DataLoader
35- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
36- # Iterate over the data in the DataLoader
37- for batch in dataloader:
38- inputs, targets = batch
39- print(f"Batch Size: {inputs.size(0)}")
40- print("---------------")
41- print(f"Inputs: {inputs}")
42- print(f"Targets: {targets}")
4311 """
4412
4513 def __init__ (
@@ -48,10 +16,12 @@ def __init__(
4816 directory : None = None ,
4917 feature_type : torch .dtype = torch .float ,
5018 target_column : str = "y" ,
51- target_type : torch .dtype = torch .long ,
19+ target_type : torch .dtype = torch .float ,
5220 train : bool = True ,
5321 rmNA = True ,
5422 dropId = False ,
23+ oe = OrdinalEncoder (),
24+ le = LabelEncoder (),
5525 ** desc ,
5626 ) -> None :
5727 super ().__init__ ()
@@ -63,6 +33,8 @@ def __init__(
6333 self .train = train
6434 self .rmNA = rmNA
6535 self .dropId = dropId
36+ self .oe = oe
37+ self .le = le
6638 self .data , self .targets = self ._load_data ()
6739
6840 @property
@@ -78,30 +50,48 @@ def _repr_content(self):
7850 return content
7951
8052 def _load_data (self ) -> tuple :
81- # print(f"Loading data from {self.path}")
8253 df = pd .read_csv (self .path , index_col = False )
83- # rm rows with NA
54+
55+ # Remove rows with NA if specified
8456 if self .rmNA :
8557 df = df .dropna ()
86- if self .dropId :
87- df = df .drop (columns = ["id" ])
8858
89- oe = OrdinalEncoder ()
90- # Apply LabelEncoder to string columns
91- le = LabelEncoder ()
92- # df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
59+ # Drop the id column if specified
60+ if self .dropId and "id" in df .columns :
61+ df = df .drop (columns = ["id" ])
9362
9463 # Split DataFrame into feature and target DataFrames
9564 feature_df = df .drop (columns = [self .target_column ])
96- feature_df = oe .fit_transform (feature_df )
65+
66+ # Identify non-numerical columns in the feature DataFrame
67+ non_numerical_columns = feature_df .select_dtypes (exclude = ["number" ]).columns .tolist ()
68+
69+ # Apply OrdinalEncoder to non-numerical feature columns
70+ if non_numerical_columns :
71+ if self .oe is None :
72+ raise ValueError (
73+ f"\n !!! non_numerical_columns in data: { non_numerical_columns } "
74+ "\n OrdinalEncoder object oe must be provided for encoding non-numerical columns"
75+ )
76+ feature_df [non_numerical_columns ] = self .oe .fit_transform (feature_df [non_numerical_columns ])
77+
9778 target_df = df [self .target_column ]
98- # only apply LabelEncoder to target column if it is a string
99- if target_df .dtype == object :
100- target_df = le .fit_transform (target_df )
10179
102- # Convert DataFrames to PyTorch tensors
103- feature_tensor = torch .tensor (feature_df , dtype = self .feature_type )
104- target_tensor = torch .tensor (target_df , dtype = self .target_type )
80+ # Check if the target column is non-numerical using dtype
81+ if not pd .api .types .is_numeric_dtype (target_df ):
82+ if self .le is None :
83+ raise ValueError (
84+ f"\n !!! The target column '{ self .target_column } ' is non-numerical"
85+ "\n LabelEncoder object le must be provided for encoding non-numerical target"
86+ )
87+ target_df = self .le .fit_transform (target_df )
88+
89+ # Convert DataFrames to NumPy arrays and then to PyTorch tensors
90+ feature_array = feature_df .to_numpy ()
91+ target_array = target_df
92+
93+ feature_tensor = torch .tensor (feature_array , dtype = self .feature_type )
94+ target_tensor = torch .tensor (target_array , dtype = self .target_type )
10595
10696 return feature_tensor , target_tensor
10797
0 commit comments