Skip to content

Commit e45b87a

Browse files
0.27.15
1 parent 2907c4d commit e45b87a

2 files changed

Lines changed: 96 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.27.14"
10+
version = "0.27.15"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/parallel.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from multiprocessing import Pool, Manager
2+
from joblib import Parallel, delayed
3+
import numpy as np
4+
from typing import Callable, Any, Union
5+
6+
def evaluate_row(row: Union[np.ndarray, list], objective_function: Callable[[np.ndarray, Any], Any], fun_control: Any) -> Any:
7+
"""
8+
Evaluates a single row using the provided objective function.
9+
10+
Args:
11+
row (array-like): The input data for the row to be evaluated.
12+
objective_function (callable): A function that computes the objective value.
13+
It should accept a NumPy array and an additional control parameter.
14+
fun_control (any): Additional control parameter to be passed to the objective function.
15+
16+
Returns:
17+
The result of the objective function applied to the row.
18+
19+
Examples:
20+
>>> from spotpython.utils.parallel import evaluate_row
21+
>>> import numpy as np
22+
>>> def sample_objective(row, control):
23+
... return sum(row) + control.get('offset', 0)
24+
>>> row = [1, 2, 3]
25+
>>> fun_control = {'offset': 10}
26+
>>> evaluate_row(row, sample_objective, fun_control)
27+
array([11, 12, 13])
28+
"""
29+
return objective_function(np.array([row]), fun_control)
30+
31+
def parallel_objective_function(objective_function, X, num_cores, fun_control, method)->np.ndarray:
32+
"""
33+
Executes an objective function in parallel using either multiprocessing or joblib.
34+
Args:
35+
objective_function (callable): The function to be evaluated for each row in `X`.
36+
X (iterable): The input data, where each element represents a row to be processed.
37+
num_cores (int): The number of CPU cores to use for parallel processing.
38+
fun_control (dict): A dictionary of shared control parameters for the objective function.
39+
method (str): The parallelization method to use. Options are:
40+
- 'mp': Use Python's multiprocessing module.
41+
- 'joblib': Use the joblib library.
42+
Returns:
43+
numpy.ndarray: A flattened array of results obtained by applying the objective function to each row in `X`.
44+
Raises:
45+
ValueError: If an unsupported `method` is provided.
46+
Examples:
47+
>>> from spotpython.utils.parallel import parallel_objective_function
48+
>>> import numpy as np
49+
>>> def sample_objective(row, control):
50+
... return sum(row) + control.get('offset', 0)
51+
>>> X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
52+
>>> fun_control = {'offset': 10}
53+
>>> parallel_objective_function(sample_objective, X, num_cores=2, fun_control=fun_control, method='mp')
54+
array([16, 25, 34])
55+
>>> parallel_objective_function(sample_objective, X, num_cores=2, fun_control=fun_control, method='joblib')
56+
array([16, 25, 34])
57+
"""
58+
with Manager() as manager:
59+
shared_control = manager.dict(fun_control)
60+
if method=='mp':
61+
with Pool(processes=num_cores) as pool:
62+
results = pool.starmap(evaluate_row, [(row, objective_function, shared_control) for row in X])
63+
elif method=='joblib':
64+
results = Parallel(n_jobs=num_cores)(delayed(evaluate_row)(row, objective_function, shared_control) for row in X)
65+
66+
return np.array(results).flatten()
67+
68+
def make_parallel(obj_func, num_cores, method='mp')->Callable:
69+
"""
70+
Creates a parallelized wrapper function for the given objective function.
71+
Args:
72+
obj_func (callable): The objective function to be parallelized.
73+
It should accept the same arguments as the wrapper function.
74+
num_cores (int): The number of cores to use for parallel processing.
75+
method (str, optional): The parallelization method to use.
76+
Defaults to 'mp' (multiprocessing). Other methods may be supported
77+
depending on the implementation of `parallel_objective_function`.
78+
Returns:
79+
callable: A wrapper function that executes the objective function
80+
in parallel using the specified number of cores and method.
81+
Examples:
82+
>>> from spotpython.utils.parallel import make_parallel
83+
>>> def sample_function(x):
84+
... return x ** 2
85+
...
86+
>>> parallel_func = make_parallel(sample_function, num_cores=4, method='mp')
87+
>>> result = parallel_func([1, 2, 3, 4])
88+
>>> print(result)
89+
[1, 4, 9, 16]
90+
"""
91+
global parallel_wrap
92+
def parallel_wrap(X, fun_control=None):
93+
return parallel_objective_function(obj_func, X, num_cores, fun_control, method)
94+
95+
return parallel_wrap

0 commit comments

Comments
 (0)