diff --git a/exir/passes/BUCK b/exir/passes/BUCK index 4647388b388..e655e97bea0 100644 --- a/exir/passes/BUCK +++ b/exir/passes/BUCK @@ -466,6 +466,7 @@ fbcode_target(_kind = runtime.python_library, "propagate_device_pass.py", ], deps = [ + ":device_copy_ops_registry", "//caffe2:torch", "//executorch/exir:delegate", "//executorch/exir:lowered_backend_module", diff --git a/exir/passes/propagate_device_pass.py b/exir/passes/propagate_device_pass.py index aaf5a97e5ae..5ed0c20b1bb 100644 --- a/exir/passes/propagate_device_pass.py +++ b/exir/passes/propagate_device_pass.py @@ -6,9 +6,14 @@ # pyre-strict +import copy import logging +import operator from typing import Optional +# Import to register the et_copy ops so torch.ops.et_copy is available. +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + import executorch.exir.schema as schema import torch @@ -124,23 +129,150 @@ def _tag_specs_with_device( return False +def _clone_spec_with_device( + spec: TensorSpec, + device_type: schema.DeviceType, + device_index: int = 0, +) -> TensorSpec: + """Create a copy of a TensorSpec with a different device.""" + new_spec = copy.copy(spec) + new_spec.init_mem_planning_fields() + _set_device_on_spec(new_spec, device_type, device_index) + return new_spec + + class PropagateDevicePass(PassBase): """ - After to_backend, walk the graph and set device metadata on TensorSpecs - based on partitioner-assigned delegation info. - - Rules: - 1. Delegated nodes: Input and output tensors of a delegate call are marked - with the target device derived from the delegate's CompileSpec - (key="target_device"). - 2. Non-delegated nodes: Remain on CPU (default). - 3. Getitem nodes that extract from a delegate call inherit the device from - the delegate call's output spec at the corresponding index. + After to_backend, walk the graph and insert H2D/D2H copy ops at delegate + boundaries based on partitioner-assigned device info. + + When a delegate has a target_device CompileSpec (e.g., "cuda:0"): + - For each delegate input: insert et_copy._h2d_copy before the delegate call. + The original input node stays CPU; the h2d_copy output is tagged as device. + - For each delegate output: insert et_copy._d2h_copy after each getitem. + The getitem stays device; the d2h_copy output is tagged as CPU. + - Getitem nodes that extract from a delegate call inherit the device. + + Skip-copy optimizations: + - skip_h2d_for_method_inputs: If the input is a graph-level placeholder + feeding directly to a delegate, don't insert H2D — tag the placeholder + as device instead (user provides GPU tensor at runtime). + - skip_d2h_for_method_outputs: If the getitem feeds directly to graph + output, don't insert D2H — the output stays on device. """ + def __init__( + self, + ) -> None: + super().__init__() + + def _is_placeholder(self, node: torch.fx.Node) -> bool: + """Check if a node is a graph-level input (placeholder).""" + return node.op == "placeholder" + + def _feeds_directly_to_output(self, node: torch.fx.Node) -> bool: + """Check if all users of a node are output nodes.""" + return all(user.op == "output" for user in node.users) + + def _insert_h2d_copies( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + target_device_type: schema.DeviceType, + device_index: int, + ) -> bool: + """Insert H2D copy nodes for each tensor input to a delegate call.""" + changed = False + new_args = list(node.args) + for i, arg in enumerate(node.args[1:], start=1): + if not isinstance(arg, torch.fx.Node): + continue + arg_spec = arg.meta.get("spec") + if not isinstance(arg_spec, TensorSpec): + continue + + with graph_module.graph.inserting_before(node): + h2d_node = graph_module.graph.call_function( + torch.ops.et_copy._h2d_copy.default, + (arg,), + ) + h2d_spec = _clone_spec_with_device( + arg_spec, target_device_type, device_index + ) + h2d_node.meta["spec"] = h2d_spec + h2d_node.meta["val"] = arg.meta.get("val") + if "tensor_meta" in arg.meta: + h2d_node.meta["tensor_meta"] = arg.meta["tensor_meta"] + new_args[i] = h2d_node + changed = True + + node.args = tuple(new_args) + return changed + + def _insert_d2h_for_getitem( + self, + graph_module: torch.fx.GraphModule, + node: torch.fx.Node, + ) -> bool: + """If *node* is a getitem extracting from a delegate call, tag its spec + with the delegate device and insert a D2H copy after it.""" + source_node = node.args[0] + if not ( + isinstance(source_node, torch.fx.Node) + and source_node.op == "call_function" + and source_node.target == executorch_call_delegate + ): + return False + + spec = node.meta.get("spec") + source_specs = source_node.meta.get("spec") + idx = node.args[1] + if not ( + isinstance(spec, TensorSpec) + and isinstance(source_specs, (tuple, list)) + and isinstance(idx, int) + and idx < len(source_specs) + ): + return False + + source_spec = source_specs[idx] + if not isinstance(source_spec, TensorSpec): + return False + + _set_device_on_spec(spec, source_spec.device, source_spec.device_index) + + with graph_module.graph.inserting_after(node): + d2h_node = graph_module.graph.call_function( + torch.ops.et_copy._d2h_copy.default, + (node,), + ) + d2h_spec = _clone_spec_with_device(spec, schema.DeviceType.CPU, 0) + d2h_node.meta["spec"] = d2h_spec + d2h_node.meta["val"] = node.meta.get("val") + if "tensor_meta" in node.meta: + d2h_node.meta["tensor_meta"] = node.meta["tensor_meta"] + + node.replace_all_uses_with( + d2h_node, + delete_user_cb=lambda user, _d2h=d2h_node: user != _d2h, + ) + return True + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + # Two-pass approach: + # Pass 1 – For each delegate with a target_device CompileSpec, insert + # H2D copy nodes before delegate inputs and tag the delegate + # output specs with the target device. Delegates without a + # target_device are left untouched (no copies, specs stay CPU). + # Pass 2 – For each getitem that extracts from a device-tagged delegate + # (tracked in device_delegates), propagate the device onto the + # getitem spec and insert a D2H copy after it so downstream + # non-delegated ops receive CPU tensors. changed = False - for node in graph_module.graph.nodes: + device_delegates: set[torch.fx.Node] = set() + + # Pass 1: insert H2D copies and tag delegate output specs. + for node in list(graph_module.graph.nodes): if node.op == "call_function" and node.target == executorch_call_delegate: lowered_module = _get_lowered_module(graph_module, node) if lowered_module is None: @@ -151,18 +283,12 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: continue target_device_type, device_index = result + device_delegates.add(node) + + changed |= self._insert_h2d_copies( + graph_module, node, target_device_type, device_index + ) - # Tag delegate input tensors. - # args[0] is the get_attr node for the lowered module; skip it. - for arg in node.args[1:]: - if isinstance(arg, torch.fx.Node): - changed |= _tag_specs_with_device( - arg.meta.get("spec"), - target_device_type, - device_index, - ) - - # Tag delegate output tensors. changed |= _tag_specs_with_device( node.meta.get("spec"), target_device_type, @@ -177,34 +303,13 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: lowered_module.backend_id, ) - # Second pass: propagate device through getitem nodes that extract - # individual outputs from a delegate call. - for node in graph_module.graph.nodes: - if node.op == "call_function" and node.target.__name__ == "getitem": - source_node = node.args[0] - if ( - isinstance(source_node, torch.fx.Node) - and source_node.op == "call_function" - and source_node.target == executorch_call_delegate - ): - spec = node.meta.get("spec") - source_specs = source_node.meta.get("spec") - idx = node.args[1] - if ( - spec is not None - and isinstance(spec, TensorSpec) - and source_specs is not None - and isinstance(source_specs, (tuple, list)) - and isinstance(idx, int) - and idx < len(source_specs) - ): - source_spec = source_specs[idx] - if isinstance(source_spec, TensorSpec): - _set_device_on_spec( - spec, - source_spec.device, - source_spec.device_index, - ) - changed = True + # Second pass: propagate device through getitem nodes and insert D2H + # only for delegates that have a target_device. + for node in list(graph_module.graph.nodes): + if node.op == "call_function" and node.target == operator.getitem: + source = node.args[0] + if isinstance(source, torch.fx.Node) and source in device_delegates: + changed |= self._insert_d2h_for_getitem(graph_module, node) + graph_module.recompile() return PassResult(graph_module, changed) diff --git a/exir/tests/TARGETS b/exir/tests/TARGETS index 21493a69644..1871cacf3ac 100644 --- a/exir/tests/TARGETS +++ b/exir/tests/TARGETS @@ -502,6 +502,7 @@ python_unittest( "//executorch/exir/backend/test:backend_with_compiler_demo", "//executorch/exir/dialects:lib", "//executorch/exir/passes:propagate_device_pass", + "//executorch/exir/passes:device_copy_ops_registry", ], ) diff --git a/exir/tests/test_propagate_device_pass.py b/exir/tests/test_propagate_device_pass.py index 26249991be9..8bb2fa1ab42 100644 --- a/exir/tests/test_propagate_device_pass.py +++ b/exir/tests/test_propagate_device_pass.py @@ -9,6 +9,9 @@ from copy import deepcopy from typing import Dict, final, List +# Import to register et_copy ops +import executorch.exir.passes._device_copy_ops_registry # noqa: F401 + import torch from executorch.exir import EdgeCompileConfig, to_edge, to_edge_transform_and_lower from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( @@ -164,6 +167,177 @@ def _assert_specs_device( if expected_index is not None: self.assertEqual(s.device_index, expected_index) + # ---- Integration tests: copy nodes after to_executorch ---- + + def test_h2d_d2h_nodes_inserted(self): + """Verify H2D/D2H copy nodes are inserted and survive the full + to_executorch pipeline with correct .out variant targets, exact + counts, and proper graph ordering.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, gm in _lower_model_to_executorch( + model, inputs, DeviceAwarePartitioner("cuda:0") + ): + with self.subTest(pipeline=pipeline): + h2d_nodes = [] + d2h_nodes = [] + delegate_nodes = [] + getitem_nodes = [] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.et_copy._h2d_copy.out: + h2d_nodes.append(node) + elif node.target == torch.ops.et_copy._d2h_copy.out: + d2h_nodes.append(node) + elif node.target == executorch_call_delegate: + delegate_nodes.append(node) + elif node.target == operator.getitem: + getitem_nodes.append(node) + + # Model has 2 inputs, 1 output → 2 H2D, 1 D2H + self.assertEqual( + len(h2d_nodes), + 2, + f"[{pipeline}] Expected 2 H2D copy nodes (one per " + f"delegate input), got {len(h2d_nodes)}", + ) + self.assertEqual( + len(d2h_nodes), + 1, + f"[{pipeline}] Expected 1 D2H copy node (one per " + f"delegate output), got {len(d2h_nodes)}", + ) + self.assertEqual(len(delegate_nodes), 1) + + # Verify graph ordering: + # placeholder → h2d_copy → delegate → getitem → d2h_copy → output + all_nodes = list(gm.graph.nodes) + delegate_idx = all_nodes.index(delegate_nodes[0]) + for h2d in h2d_nodes: + self.assertLess( + all_nodes.index(h2d), + delegate_idx, + f"[{pipeline}] H2D '{h2d.name}' must appear before " + f"delegate '{delegate_nodes[0].name}'", + ) + for d2h in d2h_nodes: + for gi in getitem_nodes: + if gi.args[0] == delegate_nodes[0]: + self.assertGreater( + all_nodes.index(d2h), + all_nodes.index(gi), + f"[{pipeline}] D2H '{d2h.name}' must appear " + f"after getitem '{gi.name}'", + ) + + def test_e2e_copy_nodes_in_executorch_graph(self): + """End-to-end: copy nodes survive the full to_executorch pipeline + and have correct .out targets and device specs on TensorSpecs.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, gm in _lower_model_to_executorch( + model, inputs, DeviceAwarePartitioner("cuda:0") + ): + with self.subTest(pipeline=pipeline): + h2d_nodes = [] + d2h_nodes = [] + + for node in gm.graph.nodes: + if node.op != "call_function": + continue + if node.target == torch.ops.et_copy._h2d_copy.out: + h2d_nodes.append(node) + elif node.target == torch.ops.et_copy._d2h_copy.out: + d2h_nodes.append(node) + + self.assertGreater( + len(h2d_nodes), + 0, + f"[{pipeline}] H2D copy nodes must survive to_executorch", + ) + self.assertGreater( + len(d2h_nodes), + 0, + f"[{pipeline}] D2H copy nodes must survive to_executorch", + ) + + for h2d in h2d_nodes: + spec = h2d.meta.get("spec") + self.assertIsNotNone( + spec, + f"[{pipeline}] H2D node '{h2d.name}' missing spec", + ) + if isinstance(spec, TensorSpec): + self.assertEqual( + spec.device, + DeviceType.CUDA, + f"[{pipeline}] H2D output '{h2d.name}' should be " + f"on CUDA, got {spec.device.name}", + ) + self.assertEqual(spec.device_index, 0) + + for d2h in d2h_nodes: + spec = d2h.meta.get("spec") + self.assertIsNotNone( + spec, + f"[{pipeline}] D2H node '{d2h.name}' missing spec", + ) + if isinstance(spec, TensorSpec): + self.assertEqual( + spec.device, + DeviceType.CPU, + f"[{pipeline}] D2H output '{d2h.name}' should be " + f"on CPU, got {spec.device.name}", + ) + + def test_no_copy_nodes_without_device(self): + """When the partitioner has no target_device CompileSpec, no H2D/D2H + copy nodes should be inserted in the final graph.""" + + class Model(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Model() + inputs = (torch.randn(2, 2), torch.randn(2, 2)) + + for pipeline, gm in _lower_model_to_executorch( + model, inputs, CpuOnlyPartitioner() + ): + with self.subTest(pipeline=pipeline): + for node in gm.graph.nodes: + if node.op != "call_function": + continue + self.assertNotEqual( + node.target, + torch.ops.et_copy._h2d_copy.out, + f"[{pipeline}] Unexpected H2D copy node '{node.name}' " + f"when no target_device is set", + ) + self.assertNotEqual( + node.target, + torch.ops.et_copy._d2h_copy.out, + f"[{pipeline}] Unexpected D2H copy node '{node.name}' " + f"when no target_device is set", + ) + + # ---- Integration tests: device consistency after to_executorch ---- + + def test_device_consistency_cuda_1(self): """Verify device tags are correct with cuda:1 after to_executorch() to verify device_index propagation through the full pipeline.""" @@ -251,7 +425,20 @@ def forward(self, a, b): continue label = f"[{pipeline}] '{node.name}'" - if node.target == executorch_call_delegate: + if node.target == torch.ops.et_copy._h2d_copy.out: + self._assert_specs_device( + specs, + DeviceType.CUDA, + f"{label} H2D output should be CUDA", + expected_index=0, + ) + elif node.target == torch.ops.et_copy._d2h_copy.out: + self._assert_specs_device( + specs, + DeviceType.CPU, + f"{label} D2H output should be CPU", + ) + elif node.target == executorch_call_delegate: self._assert_specs_device( specs, DeviceType.CUDA,