Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,21 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
)


def transpose_dims_to_permute_order(ndim: int, dim0: int, dim1: int) -> List[int]:
"""
Convert transpose(dim0, dim1) to an equivalent permute order list.
E.g., transpose(0, 1) on a 3D tensor gives [1, 0, 2].
"""
# Normalize negative dims
if dim0 < 0:
dim0 += ndim
if dim1 < 0:
dim1 += ndim
order = list(range(ndim))
order[dim0], order[dim1] = order[dim1], order[dim0]
return order


def get_transposed_dims(
node: torch.fx.Node, dims: Optional[List[int]] = None
) -> List[int]:
Expand Down
11 changes: 7 additions & 4 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
quant_node,
)
elif isinstance(pattern, AddmmPattern):
# Transpose the weight tensor
# Transpose the weight tensor using permute
weight_ndim = len(weights_inputs[0].meta["val"].shape)
perm = list(range(weight_ndim))
perm[0], perm[1] = perm[1], perm[0]
transposed_weights = graph_module.graph.call_function(
torch.ops.aten.transpose.int,
(weights_inputs[0], 0, 1),
torch.ops.aten.permute.default,
(weights_inputs[0], perm),
)
assert (
"val" in weights_inputs[0].meta
Expand All @@ -692,7 +695,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
), "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = (
torch.ops.aten.transpose.int(original_val, 0, 1)
torch.ops.aten.permute.default(original_val, perm)
)
copy_node_metadata(transposed_weights, weights_inputs[0])

Expand Down
106 changes: 36 additions & 70 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
Expand All @@ -19,7 +19,10 @@

import torch
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import quantize_tensor_multiplier
from executorch.backends.cadence.aot.compiler_utils import (
quantize_tensor_multiplier,
transpose_dims_to_permute_order,
)
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedTransposeOrPermuteOps
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
Expand Down Expand Up @@ -355,9 +358,9 @@

# Handle transpose: if mat2 is a transpose op, extract the original tensor
transposed_mat2 = False
if (
mat2.op == "call_function"
and mat2.target == exir_ops.edge.aten.transpose_copy.int
if mat2.op == "call_function" and (
mat2.target == exir_ops.edge.aten.transpose_copy.int
or mat2.target == exir_ops.edge.aten.permute_copy.default
):
# mat2 is already transposed, so we use the input to the transpose
mat2 = cast(torch.fx.Node, mat2.args[0])
Expand Down Expand Up @@ -405,9 +408,11 @@
# Transpose mat2 if it wasn't already transposed
if not transposed_mat2:
with graph.inserting_before(node):
ndim = len(mat2.meta["val"].shape)
perm = transpose_dims_to_permute_order(ndim, -1, -2)
mat2 = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(mat2, -1, -2),
exir_ops.edge.aten.permute_copy.default,
args=(mat2, perm),
)

# Metadata copy important
Expand All @@ -430,47 +435,6 @@
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface):
"""
Replace permute op with transpose if the permutation is only along
two dimensions.
"""

@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.permute_copy.default]

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the old dim and new dim order
in_tensor = node.args[0]
assert isinstance(in_tensor, torch.fx.Node)
in_shape = in_tensor.meta["val"].shape
old_dims = tuple(range(len(in_shape)))
new_dims = cast(Sequence[int], node.args[1])

# Compute the number of positions in which the old and new order differ
diff = [od for od, nd in zip(old_dims, new_dims) if od != nd]

# If the difference is zero, replace with identity (just the input)
if len(diff) == 0:
node.replace_all_uses_with(in_tensor)
return True

# If the difference is in two dimensions, we can replace this permute op
# with transpose op.
if len(diff) == 2:
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(node.args[0], diff[0], diff[1]),
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True

return False


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface):
Expand Down Expand Up @@ -798,9 +762,11 @@
# gather stencil. Also, the first two dimensions of weight must be
# transposed/interchanged.
assert isinstance(weight, torch.fx.Node)
weight_ndim = len(weight.meta["val"].shape)
perm = transpose_dims_to_permute_order(weight_ndim, 0, 1)
transposed_weight = node.graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(weight, 0, 1),
exir_ops.edge.aten.permute_copy.default,
args=(weight, perm),
)
transposed_weight.meta = weight.meta

Expand Down Expand Up @@ -1036,18 +1002,19 @@
def _transpose_dims(
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
) -> torch.fx.Node:
"""Helper function to transpose dims of a node."""
"""Helper function to transpose dims of a node using permute."""
shape = node.meta["val"].shape
dim0, dim1 = (
canonicalize_transposed_dim(dim0, shape),
canonicalize_transposed_dim(dim1, shape),
)
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
transpose_node = graph.call_function(
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
)
transpose_node.meta = node.meta
return transpose_node
permute_node.meta = node.meta
return permute_node

def _change_nchw_to_nhwc(
self, graph: torch.fx.Graph, node: torch.fx.Node
Expand Down Expand Up @@ -1263,18 +1230,19 @@
def _transpose_dims(
self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int
) -> torch.fx.Node:
"""Helper function to transpose dims of a node."""
"""Helper function to transpose dims of a node using permute."""
shape = node.meta["val"].shape
dim0, dim1 = (
canonicalize_transposed_dim(dim0, shape),
canonicalize_transposed_dim(dim1, shape),
)
dim0, dim1 = min(dim0, dim1), max(dim0, dim1)
transpose_node = graph.call_function(
exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {}
perm = transpose_dims_to_permute_order(len(shape), dim0, dim1)
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, perm), {}
)
transpose_node.meta = node.meta
return transpose_node
permute_node.meta = node.meta
return permute_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the dimension argument
Expand Down Expand Up @@ -1526,8 +1494,8 @@
if not channel_last:
with graph.inserting_before(node):
linear_res = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(linear_res, 1, 2),
exir_ops.edge.aten.permute_copy.default,
args=(linear_res, [0, 2, 1]),
)
linear_res.meta = node.meta

