Skip to content

Commit 47138f9

Browse files
docs/examples
1 parent 3bed52d commit 47138f9

2 files changed

Lines changed: 18 additions & 28 deletions

File tree

src/spotPython/fun/hyperlight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def check_X_shape(self, X: np.ndarray) -> np.ndarray:
7272
Exception:
7373
if the shape of the input array is not valid.
7474
75-
Example:
75+
Examples:
7676
>>> hyper_light = HyperLight(seed=126, log_level=50)
7777
>>> X = np.array([[1, 2], [3, 4]])
7878
>>> hyper_light.check_X_shape(X)
@@ -101,7 +101,7 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
101101
(np.ndarray):
102102
array containing the evaluation results.
103103
104-
Example:
104+
Examples:
105105
>>> hyper_light = HyperLight(seed=126, log_level=50)
106106
>>> X = np.array([[1, 2], [3, 4]])
107107
>>> fun_control = {"weights": np.array([1, 0, 0])}

src/spotPython/fun/hypersklearn.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ class HyperSklearn:
3131
rng (Generator): random number generator.
3232
fun_control (dict): dictionary containing control parameters for the function.
3333
log_level (int): log level for logger.
34+
Examples:
35+
>>> from spotPython.fun.hypersklearn import HyperSklearn
36+
>>> hyper_sklearn = HyperSklearn(seed=126, log_level=50)
37+
>>> print(hyper_sklearn.seed)
38+
126
3439
"""
3540

3641
def __init__(self, seed: int = 126, log_level: int = 50):
@@ -65,6 +70,16 @@ def check_X_shape(self, X: np.ndarray) -> None:
6570
6671
Raises:
6772
Exception: if the second dimension of X does not match the length of var_name in fun_control.
73+
Examples:
74+
>>> from spotPython.fun.hypersklearn import HyperSklearn
75+
>>> hyper_sklearn = HyperSklearn(seed=126, log_level=50)
76+
>>> hyper_sklearn.fun_control["var_name"] = ["a", "b", "c"]
77+
>>> hyper_sklearn.check_X_shape(X=np.array([[1, 2, 3]]))
78+
>>> hyper_sklearn.check_X_shape(X=np.array([[1, 2]]))
79+
Traceback (most recent call last):
80+
...
81+
Exception
82+
6883
"""
6984
try:
7085
X.shape[1]
@@ -84,6 +99,7 @@ def get_sklearn_df_eval_preds(self, model) -> tuple:
8499
85100
Raises:
86101
Exception: if call to evaluate_model fails.
102+
87103
"""
88104
try:
89105
df_eval, df_preds = self.evaluate_model(model, self.fun_control)
@@ -108,32 +124,6 @@ def fun_sklearn(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
108124
Raises:
109125
Exception: if call to evaluate_model fails.
110126
111-
Example:
112-
>>> from sklearn.tree import DecisionTreeRegressor
113-
>>> from spotPython.data.load import load_data
114-
>>> from spotPython.hyperparameters.values import generate_var_dict_from_config_space
115-
>>> from spotPython.sklearn.traintest import split_data_for_hold_out
116-
>>> data = load_data("boston")
117-
>>> data_train, data_test = split_data_for_hold_out(data)
118-
>>> config_space = {
119-
... "max_depth": [3, 4],
120-
... "min_samples_split": [2],
121-
... "min_samples_leaf": [1],
122-
... }
123-
>>> var_dict = generate_var_dict_from_config_space(config_space)
124-
>>> var_name = list(var_dict.keys())
125-
>>> var_type = ["int"] * len(var_name)
126-
>>> fun_control = {
127-
... "data": data_train,
128-
... "var_name": var_name,
129-
... "var_type": var_type,
130-
... "core_model": DecisionTreeRegressor,
131-
... "eval": "train_hold_out",
132-
... }
133-
>>> h = HyperSklearn()
134-
>>> X = np.array([[3, 2, 1]])
135-
>>> h.fun_sklearn(X, fun_control)
136-
array([3.05555556])
137127
"""
138128
z_res = np.array([], dtype=float)
139129
self.fun_control.update(fun_control)

0 commit comments

Comments
 (0)