|
| 1 | +from multiprocessing import Pool, Manager |
| 2 | +from joblib import Parallel, delayed |
| 3 | +import numpy as np |
| 4 | +from typing import Callable, Any, Union |
| 5 | + |
| 6 | +def evaluate_row(row: Union[np.ndarray, list], objective_function: Callable[[np.ndarray, Any], Any], fun_control: Any) -> Any: |
| 7 | + """ |
| 8 | + Evaluates a single row using the provided objective function. |
| 9 | +
|
| 10 | + Args: |
| 11 | + row (array-like): The input data for the row to be evaluated. |
| 12 | + objective_function (callable): A function that computes the objective value. |
| 13 | + It should accept a NumPy array and an additional control parameter. |
| 14 | + fun_control (any): Additional control parameter to be passed to the objective function. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + The result of the objective function applied to the row. |
| 18 | + |
| 19 | + Examples: |
| 20 | + >>> from spotpython.utils.parallel import evaluate_row |
| 21 | + >>> import numpy as np |
| 22 | + >>> def sample_objective(row, control): |
| 23 | + ... return sum(row) + control.get('offset', 0) |
| 24 | + >>> row = [1, 2, 3] |
| 25 | + >>> fun_control = {'offset': 10} |
| 26 | + >>> evaluate_row(row, sample_objective, fun_control) |
| 27 | + array([11, 12, 13]) |
| 28 | + """ |
| 29 | + return objective_function(np.array([row]), fun_control) |
| 30 | + |
| 31 | +def parallel_objective_function(objective_function, X, num_cores, fun_control, method)->np.ndarray: |
| 32 | + """ |
| 33 | + Executes an objective function in parallel using either multiprocessing or joblib. |
| 34 | + Args: |
| 35 | + objective_function (callable): The function to be evaluated for each row in `X`. |
| 36 | + X (iterable): The input data, where each element represents a row to be processed. |
| 37 | + num_cores (int): The number of CPU cores to use for parallel processing. |
| 38 | + fun_control (dict): A dictionary of shared control parameters for the objective function. |
| 39 | + method (str): The parallelization method to use. Options are: |
| 40 | + - 'mp': Use Python's multiprocessing module. |
| 41 | + - 'joblib': Use the joblib library. |
| 42 | + Returns: |
| 43 | + numpy.ndarray: A flattened array of results obtained by applying the objective function to each row in `X`. |
| 44 | + Raises: |
| 45 | + ValueError: If an unsupported `method` is provided. |
| 46 | + Examples: |
| 47 | + >>> from spotpython.utils.parallel import parallel_objective_function |
| 48 | + >>> import numpy as np |
| 49 | + >>> def sample_objective(row, control): |
| 50 | + ... return sum(row) + control.get('offset', 0) |
| 51 | + >>> X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] |
| 52 | + >>> fun_control = {'offset': 10} |
| 53 | + >>> parallel_objective_function(sample_objective, X, num_cores=2, fun_control=fun_control, method='mp') |
| 54 | + array([16, 25, 34]) |
| 55 | + >>> parallel_objective_function(sample_objective, X, num_cores=2, fun_control=fun_control, method='joblib') |
| 56 | + array([16, 25, 34]) |
| 57 | + """ |
| 58 | + with Manager() as manager: |
| 59 | + shared_control = manager.dict(fun_control) |
| 60 | + if method=='mp': |
| 61 | + with Pool(processes=num_cores) as pool: |
| 62 | + results = pool.starmap(evaluate_row, [(row, objective_function, shared_control) for row in X]) |
| 63 | + elif method=='joblib': |
| 64 | + results = Parallel(n_jobs=num_cores)(delayed(evaluate_row)(row, objective_function, shared_control) for row in X) |
| 65 | + |
| 66 | + return np.array(results).flatten() |
| 67 | + |
| 68 | +def make_parallel(obj_func, num_cores, method='mp')->Callable: |
| 69 | + """ |
| 70 | + Creates a parallelized wrapper function for the given objective function. |
| 71 | + Args: |
| 72 | + obj_func (callable): The objective function to be parallelized. |
| 73 | + It should accept the same arguments as the wrapper function. |
| 74 | + num_cores (int): The number of cores to use for parallel processing. |
| 75 | + method (str, optional): The parallelization method to use. |
| 76 | + Defaults to 'mp' (multiprocessing). Other methods may be supported |
| 77 | + depending on the implementation of `parallel_objective_function`. |
| 78 | + Returns: |
| 79 | + callable: A wrapper function that executes the objective function |
| 80 | + in parallel using the specified number of cores and method. |
| 81 | + Examples: |
| 82 | + >>> from spotpython.utils.parallel import make_parallel |
| 83 | + >>> def sample_function(x): |
| 84 | + ... return x ** 2 |
| 85 | + ... |
| 86 | + >>> parallel_func = make_parallel(sample_function, num_cores=4, method='mp') |
| 87 | + >>> result = parallel_func([1, 2, 3, 4]) |
| 88 | + >>> print(result) |
| 89 | + [1, 4, 9, 16] |
| 90 | + """ |
| 91 | + global parallel_wrap |
| 92 | + def parallel_wrap(X, fun_control=None): |
| 93 | + return parallel_objective_function(obj_func, X, num_cores, fun_control, method) |
| 94 | + |
| 95 | + return parallel_wrap |
0 commit comments