diff --git a/src/pyrecest/_backend/__init__.py b/src/pyrecest/_backend/__init__.py index d3963f8cf..dc16be936 100644 --- a/src/pyrecest/_backend/__init__.py +++ b/src/pyrecest/_backend/__init__.py @@ -328,6 +328,20 @@ def meshgrid(*axes, **kwargs): return meshgrid +def _flip_with_numpy_axis(flip_func): + """Return a flip wrapper accepting NumPy integer axes on strict backends.""" + + @wraps(flip_func) + def flip(x, axis): + if isinstance(axis, _numbers.Integral): + axis = int(axis) + elif axis is not None: + axis = tuple(int(one_axis) for one_axis in axis) + return flip_func(x, axis) + + return flip + + def _mean_with_numpy_signature( mean_func, asarray_func, @@ -716,6 +730,12 @@ def _create_backend_module(self, backend_name: str): getattr(backend, "asarray"), getattr(backend, "atleast_1d"), ) + if ( + module_name == "" + and attribute_name == "flip" + and backend_name == "pytorch" + ): + attribute = _flip_with_numpy_axis(attribute) if ( module_name == "" and attribute_name == "quantile" diff --git a/tests/test_pytorch_flip_numpy_axis.py b/tests/test_pytorch_flip_numpy_axis.py new file mode 100644 index 000000000..f59e27d75 --- /dev/null +++ b/tests/test_pytorch_flip_numpy_axis.py @@ -0,0 +1,13 @@ +from tests.support.backend_runner import run_backend_code + + +def test_pytorch_flip_accepts_numpy_integer_axis(): + code = ''' +import numpy as np +import pyrecest.backend as backend + +result = backend.flip([1, 2, 3], axis=np.int64(0)) +assert backend.to_numpy(result).tolist() == [3, 2, 1] +''' + result = run_backend_code("pytorch", code) + assert result.returncode == 0, result.stderr