From 5f26bf071e2f4448826a140b4b6915b84648ffa3 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Thu, 2 Jul 2026 00:05:54 +0200 Subject: [PATCH] Fix JAX assignment NumPy index arrays --- .../_jax_assignment_numpy_index_contract.py | 46 +++++++++++++++++++ src/pyrecest/stability.py | 4 ++ 2 files changed, 50 insertions(+) create mode 100644 src/pyrecest/backend_support/_jax_assignment_numpy_index_contract.py diff --git a/src/pyrecest/backend_support/_jax_assignment_numpy_index_contract.py b/src/pyrecest/backend_support/_jax_assignment_numpy_index_contract.py new file mode 100644 index 000000000..ad09e7756 --- /dev/null +++ b/src/pyrecest/backend_support/_jax_assignment_numpy_index_contract.py @@ -0,0 +1,46 @@ +"""JAX assignment compatibility helpers.""" + +from __future__ import annotations + + +def _normalize_indices(indices, np, jnp): + """Return JAX-friendly index arrays for NumPy ndarray inputs.""" + + if isinstance(indices, np.ndarray): + if indices.ndim > 0 and indices.size == 0: + return indices + return jnp.asarray(indices) + return indices + + +def _wrap_helper(helper, np, jnp): + """Normalize NumPy ndarray indices before delegating to a JAX helper.""" + + if getattr(helper, "_pyrecest_numpy_index_contract", False): + return helper + + def wrapped(x, values, indices, axis=0): + return helper(x, values, _normalize_indices(indices, np, jnp), axis=axis) + + wrapped.__name__ = getattr(helper, "__name__", "assignment") + wrapped.__doc__ = getattr(helper, "__doc__", None) + wrapped._pyrecest_numpy_index_contract = True + return wrapped + + +def patch_jax_assignment_numpy_index_contract() -> None: + """Make JAX assignment helpers accept NumPy ndarray index sequences.""" + + try: + import jax.numpy as jnp # pylint: disable=import-outside-toplevel + import numpy as np # pylint: disable=import-outside-toplevel + import pyrecest._backend.jax as jax_backend # pylint: disable=import-outside-toplevel + import pyrecest.backend as backend # pylint: disable=import-outside-toplevel + except ModuleNotFoundError: # pragma: no cover - JAX backend may be unavailable + return + + jax_backend.assignment = _wrap_helper(jax_backend.assignment, np, jnp) + jax_backend.assignment_by_sum = _wrap_helper(jax_backend.assignment_by_sum, np, jnp) + if getattr(backend, "__backend_name__", None) == "jax": + backend.assignment = jax_backend.assignment + backend.assignment_by_sum = jax_backend.assignment_by_sum diff --git a/src/pyrecest/stability.py b/src/pyrecest/stability.py index aa56e78f3..e5e443cac 100644 --- a/src/pyrecest/stability.py +++ b/src/pyrecest/stability.py @@ -6,10 +6,14 @@ from dataclasses import asdict, dataclass from typing import Final, Literal, ParamSpec, TypeVar +from pyrecest.backend_support._jax_assignment_numpy_index_contract import ( + patch_jax_assignment_numpy_index_contract as _patch_jax_assignment_numpy_index_contract, +) from pyrecest.backend_support._pytorch_allclose_device_contract import ( patch_pytorch_allclose_device_contract as _patch_pytorch_allclose_device_contract, ) +_patch_jax_assignment_numpy_index_contract() _patch_pytorch_allclose_device_contract() P = ParamSpec("P")