diff --git a/src/pyrecest/backend_support/_shared_numpy_assignment_by_sum_contract.py b/src/pyrecest/backend_support/_shared_numpy_assignment_by_sum_contract.py new file mode 100644 index 000000000..6a6f89925 --- /dev/null +++ b/src/pyrecest/backend_support/_shared_numpy_assignment_by_sum_contract.py @@ -0,0 +1,81 @@ +"""Shared NumPy ``assignment_by_sum`` duplicate-index compatibility hook.""" + +from __future__ import annotations + +import sys + + +def _install_assignment_by_sum(assignment_by_sum, backend, shared_numpy) -> None: + """Install the patched helper on loaded shared-NumPy facade modules.""" + + shared_numpy.assignment_by_sum = assignment_by_sum + for module_name in ("pyrecest._backend.numpy", "pyrecest._backend.autograd"): + module = sys.modules.get(module_name) + if module is not None: + module.assignment_by_sum = assignment_by_sum + if getattr(backend, "__backend_name__", None) in {"numpy", "autograd"}: + backend.assignment_by_sum = assignment_by_sum + + +def patch_shared_numpy_assignment_by_sum_duplicate_indices() -> None: + """Make shared NumPy assignment-by-sum accumulate repeated indices.""" + + try: + import pyrecest._backend._shared_numpy as shared_numpy # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - backend import failure path + return + + if getattr(backend, "__backend_name__", None) not in {"numpy", "autograd"}: + return + + original_assignment_by_sum = shared_numpy.assignment_by_sum + if getattr( + original_assignment_by_sum, + "_pyrecest_duplicate_index_accumulation_contract", + False, + ): + _install_assignment_by_sum(original_assignment_by_sum, backend, shared_numpy) + return + + def assignment_by_sum(x, values, indices, axis=0): + x_new = shared_numpy.copy(shared_numpy.array(x)) + + if shared_numpy._is_empty_index_sequence(indices): + return x_new + + use_vectorization = hasattr(indices, "__len__") and len(indices) < shared_numpy.ndim( + x_new + ) + if shared_numpy._is_boolean(indices): + x_new[indices] += values + return x_new + + zip_indices = shared_numpy._is_iterable(indices) and shared_numpy._is_iterable( + indices[0] + ) + if zip_indices: + indices = tuple(zip(*indices)) + if not use_vectorization: + len_indices = len(indices) if shared_numpy._is_iterable(indices) else 1 + len_values = len(values) if shared_numpy._is_iterable(values) else 1 + if len_values > 1 and len_values != len_indices: + raise ValueError("Either one value or as many values as indices") + shared_numpy._np.add.at(x_new, indices, values) + else: + indices = tuple(list(indices[:axis]) + [slice(None)] + list(indices[axis:])) + x_new[indices] += values + return x_new + + assignment_by_sum.__name__ = getattr( + original_assignment_by_sum, + "__name__", + "assignment_by_sum", + ) + assignment_by_sum.__doc__ = getattr(original_assignment_by_sum, "__doc__", None) + assignment_by_sum._pyrecest_duplicate_index_accumulation_contract = True + + _install_assignment_by_sum(assignment_by_sum, backend, shared_numpy) + + +__all__ = ["patch_shared_numpy_assignment_by_sum_duplicate_indices"] diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..9e6e8cb0f 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -9,8 +9,12 @@ from pyrecest.backend_support._pytorch_allclose_device_contract import ( patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, ) +from pyrecest.backend_support._shared_numpy_assignment_by_sum_contract import ( + patch_shared_numpy_assignment_by_sum_duplicate_indices as _patch_shared_numpy_assignment_by_sum_duplicate_indices, +) _patch_pytorch_allclose_device_contract() +_patch_shared_numpy_assignment_by_sum_duplicate_indices() P = ParamSpec("P") R = TypeVar("R") diff --git a/tests/backend_support/test_shared_numpy_assignment_by_sum_duplicate_indices.py b/tests/backend_support/test_shared_numpy_assignment_by_sum_duplicate_indices.py new file mode 100644 index 000000000..79c73fa2c --- /dev/null +++ b/tests/backend_support/test_shared_numpy_assignment_by_sum_duplicate_indices.py @@ -0,0 +1,31 @@ +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + +pytestmark = pytest.mark.backend_portable + + +@pytest.mark.parametrize("backend_name", ["numpy", "autograd"]) +def test_assignment_by_sum_accumulates_duplicate_advanced_indices(backend_name): + if backend_name == "autograd" and importlib.util.find_spec("autograd") is None: + pytest.skip("autograd is not installed") + + code = """ +import pyrecest.backend as backend + +vector = backend.zeros(3) +vector_result = backend.assignment_by_sum(vector, [1.0, 2.0], [0, 0]) +assert backend.to_numpy(vector_result).tolist() == [3.0, 0.0, 0.0] + +matrix = backend.zeros((2, 2)) +matrix_result = backend.assignment_by_sum(matrix, [1.0, 2.0], [(0, 1), (0, 1)]) +assert backend.to_numpy(matrix_result).tolist() == [[0.0, 3.0], [0.0, 0.0]] + +print("ok") +""" + result = run_backend_code(backend_name, code) + + assert result.returncode == 0, result.stderr + assert "ok" in result.stdout