diff --git a/coreai_torch/_aten_to_core.py b/coreai_torch/_aten_to_core.py index 11e9530..c9e6f0b 100644 --- a/coreai_torch/_aten_to_core.py +++ b/coreai_torch/_aten_to_core.py @@ -2245,11 +2245,29 @@ def replace_remainder( def replace_repeat(values_map: dict[str, Value], node: fx.Node, loc: Location) -> Value: x = _get_operand(values_map, node, 0) - repeats = np.array(node.args[1], dtype=np.uint32) - extra_dims = len(repeats) - x.type.rank + repeat_args = list(node.args[1]) + extra_dims = len(repeat_args) - x.type.rank if extra_dims > 0: x = coreai.expand_dims(x, list(range(extra_dims))) - return coreai.tile(x, repeats) + + if all(isinstance(r, int) for r in repeat_args): + return coreai.tile(x, np.array(repeat_args, dtype=np.uint32)) + + # At least one repeat is a SymInt fx.Node — build a rank-1 uint32 dim + # vector at runtime, with per-axis constants for plain ints and the + # resolved Value (cast to uint32, lifted to rank-1 if scalar) for + # SymInts. coreai.tile accepts a runtime Value for its dims. + chunks: list[Value] = [] + for r in repeat_args: + if isinstance(r, int): + chunks.append(coreai.constant([r], dtype=np.uint32)) + else: + assert isinstance(r, fx.Node) + v = coreai.cast(values_map[r.name], dtype=np.uint32) + if v.type.rank == 0: + v = coreai.reshape(v, [1]) + chunks.append(v) + return coreai.tile(x, coreai.concat(0, chunks)) def replace_round_decimals( diff --git a/tests/ops/test_ops.py b/tests/ops/test_ops.py index 4a15547..9771619 100644 --- a/tests/ops/test_ops.py +++ b/tests/ops/test_ops.py @@ -3132,6 +3132,61 @@ def forward(self, x: Tensor) -> Tensor: await validate_numerical_output(model=model, x=x, dynamic_shapes=dynamic_shapes) +class TestRepeat: + @pytest.mark.ir + def test_symint_arg_lowers_ir(self) -> None: + """``aten.repeat`` with a SymInt entry in the repeats list (i.e. a + ``torch.fx.Node``, not a plain int) must lower to a dynamic + ``coreai.tile`` whose dim vector is built at runtime.""" + + class RepeatModel(nn.Module): + def forward(self, x: Tensor, y: Tensor) -> Tensor: + return x.repeat(y.shape[0], 1) + + x = torch.rand(2, 3) + y = torch.rand(4, 8) + batch = torch.export.Dim("batch", min=1, max=16) + program = torch.export.export( + RepeatModel(), args=(x, y), dynamic_shapes=({}, {0: batch}) + ).run_decompositions() + + coreai_program = TorchConverter().add_exported_program(program).to_coreai() + filecheck_pattern( + str(coreai_program), + check_file=""" + // CHECK-LABEL: coreai.graph @main + // CHECK-SAME: %arg0: tensor<2x3xf32> + // CHECK-SAME: %arg1: tensor + // CHECK: %[[SHAPE:.+]] = coreai.get_shape %arg1 : tensor -> tensor<2xui32> + // CHECK: %[[SLICE:.+]] = coreai.slice %[[SHAPE]] + // CHECK-SAME: -> tensor<1xui32> + // CHECK: %[[ONE:.+]] = coreai.constant dense<1> : tensor<1xui32> + // CHECK: %[[DIMS:.+]] = coreai.concat {{.*}}, %{{.+}}, %[[ONE]] + // CHECK-SAME: -> tensor<2xui32> + // CHECK: %[[OUT:.+]] = coreai.tile %arg0, %[[DIMS]] + // CHECK: coreai.output %[[OUT]] + """, + ) + + async def test_symint_arg_numerical(self) -> None: + """Numerical validation: ``aten.repeat`` with a SymInt entry in the + repeats list must produce the same result as ``torch.repeat``.""" + + class RepeatModel(nn.Module): + def forward(self, x: Tensor, y: Tensor) -> Tensor: + return x.repeat(y.shape[0], 1) + + x = torch.rand(2, 3) + y = torch.rand(4, 8) + batch = torch.export.Dim("batch", min=1, max=16) + await validate_numerical_output( + model=RepeatModel().eval(), + x=x, + y=y, + dynamic_shapes=({}, {0: batch}), + ) + + @pytest.mark.parametrize("dynamic", [False, True]) @pytest.mark.parametrize( "x",