Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 92 additions & 29 deletions backends/arm/_passes/decompose_meandim_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def get_meandim_decomposition(op) -> tuple:
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_dynamic_meandim_decomposition(op) -> tuple:
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return (
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.expand_copy.default,
)
if op in (torch.ops.aten.mean.dim, torch.ops.aten.mean.default):
raise NotImplementedError(
"Dynamic mean.dim decomposition is not supported for torch.aten.mean."
)
raise RuntimeError(f"Can't get meandim decomposition for op {op}")


def get_avgpool(op):
if op in (exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.mean.default):
return exir_ops.edge.aten.avg_pool2d.default
Expand Down Expand Up @@ -103,26 +119,39 @@ def __init__(self, graph_module, tosa_spec, *args, **kwargs):
self._tosa_spec, WhyNoPartitionReporter()
)

def call_operator(self, op, args, kwargs, meta):
def call_operator(self, op, args, kwargs, meta, updated=False):
if op not in (
exir_ops.edge.aten.mean.dim,
torch.ops.aten.mean.dim,
exir_ops.edge.aten.mean.default,
torch.ops.aten.mean.default,
) or not self.allowed_to_transform(meta):
return super().call_operator(op, args, kwargs, meta)
return super().call_operator(op, args, kwargs, meta, updated)

x = get_node_arg(args, 0)
input_shape = list(x.data.shape)
output_shape = list(meta["val"].shape)

dims_to_reduce = get_node_arg(args, 1, range(len(input_shape)))
if dims_to_reduce is None:
dims_to_reduce = range(len(input_shape))

dims_to_reduce = [dim % len(input_shape) for dim in dims_to_reduce]
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]

has_symbolic_reduce_dim = any(
isinstance(input_shape[dim], torch.SymInt) for dim in dims_to_reduce
)
if has_symbolic_reduce_dim and get_quantization(x.node.target) is not None:
raise NotImplementedError(
"Quantized mean.dim with symbolic reduced dimensions is not supported"
)

view_op = get_view(op)

if not has_symbolic_reduce_dim:
# for static shapes we should ensure that we only keep non 1 dimensions.
dims_to_reduce = [dim for dim in dims_to_reduce if input_shape[dim] != 1]

# Reshape to 4D
if len(input_shape) != 4:
new_shape = copy(input_shape)
Expand All @@ -140,26 +169,66 @@ def call_operator(self, op, args, kwargs, meta):
x = self._maybe_insert_q_dq_after(x, meta)

# Reduce (h,w) dims by avg pool if possible
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
if not has_symbolic_reduce_dim:
x, dims_to_reduce = self._reduce_by_average_pool(
op, x, dims_to_reduce, meta
)

# Reshape back to 5D if necessary
if len(input_shape) > 4:
original_dims = input_shape[0:-3]
original_dims = input_shape[:-3]
temp_shape = list(x.data.shape)[1:]
temp_shape = original_dims + temp_shape
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]

x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
x = self._maybe_insert_q_dq_after(x, meta)
# Reduce remaining dims by sum
x = self._reduce_by_sum(op, x, dims_to_reduce, meta)

if has_symbolic_reduce_dim:
x = self._reduce_by_sum_symbolic(op, x, dims_to_reduce, meta)
else:
x = self._reduce_by_sum(op, x, dims_to_reduce, meta)

# Reshape to correct output shape if necessary
if list(x.data.shape) != output_shape:
x = super().call_operator(view_op, (x, output_shape), {}, meta, True)

return x

def _reduce_by_sum_symbolic(self, op, input_node, dims, meta):
input_shape = input_node.data.size()
reduced_shape = [input_shape[dim] for dim in dims]

sum_op, mul_op, full_op, recip_op, expand_op = (
get_dynamic_meandim_decomposition(op)
)

sum = super().call_operator(sum_op, (input_node, dims, True), {}, meta, True)

ones = super().call_operator(
full_op,
([1], 1.0),
{"dtype": meta.data["val"].dtype, "device": input_node.data.device},
meta,
True,
)
expanded_ones = super().call_operator(
expand_op,
(ones, reduced_shape),
{},
meta,
True,
)
counts = super().call_operator(
sum_op,
(expanded_ones, list(range(len(reduced_shape))), True),
{},
meta,
True,
)
recip = super().call_operator(recip_op, (counts,), {}, meta, True)
return super().call_operator(mul_op, (sum, recip), {}, meta, True)

def _reduce_by_sum(self, op, input_node, dims, meta):
if len(dims) == 0:
return input_node
Expand Down Expand Up @@ -224,13 +293,9 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
if is_supported:
out = super().call_operator(avgpool_op, args, {}, meta, True)
out = self._maybe_insert_q_dq_after(out, meta)
return (
out,
dims_to_reduce_by_sum,
)
return out, dims_to_reduce_by_sum

else:
return input_node, dims
return input_node, dims

def _maybe_insert_q_dq_after(self, op, meta):
"""If the input node of op is a dequant node, insert a q-dq pair after
Expand All @@ -242,20 +307,18 @@ def _maybe_insert_q_dq_after(self, op, meta):
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
)
input_node = op.node.all_input_nodes[0]
if (quant_ops := get_quantization(input_node.target)) is not None:
q_op, dq_op = quant_ops
quant_args = list(input_node.args[1:])
q_args = (op, *quant_args)
out = super().call_operator(
q_op,
q_args,
kwargs={},
meta=meta,
updated=True,
)
dq_args = (out, *quant_args)
return super().call_operator(
dq_op, dq_args, kwargs={}, meta=meta, updated=True
)
else:
if (quant_ops := get_quantization(input_node.target)) is None:
return op

q_op, dq_op = quant_ops
quant_args = list(input_node.args[1:])
q_args = (op, *quant_args)
out = super().call_operator(
q_op,
q_args,
kwargs={},
meta=meta,
updated=True,
)
dq_args = (out, *quant_args)
return super().call_operator(dq_op, dq_args, kwargs={}, meta=meta, updated=True)
Loading