Skip to content

Commit b5e999d

Browse files
0.29.14
parallel
1 parent 3972356 commit b5e999d

3 files changed

Lines changed: 13 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.29.13"
10+
version = "0.29.14"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/parallel.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from multiprocessing import Pool, Manager
1+
from multiprocessing import Pool
22
from joblib import Parallel, delayed
33
import numpy as np
44
from typing import Callable, Any, Union
5+
from spotpython.utils.seed import set_all_seeds
56

67

78
def evaluate_row(row: Union[np.ndarray, list], objective_function: Callable[[np.ndarray, Any], Any], fun_control: Any) -> Any:
@@ -21,12 +22,16 @@ def evaluate_row(row: Union[np.ndarray, list], objective_function: Callable[[np.
2122
>>> from spotpython.utils.parallel import evaluate_row
2223
>>> import numpy as np
2324
>>> def sample_objective(row, control):
24-
... return sum(row) + control.get('offset', 0)
25+
... return row + control.get('offset', 0)
2526
>>> row = [1, 2, 3]
2627
>>> fun_control = {'offset': 10}
2728
>>> evaluate_row(row, sample_objective, fun_control)
2829
array([11, 12, 13])
2930
"""
31+
if fun_control is not None:
32+
if "seed" in fun_control:
33+
seed = fun_control["seed"]
34+
set_all_seeds(seed)
3035
return objective_function(np.array([row]), fun_control)
3136

3237

@@ -61,13 +66,11 @@ def parallel_objective_function(objective_function, X, num_cores, fun_control, m
6166
>>> parallel_objective_function(sample_objective, X, num_cores=2, fun_control=fun_control, method='joblib')
6267
array([16, 25, 34])
6368
"""
64-
with Manager() as manager:
65-
shared_control = manager.dict(fun_control)
66-
if method == "mp":
67-
with Pool(processes=num_cores) as pool:
68-
results = pool.starmap(evaluate_row, [(row, objective_function, shared_control) for row in X])
69-
elif method == "joblib":
70-
results = Parallel(n_jobs=num_cores)(delayed(evaluate_row)(row, objective_function, shared_control) for row in X)
69+
if method == "mp":
70+
with Pool(processes=num_cores) as pool:
71+
results = pool.starmap(evaluate_row, [(row, objective_function, fun_control) for row in X])
72+
elif method == "joblib":
73+
results = Parallel(n_jobs=num_cores)(delayed(evaluate_row)(row, objective_function, fun_control) for row in X)
7174

7275
return np.array(results).flatten()
7376

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import pytest
21
import numpy as np
32
from spotpython.utils.parallel import evaluate_row
43

0 commit comments

Comments
 (0)