From f6d4edb62d8d6e2a398021f8d3acac8a9d05f961 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 21:38:08 +0200 Subject: [PATCH 1/2] Add JAX empty assignment index regression test --- ...t_jax_assignment_empty_indices_contract.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/backend_support/test_jax_assignment_empty_indices_contract.py diff --git a/tests/backend_support/test_jax_assignment_empty_indices_contract.py b/tests/backend_support/test_jax_assignment_empty_indices_contract.py new file mode 100644 index 000000000..5c9f28c8d --- /dev/null +++ b/tests/backend_support/test_jax_assignment_empty_indices_contract.py @@ -0,0 +1,37 @@ +"""Regression tests for raw JAX assignment helper empty-index handling.""" + +from __future__ import annotations + +import importlib.util + +import pytest +from tests.support.backend_runner import run_backend_code + + +@pytest.mark.backend_portable +def test_raw_jax_assignment_helpers_treat_empty_indices_as_noop(): + if importlib.util.find_spec("jax") is None: + pytest.skip("JAX is not installed") + + result = run_backend_code( + "jax", + """ +import importlib + +import numpy as np +from pyrecest.backend import to_numpy + +raw_jax = importlib.import_module("pyrecest._backend.jax") +original = raw_jax.array([1.0, 2.0, 3.0]) + +assigned = raw_jax.assignment(original, 99.0, []) +added = raw_jax.assignment_by_sum(original, 99.0, []) +array_like = raw_jax.assignment([1.0, 2.0, 3.0], 99.0, []) + +np.testing.assert_allclose(to_numpy(assigned), [1.0, 2.0, 3.0]) +np.testing.assert_allclose(to_numpy(added), [1.0, 2.0, 3.0]) +np.testing.assert_allclose(to_numpy(array_like), [1.0, 2.0, 3.0]) +""", + ) + + assert result.returncode == 0, result.stderr From f5f819b8274ffafb6b47abdfef1475115e4a0d94 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 21:40:55 +0200 Subject: [PATCH 2/2] Fix JAX empty assignment index handling --- src/pyrecest/_backend/jax/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/pyrecest/_backend/jax/__init__.py b/src/pyrecest/_backend/jax/__init__.py index 6ad2afe0a..137164e21 100644 --- a/src/pyrecest/_backend/jax/__init__.py +++ b/src/pyrecest/_backend/jax/__init__.py @@ -8,7 +8,7 @@ import jax.numpy as _jnp from jax import vmap -from jax.numpy import ( # For pyrecest; For Riemannian score-based SDE +from jax.numpy import ( # For pyrecest; For Riemannian Score-based SDE abs, all, allclose, @@ -382,6 +382,10 @@ def _assignment_value_length(values): return 1 +def _is_empty_index_sequence(indices): + return _is_iterable_index(indices) and len(indices) == 0 + + def _normalize_assignment_index(indices, ndim_x, axis=0): if _is_boolean_index(indices): return _jnp.asarray(indices), False, None @@ -445,6 +449,8 @@ def assignment(x, values, indices, axis=0): Copy of x with the values assigned at the given indices. """ x = _jnp.asarray(x) + if _is_empty_index_sequence(indices): + return x normalized_indices, use_vectorization, len_indices = _normalize_assignment_index( indices, x.ndim, @@ -486,6 +492,8 @@ def assignment_by_sum(x, values, indices, axis=0): If a list is given, it must have the same length as indices. """ x = _jnp.asarray(x) + if _is_empty_index_sequence(indices): + return x normalized_indices, use_vectorization, len_indices = _normalize_assignment_index( indices, x.ndim,