Skip to content

[BugFix][Relax][Torch] Honor correction in std/var converter#19512

Merged
tlopex merged 2 commits intoapache:mainfrom
swjng:fix/torch-std-var-correction
May 6, 2026
Merged

[BugFix][Relax][Torch] Honor correction in std/var converter#19512
tlopex merged 2 commits intoapache:mainfrom
swjng:fix/torch-std-var-correction

Conversation

@swjng
Copy link
Copy Markdown
Contributor

@swjng swjng commented May 6, 2026

Motivation

The PyTorch frontend's _var ignored the correction kwarg of
aten.var.correction. torch.export.run_decompositions() rewrites
both aten.std.correction and aten.std.dim into
aten.var.correction(..., correction=<value>) → sqrt, so every
torch.std/torch.var call lands in _var — but the correction
value was dropped on the floor. The variance was therefore always
divided by n regardless of what the user requested.

Minimal repro (vs PyTorch eager):

x = [[1, 2, 3, 4, 5], [2, 2, 2, 2, 2]]
torch.std(x, dim=1, unbiased=True)
  ref: [1.5811, 0.0]   # sqrt(2.5)
  tvm: [1.4142, 0.0]   # sqrt(2.0) -- correction silently set to 0

The same omission shows up for explicit torch.var(x, correction=k)
and any model that relies on the documented Bessel default.

Fix

Route aten.var.correction (identified by OpOverload._overloadname,
not a substring match) to a new _var_correction helper. It reads
correction from node.kwargs, treats None as 1 to match the
overload's Scalar? correction = None schema, and scales the
existing relax.op.variance output by n / (n - correction) when
correction != 0.

When n - correction <= 0, the multiplier is set to NaN rather than
raising — this mirrors PyTorch's documented
max(0, N - correction) semantics (eager produces NaN with a warning,
not an error).

Reduction-axis sizes are read from x.struct_info.shape. Dynamic
sizes raise NotImplementedError; static-shape models cover the
real-world torch.export flow.

The legacy fx path through _var is intentionally left alone — it has
a separate preexisting bug (it reads args[2] as keepdim even when
that slot is unbiased), but fixing that here would expand the scope
of this PR beyond the correction semantics.

Notes

  • _std is also registered for "std.correction" but is unreachable on
    the default exported-program path because aten.std.* always
    decomposes to var.correction + sqrt before dispatch. Sparse-tensor
    exports that skip run_decompositions still hit the old _std; that
    path is out of scope for this fix.
  • Existing test_std/test_var encoded the buggy correction=0 IR
    for torch.var(x) (which defaults to Bessel) and have been updated
    to expect the correct R.multiply(var, R.const(15/14)). New
    test_var_correction covers explicit correction=2 and correction=0.

`_var` ignored the `correction` kwarg of `aten.var.correction`. Because
`run_decompositions()` rewrites both `aten.std.correction` and
`aten.std.dim` into `var.correction + sqrt`, every torch.export std/var
call landed in `_var` and was always divided by `n` regardless of
correction. As a result `torch.std([[1,2,3,4,5]], dim=1, unbiased=True)`
produced 1.4142 (population std) instead of 1.5811 (Bessel std).

Route `aten.var.correction` to a dedicated `_var_correction` that reads
`correction` (defaulting `None` to 1, matching PyTorch's
`Scalar? correction = None` overload) and scales the variance by
`n / (n - correction)` after the existing `relax.op.variance`. The
legacy fx `_var` path is intentionally left alone to keep this fix
narrowly scoped. When `n - correction <= 0`, emit a NaN scale so the
output matches PyTorch's documented `max(0, N - correction)` behavior
instead of erroring at import time.

Updated `test_std`/`test_var` (which encoded the buggy correction=0
behavior for `torch.var(x)` on a (5, 3) input) and added
`test_var_correction` covering explicit `correction=2` and `correction=0`.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for the correction argument in torch.var and torch.std operations by calculating a scaling factor based on the reduction size. The feedback identifies the need to handle non-integer dim types in _reduction_size to avoid type errors, resolve the correction argument from the FX graph to handle symbolic values, and maintain naming consistency for tirx.IntImm.

Comment thread python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Comment thread python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Comment thread python/tvm/relax/frontend/torch/base_fx_graph_translator.py Outdated
Previously `axes = list(dim)` would raise TypeError if dim was not an
iterable of Python ints. Validate the type explicitly and return None
for the unsupported case (matching the dynamic-shape fallback already
in this helper). Also drop the redundant `tvm.` prefix on
`tirx.IntImm` to match the file's existing import.
Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tlopex tlopex merged commit 5a7da7a into apache:main May 6, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants