|
1 | 1 | import copy |
2 | 2 | from math import erf |
3 | 3 | import matplotlib.pyplot as plt |
4 | | -from numpy import max, min, var, mean |
| 4 | +from numpy import max, min, var |
5 | 5 | from numpy import sqrt |
6 | 6 | from numpy import exp |
7 | 7 | from numpy import array |
|
28 | 28 | from spotpython.utils.aggregate import aggregate_mean_var |
29 | 29 | import logging |
30 | 30 | import numpy as np |
31 | | -from typing import List, Union, Tuple, Any, Optional |
| 31 | +from typing import List, Union, Tuple, Any, Optional, Dict |
32 | 32 |
|
33 | 33 |
|
34 | 34 | logger = logging.getLogger(__name__) |
@@ -386,31 +386,52 @@ def optimize_model(self) -> Union[List[float], Tuple[float]]: |
386 | 386 | result["x"] (Union[List[float], Tuple[float]]): |
387 | 387 | A list or tuple of optimized parameter values. |
388 | 388 | """ |
389 | | - logger.debug("In optimize_model(): self.de_bounds passed to optimizer: %s", self.de_bounds) |
390 | | - if self.model_optimizer.__name__ == 'dual_annealing': |
391 | | - result = self.model_optimizer(func=self.fun_likelihood, |
392 | | - bounds=self.de_bounds) |
393 | | - elif self.model_optimizer.__name__ == 'differential_evolution': |
394 | | - result = self.model_optimizer(func=self.fun_likelihood, |
395 | | - bounds=self.de_bounds, |
396 | | - maxiter=self.model_fun_evals, |
397 | | - seed=self.seed) |
398 | | - elif self.model_optimizer.__name__ == 'direct': |
399 | | - result = self.model_optimizer(func=self.fun_likelihood, |
400 | | - bounds=self.de_bounds, |
401 | | - # maxfun=self.model_fun_evals, |
402 | | - eps=1e-2) |
403 | | - elif self.model_optimizer.__name__ == 'shgo': |
404 | | - result = self.model_optimizer(func=self.fun_likelihood, |
405 | | - bounds=self.de_bounds) |
406 | | - elif self.model_optimizer.__name__ == 'basinhopping': |
407 | | - result = self.model_optimizer(func=self.fun_likelihood, |
408 | | - x0=mean(self.de_bounds, axis=1)) |
| 389 | + logger.debug("Entering optimize_model.") |
| 390 | + if not callable(self.model_optimizer): |
| 391 | + logger.error("model_optimizer is not callable.") |
| 392 | + raise ValueError("model_optimizer must be a callable function or method.") |
| 393 | + |
| 394 | + optimizer_strategies: Dict[str, Dict] = { |
| 395 | + 'dual_annealing': {'func': self.fun_likelihood, 'bounds': self.de_bounds}, |
| 396 | + 'differential_evolution': { |
| 397 | + 'func': self.fun_likelihood, |
| 398 | + 'bounds': self.de_bounds, |
| 399 | + 'maxiter': self.model_fun_evals, |
| 400 | + 'seed': self.seed |
| 401 | + }, |
| 402 | + 'direct': { |
| 403 | + 'func': self.fun_likelihood, |
| 404 | + 'bounds': self.de_bounds, |
| 405 | + 'eps': 1e-2 |
| 406 | + }, |
| 407 | + 'shgo': {'func': self.fun_likelihood, 'bounds': self.de_bounds}, |
| 408 | + 'basinhopping': {'func': self.fun_likelihood, 'x0': np.mean(self.de_bounds, axis=1)} |
| 409 | + } |
| 410 | + |
| 411 | + optimizer_name = self.model_optimizer.__name__ |
| 412 | + logger.debug("Optimizer selected: %s", optimizer_name) |
| 413 | + |
| 414 | + if optimizer_name not in optimizer_strategies: |
| 415 | + logger.info("Using default options for optimizer: %s", optimizer_name) |
| 416 | + optimizer_args = {'func': self.fun_likelihood, 'bounds': self.de_bounds} |
409 | 417 | else: |
410 | | - result = self.model_optimizer(func=self.fun_likelihood, bounds=self.de_bounds) |
411 | | - logger.debug("In optimize_model(): result: %s", result) |
412 | | - logger.debug('In optimize_model(): returned result["x"]: %s', result["x"]) |
413 | | - return result["x"] |
| 418 | + optimizer_args = optimizer_strategies[optimizer_name] |
| 419 | + |
| 420 | + logger.debug("Parameters for optimization: %s", optimizer_args) |
| 421 | + |
| 422 | + try: |
| 423 | + result = self.model_optimizer(**optimizer_args) |
| 424 | + except Exception as e: |
| 425 | + logger.error("Optimization failed due to error: %s", str(e)) |
| 426 | + raise |
| 427 | + |
| 428 | + if "x" not in result: |
| 429 | + logger.error("Optimization result does not contain 'x'. Result: %s", result) |
| 430 | + raise ValueError("The optimization result does not contain the expected 'x' key.") |
| 431 | + logger.debug("Optimization result: %s", result) |
| 432 | + optimized_parameters = list(result["x"]) |
| 433 | + logger.debug("Extracted optimized parameters: %s", optimized_parameters) |
| 434 | + return optimized_parameters |
414 | 435 |
|
415 | 436 | def update_log(self) -> None: |
416 | 437 | """ |
|
0 commit comments