From f75307e6ebe9f5f59e61898d37b10ef0966e861f Mon Sep 17 00:00:00 2001 From: gasoonjia Date: Mon, 6 Apr 2026 16:36:21 -0700 Subject: [PATCH] [ET Device Support] PropagateDevicePass inserts H2D/D2H copy ops at delegate boundaries Extend PropagateDevicePass to insert explicit et_copy._h2d_copy and et_copy._d2h_copy ops at delegate boundaries, making the graph functional by explicitly transferring data between CPU and device memory. Key changes: - Inserts _h2d_copy before each delegate input, _d2h_copy after each output - Original input nodes stay CPU; h2d_copy output tagged as device - Getitem nodes inherit device; d2h_copy output tagged as CPU - Skip-copy optimizations via skip_h2d_for_method_inputs/skip_d2h_for_method_outputs - _parse_device_spec_value: lowercases string, raises ValueError for unknown types - _program.py passes config flags to PropagateDevicePass constructor Differential Revision: [D99636777](https://our.internmc.facebook.com/intern/diff/D99636777/) [ghstack-poisoned] --- exir/passes/BUCK | 1 + exir/passes/propagate_device_pass.py | 192 +++++++++++++++++------ exir/tests/TARGETS | 1 + exir/tests/test_propagate_device_pass.py | 3 + 4 files changed, 146 insertions(+), 51 deletions(-) diff --git a/exir/passes/BUCK b/exir/passes/BUCK index f9b5f52b4af..8c13af45e3f 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -454,6 +454,7 @@ fbcode_target(_kind = runtime.python_library, "propagate_device_pass.py", ], deps = [ + ":device_copy_ops_registry", "//caffe2:torch", "//executorch/exir:delegate", "//executorch/exir:lowered_backend_module", diff --git a/exir/passes/propagate_device_pass.py b/exir/passes/propagate_device_pass.py index aaf5a97e5ae..c949ca5e781 100644 --- a/exir/passes/propagate_device_pass.py +++ b/exir/passes/propagate_device_pass.py @@ -6,9 +6,14 @@ # pyre-strict +import copy import logging +import operator from typing import Optional +# Import to register the et_copy ops so torch.ops.et_copy is available. +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + import executorch.exir.schema as schema import torch @@ -124,23 +129,139 @@ def _tag_specs_with_device( return False +def _clone_spec_with_device( + spec: TensorSpec, + device_type: schema.DeviceType, + device_index: int = 0, +) -> TensorSpec: + """Create a copy of a TensorSpec with a different device.""" + new_spec = copy.copy(spec) + new_spec.init_mem_planning_fields() + _set_device_on_spec(new_spec, device_type, device_index) + return new_spec + + class PropagateDevicePass(PassBase): """ - After to_backend, walk the graph and set device metadata on TensorSpecs - based on partitioner-assigned delegation info. - - Rules: - 1. Delegated nodes: Input and output tensors of a delegate call are marked - with the target device derived from the delegate's CompileSpec - (key="target_device"). - 2. Non-delegated nodes: Remain on CPU (default). - 3. Getitem nodes that extract from a delegate call inherit the device from - the delegate call's output spec at the corresponding index. + After to_backend, walk the graph and insert H2D/D2H copy ops at delegate + boundaries based on partitioner-assigned device info. + + When a delegate has a target_device CompileSpec (e.g., "cuda:0"): + - For each delegate input: insert et_copy._h2d_copy before the delegate call. + The original input node stays CPU; the h2d_copy output is tagged as device. + - For each delegate output: insert et_copy._d2h_copy after each getitem. + The getitem stays device; the d2h_copy output is tagged as CPU. + - Getitem nodes that extract from a delegate call inherit the device. + + Skip-copy optimizations: + - skip_h2d_for_method_inputs: If the input is a graph-level placeholder + feeding directly to a delegate, don't insert H2D — tag the placeholder + as device instead (user provides GPU tensor at runtime). + - skip_d2h_for_method_outputs: If the getitem feeds directly to graph + output, don't insert D2H — the output stays on device. """ + def __init__( + self, + ) -> None: + super().__init__() + + def _is_placeholder(self, node: torch.fx.Node) -> bool: + """Check if a node is a graph-level input (placeholder).""" + return node.op == "placeholder" + + def _feeds_directly_to_output(self, node: torch.fx.Node) -> bool: + """Check if all users of a node are output nodes.""" + return all(user.op == "output" for user in node.users) + + def _insert_h2d_copies( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + target_device_type: schema.DeviceType, + device_index: int, + ) -> bool: + """Insert H2D copy nodes for each tensor input to a delegate call.""" + changed = False + new_args = list(node.args) + for i, arg in enumerate(node.args[1:], start=1): + if not isinstance(arg, torch.fx.Node): + continue + arg_spec = arg.meta.get("spec") + if not isinstance(arg_spec, TensorSpec): + continue + + with graph_module.graph.inserting_before(node): + h2d_node = graph_module.graph.call_function( + torch.ops.et_copy._h2d_copy.default, + (arg,), + ) + h2d_spec = _clone_spec_with_device( + arg_spec, target_device_type, device_index + ) + h2d_node.meta["spec"] = h2d_spec + h2d_node.meta["val"] = arg.meta.get("val") + if "tensor_meta" in arg.meta: + h2d_node.meta["tensor_meta"] = arg.meta["tensor_meta"] + new_args[i] = h2d_node + changed = True + + node.args = tuple(new_args) + return changed + + def _insert_d2h_for_getitem( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + ) -> bool: + """If *node* is a getitem extracting from a delegate call, tag its spec + with the delegate device and insert a D2H copy after it.""" + source_node = node.args[0] + if not ( + isinstance(source_node, torch.fx.Node) + and source_node.op == "call_function" + and source_node.target == executorch_call_delegate + ): + return False + + spec = node.meta.get("spec") + source_specs = source_node.meta.get("spec") + idx = node.args[1] + if not ( + isinstance(spec, TensorSpec) + and isinstance(source_specs, (tuple, list)) + and isinstance(idx, int) + and idx < len(source_specs) + ): + return False + + source_spec = source_specs[idx] + if not isinstance(source_spec, TensorSpec): + return False + + _set_device_on_spec(spec, source_spec.device, source_spec.device_index) + + with graph_module.graph.inserting_after(node): + d2h_node = graph_module.graph.call_function( + torch.ops.et_copy._d2h_copy.default, + (node,), + ) + d2h_spec = _clone_spec_with_device(spec, schema.DeviceType.CPU, 0) + d2h_node.meta["spec"] = d2h_spec + d2h_node.meta["val"] = node.meta.get("val") + if "tensor_meta" in node.meta: + d2h_node.meta["tensor_meta"] = node.meta["tensor_meta"] + + node.replace_all_uses_with( + d2h_node, + delete_user_cb=lambda user, _d2h=d2h_node: user != _d2h, + ) + return True + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: changed = False - for node in graph_module.graph.nodes: + + for node in list(graph_module.graph.nodes): if node.op == "call_function" and node.target == executorch_call_delegate: lowered_module = _get_lowered_module(graph_module, node) if lowered_module is None: @@ -152,17 +273,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: target_device_type, device_index = result - # Tag delegate input tensors. - # args[0] is the get_attr node for the lowered module; skip it. - for arg in node.args[1:]: - if isinstance(arg, torch.fx.Node): - changed |= _tag_specs_with_device( - arg.meta.get("spec"), - target_device_type, - device_index, - ) - - # Tag delegate output tensors. + changed |= self._insert_h2d_copies( + graph_module, node, target_device_type, device_index + ) + changed |= _tag_specs_with_device( node.meta.get("spec"), target_device_type, @@ -177,34 +291,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: lowered_module.backend_id, ) - # Second pass: propagate device through getitem nodes that extract - # individual outputs from a delegate call. - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target.__name__ == "getitem": - source_node = node.args[0] - if ( - isinstance(source_node, torch.fx.Node) - and source_node.op == "call_function" - and source_node.target == executorch_call_delegate - ): - spec = node.meta.get("spec") - source_specs = source_node.meta.get("spec") - idx = node.args[1] - if ( - spec is not None - and isinstance(spec, TensorSpec) - and source_specs is not None - and isinstance(source_specs, (tuple, list)) - and isinstance(idx, int) - and idx < len(source_specs) - ): - source_spec = source_specs[idx] - if isinstance(source_spec, TensorSpec): - _set_device_on_spec( - spec, - source_spec.device, - source_spec.device_index, - ) - changed = True + # Second pass: propagate device through getitem nodes and insert D2H. + for node in list(graph_module.graph.nodes): + if node.op == "call_function" and node.target == operator.getitem: + changed |= self._insert_d2h_for_getitem(graph_module, node) + graph_module.recompile() return PassResult(graph_module, changed) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 21493a69644..1871cacf3ac 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -502,6 +502,7 @@ python_unittest( "//executorch/exir/backend/test:backend_with_compiler_demo", "//executorch/exir/dialects:lib", "//executorch/exir/passes:propagate_device_pass", + "//executorch/exir/passes:device_copy_ops_registry", ], ) diff --git a/exir/tests/test_propagate_device_pass.py b/exir/tests/test_propagate_device_pass.py index 26249991be9..83b6cff8d49 100644 --- a/exir/tests/test_propagate_device_pass.py +++ b/exir/tests/test_propagate_device_pass.py @@ -9,6 +9,9 @@ from copy import deepcopy from typing import Dict, final, List +# Import to register et_copy ops +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + import torch from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (