Skip to content

Commit 98bfa54

Browse files
0.14.47
pkl updated
1 parent 7c29993 commit 98bfa54

7 files changed

Lines changed: 78 additions & 27 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.46"
10+
version = "0.14.47"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/csvdataset.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,20 @@ def extra_repr(self) -> str:
149149
"""
150150
split = "Train" if self.train else "Test"
151151
return f"Split: {split}"
152+
153+
def __ncols__(self) -> int:
154+
"""
155+
Returns the number of columns in the dataset.
156+
157+
Returns:
158+
int: The number of columns in the dataset.
159+
160+
Examples:
161+
>>> from spotPython.data.pkldataset import PKLDataset
162+
import torch
163+
from torch.utils.data import DataLoader
164+
dataset = PKLDataset(target_column='prognosis', feature_type=torch.long)
165+
print(dataset.__ncols__())
166+
64
167+
"""
168+
return self.data.size(1)

src/spotPython/data/pkldataset.py

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@ class PKLDataset(Dataset):
3434
The directory where the pkl file is located.
3535
feature_type (torch.dtype):
3636
The data type of the features.
37+
Defaults to torch.float.
3738
target_column (str):
3839
The name of the target column.
3940
target_type (torch.dtype):
4041
The data type of the targets.
42+
Defaults to torch.float.
4143
train (bool):
4244
Whether the dataset is for training or not.
4345
rmNA (bool):
@@ -73,7 +75,7 @@ class PKLDataset(Dataset):
7375
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7476
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
7577
[1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
76-
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,
78+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,
7779
0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],
7880
[1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,
7981
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,
@@ -104,9 +106,11 @@ def __init__(
104106
directory: None = None,
105107
feature_type: torch.dtype = torch.float,
106108
target_column: str = "y",
107-
target_type: torch.dtype = torch.long,
109+
target_type: torch.dtype = torch.float,
108110
train: bool = True,
109111
rmNA=True,
112+
oe=OrdinalEncoder(),
113+
le=LabelEncoder(),
110114
**desc,
111115
) -> None:
112116
super().__init__()
@@ -117,16 +121,15 @@ def __init__(
117121
self.target_column = target_column
118122
self.train = train
119123
self.rmNA = rmNA
124+
self.oe = oe
125+
self.le = le
120126
self.data, self.targets = self._load_data()
121127

122128
@property
123129
def path(self):
124-
# user defined directory:
125130
if self.directory:
126131
return pathlib.Path(self.directory).joinpath(self.filename)
127-
# no user defined directory, use package directory
128-
else:
129-
return pathlib.Path(__file__).parent.joinpath(self.filename)
132+
return pathlib.Path(__file__).parent.joinpath(self.filename)
130133

131134
@property
132135
def _repr_content(self):
@@ -135,26 +138,37 @@ def _repr_content(self):
135138
return content
136139

137140
def _load_data(self) -> tuple:
141+
# ensure that self.target_type and self.feature_type are the same torch types
142+
if self.target_type != self.feature_type:
143+
raise ValueError("target_type and feature_type must be the same torch type")
138144
with open(self.path, "rb") as f:
139145
df = pd.read_pickle(f)
140146
# rm rows with NA
141147
if self.rmNA:
142148
df = df.dropna()
143149

144-
oe = OrdinalEncoder()
145-
# Apply LabelEncoder to string columns
146-
le = LabelEncoder()
147-
# df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
148-
149150
# Split DataFrame into feature and target DataFrames
150151
feature_df = df.drop(columns=[self.target_column])
151-
feature_df = oe.fit_transform(feature_df)
152+
153+
# Identify non-numerical columns in the feature DataFrame
154+
non_numerical_columns = feature_df.select_dtypes(exclude=["number"]).columns.tolist()
155+
156+
# Apply OrdinalEncoder to non-numerical feature columns
157+
if non_numerical_columns:
158+
feature_df[non_numerical_columns] = self.oe.fit_transform(feature_df[non_numerical_columns])
159+
152160
target_df = df[self.target_column]
153-
target_df = le.fit_transform(target_df)
154161

155-
# Convert DataFrames to PyTorch tensors
156-
feature_tensor = torch.tensor(feature_df, dtype=self.feature_type)
157-
target_tensor = torch.tensor(target_df, dtype=self.target_type)
162+
# Check if the target column is non-numerical using dtype
163+
if not pd.api.types.is_numeric_dtype(target_df):
164+
target_df = self.le.fit_transform(target_df)
165+
166+
# Convert DataFrames to NumPy arrays and then to PyTorch tensors
167+
feature_array = feature_df.to_numpy()
168+
target_array = target_df
169+
170+
feature_tensor = torch.tensor(feature_array, dtype=self.feature_type)
171+
target_tensor = torch.tensor(target_array, dtype=self.target_type)
158172

159173
return feature_tensor, target_tensor
160174

@@ -214,3 +228,20 @@ def extra_repr(self) -> str:
214228
print(dataset)
215229
"""
216230
return "filename={}, directory={}".format(self.filename, self.directory)
231+
232+
def __ncols__(self) -> int:
233+
"""
234+
Returns the number of columns in the dataset.
235+
236+
Returns:
237+
int: The number of columns in the dataset.
238+
239+
Examples:
240+
>>> from spotPython.data.pkldataset import PKLDataset
241+
import torch
242+
from torch.utils.data import DataLoader
243+
dataset = PKLDataset(target_column='prognosis', feature_type=torch.long)
244+
print(dataset.__ncols__())
245+
64
246+
"""
247+
return self.data.size(1)

