diff --git a/backends/cadence/aot/compiler_utils.py b/backends/cadence/aot/compiler_utils.py index eff3f49abbf..65c448351e4 100644 --- a/backends/cadence/aot/compiler_utils.py +++ b/backends/cadence/aot/compiler_utils.py @@ -87,6 +87,21 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool: ) +def transpose_dims_to_permute_order(ndim: int, dim0: int, dim1: int) -> List[int]: + """ + Convert transpose(dim0, dim1) to an equivalent permute order list. + E.g., transpose(0, 1) on a 3D tensor gives [1, 0, 2]. + """ + # Normalize negative dims + if dim0 < 0: + dim0 += ndim + if dim1 < 0: + dim1 += ndim + order = list(range(ndim)) + order[dim0], order[dim1] = order[dim1], order[dim0] + return order + + def get_transposed_dims( node: torch.fx.Node, dims: Optional[List[int]] = None ) -> List[int]: diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 2853136081b..5fc4ea97bf8 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -678,10 +678,13 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 quant_node, ) elif isinstance(pattern, AddmmPattern): - # Transpose the weight tensor + # Transpose the weight tensor using permute + weight_ndim = len(weights_inputs[0].meta["val"].shape) + perm = list(range(weight_ndim)) + perm[0], perm[1] = perm[1], perm[0] transposed_weights = graph_module.graph.call_function( - torch.ops.aten.transpose.int, - (weights_inputs[0], 0, 1), + torch.ops.aten.permute.default, + (weights_inputs[0], perm), ) assert ( "val" in weights_inputs[0].meta @@ -692,7 +695,7 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 ), "fake_mode is None on weight node" with original_val.fake_mode: transposed_weights.meta["val"] = ( - torch.ops.aten.transpose.int(original_val, 0, 1) + torch.ops.aten.permute.default(original_val, perm) ) copy_node_metadata(transposed_weights, weights_inputs[0]) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 6e2f85fab0f..8d9353ccfa4 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -19,7 +19,10 @@ import torch import torch.fx -from executorch.backends.cadence.aot.compiler_utils import quantize_tensor_multiplier +from executorch.backends.cadence.aot.compiler_utils import ( + quantize_tensor_multiplier, + transpose_dims_to_permute_order, +) from executorch.backends.cadence.aot.fuse_ops import FuseCascadedTransposeOrPermuteOps from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, @@ -355,9 +358,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Handle transpose: if mat2 is a transpose op, extract the original tensor transposed_mat2 = False - if ( - mat2.op == "call_function" - and mat2.target == exir_ops.edge.aten.transpose_copy.int + if mat2.op == "call_function" and ( + mat2.target == exir_ops.edge.aten.transpose_copy.int + or mat2.target == exir_ops.edge.aten.permute_copy.default ): # mat2 is already transposed, so we use the input to the transpose mat2 = cast(torch.fx.Node, mat2.args[0]) @@ -405,9 +408,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Transpose mat2 if it wasn't already transposed if not transposed_mat2: with graph.inserting_before(node): + ndim = len(mat2.meta["val"].shape) + perm = transpose_dims_to_permute_order(ndim, -1, -2) mat2 = graph.call_function( - exir_ops.edge.aten.transpose_copy.int, - args=(mat2, -1, -2), + exir_ops.edge.aten.permute_copy.default, + args=(mat2, perm), ) # Metadata copy important @@ -430,6 +435,35 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return True +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class ReplaceTransposeWithPermutePass(RemoveOrReplacePassInterface): + """ + Replace transpose_copy.int ops with equivalent permute_copy.default ops + to canonicalize on permute as the single layout-change op. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.transpose_copy.int] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + in_tensor = node.args[0] + assert isinstance(in_tensor, torch.fx.Node) + ndim = len(in_tensor.meta["val"].shape) + dim0 = cast(int, node.args[1]) + dim1 = cast(int, node.args[2]) + perm = transpose_dims_to_permute_order(ndim, dim0, dim1) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.permute_copy.default, + args=(in_tensor, perm), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + @register_cadence_pass(CadencePassAttribute(opt_level=1)) class ReplacePermuteWithTransposePass(RemoveOrReplacePassInterface): """ @@ -471,7 +505,6 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: return False - @register_cadence_pass(CadencePassAttribute(opt_level=0)) class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(RemoveOrReplacePassInterface): """ @@ -798,9 +831,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # gather stencil. Also, the first two dimensions of weight must be # transposed/interchanged. assert isinstance(weight, torch.fx.Node) + weight_ndim = len(weight.meta["val"].shape) + perm = transpose_dims_to_permute_order(weight_ndim, 0, 1) transposed_weight = node.graph.call_function( - exir_ops.edge.aten.transpose_copy.int, - args=(weight, 0, 1), + exir_ops.edge.aten.permute_copy.default, + args=(weight, perm), ) transposed_weight.meta = weight.meta @@ -1036,18 +1071,19 @@ def targets(self) -> list[EdgeOpOverload]: def _transpose_dims( self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int ) -> torch.fx.Node: - """Helper function to transpose dims of a node.""" + """Helper function to transpose dims of a node using permute.""" shape = node.meta["val"].shape dim0, dim1 = ( canonicalize_transposed_dim(dim0, shape), canonicalize_transposed_dim(dim1, shape), ) dim0, dim1 = min(dim0, dim1), max(dim0, dim1) - transpose_node = graph.call_function( - exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} + perm = transpose_dims_to_permute_order(len(shape), dim0, dim1) + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, perm), {} ) - transpose_node.meta = node.meta - return transpose_node + permute_node.meta = node.meta + return permute_node def _change_nchw_to_nhwc( self, graph: torch.fx.Graph, node: torch.fx.Node @@ -1263,18 +1299,19 @@ def targets(self) -> list[EdgeOpOverload]: def _transpose_dims( self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int ) -> torch.fx.Node: - """Helper function to transpose dims of a node.""" + """Helper function to transpose dims of a node using permute.""" shape = node.meta["val"].shape dim0, dim1 = ( canonicalize_transposed_dim(dim0, shape), canonicalize_transposed_dim(dim1, shape), ) dim0, dim1 = min(dim0, dim1), max(dim0, dim1) - transpose_node = graph.call_function( - exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} + perm = transpose_dims_to_permute_order(len(shape), dim0, dim1) + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, perm), {} ) - transpose_node.meta = node.meta - return transpose_node + permute_node.meta = node.meta + return permute_node def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the dimension argument @@ -1526,8 +1563,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: if not channel_last: with graph.inserting_before(node): linear_res = graph.call_function( - exir_ops.edge.aten.transpose_copy.int, - args=(linear_res, 1, 2), + exir_ops.edge.aten.permute_copy.default, + args=(linear_res, [0, 2, 1]), ) linear_res.meta = node.meta @@ -1717,8 +1754,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: if not channel_last: with graph.inserting_before(node): linear_res = graph.call_function( - exir_ops.edge.aten.transpose_copy.int, - args=(linear_res, 1, 2), + exir_ops.edge.aten.permute_copy.default, + args=(linear_res, [0, 2, 1]), ) linear_res.meta = node.meta @@ -2375,9 +2412,11 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Transpose Y_arg with graph.inserting_before(node): + Y_ndim = len(Y_tensor_val.shape) + perm = transpose_dims_to_permute_order(Y_ndim, -1, -2) Y_arg_t = graph.call_function( - exir_ops.edge.aten.transpose_copy.int, - args=(Y_arg, -1, -2), + exir_ops.edge.aten.permute_copy.default, + args=(Y_arg, perm), ) Y_arg_t.meta = node.meta @@ -2410,13 +2449,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: result = super().call(graph_module) modified = modified or result.modified if modified: - # Fuse any inserted transpose node with transpose/permute nodes + # Fuse any inserted permute node with transpose/permute nodes # surrounding it. result = FuseCascadedTransposeOrPermuteOps().call(result.graph_module) modified = modified or result.modified - # Replace permute with transpose. - result = ReplacePermuteWithTransposePass().call(result.graph_module) - modified = modified or result.modified return PassResult(result.graph_module, modified) @@ -2640,10 +2676,10 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # graph with another. class CadenceReplaceOpsInGraph: passes = CommonReplacePasses.passes + [ + ReplaceTransposeWithPermutePass, ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceEmptyTensorsWithFullPass, ReplaceFunctionallyEquivalentOpTargets, - ReplacePermuteWithTransposePass, ReplaceConvolutionOptionalArgsWithConcreteArgsPass, ReplaceAddMMWithLinearPass, ReplacePadWithCatPass, @@ -2657,6 +2693,8 @@ class CadenceReplaceOpsInGraph: ReplaceIm2RowWithViewPass, MakeSliceAndCatDimOutermostPass, ReplaceMatmulWithTransposedMatmulPass, + # Convert permutes back to transposes after all passes that create them. + ReplacePermuteWithTransposePass, ReplaceNopTransposeOrPermuteWithViewPass, ReplaceLinearWithFullyConnectedOpPass, ReplaceScalarTensorWithFullPass, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index ee13726a94b..dfb93c2faed 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -37,7 +37,6 @@ ReplaceMulTensorWithMulAndFullOpsPass, ReplaceNopTransposeOrPermuteWithViewPass, ReplacePadWithCatPass, - ReplacePermuteWithTransposePass, ReplacePowWithMulPass, ReplaceRepeatWithCatPass, ReplaceScalarTensorWithFullPass, @@ -172,7 +171,7 @@ def test_replace_matmul_with_transposed_matmul( graph_after_passes = result.graph_module self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), + count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 1, ) self.assertEqual( @@ -647,7 +646,7 @@ def test_replace_transposed_conv_with_linear( weights = builder.placeholder("weights", weights_tensor) transposed_weights = builder.call_operator( - op=exir_ops.edge.aten.transpose_copy.int, args=(weights, 0, 1) + op=exir_ops.edge.aten.permute_copy.default, args=(weights, [1, 0, 2]) ) flipped_weights = builder.call_operator( exir_ops.edge.aten.flip.default, @@ -965,10 +964,7 @@ def test_replace_linear_with_fully_connected(self) -> None: builder.output([mm]) original_gm = builder.get_graph_module() - gm = cast( - PassResult, ReplacePermuteWithTransposePass()(original_gm) - ).graph_module - gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module + gm = cast(PassResult, ReplaceMMWithAddMMPass()(original_gm)).graph_module gm_before_linear = copy.deepcopy(gm) pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) @@ -1029,12 +1025,8 @@ def test_replace_addmm_with_linear(self) -> None: builder.output([addmm]) original_gm = builder.get_graph_module() - gm = cast( - PassResult, ReplacePermuteWithTransposePass()(original_gm) - ).graph_module - - gm_before_linear = copy.deepcopy(gm) - pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) + gm_before_linear = copy.deepcopy(original_gm) + pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(original_gm)) self.assertTrue(pass_result.modified) graph_after_passes = pass_result.graph_module @@ -1077,11 +1069,7 @@ def test_no_replace_addmm_with_linear(self) -> None: builder.output([addmm]) original_gm = builder.get_graph_module() - gm = cast( - PassResult, ReplacePermuteWithTransposePass()(original_gm) - ).graph_module - - pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) + pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(original_gm)) self.assertFalse(pass_result.modified) @torch.no_grad() @@ -1715,63 +1703,6 @@ def test_replace_nop_permute_with_view( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 ) - @expand( - [ - # permutations replaced by transpose - [(3, 4), (1, 0)], - [(3, 4, 6), (0, 2, 1)], - ] - ) - @torch.no_grad() - def test_replace_permute_with_transpose( - self, shape: Tuple[int], dims: Tuple[int] - ) -> None: - x = torch.randn(shape) - original_gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.aten.permute_copy.default, - args=(x, dims), - ) - - gm_before = copy.deepcopy(original_gm) - p = ReplacePermuteWithTransposePass() - result = cast(PassResult, p(original_gm)) - self.assertTrue(result.modified) - graph_after_passes = result.graph_module - inputs = [x] - validate( - gm_before, graph_after_passes, inputs, "ReplacePermuteWithTransposePass" - ) - - # Assert that permute op was replaced by a transpose op - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1 - ) - - @torch.no_grad() - def test_replace_permute_with_transpose_nop( - self, - ) -> None: - x = torch.randn(3, 4) - original_gm = single_op_builder( - placeholders=(x,), - op=exir_ops.edge.aten.permute_copy.default, - args=(x, [0, 1]), - ) - p = ReplacePermuteWithTransposePass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module - - # Assert that permute op was replaced by a transpose op - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 - ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 - ) - class TestReplaceWhereWithFullArgsWithWhereScalar(unittest.TestCase): def test_replace_aten_where_with_cadence(self) -> None: @@ -2727,7 +2658,7 @@ def test_1d_depthwise_convolution_weight_shape(self) -> None: """Test that 1D depthwise conv weight is transformed to [OC, K, 1] format. For 1D depthwise conv with groups == in_channels > 1, the weight should be - transformed from [OC, 1, K] to [OC, K, 1] (3D) via transpose_copy.int, + transformed from [OC, 1, K] to [OC, K, 1] (3D) via permute_copy.default, matching the standard NLC weight format expected by C++ kernels. """ placeholders, gm = self.create_1d_depthwise_convolution_graph_module() @@ -2749,15 +2680,11 @@ def test_1d_depthwise_convolution_weight_shape(self) -> None: # For 1D depthwise: # - Input/output use permute_copy (2 ops) - # - Weight uses transpose_copy.int (1 op) via _transpose_dims + # - Weight uses permute_copy (1 op) via _transpose_dims # - No squeeze_copy needed self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - 2, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), - 1, + 3, ) self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.aten.squeeze_copy.dim), @@ -2847,10 +2774,40 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Validate numerical accuracy validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") - # Assert that no transpose ops were added. The slice is on the second + # Assert that no permute ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. self.assertEqual( - count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + count_node(gm_after_pass, exir_ops.edge.aten.permute_copy.default), + 0, + ) + + def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: + # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 2, 3, 5) + gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + + # Deepcopy before the pass + original = copy.deepcopy(gm) + + # Apply replacement pass. + p = MakeSliceAndCatDimOutermostPass() + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate( + original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" + ) + + # Assert that no permute ops were added. The slice is on the second + # outermost dimension, but the outermost dimension is already 1. + self.assertEqual( + count_node(gm_after_pass, exir_ops.edge.aten.permute_copy.default), 0, ) @@ -2874,9 +2831,9 @@ def test_slice_insert_transpose(self) -> None: # Validate numerical accuracy validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") - # Assert that there are two transpose ops added. + # Assert that there are two permute ops added. self.assertEqual( - count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + count_node(gm_after_pass, exir_ops.edge.aten.permute_copy.default), 2, ) @@ -2915,40 +2872,10 @@ def test_cat_no_transpose_if_already_outermost(self) -> None: original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" ) - # Assert that no transpose ops were added. The slice is on the second + # Assert that no permute ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. self.assertEqual( - count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), - 0, - ) - - def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: - # Create a graph with a single slice node on second outermost dimension. - input1 = torch.randn(1, 1, 3, 5) - input2 = torch.randn(1, 2, 3, 5) - gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) - # Check if graph module is valid by running exportpass on it. - gm = ExportPass().call(gm).graph_module - self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) - - # Deepcopy before the pass - original = copy.deepcopy(gm) - - # Apply replacement pass. - p = MakeSliceAndCatDimOutermostPass() - result = cast(PassResult, p(gm)) - self.assertFalse(result.modified) - gm_after_pass = result.graph_module - - # Validate numerical accuracy - validate( - original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" - ) - - # Assert that no transpose ops were added. The slice is on the second - # outermost dimension, but the outermost dimension is already 1. - self.assertEqual( - count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + count_node(gm_after_pass, exir_ops.edge.aten.permute_copy.default), 0, ) @@ -2977,13 +2904,14 @@ def test_cat_insert_transpose(self) -> None: original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass" ) - # Assert that transpose ops were added to make cat on outermost dimension. + # Assert that permute ops were added to make cat on outermost dimension. self.assertEqual( - count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), + count_node(gm_after_pass, exir_ops.edge.aten.permute_copy.default), 3, ) + class TestReplaceMaxPool2dWithChannelLastMaxPool2dPass(unittest.TestCase): def test_replace_max_pool2d_nchw_with_nhwc(self) -> None: # Create a graph with a single quantized_max_pool2d_nchw node.