Expand Down Expand Up @@ -1717,8 +1685,8 @@
if not channel_last:
with graph.inserting_before(node):
linear_res = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(linear_res, 1, 2),
exir_ops.edge.aten.permute_copy.default,
args=(linear_res, [0, 2, 1]),
)
linear_res.meta = node.meta

Expand Down Expand Up @@ -2375,9 +2343,11 @@

# Transpose Y_arg
with graph.inserting_before(node):
Y_ndim = len(Y_tensor_val.shape)
perm = transpose_dims_to_permute_order(Y_ndim, -1, -2)
Y_arg_t = graph.call_function(
exir_ops.edge.aten.transpose_copy.int,
args=(Y_arg, -1, -2),
exir_ops.edge.aten.permute_copy.default,
args=(Y_arg, perm),
)
Y_arg_t.meta = node.meta

Expand Down Expand Up @@ -2410,13 +2380,10 @@
result = super().call(graph_module)
modified = modified or result.modified
if modified:
# Fuse any inserted transpose node with transpose/permute nodes
# Fuse any inserted permute node with transpose/permute nodes
# surrounding it.
result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module)
modified = modified or result.modified
# Replace permute with transpose.
result = ReplacePermuteWithTransposePass().call(result.graph_module)
modified = modified or result.modified

return PassResult(result.graph_module, modified)

Expand Down Expand Up @@ -2643,7 +2610,6 @@
ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass,
ReplaceEmptyTensorsWithFullPass,
ReplaceFunctionallyEquivalentOpTargets,
ReplacePermuteWithTransposePass,
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
ReplaceAddMMWithLinearPass,
ReplacePadWithCatPass,
Expand Down
77 changes: 4 additions & 73 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
ReplaceMulTensorWithMulAndFullOpsPass,
ReplaceNopTransposeOrPermuteWithViewPass,
ReplacePadWithCatPass,
ReplacePermuteWithTransposePass,
ReplacePowWithMulPass,
ReplaceRepeatWithCatPass,
ReplaceScalarTensorWithFullPass,
Expand Down Expand Up @@ -965,10 +964,7 @@ def test_replace_linear_with_fully_connected(self) -> None:
builder.output([mm])
original_gm = builder.get_graph_module()

gm = cast(
PassResult, ReplacePermuteWithTransposePass()(original_gm)
).graph_module
gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module
gm = cast(PassResult, ReplaceMMWithAddMMPass()(original_gm)).graph_module

gm_before_linear = copy.deepcopy(gm)
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
Expand Down Expand Up @@ -1029,12 +1025,8 @@ def test_replace_addmm_with_linear(self) -> None:
builder.output([addmm])
original_gm = builder.get_graph_module()

gm = cast(
PassResult, ReplacePermuteWithTransposePass()(original_gm)
).graph_module

gm_before_linear = copy.deepcopy(gm)
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
gm_before_linear = copy.deepcopy(original_gm)
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(original_gm))
self.assertTrue(pass_result.modified)
graph_after_passes = pass_result.graph_module

Expand Down Expand Up @@ -1077,11 +1069,7 @@ def test_no_replace_addmm_with_linear(self) -> None:
builder.output([addmm])
original_gm = builder.get_graph_module()

gm = cast(
PassResult, ReplacePermuteWithTransposePass()(original_gm)
).graph_module

pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm))
pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(original_gm))
self.assertFalse(pass_result.modified)

@torch.no_grad()
Expand Down Expand Up @@ -1715,63 +1703,6 @@ def test_replace_nop_permute_with_view(
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1
)

@expand(
[
# permutations replaced by transpose
[(3, 4), (1, 0)],
[(3, 4, 6), (0, 2, 1)],
]
)
@torch.no_grad()
def test_replace_permute_with_transpose(
self, shape: Tuple[int], dims: Tuple[int]
) -> None:
x = torch.randn(shape)
original_gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.permute_copy.default,
args=(x, dims),
)

gm_before = copy.deepcopy(original_gm)
p = ReplacePermuteWithTransposePass()
result = cast(PassResult, p(original_gm))
self.assertTrue(result.modified)
graph_after_passes = result.graph_module
inputs = [x]
validate(
gm_before, graph_after_passes, inputs, "ReplacePermuteWithTransposePass"
)

# Assert that permute op was replaced by a transpose op
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1
)

@torch.no_grad()
def test_replace_permute_with_transpose_nop(
self,
) -> None:
x = torch.randn(3, 4)
original_gm = single_op_builder(
placeholders=(x,),
op=exir_ops.edge.aten.permute_copy.default,
args=(x, [0, 1]),
)
p = ReplacePermuteWithTransposePass()
graph_after_passes = cast(PassResult, p(original_gm)).graph_module

# Assert that permute op was replaced by a transpose op
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0
)


class TestReplaceWhereWithFullArgsWithWhereScalar(unittest.TestCase):
def test_replace_aten_where_with_cadence(self) -> None:
Expand Down
Loading