From 840ab62f4de590aa94eb7b2014e46d1e79a66f28 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 00:03:03 +0200 Subject: [PATCH 1/2] Fix PyTorch FFT None axis alias handling --- src/pyrecest/_backend/pytorch/fft.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/pyrecest/_backend/pytorch/fft.py b/src/pyrecest/_backend/pytorch/fft.py index 150f97176..8d299b7b7 100644 --- a/src/pyrecest/_backend/pytorch/fft.py +++ b/src/pyrecest/_backend/pytorch/fft.py @@ -73,15 +73,20 @@ def _normalize_fft_dim_sequence(dim): return tuple(normalized_entries) -def _with_dim_alias(kwargs, alias, func_name): +def _with_dim_alias(kwargs, alias, func_name, *, none_alias_means_default=True): if alias not in kwargs: return kwargs kwargs = dict(kwargs) alias_value = kwargs.pop(alias) + dim_value = kwargs.get("dim") if alias_value is None: + if none_alias_means_default: + return kwargs + if dim_value is not None: + raise TypeError("conflicting FFT axis aliases") + kwargs["dim"] = None return kwargs - dim_value = kwargs.get("dim") if dim_value is not None: dim_value = _normalize_fft_dim_sequence(dim_value) alias_value = _normalize_fft_dim_sequence(alias_value) @@ -100,11 +105,17 @@ def _wrap_arraylike_fft( empty_dim_is_noop=False, normalize_scalar_dim=False, normalize_dim_sequence=False, + none_alias_means_default=True, ): @_wraps(torch_func) def fft_func(value, *args, **kwargs): if dim_alias is not None: - kwargs = _with_dim_alias(kwargs, dim_alias, func_name) + kwargs = _with_dim_alias( + kwargs, + dim_alias, + func_name, + none_alias_means_default=none_alias_means_default, + ) if normalize_scalar_dim and "dim" in kwargs: kwargs = dict(kwargs) kwargs["dim"] = _normalize_single_fft_dim(kwargs["dim"]) @@ -128,12 +139,14 @@ def fft_func(value, *args, **kwargs): func_name="rfft", dim_alias="axis", normalize_scalar_dim=True, + none_alias_means_default=False, ) irfft = _wrap_arraylike_fft( _torch.fft.irfft, func_name="irfft", dim_alias="axis", normalize_scalar_dim=True, + none_alias_means_default=False, ) fftshift = _wrap_arraylike_fft( _torch.fft.fftshift, From 6d295deaffa838297bdbcc8561370c8ba09bbd0b Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Fri, 3 Jul 2026 00:03:22 +0200 Subject: [PATCH 2/2] Add regression test for PyTorch FFT None axis aliases --- .../test_pytorch_fft_axis_contract.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/backend_support/test_pytorch_fft_axis_contract.py b/tests/backend_support/test_pytorch_fft_axis_contract.py index 5fc3a2717..f3190a734 100644 --- a/tests/backend_support/test_pytorch_fft_axis_contract.py +++ b/tests/backend_support/test_pytorch_fft_axis_contract.py @@ -45,6 +45,19 @@ def test_raw_pytorch_fft_helpers_accept_numpy_axis_aliases(): ) +@pytest.mark.backend_portable +@pytest.mark.parametrize("fft_func", [pytorch_fft.rfft, pytorch_fft.irfft]) +def test_raw_pytorch_single_axis_fft_rejects_none_axis_alias(fft_func): + with pytest.raises(TypeError): + fft_func(np.arange(4.0), axis=None) + + +@pytest.mark.backend_portable +def test_raw_pytorch_single_axis_fft_rejects_conflicting_none_axis_alias(): + with pytest.raises(TypeError): + pytorch_fft.rfft(np.arange(4.0), axis=None, dim=0) + + @pytest.mark.backend_portable def test_raw_pytorch_fft_none_axis_alias_preserves_explicit_dim(): matrix = np.arange(6.0).reshape(2, 3)