Skip to content

Commit 188809d

Browse files
0.9.10
set_data_module/set() removed encoder in cvs/pkl datasets improved
1 parent 535a41e commit 188809d

13 files changed

Lines changed: 151 additions & 430 deletions

notebooks/00_spotPython_tests.ipynb

Lines changed: 92 additions & 308 deletions
Large diffs are not rendered by default.

notebooks/testKriging.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,7 +1943,7 @@
19431943
"outputs": [],
19441944
"source": [
19451945
"from spotPython.utils.init import fun_control_init\n",
1946-
"from spotPython.hyperparameters.values import set_data_module\n",
1946+
"from spotPython.hyperparameters.values import set_control_key_value\n",
19471947
"from spotPython.data.lightdatamodule import LightDataModule\n",
19481948
"from spotPython.data.csvdataset import CSVDataset\n",
19491949
"from spotPython.data.pkldataset import PKLDataset\n",
@@ -1952,8 +1952,9 @@
19521952
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n",
19531953
"dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)\n",
19541954
"dm.setup()\n",
1955-
"set_data_module(fun_control=fun_control,\n",
1956-
" data_module=dm)\n",
1955+
"set_control_key_value(control_dict=fun_control,\n",
1956+
" key=\"data_module\",\n",
1957+
" value=dm, replace=True)\n",
19571958
"assert isinstance(fun_control[\"data_module\"], LightDataModule)"
19581959
]
19591960
},

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.9.9"
10+
version = "0.9.10"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/csvdataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import pandas as pd
33
from torch.utils.data import Dataset
4-
from sklearn.preprocessing import LabelEncoder
4+
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
55
import pathlib
66

77

@@ -85,17 +85,21 @@ def _load_data(self) -> tuple:
8585
df = df.dropna()
8686
if self.dropId:
8787
df = df.drop(columns=["id"])
88+
89+
oe = OrdinalEncoder()
8890
# Apply LabelEncoder to string columns
8991
le = LabelEncoder()
90-
df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
92+
# df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
9193

9294
# Split DataFrame into feature and target DataFrames
9395
feature_df = df.drop(columns=[self.target_column])
96+
feature_df = oe.fit_transform(feature_df)
9497
target_df = df[self.target_column]
98+
target_df = le.fit_transform(target_df)
9599

96100
# Convert DataFrames to PyTorch tensors
97-
feature_tensor = torch.tensor(feature_df.values, dtype=self.feature_type)
98-
target_tensor = torch.tensor(target_df.values, dtype=self.target_type)
101+
feature_tensor = torch.tensor(feature_df, dtype=self.feature_type)
102+
target_tensor = torch.tensor(target_df, dtype=self.target_type)
99103

100104
return feature_tensor, target_tensor
101105

src/spotPython/data/pkldataset.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import pandas as pd
33
from torch.utils.data import Dataset
4-
from sklearn.preprocessing import LabelEncoder
4+
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder
55
import pathlib
66

77

@@ -137,17 +137,21 @@ def _load_data(self) -> tuple:
137137
# rm rows with NA
138138
if self.rmNA:
139139
df = df.dropna()
140+
141+
oe = OrdinalEncoder()
140142
# Apply LabelEncoder to string columns
141143
le = LabelEncoder()
142-
df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
144+
# df = df.apply(lambda col: le.fit_transform(col) if col.dtypes == object else col)
143145

144146
# Split DataFrame into feature and target DataFrames
145147
feature_df = df.drop(columns=[self.target_column])
148+
feature_df = oe.fit_transform(feature_df)
146149
target_df = df[self.target_column]
150+
target_df = le.fit_transform(target_df)
147151

148152
# Convert DataFrames to PyTorch tensors
149-
feature_tensor = torch.tensor(feature_df.values, dtype=self.feature_type)
150-
target_tensor = torch.tensor(target_df.values, dtype=self.target_type)
153+
feature_tensor = torch.tensor(feature_df, dtype=self.feature_type)
154+
target_tensor = torch.tensor(target_df, dtype=self.target_type)
151155

