diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 9848e2fb9a7..4f557f210a9 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -399,6 +399,16 @@ - arg_meta: null kernel_name: impl::generic::quantized_conv1d_nlc_per_tensor_out +- func: cadence::quantized_depthwise_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_depthwise_conv1d_ncl_per_tensor_out + +- func: cadence::quantized_depthwise_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::generic::quantized_depthwise_conv1d_nlc_per_tensor_out + - func: cadence::quantized_conv2d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 92e82e6e7de..59275590905 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -262,6 +262,18 @@ def register_fake( lib.define( "quantized_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_depthwise_conv1d_ncl.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_depthwise_conv1d_ncl.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_depthwise_conv1d_nlc.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_depthwise_conv1d_nlc.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_conv2d_nchw(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Z)" ) @@ -1256,6 +1268,78 @@ def quantized_conv1d_nlc_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_depthwise_conv1d_ncl.per_tensor") +def quantized_depthwise_conv1d_ncl_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + # NCL format: input is [N, C, L], weight is [OC, IC/groups, K] + out_channels, _, kernel_size = weight.shape + + in_size = input.shape + assert len(in_size) == 3 + + output_size = get_conv1d_output_size( + in_size, + out_channels, + stride[-1], + padding[-1], + dilation[-1], + kernel_size, + False, + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_depthwise_conv1d_nlc.per_tensor") +def quantized_depthwise_conv1d_nlc_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + # NLC format: input is [N, L, C], weight is [OC, K, IC/groups] + out_channels, kernel_size, _ = weight.shape + + in_size = input.shape + assert len(in_size) == 3 + + output_size = get_conv1d_output_size( + in_size, + out_channels, + stride[-1], + padding[-1], + dilation[-1], + kernel_size, + True, + ) + + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::quantized_conv2d_nchw") def quantized_conv2d_nchw_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 2853136081b..63e04d549a0 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -12,6 +12,7 @@ import torch from executorch.backends.cadence.aot.compiler_utils import get_shape from executorch.backends.cadence.aot.pass_utils import get_arg +from executorch.backends.cadence.aot.utils import is_depthwise_conv from executorch.backends.cadence.aot.quantizer.patterns import ( AddmmPattern, AddPattern, @@ -758,8 +759,23 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 op_node, ) + # Determine the replacement op, routing depthwise conv1d + # to the dedicated depthwise operator. + replacement_op = pattern.replacement_op() + if ( + replacement_op + == torch.ops.cadence.quantized_conv1d_ncl.per_tensor + ): + groups = kwargs.get("groups", 1) + # NCL format: input shape is [N, C, L] + in_channels = args[0].meta["val"].shape[1] + if is_depthwise_conv(groups, in_channels): + replacement_op = ( + torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor + ) + fused = graph_module.graph.call_function( - pattern.replacement_op(), + replacement_op, args, kwargs, ) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 8404fe25268..98dc68d2434 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1136,6 +1136,122 @@ def quantized_conv1d_nlc( ) +@impl_tracked(m, "quantized_depthwise_conv1d_ncl.per_tensor") +def quantized_depthwise_conv1d_ncl_per_tensor( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int], + padding: tuple[int], + dilation: tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + """ + Quantized depthwise 1D convolution in NCL (channels-first) format. + + This op only handles depthwise convolutions (groups == in_channels, groups > 1). + Regular convolutions must use quantized_conv1d_ncl instead. + + Args: + - input_tensor (Tensor): [N, C, L] format + - weight (Tensor): [OC, 1, K] format (IC/groups == 1 for depthwise) + - bias (Tensor): [OC] + - stride, padding, dilation, groups: convolution parameters + - in_zero_point, weight_zero_point, bias_scale: quantization params + - output_scale, output_zero_point: output quantization params + - out_multiplier, out_shift: unused + """ + assert is_depthwise_conv( + groups, input_tensor.shape[1] + ), f"quantized_depthwise_conv1d_ncl requires depthwise conv (groups == in_channels), got groups={groups}, in_channels={input_tensor.shape[1]}" + + return quantized_conv_per_tensor( + input_tensor, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + +@impl_tracked(m, "quantized_depthwise_conv1d_nlc.per_tensor") +def quantized_depthwise_conv1d_nlc_per_tensor( + input_tensor: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: tuple[int], + padding: tuple[int], + dilation: tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + """ + Quantized depthwise 1D convolution in NLC (channels-last) format. + + This op only handles depthwise convolutions (groups == in_channels, groups > 1). + Regular convolutions must use quantized_conv1d_nlc instead. + + Args: + - input_tensor (Tensor): [N, L, C] format + - weight (Tensor): [OC, K, 1] format (IC/groups == 1 for depthwise) + - bias (Tensor): [OC] + - stride, padding, dilation, groups: convolution parameters + - in_zero_point, weight_zero_point, bias_scale: quantization params + - output_scale, output_zero_point: output quantization params + - out_multiplier, out_shift: unused + """ + assert is_depthwise_conv( + groups, input_tensor.shape[-1] + ), f"quantized_depthwise_conv1d_nlc requires depthwise conv (groups == in_channels), got groups={groups}, in_channels={input_tensor.shape[-1]}" + + # Convert NLC to NCL for processing + input_ncl = input_tensor.permute(0, 2, 1).contiguous() + # Convert weight from [OC, K, IC/groups] to [OC, IC/groups, K] + weight_ncl = weight.permute(0, 2, 1).contiguous() + + result_ncl = quantized_conv_per_tensor( + input_ncl, + weight_ncl, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + ) + + # Convert result back to NLC format + return result_ncl.permute(0, 2, 1).contiguous() + + @impl_tracked(m, "quantized_conv2d_nchw") def quantized_conv2d_nchw( input_tensor: torch.Tensor, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 6e2f85fab0f..d46686cc84b 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1030,6 +1030,7 @@ def targets(self) -> list[EdgeOpOverload]: exir_ops.edge.cadence.conv2d.default, exir_ops.edge.cadence.conv3d.default, exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor, + exir_ops.edge.cadence.quantized_depthwise_conv1d_ncl.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, ] @@ -1114,6 +1115,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: assert isinstance(node.target, EdgeOpOverload) quantized_op = node.target in { exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor, + exir_ops.edge.cadence.quantized_depthwise_conv1d_ncl.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, } @@ -1132,7 +1134,15 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor else: assert len(input_shape) == 3 - new_op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor + if ( + node.target + == exir_ops.edge.cadence.quantized_depthwise_conv1d_ncl.per_tensor + ): + new_op = ( + exir_ops.edge.cadence.quantized_depthwise_conv1d_nlc.per_tensor + ) + else: + new_op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor else: new_op = node.target diff --git a/backends/cadence/generic/operators/op_quantized_depthwise_conv1d_ncl.cpp b/backends/cadence/generic/operators/op_quantized_depthwise_conv1d_ncl.cpp new file mode 100644 index 00000000000..ddc8ff44b3a --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_depthwise_conv1d_ncl.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// Depthwise conv1d NCL: delegates to the regular conv1d NCL implementation +// which already handles grouped (depthwise) convolution correctly via +// ocpg/icpg decomposition. This operator exists as a separate entry point +// so that depthwise and regular conv1d are cleanly separated at the graph +// level, enabling independent optimization. +::executorch::aten::Tensor& quantized_depthwise_conv1d_ncl_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t input_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out) { + return quantized_conv1d_ncl_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + input_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/op_quantized_depthwise_conv1d_nlc.cpp b/backends/cadence/generic/operators/op_quantized_depthwise_conv1d_nlc.cpp new file mode 100644 index 00000000000..2ae06a651d2 --- /dev/null +++ b/backends/cadence/generic/operators/op_quantized_depthwise_conv1d_nlc.cpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace impl { +namespace generic { +namespace native { + +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +// Depthwise conv1d NLC: delegates to the regular conv1d NLC implementation +// which already handles grouped (depthwise) convolution correctly via +// ocpg/icpg decomposition. This operator exists as a separate entry point +// so that depthwise and regular conv1d are cleanly separated at the graph +// level, enabling independent optimization. +::executorch::aten::Tensor& quantized_depthwise_conv1d_nlc_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t input_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + int64_t out_multiplier, + int64_t out_shift, + Tensor& out) { + return quantized_conv1d_nlc_per_tensor_out( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + input_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out_multiplier, + out_shift, + out); +} + +} // namespace native +} // namespace generic +} // namespace impl diff --git a/backends/cadence/generic/operators/targets.bzl b/backends/cadence/generic/operators/targets.bzl index fa6708a188e..12246da8f3a 100644 --- a/backends/cadence/generic/operators/targets.bzl +++ b/backends/cadence/generic/operators/targets.bzl @@ -146,6 +146,28 @@ def define_common_targets(): visibility = ["PUBLIC"], ) + runtime.cxx_library( + name = "op_quantized_depthwise_conv1d_ncl", + srcs = ["op_quantized_depthwise_conv1d_ncl.cpp"], + platforms = CXX, + deps = [ + ":op_quantized_conv1d_ncl", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["PUBLIC"], + ) + + runtime.cxx_library( + name = "op_quantized_depthwise_conv1d_nlc", + srcs = ["op_quantized_depthwise_conv1d_nlc.cpp"], + platforms = CXX, + deps = [ + ":op_quantized_conv1d_nlc", + "//executorch/runtime/kernel:kernel_includes", + ], + visibility = ["PUBLIC"], + ) + runtime.cxx_library( name = "op_quantized_conv2d", srcs = ["op_quantized_conv2d.cpp"],