diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index dec890c5561..701d5337636 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -35,6 +35,22 @@ def get_meandim_decomposition(op) -> tuple: raise RuntimeError(f"Can't get meandim decomposition for op {op}") +def get_dynamic_meandim_decomposition(op) -> tuple: + if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): + return ( + exir_ops.edge.aten.sum.dim_IntList, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.expand_copy.default, + ) + if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default): + raise NotImplementedError( + "Dynamic mean.dim decomposition is not supported for torch.aten.mean." + ) + raise RuntimeError(f"Can't get meandim decomposition for op {op}") + + def get_avgpool(op): if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default): return exir_ops.edge.aten.avg_pool2d.default @@ -103,26 +119,39 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs): self._tosa_spec, WhyNoPartitionReporter() ) - def call_operator(self, op, args, kwargs, meta): + def call_operator(self, op, args, kwargs, meta, updated=False): if op not in ( exir_ops.edge.aten.mean.dim, torch.ops.aten.mean.dim, exir_ops.edge.aten.mean.default, torch.ops.aten.mean.default, ) or not self.allowed_to_transform(meta): - return super().call_operator(op, args, kwargs, meta) + return super().call_operator(op, args, kwargs, meta, updated) x = get_node_arg(args, 0) input_shape = list(x.data.shape) output_shape = list(meta["val"].shape) + dims_to_reduce = get_node_arg(args, 1, range(len(input_shape))) if dims_to_reduce is None: dims_to_reduce = range(len(input_shape)) + dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce] - dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] + + has_symbolic_reduce_dim = any( + isinstance(input_shape[dim], torch.SymInt) for dim in dims_to_reduce + ) + if has_symbolic_reduce_dim and get_quantization(x.node.target) is not None: + raise NotImplementedError( + "Quantized mean.dim with symbolic reduced dimensions is not supported" + ) view_op = get_view(op) + if not has_symbolic_reduce_dim: + # for static shapes we should ensure that we only keep non 1 dimensions. + dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1] + # Reshape to 4D if len(input_shape) != 4: new_shape = copy(input_shape) @@ -140,19 +169,25 @@ def call_operator(self, op, args, kwargs, meta): x = self._maybe_insert_q_dq_after(x, meta) # Reduce (h,w) dims by avg pool if possible - x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta) + if not has_symbolic_reduce_dim: + x, dims_to_reduce = self._reduce_by_average_pool( + op, x, dims_to_reduce, meta + ) # Reshape back to 5D if necessary if len(input_shape) > 4: - original_dims = input_shape[0:-3] + original_dims = input_shape[:-3] temp_shape = list(x.data.shape)[1:] temp_shape = original_dims + temp_shape dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce] x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) x = self._maybe_insert_q_dq_after(x, meta) - # Reduce remaining dims by sum - x = self._reduce_by_sum(op, x, dims_to_reduce, meta) + + if has_symbolic_reduce_dim: + x = self._reduce_by_sum_symbolic(op, x, dims_to_reduce, meta) + else: + x = self._reduce_by_sum(op, x, dims_to_reduce, meta) # Reshape to correct output shape if necessary if list(x.data.shape) != output_shape: @@ -160,6 +195,40 @@ def call_operator(self, op, args, kwargs, meta): return x + def _reduce_by_sum_symbolic(self, op, input_node, dims, meta): + input_shape = input_node.data.size() + reduced_shape = [input_shape[dim] for dim in dims] + + sum_op, mul_op, full_op, recip_op, expand_op = ( + get_dynamic_meandim_decomposition(op) + ) + + sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True) + + ones = super().call_operator( + full_op, + ([1], 1.0), + {"dtype": meta.data["val"].dtype, "device": input_node.data.device}, + meta, + True, + ) + expanded_ones = super().call_operator( + expand_op, + (ones, reduced_shape), + {}, + meta, + True, + ) + counts = super().call_operator( + sum_op, + (expanded_ones, list(range(len(reduced_shape))), True), + {}, + meta, + True, + ) + recip = super().call_operator(recip_op, (counts,), {}, meta, True) + return super().call_operator(mul_op, (sum, recip), {}, meta, True) + def _reduce_by_sum(self, op, input_node, dims, meta): if len(dims) == 0: return input_node @@ -224,13 +293,9 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta): if is_supported: out = super().call_operator(avgpool_op, args, {}, meta, True) out = self._maybe_insert_q_dq_after(out, meta) - return ( - out, - dims_to_reduce_by_sum, - ) + return out, dims_to_reduce_by_sum - else: - return input_node, dims + return input_node, dims def _maybe_insert_q_dq_after(self, op, meta): """If the input node of op is a dequant node, insert a q-dq pair after @@ -242,20 +307,18 @@ def _maybe_insert_q_dq_after(self, op, meta): f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}" ) input_node = op.node.all_input_nodes[0] - if (quant_ops := get_quantization(input_node.target)) is not None: - q_op, dq_op = quant_ops - quant_args = list(input_node.args[1:]) - q_args = (op, *quant_args) - out = super().call_operator( - q_op, - q_args, - kwargs={}, - meta=meta, - updated=True, - ) - dq_args = (out, *quant_args) - return super().call_operator( - dq_op, dq_args, kwargs={}, meta=meta, updated=True - ) - else: + if (quant_ops := get_quantization(input_node.target)) is None: return op + + q_op, dq_op = quant_ops + quant_args = list(input_node.args[1:]) + q_args = (op, *quant_args) + out = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (out, *quant_args) + return super().call_operator(dq_op, dq_args, kwargs={}, meta=meta, updated=True)