152156
return feature_tensor, target_tensor
153157

src/spotPython/fun/hyperlight.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,15 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
107107
get_default_hyperparameters_as_array)
108108
from spotPython.fun.hyperlight import HyperLight
109109
from spotPython.data.diabetes import Diabetes
110-
from spotPython.hyperparameters.values import set_data_set
110+
from spotPython.hyperparameters.values import set_control_key_value
111111
import numpy as np
112112
fun_control = fun_control_init(
113113
_L_in=10,
114114
_L_out=1,)
115115
dataset = Diabetes()
116-
set_data_set(fun_control=fun_control,
117-
data_set=dataset)
116+
set_control_key_value(control_dict=fun_control,
117+
key="data_set",
118+
value=dataset)
118119
add_core_model_to_fun_control(core_model=NetLightRegression,
119120
fun_control=fun_control,
120121
hyper_dict=LightHyperDict)

src/spotPython/hyperparameters/values.py

Lines changed: 11 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def assign_values(X: np.array, var_list: list) -> dict:
271271
def modify_hyper_parameter_levels(fun_control, hyperparameter, levels) -> None:
272272
"""
273273
This function modifies the levels of a hyperparameter in the fun_control dictionary.
274+
It also sets the lower and upper bounds of the hyperparameter to 0 and len(levels) - 1, respectively.
274275
275276
Args:
276277
fun_control (dict):
@@ -684,7 +685,8 @@ def get_values_from_dict(dictionary) -> np.array:
684685

685686

686687
def add_core_model_to_fun_control(core_model, fun_control, hyper_dict=None, filename=None) -> dict:
687-
"""Add the core model to the function control dictionary.
688+
"""Add the core model to the function control dictionary. It updates the keys "core_model",
689+
"core_model_hyper_dict", "var_type", "var_name" in the fun_control dictionary.
688690
689691
Args:
690692
core_model (class):
@@ -703,6 +705,14 @@ def add_core_model_to_fun_control(core_model, fun_control, hyper_dict=None, file
703705
(dict):
704706
The updated fun_control dictionary.
705707
708+
Notes:
709+
The function adds the following keys to the fun_control dictionary:
710+
"core_model": The core model.
711+
"core_model_hyper_dict": The hyper parameter dictionary for the core model.
712+
"var_type": A list of variable types.
713+
"var_name": A list of variable names.
714+
The original hyperparameters of the core model are stored in the "core_model_hyper_dict" key.
715+
706716
Examples:
707717
>>> from spotPython.light.netlightregressione import NetLightRegression
708718
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
@@ -924,66 +934,6 @@ def get_default_hyperparameters_for_core_model(fun_control) -> dict:
924934
return values
925935

926936

927-
def set_data_set(fun_control, data_set) -> dict:
928-
"""
929-
This function sets the lightning dataset in the fun_control dictionary.
930-
931-
Args:
932-
fun_control (dict):
933-
fun_control dictionary
934-
data_set (class): Dataset class from torch.utils.data
935-
936-
Returns:
937-
fun_control (dict):
938-
updated fun_control
939-
940-
Examples:
941-
>>> from spotPython.utils.init import fun_control_init
942-
from spotPython.hyperparameters.values import set_data_module
943-
from spotPython.data.lightdatamodule import LightDataModule
944-
from spotPython.data.csvdataset import CSVDataset
945-
from spotPython.data.pkldataset import PKLDataset
946-
import torch
947-
fun_control = fun_control_init()
948-
ds = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
949-
set_data_set(fun_control=fun_control,
950-
data_set=ds)
951-
fun_control["data_set"]
952-
"""
953-
fun_control.update({"data_set": data_set})
954-
955-
956-
def set_data_module(fun_control, data_module) -> dict:
957-
"""
958-
This function sets the lightning datamodule in the fun_control dictionary.
959-
960-
Args:
961-
fun_control (dict):
962-
fun_control dictionary
963-
data_module (class): DataLoader class from torch.utils.data
964-
965-
Returns:
966-
fun_control (dict):
967-
updated fun_control
968-
969-
Examples:
970-
>>> from spotPython.utils.init import fun_control_init
971-
from spotPython.hyperparameters.values import set_data_module
972-
from spotPython.data.lightdatamodule import LightDataModule
973-
from spotPython.data.csvdataset import CSVDataset
974-
from spotPython.data.pkldataset import PKLDataset
975-
import torch
976-
fun_control = fun_control_init()
977-
dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
978-
dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)
979-
dm.setup()
980-
set_data_module(fun_control=fun_control,
981-
data_module=dm)
982-
fun_control["data_module"]
983-
"""
984-
fun_control.update({"data_module": data_module})
985-
986-
987937
def get_tuned_architecture(spot_tuner, fun_control) -> dict:
988938
"""
989939
Returns the tuned architecture.

src/spotPython/light/testmodel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,17 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
3333
from spotPython.hyperparameters.values import (add_core_model_to_fun_control,
3434
get_default_hyperparameters_as_array)
3535
from spotPython.data.diabetes import Diabetes
36-
from spotPython.hyperparameters.values import set_data_set
36+
from spotPython.hyperparameters.values import set_control_key_value
3737
from spotPython.hyperparameters.values import (get_var_name, assign_values,
3838
generate_one_config_from_var_dict)
3939
import spotPython.light.testmodel as tm
4040
fun_control = fun_control_init(
4141
_L_in=10,
4242
_L_out=1,)
4343
dataset = Diabetes()
44-
set_data_set(fun_control=fun_control,
45-
data_set=dataset)
44+
set_control_key_value(control_dict=fun_control,
45+
key="data_set",
46+
value=dataset)
4647
add_core_model_to_fun_control(core_model=NetLightRegression,
4748
fun_control=fun_control,
4849
hyper_dict=LightHyperDict)

src/spotPython/light/trainmodel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,17 @@ def train_model(config: dict, fun_control: dict) -> float:
2626
add_core_model_to_fun_control,
2727
get_default_hyperparameters_as_array)
2828
from spotPython.data.diabetes import Diabetes
29-
from spotPython.hyperparameters.values import set_data_set
29+
from spotPython.hyperparameters.values import set_control_key_value
3030
from spotPython.hyperparameters.values import get_var_name, assign_values, generate_one_config_from_var_dict
3131
from spotPython.light.traintest import train_model
3232
fun_control = fun_control_init(
3333
_L_in=10,
3434
_L_out=1,)
3535
# Select a dataset
3636
dataset = Diabetes()
37-
set_data_set(fun_control=fun_control,
38-
data_set=dataset)
37+
set_control_key_value(control_dict=fun_control,
38+
key="data_set",
39+
value=dataset)
3940
# Select a model
4041
add_core_model_to_fun_control(core_model=NetLightRegression,
4142
fun_control=fun_control,

test/test_hyper_light_fun.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from spotPython.hyperparameters.values import add_core_model_to_fun_control, get_default_hyperparameters_as_array
66
from spotPython.fun.hyperlight import HyperLight
77
from spotPython.data.diabetes import Diabetes
8-
from spotPython.hyperparameters.values import set_data_set
8+
from spotPython.hyperparameters.values import set_control_key_value
99
import numpy as np
1010

1111

@@ -15,8 +15,10 @@ def test_hyper_light_fun():
1515
_L_out=1,)
1616

1717
dataset = Diabetes()
18-
set_data_set(fun_control=fun_control,
19-
data_set=dataset)
18+
set_control_key_value(control_dict=fun_control,
19+
key="data_set",
20+
value=dataset)
21+
2022

2123
add_core_model_to_fun_control(core_model=NetLightRegression,
2224
fun_control=fun_control,

0 commit comments

Comments
 (0)