Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,7 +1857,25 @@ def _index_put(self, node: fx.Node) -> relax.Var:
indices = relax.Tuple(processed_indices)
else:
indices = relax.Tuple(indices)
return self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate))

output = self.block_builder.emit(relax.op.index_put(tensor, indices, values, accumulate))

target_name = (
node.target if isinstance(node.target, str) else getattr(node.target, "__name__", "")
)
if target_name.startswith("index_put_") and len(node.args) > 0:
from torch import fx

if isinstance(node.args[0], fx.Node):
# `index_put_` is in-place. If the mutated input is an alias of another
# FX node, later reads via either the alias node or the original node
# must oberve the updated tensor.
aliased_expr = tensor
for env_node, env_expr in list(self.env.items()):
if env_expr is aliased_expr:
self.env[env_node] = output

return output

def _index_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
Expand Down
37 changes: 35 additions & 2 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,40 @@ def _translate_fx_graph(
raise ValueError(f"Unsupported op {node.op}")

assert output_args is not None
return output_args
return self._flatten_output_args(output_args)

@staticmethod
def _flatten_output_args(output_args) -> tuple[relax.Expr, ...]:
"""Flatten output args into a tuple of Relax expressions.

ExportedProgram output trees contain nested Python tuple/list containers
(e.g. mutation outputs + user tuple outputs). Emitting nested Python tuples
directly through FFI may construct invalid Relax tuples.
"""

flattened: list[relax.Expr] = []

def _visit(value):
if isinstance(value, relax.Expr):
flattened.append(value)
elif isinstance(value, list | tuple):
for item in value:
_visit(item)
elif value is None:
# Preserve explicit None outputs as Relax null objects.
flattened.append(relax.op.null_value())
else:
raise ValueError(
"Unsupported output type in exported graph output: "
f"{type(value)}"
)

_visit(output_args)

if not flattened:
raise ValueError("Exported graph produced no Relax outputs")

return tuple(flattened)

def _import_branch_subgraph(
self,
Expand Down Expand Up @@ -1995,7 +2028,7 @@ def from_exported_program(
output_args = self._translate_fx_graph(
exported_program.graph_module, nodes, inputs_vars, custom_ops
)
assert isinstance(output_args, tuple | relax.Tuple)
output_args = self._flatten_output_args(output_args)

if unwrap_unit_return_tuple and len(output_args) == 1:
ret = output_args[0]
Expand Down
108 changes: 108 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -7402,6 +7402,114 @@ def main(x: R.Tensor((2, 10), dtype="float32")) -> R.Tuple(
verify_model(IndexPutBatchedWithNone(), example_args_batched_none, {}, ExpectedBatchedWithNone)


def test_index_put_with_tuple_output():
class IndexPutTupleOutput(Module):
def forward(self, x, l, idx):
values = x
l[..., idx, idx] = values
return x[..., 1], l

example_args = (
torch.ones(2, 3, 5, dtype=torch.float32),
torch.zeros(2, 3, 5, 5, dtype=torch.float32),
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64),
)

exported_program = export(IndexPutTupleOutput(), args=example_args)
mod = from_exported_program(exported_program)

ret_sinfo = mod["main"].ret_struct_info
assert isinstance(ret_sinfo, relax.TupleStructInfo)

tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)]
assert len(tensor_fields) >= 2

assert any(
len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
for f in tensor_fields
)


def test_m4d_diag_index_put_tuple_output_regression():
class M4D(Module):
def forward(self, x):
b, k, n = 2, 3, 5
l = x.new_zeros(b, k, n, n)
idx = torch.arange(n, device=x.device)

diag = l[..., idx, idx]
diag = torch.nn.functional.elu(diag) + 1.0 + 1e-8
l[..., idx, idx] = diag

return x[..., :1], l

ex_in = torch.zeros(2, 3, 5, dtype=torch.float32)
exported_program = export(M4D().eval(), args=(ex_in,))

exported_targets = [str(getattr(n, "target", "")) for n in exported_program.graph.nodes]
assert any("index_put" in target for target in exported_targets)

# Regression focus: importing this graph should not segfault at Tuple construction.
mod = from_exported_program(exported_program)
ret_sinfo = mod["main"].ret_struct_info
assert isinstance(ret_sinfo, relax.TupleStructInfo)

tensor_fields = [f for f in ret_sinfo.fields if isinstance(f, relax.TensorStructInfo)]
assert len(tensor_fields) >= 2
# x: (2, 3, 5) → x[..., :1]: (2, 3, 1)
assert any(len(f.shape) == 3 and int(f.shape[-1]) == 1 for f in tensor_fields)
# l: (2, 3, 5, 5) → 4-D with spatial dims 5×5
assert any(
len(f.shape) == 4 and int(f.shape[-2]) == 5 and int(f.shape[-1]) == 5
for f in tensor_fields
)


def test_index_put_mutation_through_alias_regression():
class IndexPutAlias(Module):
def forward(self, x, idx, values):
y = torch.ops.aten.alias.default(x)
y[idx] = values
return x, y

example_args = (
torch.zeros(5, dtype=torch.float32),
torch.tensor([1, 3], dtype=torch.int64),
torch.tensor([2.0, 4.0], dtype=torch.float32),
)

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((5,), dtype="float32"),
idx: R.Tensor((2,), dtype="int64"),
values: R.Tensor((2,), dtype="float32"),
) -> R.Tuple(
R.Tensor((5,), dtype="float32"),
R.Tensor((5,), dtype="float32"),
R.Tensor((5,), dtype="float32"),
):
with R.dataflow():
lv: R.Tensor((5,), dtype="float32") = R.index_put(
x, (idx,), values, accumulate=False
)
# ExportedProgram may include an additional mutation output.
gv: R.Tuple(
R.Tensor((5,), dtype="float32"),
R.Tensor((5,), dtype="float32"),
R.Tensor((5,), dtype="float32"),
) = (
lv,
lv,
lv,
)
R.output(gv)
return gv

verify_model(IndexPutAlias(), example_args, {}, Expected)


def test_flip():
class Flip0(Module):
def forward(self, data):
Expand Down
Loading