diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c146cf6c00e3..0d92576c5911 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1802,11 +1802,15 @@ def _flatten(self, node: fx.Node) -> relax.Var: def _flip(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None) - if isinstance(dims, list | tuple) and len(dims) > 0: - dims = dims[0] - elif not isinstance(dims, int): - raise TypeError(f"flip expects an integer axis, but got {type(dims)}: {dims}") - return self.block_builder.emit(relax.op.flip(x, dims)) + if isinstance(dims, int): + dims = [dims] + elif not isinstance(dims, list | tuple): + raise TypeError(f"flip expects an int or list of ints, but got {type(dims)}: {dims}") + # relax.op.flip is single-axis; iterate to honor multi-axis torch.flip semantics. + out = x + for d in dims: + out = self.block_builder.emit(relax.op.flip(out, d)) + return out def _gather(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 602949937247..d5ed2aca7c49 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7441,6 +7441,47 @@ def main( verify_model(Flip1(), example_args, {}, Expected1) +def test_flip_multi_axis(): + class FlipMulti(Module): + def forward(self, data): + return torch.flip(data, [0, 1]) + + class FlipNegMulti(Module): + def forward(self, data): + return torch.flip(data, dims=[-1, -2]) + + @tvm.script.ir_module + class ExpectedMulti: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=0) + lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=1) + gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class ExpectedNegMulti: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tuple(R.Tensor((2, 3), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=-1) + lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=-2) + gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32),) + + verify_model(FlipMulti(), example_args, {}, ExpectedMulti) + verify_model(FlipNegMulti(), example_args, {}, ExpectedNegMulti) + + def test_take(): class Take(Module): def forward(self, data, indices): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 4d9060bf720e..890c6ef3a1ff 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -5862,6 +5862,27 @@ def main( verify_model(Flip1(), [([2, 2], "float32")], {}, Expected1) +def test_flip_multi_axis(): + class FlipMulti(Module): + def forward(self, data): + return torch.flip(data, [0, 1]) + + @tvm.script.ir_module + class ExpectedMulti: + @R.function + def main( + inp_0: R.Tensor((2, 3), dtype="float32"), + ) -> R.Tensor((2, 3), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((2, 3), dtype="float32") = R.flip(inp_0, axis=0) + lv1: R.Tensor((2, 3), dtype="float32") = R.flip(lv, axis=1) + gv: R.Tensor((2, 3), dtype="float32") = lv1 + R.output(gv) + return gv + + verify_model(FlipMulti(), [([2, 3], "float32")], {}, ExpectedMulti) + + def test_take(): class Take(Module): def forward(self, data, indices):