Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Shared NumPy ``assignment_by_sum`` duplicate-index compatibility hook."""

from __future__ import annotations

import sys


def _install_assignment_by_sum(assignment_by_sum, backend, shared_numpy) -> None:
"""Install the patched helper on loaded shared-NumPy facade modules."""

shared_numpy.assignment_by_sum = assignment_by_sum
for module_name in ("pyrecest._backend.numpy", "pyrecest._backend.autograd"):
module = sys.modules.get(module_name)
if module is not None:
module.assignment_by_sum = assignment_by_sum
if getattr(backend, "__backend_name__", None) in {"numpy", "autograd"}:
backend.assignment_by_sum = assignment_by_sum


def patch_shared_numpy_assignment_by_sum_duplicate_indices() -> None:
"""Make shared NumPy assignment-by-sum accumulate repeated indices."""

try:
import pyrecest._backend._shared_numpy as shared_numpy # pylint: disable=import-outside-toplevel
import pyrecest.backend as backend # pylint: disable=import-outside-toplevel
except ModuleNotFoundError: # pragma: no cover - backend import failure path
return

if getattr(backend, "__backend_name__", None) not in {"numpy", "autograd"}:
return

original_assignment_by_sum = shared_numpy.assignment_by_sum
if getattr(
original_assignment_by_sum,
"_pyrecest_duplicate_index_accumulation_contract",
False,
):
_install_assignment_by_sum(original_assignment_by_sum, backend, shared_numpy)
return

def assignment_by_sum(x, values, indices, axis=0):
x_new = shared_numpy.copy(shared_numpy.array(x))

if shared_numpy._is_empty_index_sequence(indices):
return x_new

use_vectorization = hasattr(indices, "__len__") and len(indices) < shared_numpy.ndim(
x_new
)
if shared_numpy._is_boolean(indices):
x_new[indices] += values
return x_new

zip_indices = shared_numpy._is_iterable(indices) and shared_numpy._is_iterable(
indices[0]
)
if zip_indices:
indices = tuple(zip(*indices))
if not use_vectorization:
len_indices = len(indices) if shared_numpy._is_iterable(indices) else 1
len_values = len(values) if shared_numpy._is_iterable(values) else 1
if len_values > 1 and len_values != len_indices:
raise ValueError("Either one value or as many values as indices")
shared_numpy._np.add.at(x_new, indices, values)
else:
indices = tuple(list(indices[:axis]) + [slice(None)] + list(indices[axis:]))
x_new[indices] += values
return x_new

assignment_by_sum.__name__ = getattr(
original_assignment_by_sum,
"__name__",
"assignment_by_sum",
)
assignment_by_sum.__doc__ = getattr(original_assignment_by_sum, "__doc__", None)
assignment_by_sum._pyrecest_duplicate_index_accumulation_contract = True

_install_assignment_by_sum(assignment_by_sum, backend, shared_numpy)


__all__ = ["patch_shared_numpy_assignment_by_sum_duplicate_indices"]
4 changes: 4 additions & 0 deletions src/pyrecest/stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from pyrecest.backend_support._pytorch_allclose_device_contract import (
patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract,
)
from pyrecest.backend_support._shared_numpy_assignment_by_sum_contract import (
patch_shared_numpy_assignment_by_sum_duplicate_indices as _patch_shared_numpy_assignment_by_sum_duplicate_indices,
)

_patch_pytorch_allclose_device_contract()
_patch_shared_numpy_assignment_by_sum_duplicate_indices()

P = ParamSpec("P")
R = TypeVar("R")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code

pytestmark = pytest.mark.backend_portable


@pytest.mark.parametrize("backend_name", ["numpy", "autograd"])
def test_assignment_by_sum_accumulates_duplicate_advanced_indices(backend_name):
if backend_name == "autograd" and importlib.util.find_spec("autograd") is None:
pytest.skip("autograd is not installed")

code = """
import pyrecest.backend as backend

vector = backend.zeros(3)
vector_result = backend.assignment_by_sum(vector, [1.0, 2.0], [0, 0])
assert backend.to_numpy(vector_result).tolist() == [3.0, 0.0, 0.0]

matrix = backend.zeros((2, 2))
matrix_result = backend.assignment_by_sum(matrix, [1.0, 2.0], [(0, 1), (0, 1)])
assert backend.to_numpy(matrix_result).tolist() == [[0.0, 3.0], [0.0, 0.0]]

print("ok")
"""
result = run_backend_code(backend_name, code)

assert result.returncode == 0, result.stderr
assert "ok" in result.stdout
Loading