From 2471426bc2d9b9365a92e35625220826263cbe0f Mon Sep 17 00:00:00 2001 From: javierdejesusda Date: Sat, 20 Jun 2026 22:46:20 +0200 Subject: [PATCH] [Relax][PyTorch] Add atan2 converter torch.atan2 was not registered in either the ExportedProgram or FX frontend, so importing a model that uses it failed with an "Unsupported function types" error. The relax.op.atan2 operator already exists and legalizes to topi.atan2, so the frontends only needed to route the op to it. Register atan2 in the FX frontend and atan2.default in the ExportedProgram frontend, reusing the shared _binary_op helper (the same pattern as the existing maximum/minimum/logaddexp converters), and add a structural test in both test_frontend_from_fx.py and test_frontend_from_exported_program.py. --- .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 2 ++ .../test_frontend_from_exported_program.py | 26 +++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 21 +++++++++++++++ 4 files changed, 50 insertions(+) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6c9e3e3f5ef5..b96316adeef3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1603,6 +1603,7 @@ def create_convert_map( "add.Tensor": self._binary_op(relax.op.add, operator.add), "add.Scalar": self._binary_op(relax.op.add, operator.add), "add_.Tensor": self._binary_op(relax.op.add, operator.add), + "atan2.default": self._binary_op(relax.op.atan2, torch.atan2), "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 66d17a58283b..4932871bad6a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -791,6 +791,7 @@ def create_convert_map( ) -> dict[torch.nn.Module | str, Callable[[fx.Node], relax.Var]]: import operator + import torch # type: ignore from torch import nn return { @@ -909,6 +910,7 @@ def create_convert_map( # binary "add": self._binary_op(relax.op.add, operator.add), "and_": self._binary_op(relax.op.bitwise_and, operator.and_), + "atan2": self._binary_op(relax.op.atan2, torch.atan2), "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_), "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), "div": self._div, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ee2f4a8f8df6..dac0bd1e2a83 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1016,6 +1016,32 @@ def main( verify_model(LogAddExp(), example_args, {}, expected) +def test_atan2(): + class Atan2(Module): + def forward(self, lhs, rhs): + return torch.atan2(lhs, rhs) + + @tvm.script.ir_module + class expected: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan2(lhs, rhs) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 10, 10, dtype=torch.float32), + torch.randn(1, 3, 10, 10, dtype=torch.float32), + ) + verify_model(Atan2(), example_args, {}, expected) + + def test_logical_and(): class LogicalAnd(Module): def forward(self, lhs, rhs): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 34da69d5f061..bcb9252b89ae 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5335,6 +5335,27 @@ def main( verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) +def test_atan2(): + class Atan2(Module): + def forward(self, x, y): + return torch.atan2(x, y) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.atan2(inp_0, inp_1) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Atan2(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) + + def test_attention(): @I.ir_module class Expected1: