@@ -32,6 +32,7 @@ def fun_control_init(
3232 log_level = 50 ,
3333 lower = None ,
3434 max_time = 1 ,
35+ metric_sklearn = None ,
3536 noise = False ,
3637 n_points = 1 ,
3738 num_workers = 0 ,
@@ -43,6 +44,7 @@ def fun_control_init(
4344 show_progress = True ,
4445 sigma = 0.0 ,
4546 surrogate = None ,
47+ target_column = None ,
4648 task = None ,
4749 test = None ,
4850 test_seed = 1234 ,
@@ -109,6 +111,8 @@ def fun_control_init(
109111 lower bound
110112 max_time (int):
111113 The maximum time in minutes.
114+ metric_sklearn (object):
115+ The metric object from the scikit-learn library. Default is None.
112116 noise (bool):
113117 Whether the objective function is noiy or not. Default is False.
114118 Affects the repeat of the function evaluations.
@@ -137,6 +141,8 @@ def fun_control_init(
137141 Currently only 1-dim functions are supported. Default is `False`.
138142 surrogate (object):
139143 The surrogate model object. Default is None.
144+ target_column (str):
145+ The name of the target column. Default is None.
140146 task (str):
141147 The task to perform. It can be either "classification" or "regression".
142148 Default is None.
@@ -288,7 +294,7 @@ def fun_control_init(
288294 "lower" : lower ,
289295 "max_time" : max_time ,
290296 "metric_river" : None ,
291- "metric_sklearn" : None ,
297+ "metric_sklearn" : metric_sklearn ,
292298 "metric_torch" : None ,
293299 "metric_params" : {},
294300 "model_dict" : {},
@@ -310,7 +316,7 @@ def fun_control_init(
310316 "sigma" : sigma ,
311317 "spot_tensorboard_path" : spot_tensorboard_path ,
312318 "spot_writer" : spot_writer ,
313- "target_column" : None ,
319+ "target_column" : target_column ,
314320 "task" : task ,
315321 "test" : test ,
316322 "test_seed" : test_seed ,
0 commit comments