@@ -271,6 +271,7 @@ def assign_values(X: np.array, var_list: list) -> dict:
271271def modify_hyper_parameter_levels (fun_control , hyperparameter , levels ) -> None :
272272 """
273273 This function modifies the levels of a hyperparameter in the fun_control dictionary.
274+ It also sets the lower and upper bounds of the hyperparameter to 0 and len(levels) - 1, respectively.
274275
275276 Args:
276277 fun_control (dict):
@@ -684,7 +685,8 @@ def get_values_from_dict(dictionary) -> np.array:
684685
685686
686687def add_core_model_to_fun_control (core_model , fun_control , hyper_dict = None , filename = None ) -> dict :
687- """Add the core model to the function control dictionary.
688+ """Add the core model to the function control dictionary. It updates the keys "core_model",
689+ "core_model_hyper_dict", "var_type", "var_name" in the fun_control dictionary.
688690
689691 Args:
690692 core_model (class):
@@ -703,6 +705,14 @@ def add_core_model_to_fun_control(core_model, fun_control, hyper_dict=None, file
703705 (dict):
704706 The updated fun_control dictionary.
705707
708+ Notes:
709+ The function adds the following keys to the fun_control dictionary:
710+ "core_model": The core model.
711+ "core_model_hyper_dict": The hyper parameter dictionary for the core model.
712+ "var_type": A list of variable types.
713+ "var_name": A list of variable names.
714+ The original hyperparameters of the core model are stored in the "core_model_hyper_dict" key.
715+
706716 Examples:
707717 >>> from spotPython.light.netlightregressione import NetLightRegression
708718 from spotPython.hyperdict.light_hyper_dict import LightHyperDict
@@ -924,66 +934,6 @@ def get_default_hyperparameters_for_core_model(fun_control) -> dict:
924934 return values
925935
926936
927- def set_data_set (fun_control , data_set ) -> dict :
928- """
929- This function sets the lightning dataset in the fun_control dictionary.
930-
931- Args:
932- fun_control (dict):
933- fun_control dictionary
934- data_set (class): Dataset class from torch.utils.data
935-
936- Returns:
937- fun_control (dict):
938- updated fun_control
939-
940- Examples:
941- >>> from spotPython.utils.init import fun_control_init
942- from spotPython.hyperparameters.values import set_data_module
943- from spotPython.data.lightdatamodule import LightDataModule
944- from spotPython.data.csvdataset import CSVDataset
945- from spotPython.data.pkldataset import PKLDataset
946- import torch
947- fun_control = fun_control_init()
948- ds = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
949- set_data_set(fun_control=fun_control,
950- data_set=ds)
951- fun_control["data_set"]
952- """
953- fun_control .update ({"data_set" : data_set })
954-
955-
956- def set_data_module (fun_control , data_module ) -> dict :
957- """
958- This function sets the lightning datamodule in the fun_control dictionary.
959-
960- Args:
961- fun_control (dict):
962- fun_control dictionary
963- data_module (class): DataLoader class from torch.utils.data
964-
965- Returns:
966- fun_control (dict):
967- updated fun_control
968-
969- Examples:
970- >>> from spotPython.utils.init import fun_control_init
971- from spotPython.hyperparameters.values import set_data_module
972- from spotPython.data.lightdatamodule import LightDataModule
973- from spotPython.data.csvdataset import CSVDataset
974- from spotPython.data.pkldataset import PKLDataset
975- import torch
976- fun_control = fun_control_init()
977- dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)
978- dm = LightDataModule(dataset=dataset, batch_size=5, test_size=7)
979- dm.setup()
980- set_data_module(fun_control=fun_control,
981- data_module=dm)
982- fun_control["data_module"]
983- """
984- fun_control .update ({"data_module" : data_module })
985-
986-
987937def get_tuned_architecture (spot_tuner , fun_control ) -> dict :
988938 """
989939 Returns the tuned architecture.
0 commit comments