diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index e8c0f2a602b..2f8dbf33416 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -22,6 +22,28 @@ logger: logging.Logger = logging.getLogger(__name__) QuantArgs = tuple[float, int, int, int, torch.dtype] +TRANSPARENT_OPS: frozenset[torch._ops.OpOverloadPacket] = frozenset( + { + torch.ops.aten.view, + torch.ops.aten.view_copy, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, + torch.ops.aten.permute, + torch.ops.aten.permute_copy, + torch.ops.aten.transpose, + torch.ops.aten.transpose_copy, + torch.ops.aten.squeeze, + torch.ops.aten.squeeze_copy, + torch.ops.aten.unsqueeze, + torch.ops.aten.unsqueeze_copy, + torch.ops.aten.slice, + torch.ops.aten.slice_copy, + torch.ops.aten.contiguous, + torch.ops.aten.clone, + torch.ops.aten.to, + torch.ops.aten._to_copy, + } +) @torch.no_grad() @@ -244,6 +266,11 @@ def extract_input_quant_params_from_graph( ) -> dict[int, QuantArgs]: """ Extract quantization parameters from the FX graph for model inputs. + + For each name in ``input_names``, walk forward from the matching input + node through value-preserving "transparent" ops (reshape, permute, ...) + until reaching the ``quantize_per_tensor`` that fixes that input's scale + and zero-point. Results are keyed by the index into ``input_names``. """ quant_args: dict[int, QuantArgs] = {} found_names: set[str] = set() @@ -251,29 +278,39 @@ def extract_input_quant_params_from_graph( if not input_names: return quant_args + # Inputs are referenced by node name, which may be a placeholder or a node + # that unpacks/derives the input (e.g. a `getitem` off a tuple/multi-output + # input, as the modai eye-tracking model does), so look the start node up + # across all nodes -- not just placeholders. Build the name->node map once + # and reuse it for every requested input. + nodes_by_name = {n.name: n for n in module.graph.nodes} + + quantize_ops = _get_quantize_ops() for idx, name in enumerate(input_names): - for node in module.graph.nodes: - if node.op != "call_function": + start = nodes_by_name.get(name) + if start is None: + continue + seen: set[torch.fx.Node] = set() + to_visit: list[torch.fx.Node] = list(start.users) + while to_visit: + node = to_visit.pop() + if node in seen or node.op != "call_function": continue - - if ( - node.args - and isinstance(node.args[0], torch.fx.Node) - and node.args[0].name == name - and not node.name.startswith("_assert_tensor_metadata") - and "quantize_per_tensor" in str(node.target) - ): - args = node.args[1:] - if len(args) >= 5: - quant_args[idx] = ( - float(args[0]), # scale - int(args[1]), # zero_point - int(args[2]), # qmin - int(args[3]), # qmax - args[4], # dtype - ) - found_names.add(name) + seen.add(node) + if node.target in quantize_ops: + # Normalize args→kwargs so params passed positionally or as + # kwargs (or via defaults) are all handled uniformly. + quant_args[idx] = ( + float(get_arg(node, "scale", float)), + int(get_arg(node, "zero_point", int)), + int(get_arg(node, "quant_min", int)), + int(get_arg(node, "quant_max", int)), + get_arg(node, "dtype", torch.dtype), + ) + found_names.add(name) break + if getattr(node.target, "overloadpacket", None) in TRANSPARENT_OPS: + to_visit.extend(node.users) missing_names = set(input_names) - found_names if missing_names: