From 39204850fe0827272b5e5c4672892a8aa5cf8d69 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 4 May 2026 11:07:34 +0200 Subject: [PATCH] Arm backend: Add TOSA dialect data layout ops Adds TOSA dialect fake implementations for CONCAT, RESHAPE, REVERSE, TILE and TRANSPOSE. Also moves PAD and SLICE into data_layout_ops.py. Signed-off-by: Oscar Andersson Change-Id: I93adb38dcfa4382b0bb60853c45db252de5f4250 --- .../misc/tosa_dialect/test_data_layout_ops.py | 180 ++++++++++++ backends/arm/tosa/dialect/__init__.py | 3 +- .../arm/tosa/dialect/ops/data_layout_ops.py | 274 ++++++++++++++++++ backends/arm/tosa/dialect/ops/pad.py | 61 ---- backends/arm/tosa/dialect/ops/slice.py | 65 ----- 5 files changed, 455 insertions(+), 128 deletions(-) create mode 100644 backends/arm/test/misc/tosa_dialect/test_data_layout_ops.py create mode 100644 backends/arm/tosa/dialect/ops/data_layout_ops.py delete mode 100644 backends/arm/tosa/dialect/ops/pad.py delete mode 100644 backends/arm/tosa/dialect/ops/slice.py diff --git a/backends/arm/test/misc/tosa_dialect/test_data_layout_ops.py b/backends/arm/test/misc/tosa_dialect/test_data_layout_ops.py new file mode 100644 index 00000000000..35074fc32b5 --- /dev/null +++ b/backends/arm/test/misc/tosa_dialect/test_data_layout_ops.py @@ -0,0 +1,180 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import pytest +import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def _fake_tensor(dtype: torch.dtype, mode: FakeTensorMode) -> torch.Tensor: + return mode.from_tensor(torch.empty((2, 3), dtype=dtype)) + + +_DATA_LAYOUT_OPS = [ + pytest.param( + lambda x: exir_ops.backend.tosa.CONCAT.default([x, x], axis=0), + (4, 3), + id="concat", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.PAD.default(x, [1, 2, 3, 4], value=0), + (5, 10), + id="pad", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.RESHAPE.default(x, [3, 2]), + (3, 2), + id="reshape", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.REVERSE.default(x, axis=0), + (2, 3), + id="reverse", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.SLICE.default(x, [0, 1], [2, 2]), + (2, 2), + id="slice", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.TILE.default(x, [1, 2]), + (2, 6), + id="tile", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.TRANSPOSE.default(x, [1, 0]), + (3, 2), + id="transpose", + ), +] + +_POSITIVE_DTYPES = [ + pytest.param("TOSA-1.1+FP", torch.float32, id="fp32"), + pytest.param("TOSA-1.1+INT", torch.int32, id="int32"), + pytest.param("TOSA-1.1+FP", torch.bool, id="bool"), + pytest.param("TOSA-1.1+INT+int64", torch.int64, id="int64"), + pytest.param("TOSA-1.1+FP+bf16", torch.bfloat16, id="bf16"), + pytest.param("TOSA-1.1+FP+fp8e4m3", torch.float8_e4m3fn, id="fp8e4m3"), + pytest.param("TOSA-1.1+FP+fp8e5m2", torch.float8_e5m2, id="fp8e5m2"), +] + + +@pytest.mark.parametrize("spec,dtype", _POSITIVE_DTYPES) +@pytest.mark.parametrize("op,expected_shape", _DATA_LAYOUT_OPS) +def test_data_layout_ops_positive(op, expected_shape, spec, dtype) -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string(spec) + ), FakeTensorMode() as mode: + output = op(_fake_tensor(dtype, mode)) + + assert output.dtype == dtype + assert tuple(output.shape) == expected_shape + + +@pytest.mark.parametrize( + "op,error_match", + [ + pytest.param( + lambda x: exir_ops.backend.tosa.CONCAT.default([x, x], axis=2), + "out of range", + id="concat", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.PAD.default(x, [0, -1, 0, 0], value=0), + "non-negative", + id="pad", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.RESHAPE.default(x, [-2, -3]), + "Negative dimension", + id="reshape", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.REVERSE.default(x, axis=2), + "out of range", + id="reverse", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.SLICE.default(x, [0, 0], [2, 0]), + r"Expected start \+ size", + id="slice", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.TILE.default(x, [0, 1]), + "TILE multiples must be positive", + id="tile", + ), + pytest.param( + lambda x: exir_ops.backend.tosa.TRANSPOSE.default(x, [0, 0]), + "Invalid permutation", + id="transpose", + ), + ], +) +def test_data_layout_ops_reject_invalid_arguments(op, error_match) -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match=error_match): + op(_fake_tensor(torch.float32, mode)) + + +@pytest.mark.parametrize("op,expected_shape", _DATA_LAYOUT_OPS) +def test_data_layout_ops_reject_int64_without_extension(op, expected_shape) -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="Unsupported dtype"): + op(_fake_tensor(torch.int64, mode)) + + +def test_int16_data_layout_dtype_support_follows_tosa_spec() -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + x = _fake_tensor(torch.int16, mode) + + assert exir_ops.backend.tosa.RESHAPE.default(x, [3, 2]).dtype == torch.int16 + assert exir_ops.backend.tosa.REVERSE.default(x, axis=0).dtype == torch.int16 + assert exir_ops.backend.tosa.TILE.default(x, [1, 1]).dtype == torch.int16 + + with pytest.raises(TosaValueError, match="Unsupported dtype"): + exir_ops.backend.tosa.CONCAT.default([x, x], axis=0) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT+int16") + ), FakeTensorMode() as mode: + x = _fake_tensor(torch.int16, mode) + assert exir_ops.backend.tosa.CONCAT.default([x, x], axis=0).dtype == torch.int16 + + +def test_pad_rejects_wrong_padding_length() -> None: + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="Padding length"): + exir_ops.backend.tosa.PAD.default( + mode.from_tensor(torch.randn((2, 3), dtype=torch.float32)), + [1, 2], + value=0.0, + ) + + +def test_reshape_rejects_size_change(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="same number of elements"): + exir_ops.backend.tosa.RESHAPE.default( + mode.from_tensor(torch.randn((2, 3), dtype=torch.float32)), + [5], + ) diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index 9f16720d893..10f526f40b6 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -10,6 +10,7 @@ conv2d, conv3d, custom, + data_layout_ops, depthwise_conv2d, fft, gather, @@ -17,13 +18,11 @@ matmul, max_pool2d, max_pool2d_adaptive, - pad, reduction_ops, rescale, resize, scatter, shape_ops, - slice, table, transpose_conv2d, unary_elementwise, diff --git a/backends/arm/tosa/dialect/ops/data_layout_ops.py b/backends/arm/tosa/dialect/ops/data_layout_ops.py new file mode 100644 index 00000000000..f7b8e2e1825 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/data_layout_ops.py @@ -0,0 +1,274 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Iterable + +import torch + +from executorch.backends.arm.constants import MAX_RANK +from executorch.backends.arm.tosa.dialect.lib import TosaValueError +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import ( + get_context_spec, + TosaSpecification, +) + + +def _supported_data_layout_dtypes( + allow_int16_without_extension: bool, +) -> set[torch.dtype]: + tosa_spec = get_context_spec() + supported_dtypes = {torch.bool} + + if tosa_spec.support_integer(): + supported_dtypes.update({torch.int8, torch.int32}) + if allow_int16_without_extension or tosa_spec.support_extension("int16"): + supported_dtypes.add(torch.int16) + if tosa_spec.support_float(): + supported_dtypes.update({torch.float16, torch.float32}) + if tosa_spec.support_extension("int64"): + supported_dtypes.add(torch.int64) + if tosa_spec.support_extension("bf16"): + supported_dtypes.add(torch.bfloat16) + if tosa_spec.support_extension("fp8e4m3"): + supported_dtypes.add(torch.float8_e4m3fn) + if tosa_spec.support_extension("fp8e5m2"): + supported_dtypes.add(torch.float8_e5m2) + + return supported_dtypes + + +def _validate_data_layout_dtype( + dtype: torch.dtype, op: str, allow_int16_without_extension: bool = True +) -> None: + supported_dtypes = _supported_data_layout_dtypes(allow_int16_without_extension) + if dtype not in supported_dtypes: + raise TosaValueError( + f"Unsupported dtype {dtype} for {op}. Supported dtypes are {supported_dtypes}", + op=op, + ) + + +def _validate_data_layout_tensor( + x: torch.Tensor, op: str, allow_int16_without_extension: bool = True +) -> None: + _validate_data_layout_dtype(x.dtype, op, allow_int16_without_extension) + + +def _validate_concat_tensor(x: torch.Tensor) -> None: + _validate_data_layout_tensor(x, "CONCAT", allow_int16_without_extension=False) + + +def _shape_product(shape: Iterable[int | torch.SymInt], op: str) -> int | torch.SymInt: + result: int | torch.SymInt = 1 + for dim in shape: + if dim < 0: + raise TosaValueError( + f"Negative dimension {dim} is not allowed in shape {shape}", + op=op, + ) + result = result * dim + return result + + +@register_fake_tosa_op( + "CONCAT(Tensor[] input1, *, int axis) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def CONCAT(inputs: list[torch.Tensor], *, axis: int) -> torch.Tensor: + if not inputs: + raise TosaValueError("CONCAT requires at least one input tensor", op="CONCAT") + + reference = inputs[0] + _validate_concat_tensor(reference) + + if axis < 0 or axis >= max(1, reference.dim()): + raise TosaValueError( + f"CONCAT axis {axis} is out of range for rank {reference.dim()}", + op="CONCAT", + ) + + output_shape = list(reference.shape) + axis_sum = 0 + for tensor in inputs: + _validate_concat_tensor(tensor) + if tensor.dtype != reference.dtype: + raise TosaValueError( + "CONCAT requires matching dtypes, got " + f"{reference.dtype} and {tensor.dtype}", + op="CONCAT", + ) + if tensor.dim() < 1 or tensor.dim() > MAX_RANK: + raise TosaValueError( + f"CONCAT input tensors must have rank between 1 and {MAX_RANK}, got {tensor.dim()}", + op="CONCAT", + ) + if tensor.dim() != reference.dim(): + raise TosaValueError( + "CONCAT requires matching ranks, got " + f"{reference.dim()} and {tensor.dim()}", + op="CONCAT", + ) + for dim, (lhs, rhs) in enumerate(zip(reference.shape, tensor.shape)): + if dim != axis and lhs != rhs: + raise TosaValueError( + "CONCAT requires matching non-axis dimensions, " + f"got {tuple(reference.shape)} and {tuple(tensor.shape)}", + op="CONCAT", + ) + axis_sum = axis_sum + tensor.shape[axis] + + output_shape[axis] = axis_sum + return torch.empty(size=output_shape, dtype=reference.dtype) + + +@register_fake_tosa_op( + "PAD(Tensor input1, SymInt[] padding, *, Scalar value) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def PAD(x: torch.Tensor, padding: list[int | torch.SymInt], *, value) -> torch.Tensor: + _validate_data_layout_dtype(x.dtype, "PAD") + + if len(padding) != 2 * len(x.shape): + raise TosaValueError( + f"Padding length {len(padding)} is not compatible with input rank {len(x.shape)}", + op="PAD", + ) + + output_shape: list[int | torch.SymInt] = [] + for i, dim in enumerate(x.shape): + pad_before = padding[i * 2] + pad_after = padding[i * 2 + 1] + if pad_before < 0 or pad_after < 0: + raise TosaValueError( + f"Expected padding values to be non-negative, got {pad_before} and {pad_after}", + op="PAD", + ) + output_shape.append(pad_before + dim + pad_after) + + return torch.empty(size=output_shape, dtype=x.dtype) + + +@register_fake_tosa_op( + "RESHAPE(Tensor input1, SymInt[] shape) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def RESHAPE(x: torch.Tensor, shape: list[int | torch.SymInt]) -> torch.Tensor: + _validate_data_layout_tensor(x, "RESHAPE") + if _shape_product(x.shape, "RESHAPE") != _shape_product(shape, "RESHAPE"): + raise TosaValueError( + "RESHAPE requires the same number of elements, got " + f"{tuple(x.shape)} -> {tuple(shape)}", + op="RESHAPE", + ) + return torch.empty(size=shape, dtype=x.dtype) + + +@register_fake_tosa_op( + "REVERSE(Tensor input1, *, int axis) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def REVERSE(x: torch.Tensor, *, axis: int) -> torch.Tensor: + _validate_data_layout_tensor(x, "REVERSE") + if x.dim() < 1: + raise TosaValueError("REVERSE requires rank >= 1 input", op="REVERSE") + if axis < 0 or axis >= x.dim(): + raise TosaValueError( + f"REVERSE axis {axis} is out of range for rank {x.dim()}", + op="REVERSE", + ) + return torch.empty_like(x) + + +@register_fake_tosa_op( + "SLICE(Tensor input1, SymInt[] start, SymInt[] size) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def SLICE( + x: torch.Tensor, start: list[int | torch.SymInt], size: list[int | torch.SymInt] +) -> torch.Tensor: + input_rank = x.dim() + if input_rank != len(start): + raise TosaValueError( + f"start list does not have the same rank {len(start)} as input {input_rank}", + op="SLICE", + ) + if len(start) != len(size): + raise TosaValueError( + f"size list does not have the same rank {len(size)} as start list {len(start)}", + op="SLICE", + ) + + for i, dim_start in enumerate(start): + if dim_start < 0 or dim_start > x.shape[i]: + raise TosaValueError( + f"Expected start values between [0, {x.shape[i]}] but got {dim_start}", + op="SLICE", + ) + dim_size = size[i] + if dim_size <= 0 or dim_start + dim_size > x.shape[i]: + raise TosaValueError( + f"Expected start + size values between [0, {x.shape[i]}] but got {dim_start + dim_size}", + op="SLICE", + ) + + _validate_data_layout_dtype(x.dtype, "SLICE") + + return torch.empty(size=size, dtype=x.dtype) + + +@register_fake_tosa_op( + "TILE(Tensor input1, SymInt[] multiples) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def TILE(x: torch.Tensor, multiples: list[int | torch.SymInt]) -> torch.Tensor: + _validate_data_layout_tensor(x, "TILE") + if len(multiples) != x.dim(): + raise TosaValueError( + f"TILE multiples length {len(multiples)} does not match rank {x.dim()}", + op="TILE", + ) + output_shape = [] + for dim, multiple in enumerate(multiples): + if multiple <= 0: + raise TosaValueError( + f"TILE multiples must be positive, got {multiple} at dimension {dim}", + op="TILE", + ) + output_shape.append(x.shape[dim] * multiple) + return torch.empty(size=output_shape, dtype=x.dtype) + + +@register_fake_tosa_op( + "TRANSPOSE(Tensor input, int[] perms) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def TRANSPOSE(x: torch.Tensor, perms: list[int]) -> torch.Tensor: + _validate_data_layout_tensor(x, "TRANSPOSE") + input_rank = x.dim() + + if input_rank < 1 or input_rank > MAX_RANK: + raise TosaValueError( + f"TRANSPOSE requires rank in [1, {MAX_RANK}], got {input_rank}", + op="TRANSPOSE", + ) + + if len(perms) != input_rank: + raise TosaValueError( + f"Expected permutation rank {input_rank}, got {len(perms)}", + op="TRANSPOSE", + ) + + seen_dims: set[int] = set() + for dim in perms: + if dim < 0 or dim >= input_rank or dim in seen_dims: + raise TosaValueError( + f"Invalid permutation {perms} for rank-{input_rank} input", + op="TRANSPOSE", + ) + seen_dims.add(dim) + + output_shape = [x.shape[dim] for dim in perms] + return torch.empty(size=output_shape, dtype=x.dtype) diff --git a/backends/arm/tosa/dialect/ops/pad.py b/backends/arm/tosa/dialect/ops/pad.py deleted file mode 100644 index 3b5628b0ede..00000000000 --- a/backends/arm/tosa/dialect/ops/pad.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import List - -import torch - -from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op - -from executorch.backends.arm.tosa.specification import ( - get_context_spec, - TosaSpecification, -) - - -@register_fake_tosa_op( - "PAD(Tensor input1, SymInt[] padding, *, Scalar value) -> Tensor", # schema - ( - TosaSpecification.create_from_string("TOSA-1.0+INT"), - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ), # target TOSA specifications -) -def PAD(a: torch.Tensor, padding: List[int | torch.SymInt], *, value): - tosa_spec = get_context_spec() - - supported_dtypes = {torch.bool} - if tosa_spec.support_integer(): - supported_dtypes.update({torch.int8, torch.int16, torch.int32}) - if tosa_spec.support_float(): - supported_dtypes.update({torch.float16, torch.float32}) - if tosa_spec.support_extension("bf16"): - supported_dtypes.add(torch.bfloat16) - if tosa_spec.support_extension("fp8e4m3"): - supported_dtypes.add(torch.float8_e4m3fn) - if tosa_spec.support_extension("fp8e5m2"): - supported_dtypes.add(torch.float8_e5m2) - if a.dtype not in supported_dtypes: - raise TosaValueError( - f"Input tensor dtype {a.dtype} is not supported by the target TOSA specification." - f" Supported dtypes are: {supported_dtypes}", - op="PAD", - ) - - if len(padding) != 2 * len(a.shape): - raise TosaValueError( - f"Padding length {len(padding)} is not compatible with input rank {len(a.shape)}", - op="PAD", - ) - - # new shape: - new_shape: List[int | torch.SymInt] = [] - for i, d in enumerate(a.shape): - pad_before = padding[i * 2] - pad_after = padding[i * 2 + 1] - new_shape.append(pad_before + d + pad_after) - - # return a new tensor with the new shape - return torch.empty(size=new_shape, dtype=a.dtype) diff --git a/backends/arm/tosa/dialect/ops/slice.py b/backends/arm/tosa/dialect/ops/slice.py deleted file mode 100644 index 3406ccf911b..00000000000 --- a/backends/arm/tosa/dialect/ops/slice.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -from executorch.backends.arm.tosa.dialect.lib import TosaValueError -from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op - -from executorch.backends.arm.tosa.specification import ( - get_context_spec, - TosaSpecification, -) - - -@register_fake_tosa_op( - "SLICE(Tensor input1, SymInt[] start, SymInt[] size) -> Tensor", # schema - TosaSpecification.all_versions_and_profiles(), # target TOSA specifications -) -def SLICE(a, start, size): - tosa_spec = get_context_spec() - - # Rank validation - input_rank = a.dim() - if input_rank != len(start): - raise TosaValueError( - f"start list does not have the same rank {len(start)} as input {input_rank}" - ) - if len(start) != len(size): - raise TosaValueError( - f"size list does not have the same rank {len(size)} as start list {len(start)}" - ) - - # Shape validation - for i in range(len(start)): - dim_start = start[i] - if dim_start < 0 or dim_start > a.shape[i]: - raise TosaValueError( - f"Expected start values between [0, {a.shape[i]}] but got {dim_start}" - ) - dim_size = size[i] - if dim_size < 0 or dim_start + dim_size > a.shape[i]: - raise TosaValueError( - f"Expected start + size values between [0, {a.shape[i]}] but got {dim_start + dim_size}" - ) - - # Dtype validation - supported_dtypes = [torch.bool] - if tosa_spec.support_integer(): - supported_dtypes += [torch.int8, torch.int16, torch.int32] - if tosa_spec.support_float(): - supported_dtypes += [torch.float16, torch.float32] - if tosa_spec.support_extension("bf16"): - supported_dtypes += [torch.bfloat16] - if tosa_spec.support_extension("fp8e4m3"): - supported_dtypes += [torch.float8_e4m3fn] - if tosa_spec.support_extension("fp8e5m2"): - supported_dtypes += [torch.float8_e5m2] - - if a.dtype not in supported_dtypes: - raise TosaValueError( - f"Unsupported dtype {a.dtype} for SLICE. Supported dtypes are {supported_dtypes}" - ) - - return torch.empty(size=size, dtype=a.dtype)