Skip to content
Open
12 changes: 11 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,17 @@ def _index_put(self, node: fx.Node) -> relax.Var:
indices = relax.Tuple(processed_indices)
else:
indices = relax.Tuple(indices)
return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate))

output = self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate))

target_name = node.target if isinstance(node.target, str) else getattr(node.target, "__name__", "")
if target_name.startswith("index_put_") and len(node.args) > 0:
from torch import fx

if isinstance(node.args[0], fx.Node):
self.env[node.args[0]] = output

return output

def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
Expand Down
37 changes: 35 additions & 2 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,40 @@ def _translate_fx_graph(
raise ValueError(f"Unsupported op {node.op}")

assert output_args is not None
return output_args
return self._flatten_output_args(output_args)

@staticmethod
def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]:
"""Flatten output args into a tuple of Relax expressions.

ExportedProgram output trees contain nested Python tuple/list containers
(e.g. mutation outputs + user tuple outputs). Emitting nested Python tuples
directly through FFI may construct invalid Relax tuples.
"""

flattened: list[relax.Expr] = []

def _visit(value):
if isinstance(value, relax.Expr):
flattened.append(value)
elif isinstance(value, list | tuple):
for item in value:
_visit(item)
elif value is None:
# Preserve explicit None outputs as Relax null objects.
flattened.append(relax.op.null_value())
else:
raise ValueError(
"Unsupported output type in exported graph output: "
f"{type(value)}"
)

_visit(output_args)

if not flattened:
raise ValueError("Exported graph produced no Relax outputs")

return tuple(flattened)

def _import_branch_subgraph(
self,
Expand Down Expand Up @@ -1995,7 +2028,7 @@ def from_exported_program(
output_args = self._translate_fx_graph(
exported_program.graph_module, nodes, inputs_vars, custom_ops
)
assert isinstance(output_args, tuple | relax.Tuple)
output_args = self._flatten_output_args(output_args)

if unwrap_unit_return_tuple and len(output_args) == 1:
ret = output_args[0]
Expand Down
63 changes: 63 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7402,6 +7402,69 @@ def main(x: R.Tensor((2, 10), dtype="float32")) -> R.Tuple(
verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone)


def test_index_put_with_tuple_output():
class IndexPutTupleOutput(Module):
def forward(self, x, l, idx):
values = x
l[..., idx, idx] = values
return x[..., 1], l

example_args = (
torch.ones(2, 3, 5, dtype=torch.float32),
torch.zeros(2, 3, 5, 5, dtype=torch.float32),
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64),
)

exported_program = export(IndexPutTupleOutput(), args=example_args)
mod = from_exported_program(exported_program)

ret_sinfo = mod["main"].ret_struct_info
assert isinstance(ret_sinfo, relax.TupleStructInfo)

tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)]
assert len(tensor_fields) >= 2

assert any(
len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
for f in tensor_fields
)


def test_m4d_diag_index_put_tuple_output_regression():
class M4D(Module):
def forward(self, x):
b, k, n = 2, 3, 5
l = x.new_zeros(b, k, n, n)
idx = torch.arange(n, device=x.device)

diag = l[..., idx, idx]
diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8
l[..., idx, idx] = diag

return x[..., :1], l

ex_in = torch.zeros(2, 3, 5, dtype=torch.float32)
exported_program = export(M4D().eval(), args=(ex_in,))

exported_targets = [str(getattr(n, "target", "")) for n in exported_program.graph.nodes]
assert any("index_put" in target for target in exported_targets)

# Regression focus: importing this graph should not segfault at Tuple construction.
mod = from_exported_program(exported_program)
ret_sinfo = mod["main"].ret_struct_info
assert isinstance(ret_sinfo, relax.TupleStructInfo)

tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)]
assert len(tensor_fields) >= 2
# x: (2, 3, 5) → x[..., :1]: (2, 3, 1)
assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields)
# l: (2, 3, 5, 5) → 4-D with spatial dims 5×5
assert any(
len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
for f in tensor_fields
)


def test_flip():
class Flip0(Module):
def forward(self, data):
Expand Down
Loading