From 2c53074a30e9bc6f55adf4cf789e8ec429209589 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 21:59:50 +0200 Subject: [PATCH 1/2] Fix PyTorch fftconvolve Python bool axes --- src/pyrecest/_backend/pytorch/signal.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pyrecest/_backend/pytorch/signal.py b/src/pyrecest/_backend/pytorch/signal.py index b5e07ff49..f4376cd3f 100644 --- a/src/pyrecest/_backend/pytorch/signal.py +++ b/src/pyrecest/_backend/pytorch/signal.py @@ -5,6 +5,8 @@ def _coerce_axis(axis): + if isinstance(axis, bool): + return int(axis) try: axis_array = _np.asarray(axis) except (TypeError, ValueError) as exc: From 275a28194929af216655760693ab3b3dc38446e6 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Wed, 1 Jul 2026 22:00:46 +0200 Subject: [PATCH 2/2] Add PyTorch fftconvolve bool axes regression tests --- .../backend/test_pytorch_signal_bool_axes.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/backend/test_pytorch_signal_bool_axes.py diff --git a/tests/backend/test_pytorch_signal_bool_axes.py b/tests/backend/test_pytorch_signal_bool_axes.py new file mode 100644 index 000000000..7604614db --- /dev/null +++ b/tests/backend/test_pytorch_signal_bool_axes.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import numpy as np +import pytest + +pytorch_backend = pytest.importorskip("pyrecest._backend.pytorch") + + +@pytest.mark.parametrize( + ("axes", "expected"), + [ + ( + False, + [[0.5, 2.0, 1.5], [3.5, 9.0, 7.5], [6.0, 10.0, 9.0]], + ), + ( + True, + [[0.5, 2.0, 4.0, 4.0, 1.5], [6.0, 15.5, 25.0, 19.5, 9.0]], + ), + ], +) +def test_fftconvolve_accepts_python_bool_axes_like_scalar_axes( + axes: bool, + expected: list[list[float]], +) -> None: + in1 = pytorch_backend.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + in2 = pytorch_backend.array([[0.5, 1.0, 0.5], [1.5, 2.0, 1.5]]) + + result = pytorch_backend.signal.fftconvolve(in1, in2, axes=axes) + + assert pytorch_backend.allclose(result, pytorch_backend.array(expected)) + + +def test_fftconvolve_rejects_numpy_bool_axis() -> None: + in1 = pytorch_backend.ones((2, 3)) + in2 = pytorch_backend.ones((2, 3)) + + with pytest.raises((TypeError, ValueError)): + pytorch_backend.signal.fftconvolve(in1, in2, axes=np.bool_(True))