diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 6c9e3e3f5ef5..b96316adeef3 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1603,6 +1603,7 @@ def create_convert_map( "add.Tensor": self._binary_op(relax.op.add, operator.add), "add.Scalar": self._binary_op(relax.op.add, operator.add), "add_.Tensor": self._binary_op(relax.op.add, operator.add), + "atan2.default": self._binary_op(relax.op.atan2, torch.atan2), "bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_), "bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 66d17a58283b..4932871bad6a 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -791,6 +791,7 @@ def create_convert_map( ) -> dict[torch.nn.Module | str, Callable[[fx.Node], relax.Var]]: import operator + import torch # type: ignore from torch import nn return { @@ -909,6 +910,7 @@ def create_convert_map( # binary "add": self._binary_op(relax.op.add, operator.add), "and_": self._binary_op(relax.op.bitwise_and, operator.and_), + "atan2": self._binary_op(relax.op.atan2, torch.atan2), "bitwise_or_": self._binary_op_inplace(relax.op.bitwise_or, operator.or_), "bitwise_or": self._binary_op(relax.op.bitwise_or, operator.or_), "div": self._div, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index ee2f4a8f8df6..dac0bd1e2a83 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1016,6 +1016,32 @@ def main( verify_model(LogAddExp(), example_args, {}, expected) +def test_atan2(): + class Atan2(Module): + def forward(self, lhs, rhs): + return torch.atan2(lhs, rhs) + + @tvm.script.ir_module + class expected: + @R.function + def main( + lhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + rhs: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.atan2(lhs, rhs) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = ( + torch.randn(1, 3, 10, 10, dtype=torch.float32), + torch.randn(1, 3, 10, 10, dtype=torch.float32), + ) + verify_model(Atan2(), example_args, {}, expected) + + def test_logical_and(): class LogicalAnd(Module): def forward(self, lhs, rhs): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 34da69d5f061..bcb9252b89ae 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5335,6 +5335,27 @@ def main( verify_model(Min(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) +def test_atan2(): + class Atan2(Module): + def forward(self, x, y): + return torch.atan2(x, y) + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((256, 256), dtype="float32") = R.atan2(inp_0, inp_1) + gv: R.Tensor((256, 256), dtype="float32") = lv + R.output(gv) + return gv + + verify_model(Atan2(), [([256, 256], "float32"), ([256, 256], "float32")], {}, Expected1) + + def test_attention(): @I.ir_module class Expected1: