diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c146cf6c00e3..1b27da9e0b82 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1857,7 +1857,17 @@ def _index_put(self, node: fx.Node) -> relax.Var: indices = relax.Tuple(processed_indices) else: indices = relax.Tuple(indices) - return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) + + output = self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) + + target_name = node.target if isinstance(node.target, str) else getattr(node.target, "__name__", "") + if target_name.startswith("index_put_") and len(node.args) > 0: + from torch import fx + + if isinstance(node.args[0], fx.Node): + self.env[node.args[0]] = output + + return output def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index cc37554bf301..5bd2c785f205 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1338,7 +1338,40 @@ def _translate_fx_graph( raise ValueError(f"Unsupported op {node.op}") assert output_args is not None - return output_args + return self._flatten_output_args(output_args) + + @staticmethod + def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]: + """Flatten output args into a tuple of Relax expressions. + + ExportedProgram output trees contain nested Python tuple/list containers + (e.g. mutation outputs + user tuple outputs). Emitting nested Python tuples + directly through FFI may construct invalid Relax tuples. + """ + + flattened: list[relax.Expr] = [] + + def _visit(value): + if isinstance(value, relax.Expr): + flattened.append(value) + elif isinstance(value, list | tuple): + for item in value: + _visit(item) + elif value is None: + # Preserve explicit None outputs as Relax null objects. + flattened.append(relax.op.null_value()) + else: + raise ValueError( + "Unsupported output type in exported graph output: " + f"{type(value)}" + ) + + _visit(output_args) + + if not flattened: + raise ValueError("Exported graph produced no Relax outputs") + + return tuple(flattened) def _import_branch_subgraph( self, @@ -1995,7 +2028,7 @@ def from_exported_program( output_args = self._translate_fx_graph( exported_program.graph_module, nodes, inputs_vars, custom_ops ) - assert isinstance(output_args, tuple | relax.Tuple) + output_args = self._flatten_output_args(output_args) if unwrap_unit_return_tuple and len(output_args) == 1: ret = output_args[0] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 602949937247..5643f5541fbe 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7402,6 +7402,69 @@ def main(x: R.Tensor((2, 10), dtype="float32")) -> R.Tuple( verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone) +def test_index_put_with_tuple_output(): + class IndexPutTupleOutput(Module): + def forward(self, x, l, idx): + values = x + l[..., idx, idx] = values + return x[..., 1], l + + example_args = ( + torch.ones(2, 3, 5, dtype=torch.float32), + torch.zeros(2, 3, 5, 5, dtype=torch.float32), + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64), + ) + + exported_program = export(IndexPutTupleOutput(), args=example_args) + mod = from_exported_program(exported_program) + + ret_sinfo = mod["main"].ret_struct_info + assert isinstance(ret_sinfo, relax.TupleStructInfo) + + tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] + assert len(tensor_fields) >= 2 + + assert any( + len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 + for f in tensor_fields + ) + + +def test_m4d_diag_index_put_tuple_output_regression(): + class M4D(Module): + def forward(self, x): + b, k, n = 2, 3, 5 + l = x.new_zeros(b, k, n, n) + idx = torch.arange(n, device=x.device) + + diag = l[..., idx, idx] + diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8 + l[..., idx, idx] = diag + + return x[..., :1], l + + ex_in = torch.zeros(2, 3, 5, dtype=torch.float32) + exported_program = export(M4D().eval(), args=(ex_in,)) + + exported_targets = [str(getattr(n, "target", "")) for n in exported_program.graph.nodes] + assert any("index_put" in target for target in exported_targets) + + # Regression focus: importing this graph should not segfault at Tuple construction. + mod = from_exported_program(exported_program) + ret_sinfo = mod["main"].ret_struct_info + assert isinstance(ret_sinfo, relax.TupleStructInfo) + + tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] + assert len(tensor_fields) >= 2 + # x: (2, 3, 5) → x[..., :1]: (2, 3, 1) + assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields) + # l: (2, 3, 5, 5) → 4-D with spatial dims 5×5 + assert any( + len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 + for f in tensor_fields + ) + + def test_flip(): class Flip0(Module): def forward(self, data):