From 1ca48fdbdd4ecbe6173aaeb7b0471e25fd039f0a Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 5 May 2026 13:22:52 +0100 Subject: [PATCH] Arm backend: Add TOSA dialect ARGMAX op Signed-off-by: Saoirse Stewart --- .../arm/test/misc/test_tosa_dialect_argmax.py | 43 ++++++++++ backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/argmax.py | 78 +++++++++++++++++++ 3 files changed, 122 insertions(+) create mode 100644 backends/arm/test/misc/test_tosa_dialect_argmax.py create mode 100644 backends/arm/tosa/dialect/ops/argmax.py diff --git a/backends/arm/test/misc/test_tosa_dialect_argmax.py b/backends/arm/test/misc/test_tosa_dialect_argmax.py new file mode 100644 index 00000000000..50985fbf336 --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_argmax.py @@ -0,0 +1,43 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_argmax_tosa_fp() -> None: + sample_input = torch.randn((2, 3, 4), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.ARGMAX.default( + mode.from_tensor(sample_input), + axis=1, + ) + + assert output.dtype == torch.int32 + assert tuple(output.shape) == (2, 4) + + +def test_argmax_rejects_bfloat16_without_extension() -> None: + sample_input = torch.randn((2, 3, 4), dtype=torch.bfloat16) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="doesn't support bfloat16"): + exir_ops.backend.tosa.ARGMAX.default( + mode.from_tensor(sample_input), + axis=1, + ) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 4678da4d118..a03f3c3998f 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -5,6 +5,7 @@ from executorch.backends.arm.tosa.dialect.ops import ( # noqa F401 activation, + argmax, avg_pool2d, avg_pool2d_adaptive, conv2d, diff --git a/backends/arm/tosa/dialect/ops/argmax.py b/backends/arm/tosa/dialect/ops/argmax.py new file mode 100644 index 00000000000..a2717124fcd --- /dev/null +++ b/backends/arm/tosa/dialect/ops/argmax.py @@ -0,0 +1,78 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops._common import validate_nan_mode +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + + +def _validate_argmax_dtype(dtype: torch.dtype) -> None: + tosa_spec = get_context_spec() + + if dtype == torch.int8: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support int8 for ARGMAX", + op="ARGMAX", + ) + return + + if dtype == torch.int16: + if not (tosa_spec.support_integer() and tosa_spec.support_extension("int16")): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support int16 for ARGMAX", + op="ARGMAX", + ) + return + + if dtype in (torch.float16, torch.float32): + if not tosa_spec.support_float(): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support {dtype} for ARGMAX", + op="ARGMAX", + ) + return + + if dtype == torch.bfloat16: + if not (tosa_spec.support_float() and tosa_spec.support_extension("bf16")): + raise TosaValueError( + f"TOSA spec {tosa_spec} doesn't support bfloat16 for ARGMAX", + op="ARGMAX", + ) + return + + raise TosaValueError(f"Unsupported dtype {dtype} for ARGMAX", op="ARGMAX") + + +@register_fake_tosa_op( + 'ARGMAX(Tensor input, int axis, *, str nan_mode="PROPAGATE") -> Tensor', + TosaSpecification.all_versions_and_profiles(), +) +def ARGMAX( + input: torch.Tensor, + axis: int, + *, + nan_mode: str = "PROPAGATE", +) -> torch.Tensor: + validate_nan_mode(nan_mode, "ARGMAX") + _validate_argmax_dtype(input.dtype) + + if input.dim() == 0: + raise TosaValueError( + "ARGMAX requires an input with rank at least 1", op="ARGMAX" + ) + if axis < 0 or axis >= input.dim(): + raise TosaValueError( + f"axis must be in [0, {input.dim() - 1}] but got {axis}", + op="ARGMAX", + ) + + output_shape = tuple(input.shape[:axis]) + tuple(input.shape[axis + 1 :]) + return torch.empty(output_shape, dtype=torch.int32)