From 584861f0d482341e937b1ad96a2f2c965c33aacb Mon Sep 17 00:00:00 2001 From: cchung100m Date: Fri, 1 May 2026 14:39:10 +0800 Subject: [PATCH 01/10] [Relax][PyTorch] Fix segfault in from_exported_program when model uses index_put_ with tuple output --- .../torch/base_fx_graph_translator.py | 12 ++++++- .../torch/exported_program_translator.py | 32 +++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c146cf6c00e3..e2d5e3879ffa 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1857,7 +1857,17 @@ def _index_put(self, node: fx.Node) -> relax.Var: indices = relax.Tuple(processed_indices) else: indices = relax.Tuple(indices) - return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) + + output = self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) + + target_name = getattr(node.target, "__name__", "") + if target_name.startswith("index_put_") and len(node.args) > 0: + from torch import fx + + if isinstance(node.args[0], fx.Node): + self.env[node.args[0]] = output + + return output def _index_tensor(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index cc37554bf301..3b6d03c8b9e6 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1340,6 +1340,38 @@ def _translate_fx_graph( assert output_args is not None return output_args + @staticmethod + def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]: + """Flatten output args into a tuple of Relax expressions. + + ExportedProgram output trees contain nested Python tuple/list containers + (e.g. mutation outputs + user tuple outputs). Emitting nested Python tuples + directly through FFI may construct invalid Relax tuples. + """ + + flattened: list[relax.Expr] = [] + + def _visit(value): + if isinstance(value, relax.Expr): + flattened.append(value) + elif isinstance(value, list | tuple): + for item in value: + _visit(item) + elif value is None: + return + else: + raise ValueError( + "Unsupported output type in exported graph output: " + f"{type(value)}" + ) + + _visit(output_args) + + if not flattened: + raise ValueError("Exported graph produced no Relax outputs") + + return tuple(flattened) + def _import_branch_subgraph( self, graph_module, # torch.fx.GraphModule From ccd4bbd1e8243d34cd4ffeefdc54563ed200cdd9 Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Sat, 2 May 2026 18:39:21 +0800 Subject: [PATCH 02/10] Flatten output arguments before returning them --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 3b6d03c8b9e6..b8637261dca5 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1338,7 +1338,7 @@ def _translate_fx_graph( raise ValueError(f"Unsupported op {node.op}") assert output_args is not None - return output_args + return self._flatten_output_args(output_args) @staticmethod def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]: From a0e583a61d818ac9738d6cb4facf31ed25fbe602 Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Sat, 2 May 2026 23:45:13 +0800 Subject: [PATCH 03/10] Handle None outputs in exported program translator Preserve explicit None outputs by appending Relax null objects. --- python/tvm/relax/frontend/torch/exported_program_translator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b8637261dca5..8f7f3f55482b 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1358,7 +1358,8 @@ def _visit(value): for item in value: _visit(item) elif value is None: - return + # Preserve explicit None outputs as Relax null objects. + flattened.append(relax.op.null_value()) else: raise ValueError( "Unsupported output type in exported graph output: " From 2548f10b3d017e99d9d7b9ba3d2d4874c7e8f48c Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Sun, 3 May 2026 09:55:05 +0800 Subject: [PATCH 04/10] Fix target_name assignment for node.target --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e2d5e3879ffa..1b27da9e0b82 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1860,7 +1860,7 @@ def _index_put(self, node: fx.Node) -> relax.Var: output = self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate)) - target_name = getattr(node.target, "__name__", "") + target_name = node.target if isinstance(node.target, str) else getattr(node.target, "__name__", "") if target_name.startswith("index_put_") and len(node.args) > 0: from torch import fx From 59224206d58b22f46ace4851fc17e727e13663df Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Sun, 3 May 2026 13:59:52 +0800 Subject: [PATCH 05/10] Flatten output_args in exported_program_translator Flatten output arguments before further processing. --- python/tvm/relax/frontend/torch/exported_program_translator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 8f7f3f55482b..5bd2c785f205 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -2028,7 +2028,7 @@ def from_exported_program( output_args = self._translate_fx_graph( exported_program.graph_module, nodes, inputs_vars, custom_ops ) - assert isinstance(output_args, tuple | relax.Tuple) + output_args = self._flatten_output_args(output_args) if unwrap_unit_return_tuple and len(output_args) == 1: ret = output_args[0] From cc13a5529a0ee42997403b93ee3b54c18052ab5c Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 3 May 2026 17:49:57 +0800 Subject: [PATCH 06/10] [Relax][PyTorch] Add test case: test_index_put_with_tuple_output --- .../test_frontend_from_exported_program.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 602949937247..12da88213637 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7402,6 +7402,32 @@ def main(x: R.Tensor((2, 10), dtype="float32")) -> R.Tuple( verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone) +def test_index_put_with_tuple_output(): + class IndexPutTupleOutput(Module): + def forward(self, x, l, idx): + values = x[..., :1] + l[..., idx, idx] = values + return x[..., 1], l + + example_args = ( + torch.ones(2, 3, 11, 11, dtype=torch.float32), + torch.zeros(2, 3, 11, 11, dtype=torch.float32), + torch.tensor([0, 2, 5, 7, 9], dtype=torch.int64), + ) + + exported_program = export(IndexPutTupleOutput(), args=example_args) + mod = from_exported_program(exported_program) + + ret_sinfo = mod["main"].ret_struct_info + assert isinstance(ret_sinfo, relax.TupleStructInfo) + + tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] + assert len(tensor_fields) >= 2 + + assert any(len(f.shape) == 4 and f.shape[-1] == 1 for f in tensor_fields) + assert any(len(f.shape) == 4 and f.shape[-2] == 11 and f.shape[-1] == 11 for f in tensor_fields) + + def test_flip(): class Flip0(Module): def forward(self, data): From a355ee3a68c11423de1dd6ec2f59f6b53fc40861 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 3 May 2026 19:21:53 +0800 Subject: [PATCH 07/10] [Relax][PyTorch] Fix test case: test_index_put_with_tuple_output --- tests/python/relax/test_frontend_from_exported_program.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 12da88213637..2351db67aacb 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7405,12 +7405,12 @@ def main(x: R.Tensor((2, 10), dtype="float32")) -> R.Tuple( def test_index_put_with_tuple_output(): class IndexPutTupleOutput(Module): def forward(self, x, l, idx): - values = x[..., :1] + values = x l[..., idx, idx] = values return x[..., 1], l example_args = ( - torch.ones(2, 3, 11, 11, dtype=torch.float32), + torch.ones(2, 3, 5, dtype=torch.float32), torch.zeros(2, 3, 11, 11, dtype=torch.float32), torch.tensor([0, 2, 5, 7, 9], dtype=torch.int64), ) @@ -7424,7 +7424,7 @@ def forward(self, x, l, idx): tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] assert len(tensor_fields) >= 2 - assert any(len(f.shape) == 4 and f.shape[-1] == 1 for f in tensor_fields) + assert any(len(f.shape) == 3 and f.shape[-1] == 1 for f in tensor_fields) assert any(len(f.shape) == 4 and f.shape[-2] == 11 and f.shape[-1] == 11 for f in tensor_fields) From 573e793cbf924efaa67e14a709608e06df8a1af3 Mon Sep 17 00:00:00 2001 From: cchung100m Date: Sun, 3 May 2026 20:42:41 +0800 Subject: [PATCH 08/10] [Relax][PyTorch] Fix test case: test_index_put_with_tuple_output --- .../relax/test_frontend_from_exported_program.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2351db67aacb..51075b9ef73a 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7411,8 +7411,8 @@ def forward(self, x, l, idx): example_args = ( torch.ones(2, 3, 5, dtype=torch.float32), - torch.zeros(2, 3, 11, 11, dtype=torch.float32), - torch.tensor([0, 2, 5, 7, 9], dtype=torch.int64), + torch.zeros(2, 3, 5, 5, dtype=torch.float32), + torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64), ) exported_program = export(IndexPutTupleOutput(), args=example_args) @@ -7424,8 +7424,11 @@ def forward(self, x, l, idx): tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] assert len(tensor_fields) >= 2 - assert any(len(f.shape) == 3 and f.shape[-1] == 1 for f in tensor_fields) - assert any(len(f.shape) == 4 and f.shape[-2] == 11 and f.shape[-1] == 11 for f in tensor_fields) + assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields) + assert any( + len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 + for f in tensor_fields + ) def test_flip(): From dda02c2aad2cbf4cc57d02a569ac4a05bd79536f Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Mon, 4 May 2026 10:43:21 +0800 Subject: [PATCH 09/10] Remove 3D tensor shape assertion from test Remove assertion for 3D tensor shape in frontend test. --- tests/python/relax/test_frontend_from_exported_program.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 51075b9ef73a..dfd1dbc7eb5c 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7424,7 +7424,6 @@ def forward(self, x, l, idx): tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] assert len(tensor_fields) >= 2 - assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields) assert any( len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 for f in tensor_fields From 5fd8971212f41423f6e7c4ea6fb2bb85ca0642a4 Mon Sep 17 00:00:00 2001 From: Neo Chien <6762509+cchung100m@users.noreply.github.com> Date: Mon, 4 May 2026 12:15:12 +0800 Subject: [PATCH 10/10] Implement regression test for M4D index_put Add regression test for M4D module's index_put behavior. --- .../test_frontend_from_exported_program.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dfd1dbc7eb5c..5643f5541fbe 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7430,6 +7430,41 @@ def forward(self, x, l, idx): ) +def test_m4d_diag_index_put_tuple_output_regression(): + class M4D(Module): + def forward(self, x): + b, k, n = 2, 3, 5 + l = x.new_zeros(b, k, n, n) + idx = torch.arange(n, device=x.device) + + diag = l[..., idx, idx] + diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8 + l[..., idx, idx] = diag + + return x[..., :1], l + + ex_in = torch.zeros(2, 3, 5, dtype=torch.float32) + exported_program = export(M4D().eval(), args=(ex_in,)) + + exported_targets = [str(getattr(n, "target", "")) for n in exported_program.graph.nodes] + assert any("index_put" in target for target in exported_targets) + + # Regression focus: importing this graph should not segfault at Tuple construction. + mod = from_exported_program(exported_program) + ret_sinfo = mod["main"].ret_struct_info + assert isinstance(ret_sinfo, relax.TupleStructInfo) + + tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)] + assert len(tensor_fields) >= 2 + # x: (2, 3, 5) → x[..., :1]: (2, 3, 1) + assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields) + # l: (2, 3, 5, 5) → 4-D with spatial dims 5×5 + assert any( + len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5 + for f in tensor_fields + ) + + def test_flip(): class Flip0(Module): def forward(self, data):