[BugFix][Relax][Torch] Honor correction in std/var converter#19512
Merged
tlopex merged 2 commits intoapache:mainfrom May 6, 2026
Merged
[BugFix][Relax][Torch] Honor correction in std/var converter#19512tlopex merged 2 commits intoapache:mainfrom
correction in std/var converter#19512tlopex merged 2 commits intoapache:mainfrom
Conversation
`_var` ignored the `correction` kwarg of `aten.var.correction`. Because `run_decompositions()` rewrites both `aten.std.correction` and `aten.std.dim` into `var.correction + sqrt`, every torch.export std/var call landed in `_var` and was always divided by `n` regardless of correction. As a result `torch.std([[1,2,3,4,5]], dim=1, unbiased=True)` produced 1.4142 (population std) instead of 1.5811 (Bessel std). Route `aten.var.correction` to a dedicated `_var_correction` that reads `correction` (defaulting `None` to 1, matching PyTorch's `Scalar? correction = None` overload) and scales the variance by `n / (n - correction)` after the existing `relax.op.variance`. The legacy fx `_var` path is intentionally left alone to keep this fix narrowly scoped. When `n - correction <= 0`, emit a NaN scale so the output matches PyTorch's documented `max(0, N - correction)` behavior instead of erroring at import time. Updated `test_std`/`test_var` (which encoded the buggy correction=0 behavior for `torch.var(x)` on a (5, 3) input) and added `test_var_correction` covering explicit `correction=2` and `correction=0`.
Contributor
There was a problem hiding this comment.
Code Review
This pull request implements support for the correction argument in torch.var and torch.std operations by calculating a scaling factor based on the reduction size. The feedback identifies the need to handle non-integer dim types in _reduction_size to avoid type errors, resolve the correction argument from the FX graph to handle symbolic values, and maintain naming consistency for tirx.IntImm.
Previously `axes = list(dim)` would raise TypeError if dim was not an iterable of Python ints. Validate the type explicitly and return None for the unsupported case (matching the dynamic-shape fallback already in this helper). Also drop the redundant `tvm.` prefix on `tirx.IntImm` to match the file's existing import.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The PyTorch frontend's
_varignored thecorrectionkwarg ofaten.var.correction.torch.export.run_decompositions()rewritesboth
aten.std.correctionandaten.std.dimintoaten.var.correction(..., correction=<value>) → sqrt, so everytorch.std/torch.varcall lands in_var— but the correctionvalue was dropped on the floor. The variance was therefore always
divided by
nregardless of what the user requested.Minimal repro (vs PyTorch eager):
The same omission shows up for explicit
torch.var(x, correction=k)and any model that relies on the documented Bessel default.
Fix
Route
aten.var.correction(identified byOpOverload._overloadname,not a substring match) to a new
_var_correctionhelper. It readscorrectionfromnode.kwargs, treatsNoneas 1 to match theoverload's
Scalar? correction = Noneschema, and scales theexisting
relax.op.varianceoutput byn / (n - correction)whencorrection != 0.When
n - correction <= 0, the multiplier is set to NaN rather thanraising — this mirrors PyTorch's documented
max(0, N - correction)semantics (eager produces NaN with a warning,not an error).
Reduction-axis sizes are read from
x.struct_info.shape. Dynamicsizes raise
NotImplementedError; static-shape models cover thereal-world
torch.exportflow.The legacy fx path through
_varis intentionally left alone — it hasa separate preexisting bug (it reads
args[2]askeepdimeven whenthat slot is
unbiased), but fixing that here would expand the scopeof this PR beyond the
correctionsemantics.Notes
_stdis also registered for"std.correction"but is unreachable onthe default exported-program path because
aten.std.*alwaysdecomposes to
var.correction + sqrtbefore dispatch. Sparse-tensorexports that skip
run_decompositionsstill hit the old_std; thatpath is out of scope for this fix.
test_std/test_varencoded the buggycorrection=0IRfor
torch.var(x)(which defaults to Bessel) and have been updatedto expect the correct
R.multiply(var, R.const(15/14)). Newtest_var_correctioncovers explicitcorrection=2andcorrection=0.