1- from multiprocessing import Pool , Manager
1+ from multiprocessing import Pool
22from joblib import Parallel , delayed
33import numpy as np
44from typing import Callable , Any , Union
5+ from spotpython .utils .seed import set_all_seeds
56
67
78def evaluate_row (row : Union [np .ndarray , list ], objective_function : Callable [[np .ndarray , Any ], Any ], fun_control : Any ) -> Any :
@@ -21,12 +22,16 @@ def evaluate_row(row: Union[np.ndarray, list], objective_function: Callable[[np.
2122 >>> from spotpython.utils.parallel import evaluate_row
2223 >>> import numpy as np
2324 >>> def sample_objective(row, control):
24- ... return sum( row) + control.get('offset', 0)
25+ ... return row + control.get('offset', 0)
2526 >>> row = [1, 2, 3]
2627 >>> fun_control = {'offset': 10}
2728 >>> evaluate_row(row, sample_objective, fun_control)
2829 array([11, 12, 13])
2930 """
31+ if fun_control is not None :
32+ if "seed" in fun_control :
33+ seed = fun_control ["seed" ]
34+ set_all_seeds (seed )
3035 return objective_function (np .array ([row ]), fun_control )
3136
3237
@@ -61,13 +66,11 @@ def parallel_objective_function(objective_function, X, num_cores, fun_control, m
6166 >>> parallel_objective_function(sample_objective, X, num_cores=2, fun_control=fun_control, method='joblib')
6267 array([16, 25, 34])
6368 """
64- with Manager () as manager :
65- shared_control = manager .dict (fun_control )
66- if method == "mp" :
67- with Pool (processes = num_cores ) as pool :
68- results = pool .starmap (evaluate_row , [(row , objective_function , shared_control ) for row in X ])
69- elif method == "joblib" :
70- results = Parallel (n_jobs = num_cores )(delayed (evaluate_row )(row , objective_function , shared_control ) for row in X )
69+ if method == "mp" :
70+ with Pool (processes = num_cores ) as pool :
71+ results = pool .starmap (evaluate_row , [(row , objective_function , fun_control ) for row in X ])
72+ elif method == "joblib" :
73+ results = Parallel (n_jobs = num_cores )(delayed (evaluate_row )(row , objective_function , fun_control ) for row in X )
7174
7275 return np .array (results ).flatten ()
7376
0 commit comments