src/spotPython/fun/hyperlight.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
146146
except Exception as err:
147147
if fun_control["verbosity"] > 0:
148148
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
149+
pprint.pprint(fun_control)
149150
print(f"Error in fun(). Call to train_model failed. {err=}, {type(err)=}")
150151
print("Setting df_eval to np.nan\n")
151152
logger.error(f"Error in fun(). Call to train_model failed. {err=}, {type(err)=}")

src/spotPython/hyperdict/light_hyper_dict.json

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -349,17 +349,11 @@
349349
},
350350
"optimizer": {
351351
"levels": [
352-
"Adadelta",
353-
"Adagrad",
354352
"Adam",
355353
"AdamW",
356-
"SparseAdam",
357354
"Adamax",
358-
"ASGD",
359355
"NAdam",
360356
"RAdam",
361-
"RMSprop",
362-
"Rprop",
363357
"SGD"
364358
],
365359
"type": "factor",
@@ -393,16 +387,14 @@
393387
},
394388
"initialization": {
395389
"levels": [
396-
"Default",
397-
"Kaiming",
398-
"Xavier"
390+
"Default"
399391
],
400392
"type": "factor",
401393
"default": "Default",
402394
"transform": "None",
403395
"core_model_parameter_type": "str",
404396
"lower": 0,
405-
"upper": 2
397+
"upper": 0
406398
}
407399
}
408400
}

src/spotPython/spot/spot.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,14 @@ def write_db_dict(self) -> None:
696696
print("The following dictionaries are written to the json file spotPython_db.json:")
697697
print("fun_control:")
698698
pprint.pprint(fun_control)
699+
# check if all the keys in the dictionary are serializable
700+
for key in fun_control.keys():
701+
if not isinstance(fun_control[key], (int, float, str, list, dict)):
702+
# remove the key from the dictionary
703+
print(f"Removing non-serializable key: {key}")
704+
fun_control.pop(key)
705+
print("fun_control after removing non-serializabel keys:")
706+
pprint.pprint(fun_control)
699707
print("design_control:")
700708
pprint.pprint(design_control)
701709
print("optimizer_control:")

src/spotPython/utils/init.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,10 @@ def fun_control_init(
242242
The weight coefficient of the objective function. Positive values mean minimization.
243243
If set to -1, scores that are better when maximized will be minimized, e.g, accuracy.
244244
Can be an array, so that different weights can be used for different (multiple) objectives.
245+
Default is 1.0.
245246
weight_coeff (float):
246247
Determines how to weight older measures. Default is 1.0. Used in the OML algorithm eval_oml.py.
248+
Default is 0.0.
247249
weights_entry (str):
248250
The weights entry used in the GUI. Default is None.
249251

0 commit comments

Comments
 (0)