From db4f4b0d61d51e8d796210fb1b5f213fca8ccaa7 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 9 Apr 2026 12:12:47 +0200 Subject: [PATCH 1/9] initial implementation --- src/easyscience/base_classes/__init__.py | 11 ++- .../fitting/minimizers/minimizer_bumps.py | 39 ++++++++- .../fitting/minimizers/minimizer_dfo.py | 15 +++- .../fitting/minimizers/minimizer_lmfit.py | 2 + src/easyscience/fitting/minimizers/utils.py | 4 + src/easyscience/fitting/multi_fitter.py | 2 + tests/integration/fitting/test_fitter.py | 50 ++++++++++-- .../integration/fitting/test_multi_fitter.py | 40 +++++++++ .../minimizers/test_minimizer_bumps.py | 18 ++++- .../fitting/minimizers/test_minimizer_dfo.py | 81 ++++++++++++++++++- .../minimizers/test_minimizer_lmfit.py | 19 +++++ 11 files changed, 256 insertions(+), 25 deletions(-) diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index b5dc0418..b2a04b24 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -3,9 +3,18 @@ from .based_base import BasedBase from .collection_base import CollectionBase +from .collection_base_easylist import CollectionBaseEasyList from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase from .obj_base import ObjBase -__all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList] +__all__ = [ + BasedBase, + CollectionBase, + CollectionBaseEasyList, + ObjBase, + ModelBase, + NewBase, + EasyList, +] diff --git a/src/easyscience/fitting/minimizers/minimizer_bumps.py b/src/easyscience/fitting/minimizers/minimizer_bumps.py index 37f8873e..68098ddc 100644 --- a/src/easyscience/fitting/minimizers/minimizer_bumps.py +++ b/src/easyscience/fitting/minimizers/minimizer_bumps.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: BSD-3-Clause import copy +import functools +import inspect from typing import Callable from typing import List from typing import Optional @@ -28,6 +30,19 @@ FIT_AVAILABLE_IDS_FILTERED.remove('pt') +class _EvalCounter: + def __init__(self, fn: Callable): + self._fn = fn + self.count = 0 + self.__name__ = getattr(fn, '__name__', self.__class__.__name__) + self.__signature__ = inspect.signature(fn) + functools.update_wrapper(self, fn) + + def __call__(self, *args, **kwargs): + self.count += 1 + return self._fn(*args, **kwargs) + + class Bumps(MinimizerBase): """ This is a wrapper to Bumps: https://bumps.readthedocs.io/ @@ -54,6 +69,7 @@ def __init__( """ super().__init__(obj=obj, fit_function=fit_function, minimizer_enum=minimizer_enum) self._p_0 = {} + self._eval_counter: Optional[_EvalCounter] = None @staticmethod def all_methods() -> List[str]: @@ -148,7 +164,7 @@ def fit( try: model_results = bumps_fit(problem, **method_dict, **minimizer_kwargs, **kwargs) self._set_parameter_fit_result(model_results, stack_status, problem._parameters) - results = self._gen_fit_results(model_results) + results = self._gen_fit_results(model_results, max_evaluations=max_evaluations) except Exception as e: for key in self._cached_pars.keys(): self._cached_pars[key].value = self._cached_pars_vals[key][0] @@ -200,7 +216,8 @@ def _make_model(self, parameters: Optional[List[BumpsParameter]] = None) -> Call :return: Callable to make a bumps Curve model :rtype: Callable """ - fit_func = self._generate_fit_function() + fit_func = _EvalCounter(self._generate_fit_function()) + self._eval_counter = fit_func def _outer(obj): def _make_func(x, y, weights): @@ -249,7 +266,12 @@ def _set_parameter_fit_result( if stack_status: global_object.stack.endMacro() - def _gen_fit_results(self, fit_results, **kwargs) -> FitResults: + def _gen_fit_results( + self, + fit_results, + max_evaluations: Optional[int] = None, + **kwargs, + ) -> FitResults: """Convert fit results into the unified `FitResults` format. :param fit_result: Fit object which contains info on the fit @@ -261,7 +283,10 @@ def _gen_fit_results(self, fit_results, **kwargs) -> FitResults: for name, value in kwargs.items(): if getattr(results, name, False): setattr(results, name, value) - results.success = fit_results.success + nit = getattr(fit_results, 'nit', 0) + stopped_on_budget = max_evaluations is not None and nit >= max_evaluations - 1 + + results.success = fit_results.success and not stopped_on_budget pars = self._cached_pars item = {} for index, name in enumerate(self._cached_model.pars.keys()): @@ -275,6 +300,12 @@ def _gen_fit_results(self, fit_results, **kwargs) -> FitResults: results.y_obs = self._cached_model.y results.y_calc = self.evaluate(results.x, minimizer_parameters=results.p) results.y_err = self._cached_model.dy + results.n_evaluations = None if self._eval_counter is None else self._eval_counter.count + results.message = ( + f'Fit stopped: reached maximum evaluations ({max_evaluations})' + if stopped_on_budget + else '' + ) # results.residual = results.y_obs - results.y_calc # results.goodness_of_fit = np.sum(results.residual**2) results.minimizer_engine = self.__class__ diff --git a/src/easyscience/fitting/minimizers/minimizer_dfo.py b/src/easyscience/fitting/minimizers/minimizer_dfo.py index a480c823..74c2d8f9 100644 --- a/src/easyscience/fitting/minimizers/minimizer_dfo.py +++ b/src/easyscience/fitting/minimizers/minimizer_dfo.py @@ -122,6 +122,10 @@ def fit( model_results = self._dfo_fit(self._cached_pars, model, **kwargs) self._set_parameter_fit_result(model_results, stack_status) results = self._gen_fit_results(model_results, weights) + except FitError: + for key in self._cached_pars.keys(): + self._cached_pars[key].value = self._cached_pars_vals[key][0] + raise except Exception as e: for key in self._cached_pars.keys(): self._cached_pars[key].value = self._cached_pars_vals[key][0] @@ -208,7 +212,7 @@ def _gen_fit_results(self, fit_results, weights, **kwargs) -> FitResults: for name, value in kwargs.items(): if getattr(results, name, False): setattr(results, name, value) - results.success = not bool(fit_results.flag) + results.success = fit_results.flag == fit_results.EXIT_SUCCESS pars = {} for p_name, par in self._cached_pars.items(): @@ -220,11 +224,14 @@ def _gen_fit_results(self, fit_results, weights, **kwargs) -> FitResults: results.y_obs = self._cached_model.y results.y_calc = self.evaluate(results.x, minimizer_parameters=results.p) results.y_err = weights + results.n_evaluations = int(fit_results.nf) + results.message = str(fit_results.msg) # results.residual = results.y_obs - results.y_calc # results.goodness_of_fit = fit_results.f results.minimizer_engine = self.__class__ results.fit_args = None + results.engine_result = fit_results # results.check_sanity() return results @@ -258,10 +265,10 @@ def _dfo_fit( results = dfols.solve(model, pars_values, bounds=bounds, **kwargs) - if 'Success' not in results.msg: - raise FitError(f'Fit failed with message: {results.msg}') + if results.flag in {results.EXIT_SUCCESS, results.EXIT_MAXFUN_WARNING}: + return results - return results + raise FitError(f'Fit failed with message: {results.msg}') @staticmethod def _prepare_kwargs( diff --git a/src/easyscience/fitting/minimizers/minimizer_lmfit.py b/src/easyscience/fitting/minimizers/minimizer_lmfit.py index 4a8104b2..237ee00a 100644 --- a/src/easyscience/fitting/minimizers/minimizer_lmfit.py +++ b/src/easyscience/fitting/minimizers/minimizer_lmfit.py @@ -298,6 +298,8 @@ def _gen_fit_results(self, fit_results: ModelResult, **kwargs) -> FitResults: # results.goodness_of_fit = fit_results.chisqr results.y_calc = fit_results.best_fit results.y_err = 1 / fit_results.weights + results.n_evaluations = fit_results.nfev + results.message = fit_results.message results.minimizer_engine = self.__class__ results.fit_args = None diff --git a/src/easyscience/fitting/minimizers/utils.py b/src/easyscience/fitting/minimizers/utils.py index 76449a17..e3633365 100644 --- a/src/easyscience/fitting/minimizers/utils.py +++ b/src/easyscience/fitting/minimizers/utils.py @@ -20,6 +20,8 @@ class FitResults: 'y_obs', 'y_calc', 'y_err', + 'n_evaluations', + 'message', 'engine_result', 'total_results', ] @@ -35,6 +37,8 @@ def __init__(self): self.y_obs = np.ndarray([]) self.y_calc = np.ndarray([]) self.y_err = np.ndarray([]) + self.n_evaluations = None + self.message = '' self.engine_result = None self.total_results = None diff --git a/src/easyscience/fitting/multi_fitter.py b/src/easyscience/fitting/multi_fitter.py index 94e715c6..6f3b9938 100644 --- a/src/easyscience/fitting/multi_fitter.py +++ b/src/easyscience/fitting/multi_fitter.py @@ -127,6 +127,8 @@ def _post_compute_reshaping( current_results.minimizer_engine = fit_result_obj.minimizer_engine current_results.p = fit_result_obj.p current_results.p0 = fit_result_obj.p0 + current_results.n_evaluations = fit_result_obj.n_evaluations + current_results.message = fit_result_obj.message current_results.x = this_x current_results.y_obs = y[idx] current_results.y_calc = np.reshape( diff --git a/tests/integration/fitting/test_fitter.py b/tests/integration/fitting/test_fitter.py index c6d130fd..9b956cb1 100644 --- a/tests/integration/fitting/test_fitter.py +++ b/tests/integration/fitting/test_fitter.py @@ -207,14 +207,48 @@ def test_basic_max_evaluations(fit_engine): except AttributeError: pytest.skip(msg=f'{fit_engine} is not installed') f.max_evaluations = 3 - try: - result = f.fit(x=x, y=y, weights=weights) - # Result should not be the same as the reference - assert sp_sin.phase.value != pytest.approx(ref_sin.phase.value, rel=1e-3) - assert sp_sin.offset.value != pytest.approx(ref_sin.offset.value, rel=1e-3) - except FitError as e: - # DFO throws a different error - assert 'Objective has been called MAXFUN times' in str(e) + result = f.fit(x=x, y=y, weights=weights) + # Result should not be the same as the reference + assert sp_sin.phase.value != pytest.approx(ref_sin.phase.value, rel=1e-3) + assert sp_sin.offset.value != pytest.approx(ref_sin.offset.value, rel=1e-3) + + +@pytest.mark.fast +@pytest.mark.parametrize( + 'fit_engine', + [ + None, + AvailableMinimizers.LMFit, + AvailableMinimizers.Bumps, + AvailableMinimizers.DFO, + ], +) +def test_max_evaluations_populates_fit_result_fields(fit_engine): + """With a tight budget every engine must return success=False, n_evaluations>0, non-empty message.""" + ref_sin = AbsSin(0.2, np.pi) + sp_sin = AbsSin(0.354, 3.05) + + x = np.linspace(0, 5, 200) + weights = np.ones_like(x) + y = ref_sin(x) + + sp_sin.offset.fixed = False + sp_sin.phase.fixed = False + + f = Fitter(sp_sin, sp_sin) + if fit_engine is not None: + try: + f.switch_minimizer(fit_engine) + except AttributeError: + pytest.skip(msg=f'{fit_engine} is not installed') + f.max_evaluations = 3 + result = f.fit(x=x, y=y, weights=weights) + + assert result.success is False + assert result.n_evaluations is not None + assert result.n_evaluations > 0 + assert isinstance(result.message, str) + assert len(result.message) > 0 @pytest.mark.fast diff --git a/tests/integration/fitting/test_multi_fitter.py b/tests/integration/fitting/test_multi_fitter.py index 1cc5b395..25cc9adf 100644 --- a/tests/integration/fitting/test_multi_fitter.py +++ b/tests/integration/fitting/test_multi_fitter.py @@ -103,6 +103,46 @@ def test_multi_fit(fit_engine): assert result.residual == pytest.approx(F_real[idx](X[idx]) - F_ref[idx](X[idx]), abs=1e-2) +@pytest.mark.parametrize('fit_engine', [None, 'LMFit', 'Bumps', 'DFO']) +def test_multi_fit_propagates_n_evaluations_and_message(fit_engine): + """Verify that n_evaluations and message are copied into each per-dataset result.""" + ref_sin_1 = AbsSin(0.2, np.pi) + sp_sin_1 = AbsSin(0.354, 3.05) + ref_sin_2 = AbsSin(np.pi * 0.45, 0.45 * np.pi * 0.5) + sp_sin_2 = AbsSin(1, 0.5) + + ref_sin_2.offset.make_dependent_on( + dependency_expression='ref_sin1', dependency_map={'ref_sin1': ref_sin_1.offset} + ) + sp_sin_2.offset.make_dependent_on( + dependency_expression='sp_sin1', dependency_map={'sp_sin1': sp_sin_1.offset} + ) + + x1 = np.linspace(0, 5, 200) + y1 = ref_sin_1(x1) + x2 = np.copy(x1) + y2 = ref_sin_2(x2) + weights = np.ones_like(x1) + + sp_sin_1.offset.fixed = False + sp_sin_1.phase.fixed = False + sp_sin_2.phase.fixed = False + + f = MultiFitter([sp_sin_1, sp_sin_2], [sp_sin_1, sp_sin_2]) + if fit_engine is not None: + try: + f.switch_minimizer(fit_engine) + except AttributeError: + pytest.skip(msg=f'{fit_engine} is not installed') + + results = f.fit(x=[x1, x2], y=[y1, y2], weights=[weights, weights]) + for result in results: + assert result.n_evaluations is not None + assert isinstance(result.n_evaluations, int) + assert result.n_evaluations > 0 + assert isinstance(result.message, str) + + @pytest.mark.parametrize('fit_engine', [None, 'LMFit', 'Bumps', 'DFO']) def test_multi_fit2(fit_engine): ref_sin_1 = AbsSin(0.2, np.pi) diff --git a/tests/unit/fitting/minimizers/test_minimizer_bumps.py b/tests/unit/fitting/minimizers/test_minimizer_bumps.py index ba86b4d0..3d01fddc 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_bumps.py +++ b/tests/unit/fitting/minimizers/test_minimizer_bumps.py @@ -89,7 +89,7 @@ def fake_set_parameter_fit_result(fit_result, stack_status, par_list): assert result == 'gen_fit_results' mock_bumps_fit.assert_called_once_with(mock_FitProblem_instance, method='amoeba') minimizer._make_model.assert_called_once_with(parameters=None) - minimizer._gen_fit_results.assert_called_once_with('fit') + minimizer._gen_fit_results.assert_called_once_with('fit', max_evaluations=None) mock_model_function.assert_called_once_with(1.0, 2.0, 1) mock_FitProblem.assert_called_once_with(mock_model) @@ -127,10 +127,13 @@ def test_make_model(self, minimizer: Bumps, monkeypatch) -> None: curve_for_model = model( x=np.array([1, 2]), y=np.array([10, 20]), weights=np.array([100, 200]) ) + wrapped_fit_function = mock_Curve.call_args[0][0] + wrapped_fit_function(np.array([1, 2]), pmock_parm_1=3) # Expect minimizer._generate_fit_function.assert_called_once_with() - assert mock_Curve.call_args[0][0] == mock_fit_function + assert minimizer._eval_counter is wrapped_fit_function + assert minimizer._eval_counter.count == 1 assert all(mock_Curve.call_args[0][1] == np.array([1, 2])) assert all(mock_Curve.call_args[0][2] == np.array([10, 20])) assert curve_for_model == 'curve' @@ -178,6 +181,7 @@ def test_gen_fit_results(self, minimizer: Bumps, monkeypatch): mock_fit_result = MagicMock() mock_fit_result.success = True + mock_fit_result.nit = 2 # nit >= max_evaluations - 1 → budget exhausted mock_cached_model = MagicMock() mock_cached_model.x = 'x' @@ -193,28 +197,34 @@ def test_gen_fit_results(self, minimizer: Bumps, monkeypatch): minimizer._cached_pars = {'par_1': mock_cached_par_1, 'par_2': mock_cached_par_2} minimizer._p_0 = 'p_0' + minimizer._eval_counter = MagicMock(count=7) minimizer.evaluate = MagicMock(return_value='evaluate') # Then domain_fit_results = minimizer._gen_fit_results( - mock_fit_result, **{'kwargs_set_key': 'kwargs_set_val'} + mock_fit_result, + max_evaluations=3, + **{'kwargs_set_key': 'kwargs_set_val'}, ) # Expect assert domain_fit_results == mock_domain_fit_results assert domain_fit_results.kwargs_set_key == 'kwargs_set_val' - assert domain_fit_results.success == True + assert domain_fit_results.success == False assert domain_fit_results.y_obs == 'y' assert domain_fit_results.x == 'x' assert domain_fit_results.p == {'ppar_1': 'par_value_1', 'ppar_2': 'par_value_2'} assert domain_fit_results.p0 == 'p_0' assert domain_fit_results.y_calc == 'evaluate' assert domain_fit_results.y_err == 'dy' + assert domain_fit_results.n_evaluations == 7 + assert domain_fit_results.message == 'Fit stopped: reached maximum evaluations (3)' assert ( str(domain_fit_results.minimizer_engine) == "" ) assert domain_fit_results.fit_args is None + assert domain_fit_results.engine_result == mock_fit_result minimizer.evaluate.assert_called_once_with( 'x', minimizer_parameters={'ppar_1': 'par_value_1', 'ppar_2': 'par_value_2'} ) diff --git a/tests/unit/fitting/minimizers/test_minimizer_dfo.py b/tests/unit/fitting/minimizers/test_minimizer_dfo.py index e1d5eeef..c9a62fb0 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_dfo.py +++ b/tests/unit/fitting/minimizers/test_minimizer_dfo.py @@ -177,7 +177,10 @@ def test_gen_fit_results(self, minimizer: DFO, monkeypatch): ) mock_fit_result = MagicMock() - mock_fit_result.flag = False + mock_fit_result.EXIT_SUCCESS = 0 + mock_fit_result.flag = 0 + mock_fit_result.nf = 12 + mock_fit_result.msg = 'Maximum function evaluations reached' mock_cached_model = MagicMock() mock_cached_model.x = 'x' @@ -214,15 +217,76 @@ def test_gen_fit_results(self, minimizer: DFO, monkeypatch): assert domain_fit_results.p0 == 'p_0' assert domain_fit_results.y_calc == 'evaluate' assert domain_fit_results.y_err == 'weights' + assert domain_fit_results.n_evaluations == 12 + assert domain_fit_results.message == 'Maximum function evaluations reached' + assert domain_fit_results.engine_result == mock_fit_result assert ( str(domain_fit_results.minimizer_engine) == "" ) - assert domain_fit_results.fit_args is None - minimizer.evaluate.assert_called_once_with( - 'x', minimizer_parameters={'ppar_1': 'par_value_1', 'ppar_2': 'par_value_2'} + + def test_gen_fit_results_maxfun_warning_sets_success_false(self, minimizer: DFO, monkeypatch): + """When DFO returns EXIT_MAXFUN_WARNING, _gen_fit_results must set success=False.""" + mock_domain_fit_results = MagicMock() + mock_FitResults = MagicMock(return_value=mock_domain_fit_results) + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_dfo, 'FitResults', mock_FitResults ) + mock_fit_result = MagicMock() + mock_fit_result.EXIT_SUCCESS = 0 + mock_fit_result.EXIT_MAXFUN_WARNING = 1 + mock_fit_result.flag = 1 # MAXFUN_WARNING + mock_fit_result.nf = 50 + mock_fit_result.msg = 'Objective has been called MAXFUN times' + + mock_cached_model = MagicMock() + mock_cached_model.x = 'x' + mock_cached_model.y = 'y' + minimizer._cached_model = mock_cached_model + + mock_cached_par_1 = MagicMock() + mock_cached_par_1.value = 'v1' + minimizer._cached_pars = {'par_1': mock_cached_par_1} + minimizer._p_0 = 'p_0' + minimizer.evaluate = MagicMock(return_value='evaluate') + + domain_fit_results = minimizer._gen_fit_results(mock_fit_result, 'weights') + + assert domain_fit_results.success == False + assert domain_fit_results.n_evaluations == 50 + assert domain_fit_results.message == 'Objective has been called MAXFUN times' + + def test_dfo_fit_allows_maxfun_warning(self, minimizer: DFO, monkeypatch) -> None: + mock_result = MagicMock() + mock_result.EXIT_SUCCESS = 0 + mock_result.EXIT_MAXFUN_WARNING = 1 + mock_result.flag = 1 + + mock_solve = MagicMock(return_value=mock_result) + monkeypatch.setattr(easyscience.fitting.minimizers.minimizer_dfo.dfols, 'solve', mock_solve) + + parameter = MagicMock(min=0.0, max=1.0, value=0.5) + + result = minimizer._dfo_fit({'par': parameter}, MagicMock()) + + assert result == mock_result + + def test_dfo_fit_raises_for_non_maxfun_failure(self, minimizer: DFO, monkeypatch) -> None: + mock_result = MagicMock() + mock_result.EXIT_SUCCESS = 0 + mock_result.EXIT_MAXFUN_WARNING = 1 + mock_result.flag = 4 + mock_result.msg = 'linear algebra error' + + mock_solve = MagicMock(return_value=mock_result) + monkeypatch.setattr(easyscience.fitting.minimizers.minimizer_dfo.dfols, 'solve', mock_solve) + + parameter = MagicMock(min=0.0, max=1.0, value=0.5) + + with pytest.raises(FitError, match='linear algebra error'): + minimizer._dfo_fit({'par': parameter}, MagicMock()) + def test_dfo_fit(self, minimizer: DFO, monkeypatch): # When mock_parm_1 = MagicMock(Parameter) @@ -239,6 +303,9 @@ def test_dfo_fit(self, minimizer: DFO, monkeypatch): mock_dfols = MagicMock() mock_results = MagicMock() + mock_results.EXIT_SUCCESS = 0 + mock_results.EXIT_MAXFUN_WARNING = 1 + mock_results.flag = 0 mock_results.msg = 'Success' mock_dfols.solve = MagicMock(return_value=mock_results) @@ -272,6 +339,9 @@ def test_dfo_fit_no_scaling(self, minimizer: DFO, monkeypatch): mock_dfols = MagicMock() mock_results = MagicMock() + mock_results.EXIT_SUCCESS = 0 + mock_results.EXIT_MAXFUN_WARNING = 1 + mock_results.flag = 0 mock_results.msg = 'Success' mock_dfols.solve = MagicMock(return_value=mock_results) @@ -297,6 +367,9 @@ def test_dfo_fit_exception(self, minimizer: DFO, monkeypatch): mock_dfols = MagicMock() mock_results = MagicMock() + mock_results.EXIT_SUCCESS = 0 + mock_results.EXIT_MAXFUN_WARNING = 1 + mock_results.flag = 3 mock_results.msg = 'Failed' mock_dfols.solve = MagicMock(return_value=mock_results) diff --git a/tests/unit/fitting/minimizers/test_minimizer_lmfit.py b/tests/unit/fitting/minimizers/test_minimizer_lmfit.py index ac280873..0b2065a1 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_lmfit.py +++ b/tests/unit/fitting/minimizers/test_minimizer_lmfit.py @@ -209,6 +209,25 @@ def test_fit_exception(self, minimizer: LMFit) -> None: with pytest.raises(FitError): minimizer.fit(x=1.0, y=2.0, weights=1) + def test_gen_fit_results_populates_evaluation_metadata(self, minimizer: LMFit) -> None: + fit_results = MagicMock() + fit_results.success = False + fit_results.data = 'data' + fit_results.userkws = {'x': 'x'} + fit_results.values = {'p1': 1.0} + fit_results.init_values = {'p1': 0.5} + fit_results.best_fit = 'best_fit' + fit_results.weights = 2 + fit_results.nfev = 9 + fit_results.message = 'max evaluations reached' + + result = minimizer._gen_fit_results(fit_results) + + assert result.success is False + assert result.n_evaluations == 9 + assert result.message == 'max evaluations reached' + assert result.engine_result == fit_results + def test_convert_to_pars_obj(self, minimizer: LMFit, monkeypatch) -> None: # When minimizer._object = MagicMock() From 2e6398537554fdfcb3eca10c8554244600b2159c Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 9 Apr 2026 14:02:38 +0200 Subject: [PATCH 2/9] ruff --- src/easyscience/base_classes/__init__.py | 14 +++++++------- .../unit/fitting/minimizers/test_minimizer_dfo.py | 8 ++++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index b2a04b24..ce5698e7 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -10,11 +10,11 @@ from .obj_base import ObjBase __all__ = [ - BasedBase, - CollectionBase, - CollectionBaseEasyList, - ObjBase, - ModelBase, - NewBase, - EasyList, + BasedBase, + CollectionBase, + CollectionBaseEasyList, + ObjBase, + ModelBase, + NewBase, + EasyList, ] diff --git a/tests/unit/fitting/minimizers/test_minimizer_dfo.py b/tests/unit/fitting/minimizers/test_minimizer_dfo.py index c9a62fb0..f4c71be9 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_dfo.py +++ b/tests/unit/fitting/minimizers/test_minimizer_dfo.py @@ -264,7 +264,9 @@ def test_dfo_fit_allows_maxfun_warning(self, minimizer: DFO, monkeypatch) -> Non mock_result.flag = 1 mock_solve = MagicMock(return_value=mock_result) - monkeypatch.setattr(easyscience.fitting.minimizers.minimizer_dfo.dfols, 'solve', mock_solve) + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_dfo.dfols, 'solve', mock_solve + ) parameter = MagicMock(min=0.0, max=1.0, value=0.5) @@ -280,7 +282,9 @@ def test_dfo_fit_raises_for_non_maxfun_failure(self, minimizer: DFO, monkeypatch mock_result.msg = 'linear algebra error' mock_solve = MagicMock(return_value=mock_result) - monkeypatch.setattr(easyscience.fitting.minimizers.minimizer_dfo.dfols, 'solve', mock_solve) + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_dfo.dfols, 'solve', mock_solve + ) parameter = MagicMock(min=0.0, max=1.0, value=0.5) From f88120dc678678f4e0c6f2681ad52a1634ef4ce3 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 9 Apr 2026 14:15:40 +0200 Subject: [PATCH 3/9] added missing collection base files --- .../base_classes/collection_base_easylist.py | 153 ++++++++++++++++ .../test_collection_base_easylist.py | 172 ++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 src/easyscience/base_classes/collection_base_easylist.py create mode 100644 tests/unit_tests/base_classes/test_collection_base_easylist.py diff --git a/src/easyscience/base_classes/collection_base_easylist.py b/src/easyscience/base_classes/collection_base_easylist.py new file mode 100644 index 00000000..f9a9ee4d --- /dev/null +++ b/src/easyscience/base_classes/collection_base_easylist.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import MutableSequence +from typing import Any + +from easyscience.io.serializer_base import SerializerBase + +from .model_base import ModelBase + + +class CollectionBaseEasyList(ModelBase, MutableSequence[ModelBase]): + """Compatibility model-aware collection with list semantics. + + This preserves the older `CollectionBaseEasyList` API expected by legacy + tests while keeping the newer `CollectionBase` implementation unchanged. + """ + + def __init__( + self, + name: str, + *args: ModelBase | list[ModelBase], + unique_name: str | None = None, + display_name: str | None = None, + ): + if display_name is None: + display_name = name + super().__init__(unique_name=unique_name, display_name=display_name) + self._name = name + self._data: list[ModelBase] = [] + + for item in args: + if isinstance(item, list): + for nested_item in item: + self.append(nested_item) + else: + self.append(item) + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, new_name: str) -> None: + self._name = new_name + self.display_name = new_name + + def _get_key(self, obj: ModelBase) -> str: + return obj.unique_name + + def _is_duplicate(self, value: ModelBase) -> bool: + for existing in self._data: + try: + if self._get_key(existing) == self._get_key(value): + return True + except AttributeError: + if existing is value: + return True + return False + + def get_all_variables(self): + variables = [] + for item in self._data: + variables.extend(item.get_all_variables()) + return variables + + def get_all_parameters(self): + parameters = [] + for item in self._data: + parameters.extend(item.get_all_parameters()) + return parameters + + def get_free_parameters(self): + parameters = [] + for item in self._data: + parameters.extend(item.get_free_parameters()) + return parameters + + def __getitem__(self, idx: int | slice | str): + if isinstance(idx, int): + return self._data[idx] + if isinstance(idx, slice): + return self.__class__(self.name, self._data[idx]) + if isinstance(idx, str): + unique_name_match = next( + (item for item in self._data if item.unique_name == idx), None + ) + if unique_name_match is not None: + return unique_name_match + + matches = [item for item in self._data if item.name == idx] + if not matches: + raise IndexError('Given index does not exist') + if len(matches) == 1: + return matches[0] + return self.__class__(self.name, matches) + raise TypeError('Index must be an int, slice, or str') + + def __setitem__(self, idx: int | slice, value): + if isinstance(idx, int): + if not isinstance(value, ModelBase): + raise AttributeError('CollectionBaseEasyList can only contain model objects') + self._data[idx] = value + return + if isinstance(idx, slice): + replacement = list(value) + if not all(isinstance(item, ModelBase) for item in replacement): + raise AttributeError('CollectionBaseEasyList can only contain model objects') + self._data[idx] = replacement + return + raise TypeError('Index must be an int or slice') + + def __delitem__(self, idx: int | slice) -> None: + del self._data[idx] + + def __len__(self) -> int: + return len(self._data) + + def insert(self, index: int, value: ModelBase) -> None: + if not isinstance(value, ModelBase): + raise AttributeError('CollectionBaseEasyList can only contain model objects') + if self._is_duplicate(value): + return + self._data.insert(index, value) + + def sort(self, key=None, reverse: bool = False) -> None: + self._data.sort(key=key, reverse=reverse) + + def _convert_to_dict(self, in_dict, encoder, skip=None, **kwargs) -> dict: + if skip is None: + skip = [] + in_dict['data'] = [ + encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data + ] + return in_dict + + @classmethod + def from_dict(cls, obj_dict: dict[str, Any]) -> CollectionBaseEasyList: + if not SerializerBase._is_serialized_easyscience_object(obj_dict): + raise ValueError( + 'Input must be a dictionary representing an EasyScience CollectionBaseEasyList object.' + ) + if obj_dict['@class'] != cls.__name__: + raise ValueError( + f'Class name in dictionary does not match the expected class: {cls.__name__}.' + ) + + kwargs = SerializerBase.deserialize_dict(obj_dict) + data = kwargs.pop('data', []) + name = kwargs.pop('name') + return cls(name, data, **kwargs) diff --git a/tests/unit_tests/base_classes/test_collection_base_easylist.py b/tests/unit_tests/base_classes/test_collection_base_easylist.py new file mode 100644 index 00000000..107e96e7 --- /dev/null +++ b/tests/unit_tests/base_classes/test_collection_base_easylist.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project Parameter: + return Parameter(name, cast(Any, value), **kwargs) + + +class DummyModel(ModelBase): + def __init__( + self, + name: str, + value: Any, + unique_name: str | None = None, + display_name: str | None = None, + ): + if display_name is None: + display_name = name + super().__init__(unique_name=unique_name, display_name=display_name) + self._name = name + self._value = ( + value if isinstance(value, Parameter) else make_parameter(f'{name}_value', value) + ) + + @property + def name(self) -> str: + return self._name + + @name.setter + def name(self, new_name: str) -> None: + self._name = new_name + self.display_name = new_name + + @property + def value(self) -> Parameter: + return self._value + + @value.setter + def value(self, new_value: float) -> None: + self._value.value = new_value + + +@pytest.fixture(autouse=True) +def clear(): + global_object.map._clear() + + +def test_collection_base_getitem_supports_unique_name_and_name_fallback(): + m1 = DummyModel('dup', make_parameter('p1', 1.0), unique_name='m1') + m2 = DummyModel('dup', make_parameter('p2', 2.0), unique_name='m2') + + collection = CollectionBase('test', m1, m2) + + assert collection['m1'] is m1 + same_name = collection['dup'] + assert isinstance(same_name, CollectionBase) + assert len(same_name) == 2 + assert list(same_name) == [m1, m2] + + +def test_collection_base_get_all_variables_recurses_into_models(): + p1 = make_parameter('p1', 1.0) + p2 = make_parameter('p2', 2.0) + model1 = DummyModel('m1', p1) + model2 = DummyModel('m2', p2) + + collection = CollectionBase('test', model1, model2) + + variables = collection.get_all_variables() + + assert p1 in variables + assert p2 in variables + + +def test_collection_base_get_parameters_recurses_into_nested_objects(): + nested = CollectionBase( + 'nested', + DummyModel('m1', make_parameter('p1', 1.0)), + DummyModel('m2', make_parameter('p2', 2.0, fixed=True)), + ) + model = DummyModel('model', make_parameter('p3', 3.0)) + collection = CollectionBase('test', nested, model) + + parameters = collection.get_all_parameters() + free_parameters = collection.get_free_parameters() + + assert [parameter.name for parameter in parameters] == ['p1', 'p2', 'p3'] + assert [parameter.name for parameter in free_parameters] == ['p1', 'p3'] + + +def test_collection_base_rejects_non_model_items(): + with pytest.raises(AttributeError, match='model objects'): + CollectionBase('test', make_parameter('p1', 1.0)) + + +def test_collection_base_rejects_basedbase_objects(): + with pytest.raises(AttributeError, match='model objects'): + from easyscience import ObjBase + + CollectionBase('test', ObjBase('legacy', p=make_parameter('p1', 1.0))) + + +def test_collection_base_accepts_nested_collections(): + inner = CollectionBase('inner', DummyModel('m', make_parameter('p', 1.0))) + outer = CollectionBase('outer', inner) + + assert len(outer) == 1 + assert outer[0] is inner + + +def test_collection_base_prevents_duplicates_by_unique_name(): + m1 = DummyModel('m1', make_parameter('p1', 1.0), unique_name='m1') + m2 = DummyModel('m2', make_parameter('p2', 2.0), unique_name='m2') + + collection = CollectionBase('test', m1) + collection._get_key = lambda _obj: 'same-key' + collection.append(m2) + + assert len(collection) == 1 + assert collection[0] is m1 + + +def test_collection_base_duplicate_identity_fallback_when_key_unavailable(): + m1 = DummyModel('m1', make_parameter('p1', 1.0), unique_name='m1') + collection = CollectionBase('test', m1) + + def _broken_get_key(_obj): + raise AttributeError('key unavailable') + + collection._get_key = _broken_get_key + collection.append(m1) + + assert len(collection) == 1 + + +def test_collection_base_to_dict_round_trip_preserves_name_and_data(): + m1 = DummyModel('m1', make_parameter('p1', 1.0)) + m2 = DummyModel('m2', make_parameter('p2', 2.0)) + collection = CollectionBase('test', m1, m2) + + encoded = collection.to_dict() + decoded = CollectionBase.from_dict(encoded) + + assert decoded.name == 'test' + assert [item.name for item in decoded] == ['m1', 'm2'] + + +def test_collection_base_sort_accepts_key(): + m1 = DummyModel('m1', make_parameter('p1', 3.0)) + m2 = DummyModel('m2', make_parameter('p2', 1.0)) + m3 = DummyModel('m3', make_parameter('p3', 2.0)) + collection = CollectionBase('test', m1, m2, m3) + + collection.sort(key=lambda item: item.value.value) + + assert [item.name for item in collection] == ['m2', 'm3', 'm1'] + + +def test_collection_base_isinstance_model_base(): + collection = CollectionBase('test') + assert isinstance(collection, ModelBase) From 3dbc23a93e4e1cdb46a0dd81d105c0f1d8f855b9 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Thu, 9 Apr 2026 15:22:29 +0200 Subject: [PATCH 4/9] removed unneeded files --- src/easyscience/base_classes/__init__.py | 11 +- .../base_classes/collection_base_easylist.py | 153 ---------------- .../test_collection_base_easylist.py | 172 ------------------ 3 files changed, 1 insertion(+), 335 deletions(-) delete mode 100644 src/easyscience/base_classes/collection_base_easylist.py delete mode 100644 tests/unit_tests/base_classes/test_collection_base_easylist.py diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index ce5698e7..b5dc0418 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -3,18 +3,9 @@ from .based_base import BasedBase from .collection_base import CollectionBase -from .collection_base_easylist import CollectionBaseEasyList from .easy_list import EasyList from .model_base import ModelBase from .new_base import NewBase from .obj_base import ObjBase -__all__ = [ - BasedBase, - CollectionBase, - CollectionBaseEasyList, - ObjBase, - ModelBase, - NewBase, - EasyList, -] +__all__ = [BasedBase, CollectionBase, ObjBase, ModelBase, NewBase, EasyList] diff --git a/src/easyscience/base_classes/collection_base_easylist.py b/src/easyscience/base_classes/collection_base_easylist.py deleted file mode 100644 index f9a9ee4d..00000000 --- a/src/easyscience/base_classes/collection_base_easylist.py +++ /dev/null @@ -1,153 +0,0 @@ -# SPDX-FileCopyrightText: 2026 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -from collections.abc import MutableSequence -from typing import Any - -from easyscience.io.serializer_base import SerializerBase - -from .model_base import ModelBase - - -class CollectionBaseEasyList(ModelBase, MutableSequence[ModelBase]): - """Compatibility model-aware collection with list semantics. - - This preserves the older `CollectionBaseEasyList` API expected by legacy - tests while keeping the newer `CollectionBase` implementation unchanged. - """ - - def __init__( - self, - name: str, - *args: ModelBase | list[ModelBase], - unique_name: str | None = None, - display_name: str | None = None, - ): - if display_name is None: - display_name = name - super().__init__(unique_name=unique_name, display_name=display_name) - self._name = name - self._data: list[ModelBase] = [] - - for item in args: - if isinstance(item, list): - for nested_item in item: - self.append(nested_item) - else: - self.append(item) - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, new_name: str) -> None: - self._name = new_name - self.display_name = new_name - - def _get_key(self, obj: ModelBase) -> str: - return obj.unique_name - - def _is_duplicate(self, value: ModelBase) -> bool: - for existing in self._data: - try: - if self._get_key(existing) == self._get_key(value): - return True - except AttributeError: - if existing is value: - return True - return False - - def get_all_variables(self): - variables = [] - for item in self._data: - variables.extend(item.get_all_variables()) - return variables - - def get_all_parameters(self): - parameters = [] - for item in self._data: - parameters.extend(item.get_all_parameters()) - return parameters - - def get_free_parameters(self): - parameters = [] - for item in self._data: - parameters.extend(item.get_free_parameters()) - return parameters - - def __getitem__(self, idx: int | slice | str): - if isinstance(idx, int): - return self._data[idx] - if isinstance(idx, slice): - return self.__class__(self.name, self._data[idx]) - if isinstance(idx, str): - unique_name_match = next( - (item for item in self._data if item.unique_name == idx), None - ) - if unique_name_match is not None: - return unique_name_match - - matches = [item for item in self._data if item.name == idx] - if not matches: - raise IndexError('Given index does not exist') - if len(matches) == 1: - return matches[0] - return self.__class__(self.name, matches) - raise TypeError('Index must be an int, slice, or str') - - def __setitem__(self, idx: int | slice, value): - if isinstance(idx, int): - if not isinstance(value, ModelBase): - raise AttributeError('CollectionBaseEasyList can only contain model objects') - self._data[idx] = value - return - if isinstance(idx, slice): - replacement = list(value) - if not all(isinstance(item, ModelBase) for item in replacement): - raise AttributeError('CollectionBaseEasyList can only contain model objects') - self._data[idx] = replacement - return - raise TypeError('Index must be an int or slice') - - def __delitem__(self, idx: int | slice) -> None: - del self._data[idx] - - def __len__(self) -> int: - return len(self._data) - - def insert(self, index: int, value: ModelBase) -> None: - if not isinstance(value, ModelBase): - raise AttributeError('CollectionBaseEasyList can only contain model objects') - if self._is_duplicate(value): - return - self._data.insert(index, value) - - def sort(self, key=None, reverse: bool = False) -> None: - self._data.sort(key=key, reverse=reverse) - - def _convert_to_dict(self, in_dict, encoder, skip=None, **kwargs) -> dict: - if skip is None: - skip = [] - in_dict['data'] = [ - encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data - ] - return in_dict - - @classmethod - def from_dict(cls, obj_dict: dict[str, Any]) -> CollectionBaseEasyList: - if not SerializerBase._is_serialized_easyscience_object(obj_dict): - raise ValueError( - 'Input must be a dictionary representing an EasyScience CollectionBaseEasyList object.' - ) - if obj_dict['@class'] != cls.__name__: - raise ValueError( - f'Class name in dictionary does not match the expected class: {cls.__name__}.' - ) - - kwargs = SerializerBase.deserialize_dict(obj_dict) - data = kwargs.pop('data', []) - name = kwargs.pop('name') - return cls(name, data, **kwargs) diff --git a/tests/unit_tests/base_classes/test_collection_base_easylist.py b/tests/unit_tests/base_classes/test_collection_base_easylist.py deleted file mode 100644 index 107e96e7..00000000 --- a/tests/unit_tests/base_classes/test_collection_base_easylist.py +++ /dev/null @@ -1,172 +0,0 @@ -# SPDX-FileCopyrightText: 2025 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2025 Contributors to the EasyScience project Parameter: - return Parameter(name, cast(Any, value), **kwargs) - - -class DummyModel(ModelBase): - def __init__( - self, - name: str, - value: Any, - unique_name: str | None = None, - display_name: str | None = None, - ): - if display_name is None: - display_name = name - super().__init__(unique_name=unique_name, display_name=display_name) - self._name = name - self._value = ( - value if isinstance(value, Parameter) else make_parameter(f'{name}_value', value) - ) - - @property - def name(self) -> str: - return self._name - - @name.setter - def name(self, new_name: str) -> None: - self._name = new_name - self.display_name = new_name - - @property - def value(self) -> Parameter: - return self._value - - @value.setter - def value(self, new_value: float) -> None: - self._value.value = new_value - - -@pytest.fixture(autouse=True) -def clear(): - global_object.map._clear() - - -def test_collection_base_getitem_supports_unique_name_and_name_fallback(): - m1 = DummyModel('dup', make_parameter('p1', 1.0), unique_name='m1') - m2 = DummyModel('dup', make_parameter('p2', 2.0), unique_name='m2') - - collection = CollectionBase('test', m1, m2) - - assert collection['m1'] is m1 - same_name = collection['dup'] - assert isinstance(same_name, CollectionBase) - assert len(same_name) == 2 - assert list(same_name) == [m1, m2] - - -def test_collection_base_get_all_variables_recurses_into_models(): - p1 = make_parameter('p1', 1.0) - p2 = make_parameter('p2', 2.0) - model1 = DummyModel('m1', p1) - model2 = DummyModel('m2', p2) - - collection = CollectionBase('test', model1, model2) - - variables = collection.get_all_variables() - - assert p1 in variables - assert p2 in variables - - -def test_collection_base_get_parameters_recurses_into_nested_objects(): - nested = CollectionBase( - 'nested', - DummyModel('m1', make_parameter('p1', 1.0)), - DummyModel('m2', make_parameter('p2', 2.0, fixed=True)), - ) - model = DummyModel('model', make_parameter('p3', 3.0)) - collection = CollectionBase('test', nested, model) - - parameters = collection.get_all_parameters() - free_parameters = collection.get_free_parameters() - - assert [parameter.name for parameter in parameters] == ['p1', 'p2', 'p3'] - assert [parameter.name for parameter in free_parameters] == ['p1', 'p3'] - - -def test_collection_base_rejects_non_model_items(): - with pytest.raises(AttributeError, match='model objects'): - CollectionBase('test', make_parameter('p1', 1.0)) - - -def test_collection_base_rejects_basedbase_objects(): - with pytest.raises(AttributeError, match='model objects'): - from easyscience import ObjBase - - CollectionBase('test', ObjBase('legacy', p=make_parameter('p1', 1.0))) - - -def test_collection_base_accepts_nested_collections(): - inner = CollectionBase('inner', DummyModel('m', make_parameter('p', 1.0))) - outer = CollectionBase('outer', inner) - - assert len(outer) == 1 - assert outer[0] is inner - - -def test_collection_base_prevents_duplicates_by_unique_name(): - m1 = DummyModel('m1', make_parameter('p1', 1.0), unique_name='m1') - m2 = DummyModel('m2', make_parameter('p2', 2.0), unique_name='m2') - - collection = CollectionBase('test', m1) - collection._get_key = lambda _obj: 'same-key' - collection.append(m2) - - assert len(collection) == 1 - assert collection[0] is m1 - - -def test_collection_base_duplicate_identity_fallback_when_key_unavailable(): - m1 = DummyModel('m1', make_parameter('p1', 1.0), unique_name='m1') - collection = CollectionBase('test', m1) - - def _broken_get_key(_obj): - raise AttributeError('key unavailable') - - collection._get_key = _broken_get_key - collection.append(m1) - - assert len(collection) == 1 - - -def test_collection_base_to_dict_round_trip_preserves_name_and_data(): - m1 = DummyModel('m1', make_parameter('p1', 1.0)) - m2 = DummyModel('m2', make_parameter('p2', 2.0)) - collection = CollectionBase('test', m1, m2) - - encoded = collection.to_dict() - decoded = CollectionBase.from_dict(encoded) - - assert decoded.name == 'test' - assert [item.name for item in decoded] == ['m1', 'm2'] - - -def test_collection_base_sort_accepts_key(): - m1 = DummyModel('m1', make_parameter('p1', 3.0)) - m2 = DummyModel('m2', make_parameter('p2', 1.0)) - m3 = DummyModel('m3', make_parameter('p3', 2.0)) - collection = CollectionBase('test', m1, m2, m3) - - collection.sort(key=lambda item: item.value.value) - - assert [item.name for item in collection] == ['m2', 'm3', 'm1'] - - -def test_collection_base_isinstance_model_base(): - collection = CollectionBase('test') - assert isinstance(collection, ModelBase) From 4effb57bc19eb714b69f82ebcc8d96fcb16daad0 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 10 Apr 2026 11:04:06 +0200 Subject: [PATCH 5/9] added a test --- .../fitting/minimizers/test_minimizer_dfo.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/unit/fitting/minimizers/test_minimizer_dfo.py b/tests/unit/fitting/minimizers/test_minimizer_dfo.py index f4c71be9..a4e2d340 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_dfo.py +++ b/tests/unit/fitting/minimizers/test_minimizer_dfo.py @@ -364,6 +364,31 @@ def test_dfo_fit_no_scaling(self, minimizer: DFO, monkeypatch): assert 'kwargs_set_key' in list(mock_dfols.solve.call_args[1].keys()) assert mock_dfols.solve.call_args[1]['kwargs_set_key'] == 'kwargs_set_val' + def test_fit_generic_exception_resets_parameters_and_raises_fit_error(self, minimizer: DFO) -> None: + """When _dfo_fit raises a non-FitError exception, fit() must reset + parameter values to cached originals and re-raise as FitError.""" + from easyscience import global_object + + global_object.stack.enabled = False + + mock_model = MagicMock() + mock_model_function = MagicMock(return_value=mock_model) + minimizer._make_model = MagicMock(return_value=mock_model_function) + minimizer._dfo_fit = MagicMock(side_effect=RuntimeError('solver crashed')) + + cached_par_1 = MagicMock() + cached_par_1.value = 5.0 + cached_par_2 = MagicMock() + cached_par_2.value = 10.0 + minimizer._cached_pars = {'a': cached_par_1, 'b': cached_par_2} + minimizer._cached_pars_vals = {'a': (1.0, 0.1), 'b': (2.0, 0.2)} + + with pytest.raises(FitError): + minimizer.fit(x=np.array([1.0]), y=np.array([1.0]), weights=np.array([1.0])) + + assert cached_par_1.value == 1.0 + assert cached_par_2.value == 2.0 + def test_dfo_fit_exception(self, minimizer: DFO, monkeypatch): # When pars = {1: MagicMock(Parameter)} From 41c222462796463f373c2495c0c573440357d0f2 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Fri, 10 Apr 2026 11:04:36 +0200 Subject: [PATCH 6/9] ruff --- tests/unit/fitting/minimizers/test_minimizer_dfo.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/fitting/minimizers/test_minimizer_dfo.py b/tests/unit/fitting/minimizers/test_minimizer_dfo.py index a4e2d340..34f98c4f 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_dfo.py +++ b/tests/unit/fitting/minimizers/test_minimizer_dfo.py @@ -364,7 +364,9 @@ def test_dfo_fit_no_scaling(self, minimizer: DFO, monkeypatch): assert 'kwargs_set_key' in list(mock_dfols.solve.call_args[1].keys()) assert mock_dfols.solve.call_args[1]['kwargs_set_key'] == 'kwargs_set_val' - def test_fit_generic_exception_resets_parameters_and_raises_fit_error(self, minimizer: DFO) -> None: + def test_fit_generic_exception_resets_parameters_and_raises_fit_error( + self, minimizer: DFO + ) -> None: """When _dfo_fit raises a non-FitError exception, fit() must reset parameter values to cached originals and re-raise as FitError.""" from easyscience import global_object From d5321a20b4a5af5be89eda281175451cdb9fe966 Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Mon, 13 Apr 2026 16:52:31 +0200 Subject: [PATCH 7/9] addressed code review points --- .../fitting/minimizers/minimizer_bumps.py | 6 +- src/easyscience/fitting/minimizers/utils.py | 33 +++++++ .../minimizers/test_minimizer_bumps.py | 34 ++++++++ tests/unit/fitting/minimizers/test_utils.py | 85 +++++++++++++++++++ 4 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 tests/unit/fitting/minimizers/test_utils.py diff --git a/src/easyscience/fitting/minimizers/minimizer_bumps.py b/src/easyscience/fitting/minimizers/minimizer_bumps.py index 68098ddc..db6bfcca 100644 --- a/src/easyscience/fitting/minimizers/minimizer_bumps.py +++ b/src/easyscience/fitting/minimizers/minimizer_bumps.py @@ -283,8 +283,8 @@ def _gen_fit_results( for name, value in kwargs.items(): if getattr(results, name, False): setattr(results, name, value) - nit = getattr(fit_results, 'nit', 0) - stopped_on_budget = max_evaluations is not None and nit >= max_evaluations - 1 + n_evaluations = None if self._eval_counter is None else self._eval_counter.count + stopped_on_budget = max_evaluations is not None and n_evaluations is not None and n_evaluations >= max_evaluations results.success = fit_results.success and not stopped_on_budget pars = self._cached_pars @@ -300,7 +300,7 @@ def _gen_fit_results( results.y_obs = self._cached_model.y results.y_calc = self.evaluate(results.x, minimizer_parameters=results.p) results.y_err = self._cached_model.dy - results.n_evaluations = None if self._eval_counter is None else self._eval_counter.count + results.n_evaluations = n_evaluations results.message = ( f'Fit stopped: reached maximum evaluations ({max_evaluations})' if stopped_on_budget diff --git a/src/easyscience/fitting/minimizers/utils.py b/src/easyscience/fitting/minimizers/utils.py index e3633365..a9a7b781 100644 --- a/src/easyscience/fitting/minimizers/utils.py +++ b/src/easyscience/fitting/minimizers/utils.py @@ -42,6 +42,39 @@ def __init__(self): self.engine_result = None self.total_results = None + def __repr__(self) -> str: + engine_name = self.minimizer_engine.__name__ if self.minimizer_engine else None + try: + chi2_val = self.chi2 + reduced_val = self.reduced_chi + if not np.isfinite(chi2_val) or not np.isfinite(reduced_val): + raise ValueError + chi2 = f'{chi2_val:.4g}' + reduced = f'{reduced_val:.4g}' + except Exception: + chi2 = 'N/A' + reduced = 'N/A' + + try: + n_points = len(self.x) + except TypeError: + n_points = 0 + + lines = [ + f'FitResults(success={self.success}', + f' n_pars={self.n_pars}, n_points={n_points}', + f' chi2={chi2}, reduced_chi={reduced}', + f' n_evaluations={self.n_evaluations}', + f' minimizer={engine_name}', + ] + if self.message: + lines.append(f" message='{self.message}'") + if self.p: + par_str = ', '.join(f'{k}={v:.4g}' for k, v in self.p.items()) + lines.append(f' parameters={{{par_str}}}') + lines.append(')') + return '\n'.join(lines) + @property def n_pars(self): return len(self.p) diff --git a/tests/unit/fitting/minimizers/test_minimizer_bumps.py b/tests/unit/fitting/minimizers/test_minimizer_bumps.py index 3d01fddc..488ae5bc 100644 --- a/tests/unit/fitting/minimizers/test_minimizer_bumps.py +++ b/tests/unit/fitting/minimizers/test_minimizer_bumps.py @@ -228,3 +228,37 @@ def test_gen_fit_results(self, minimizer: Bumps, monkeypatch): minimizer.evaluate.assert_called_once_with( 'x', minimizer_parameters={'ppar_1': 'par_value_1', 'ppar_2': 'par_value_2'} ) + + def test_gen_fit_results_uses_n_evaluations_for_budget_check( + self, minimizer: Bumps, monkeypatch + ): + mock_domain_fit_results = MagicMock() + mock_FitResults = MagicMock(return_value=mock_domain_fit_results) + monkeypatch.setattr( + easyscience.fitting.minimizers.minimizer_bumps, 'FitResults', mock_FitResults + ) + + mock_fit_result = MagicMock() + mock_fit_result.success = True + mock_fit_result.nit = 99 + + mock_cached_model = MagicMock() + mock_cached_model.x = 'x' + mock_cached_model.y = 'y' + mock_cached_model.dy = 'dy' + mock_cached_model.pars = {'ppar_1': 0} + minimizer._cached_model = mock_cached_model + + mock_cached_par = MagicMock() + mock_cached_par.value = 'par_value_1' + minimizer._cached_pars = {'par_1': mock_cached_par} + + minimizer._p_0 = 'p_0' + minimizer._eval_counter = MagicMock(count=2) + minimizer.evaluate = MagicMock(return_value='evaluate') + + domain_fit_results = minimizer._gen_fit_results(mock_fit_result, max_evaluations=3) + + assert domain_fit_results.success == True + assert domain_fit_results.n_evaluations == 2 + assert domain_fit_results.message == '' diff --git a/tests/unit/fitting/minimizers/test_utils.py b/tests/unit/fitting/minimizers/test_utils.py new file mode 100644 index 00000000..fb6f29fe --- /dev/null +++ b/tests/unit/fitting/minimizers/test_utils.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: 2026 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + +from easyscience.fitting.minimizers.utils import FitResults + + +class TestFitResultsRepr: + def _make_result(self, **overrides): + r = FitResults() + r.success = True + r.x = np.array([1.0, 2.0, 3.0]) + r.y_obs = np.array([1.0, 2.0, 3.0]) + r.y_calc = np.array([1.1, 1.9, 3.05]) + r.y_err = np.array([0.1, 0.1, 0.1]) + r.p = {'pa': 1.234, 'pb': 5.678} + r.n_evaluations = 42 + r.minimizer_engine = type('Bumps', (), {'__name__': 'Bumps'}) + for k, v in overrides.items(): + setattr(r, k, v) + return r + + def test_repr_contains_success(self): + r = self._make_result() + assert 'success=True' in repr(r) + + def test_repr_contains_n_pars_and_n_points(self): + r = self._make_result() + text = repr(r) + assert 'n_pars=2' in text + assert 'n_points=3' in text + + def test_repr_contains_chi2_values(self): + r = self._make_result() + text = repr(r) + assert 'chi2=' in text + assert 'reduced_chi=' in text + assert 'N/A' not in text + + def test_repr_shows_na_when_chi2_cannot_be_computed(self): + r = self._make_result(y_err=np.array([0.0, 0.0, 0.0])) + text = repr(r) + assert 'chi2=N/A' in text + assert 'reduced_chi=N/A' in text + + def test_repr_contains_n_evaluations(self): + r = self._make_result() + assert 'n_evaluations=42' in repr(r) + + def test_repr_contains_minimizer_name(self): + r = self._make_result() + assert 'minimizer=Bumps' in repr(r) + + def test_repr_minimizer_none(self): + r = self._make_result(minimizer_engine=None) + assert 'minimizer=None' in repr(r) + + def test_repr_includes_message_when_set(self): + r = self._make_result(message='Fit stopped: reached maximum evaluations (3)') + assert 'Fit stopped: reached maximum evaluations (3)' in repr(r) + + def test_repr_omits_message_when_empty(self): + r = self._make_result(message='') + assert 'message' not in repr(r) + + def test_repr_includes_parameters(self): + r = self._make_result() + text = repr(r) + assert 'pa=1.234' in text + assert 'pb=5.678' in text + + def test_repr_omits_parameters_when_empty(self): + r = self._make_result(p={}) + assert 'parameters' not in repr(r) + + def test_repr_default_fit_results(self): + r = FitResults() + text = repr(r) + assert 'success=False' in text + assert 'n_pars=0' in text + assert 'n_points=0' in text + assert 'n_evaluations=None' in text + assert 'chi2=N/A' in text From f83ef52b97c529228f58c28383779f3f35ea62ef Mon Sep 17 00:00:00 2001 From: Piotr Rozyczko Date: Mon, 13 Apr 2026 16:53:06 +0200 Subject: [PATCH 8/9] ruff format --- src/easyscience/fitting/minimizers/minimizer_bumps.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/easyscience/fitting/minimizers/minimizer_bumps.py b/src/easyscience/fitting/minimizers/minimizer_bumps.py index db6bfcca..3862f196 100644 --- a/src/easyscience/fitting/minimizers/minimizer_bumps.py +++ b/src/easyscience/fitting/minimizers/minimizer_bumps.py @@ -284,7 +284,11 @@ def _gen_fit_results( if getattr(results, name, False): setattr(results, name, value) n_evaluations = None if self._eval_counter is None else self._eval_counter.count - stopped_on_budget = max_evaluations is not None and n_evaluations is not None and n_evaluations >= max_evaluations + stopped_on_budget = ( + max_evaluations is not None + and n_evaluations is not None + and n_evaluations >= max_evaluations + ) results.success = fit_results.success and not stopped_on_budget pars = self._cached_pars From ce6feeba4fd802ecf8192ce557ba9fde48ac3cc6 Mon Sep 17 00:00:00 2001 From: rozyczko Date: Tue, 14 Apr 2026 08:14:41 +0200 Subject: [PATCH 9/9] reduced_chi -> reduced_chi2 --- src/easyscience/fitting/minimizers/utils.py | 6 +++--- tests/integration/fitting/test_fitter.py | 6 +++--- tests/integration/fitting/test_multi_fitter.py | 6 +++--- tests/unit/fitting/minimizers/test_utils.py | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/easyscience/fitting/minimizers/utils.py b/src/easyscience/fitting/minimizers/utils.py index a9a7b781..c6d462fe 100644 --- a/src/easyscience/fitting/minimizers/utils.py +++ b/src/easyscience/fitting/minimizers/utils.py @@ -46,7 +46,7 @@ def __repr__(self) -> str: engine_name = self.minimizer_engine.__name__ if self.minimizer_engine else None try: chi2_val = self.chi2 - reduced_val = self.reduced_chi + reduced_val = self.reduced_chi2 if not np.isfinite(chi2_val) or not np.isfinite(reduced_val): raise ValueError chi2 = f'{chi2_val:.4g}' @@ -63,7 +63,7 @@ def __repr__(self) -> str: lines = [ f'FitResults(success={self.success}', f' n_pars={self.n_pars}, n_points={n_points}', - f' chi2={chi2}, reduced_chi={reduced}', + f' chi2={chi2}, reduced_chi2={reduced}', f' n_evaluations={self.n_evaluations}', f' minimizer={engine_name}', ] @@ -88,7 +88,7 @@ def chi2(self): return ((self.residual / self.y_err) ** 2).sum() @property - def reduced_chi(self): + def reduced_chi2(self): return self.chi2 / (len(self.x) - self.n_pars) diff --git a/tests/integration/fitting/test_fitter.py b/tests/integration/fitting/test_fitter.py index 9b956cb1..63ede513 100644 --- a/tests/integration/fitting/test_fitter.py +++ b/tests/integration/fitting/test_fitter.py @@ -81,7 +81,7 @@ def __call__(self, x: np.ndarray) -> np.ndarray: def check_fit_results(result, sp_sin, ref_sin, x, **kwargs): assert result.n_pars == len(sp_sin.get_fit_parameters()) assert result.chi2 == pytest.approx(0, abs=1.5e-3 * (len(result.x) - result.n_pars)) - assert result.reduced_chi == pytest.approx(0, abs=1.5e-3) + assert result.reduced_chi2 == pytest.approx(0, abs=1.5e-3) assert result.success if 'sp_ref1' in kwargs.keys(): sp_ref1 = kwargs['sp_ref1'] @@ -385,7 +385,7 @@ def test_2D_vectorized(fit_engine): else: raise e assert result.n_pars == len(m2.get_fit_parameters()) - assert result.reduced_chi == pytest.approx(0, abs=1.5e-3) + assert result.reduced_chi2 == pytest.approx(0, abs=1.5e-3) assert result.success assert np.all(result.x == XY) y_calc_ref = m2(XY) @@ -424,7 +424,7 @@ def test_2D_non_vectorized(fit_engine): else: raise e assert result.n_pars == len(m2.get_fit_parameters()) - assert result.reduced_chi == pytest.approx(0, abs=1.5e-3) + assert result.reduced_chi2 == pytest.approx(0, abs=1.5e-3) assert result.success assert np.all(result.x == XY) y_calc_ref = m2(XY.reshape(-1, 2)) diff --git a/tests/integration/fitting/test_multi_fitter.py b/tests/integration/fitting/test_multi_fitter.py index 25cc9adf..fe4df933 100644 --- a/tests/integration/fitting/test_multi_fitter.py +++ b/tests/integration/fitting/test_multi_fitter.py @@ -95,7 +95,7 @@ def test_multi_fit(fit_engine): sp_sin_2.get_fit_parameters() ) assert result.chi2 == pytest.approx(0, abs=1.5e-3 * (len(result.x) - result.n_pars)) - assert result.reduced_chi == pytest.approx(0, abs=1.5e-3) + assert result.reduced_chi2 == pytest.approx(0, abs=1.5e-3) assert result.success assert np.all(result.x == X[idx]) assert np.all(result.y_obs == Y[idx]) @@ -200,7 +200,7 @@ def test_multi_fit2(fit_engine): sp_sin_2.get_fit_parameters() ) + len(sp_line.get_fit_parameters()) assert result.chi2 == pytest.approx(0, abs=1.5e-3 * (len(result.x) - result.n_pars)) - assert result.reduced_chi == pytest.approx(0, abs=1.5e-3) + assert result.reduced_chi2 == pytest.approx(0, abs=1.5e-3) assert result.success assert np.all(result.x == X[idx]) assert np.all(result.y_obs == Y[idx]) @@ -275,7 +275,7 @@ def test_multi_fit_1D_2D(fit_engine): fit_engine != 'DFO' ): # DFO apparently does not fit well with even weights. Can't be bothered to fix assert result.chi2 == pytest.approx(0, abs=1.5e-3 * (len(result.x) - result.n_pars)) - assert result.reduced_chi == pytest.approx(0, abs=1.5e-3) + assert result.reduced_chi2 == pytest.approx(0, abs=1.5e-3) assert result.y_calc == pytest.approx(F_ref[idx](X[idx]), abs=1e-2) assert result.residual == pytest.approx( F_real[idx](X[idx]) - F_ref[idx](X[idx]), abs=1e-2 diff --git a/tests/unit/fitting/minimizers/test_utils.py b/tests/unit/fitting/minimizers/test_utils.py index fb6f29fe..f9227852 100644 --- a/tests/unit/fitting/minimizers/test_utils.py +++ b/tests/unit/fitting/minimizers/test_utils.py @@ -36,14 +36,14 @@ def test_repr_contains_chi2_values(self): r = self._make_result() text = repr(r) assert 'chi2=' in text - assert 'reduced_chi=' in text + assert 'reduced_chi2=' in text assert 'N/A' not in text def test_repr_shows_na_when_chi2_cannot_be_computed(self): r = self._make_result(y_err=np.array([0.0, 0.0, 0.0])) text = repr(r) assert 'chi2=N/A' in text - assert 'reduced_chi=N/A' in text + assert 'reduced_chi2=N/A' in text def test_repr_contains_n_evaluations(self): r = self._make_result()