From 9a7cd3de08bfa5c639ca8fb6b268b2800969e230 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Wed, 6 May 2026 17:41:09 +0900 Subject: [PATCH 1/2] [BugFix][Relax][Torch] Honor `correction` in std/var converter `_var` ignored the `correction` kwarg of `aten.var.correction`. Because `run_decompositions()` rewrites both `aten.std.correction` and `aten.std.dim` into `var.correction + sqrt`, every torch.export std/var call landed in `_var` and was always divided by `n` regardless of correction. As a result `torch.std([[1,2,3,4,5]], dim=1, unbiased=True)` produced 1.4142 (population std) instead of 1.5811 (Bessel std). Route `aten.var.correction` to a dedicated `_var_correction` that reads `correction` (defaulting `None` to 1, matching PyTorch's `Scalar? correction = None` overload) and scales the variance by `n / (n - correction)` after the existing `relax.op.variance`. The legacy fx `_var` path is intentionally left alone to keep this fix narrowly scoped. When `n - correction <= 0`, emit a NaN scale so the output matches PyTorch's documented `max(0, N - correction)` behavior instead of erroring at import time. Updated `test_std`/`test_var` (which encoded the buggy correction=0 behavior for `torch.var(x)` on a (5, 3) input) and added `test_var_correction` covering explicit `correction=2` and `correction=0`. --- .../torch/base_fx_graph_translator.py | 58 +++++++++++++++++++ .../test_frontend_from_exported_program.py | 49 +++++++++++++++- 2 files changed, 104 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index c146cf6c00e3..9c5bbb78816c 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1645,12 +1645,70 @@ def _sum(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim)) def _var(self, node: fx.Node) -> relax.Var: + # `aten.var.correction` (and decomposed `aten.std.*`) carries an + # optional `correction` kwarg whose `None` default means 1 (Bessel). + # Legacy fx `tensor.var(...)` calls go through the original path + # below to keep this fix narrowly scoped. + target = node.target + if getattr(target, "_overloadname", None) == "correction" or getattr( + target, "overload_name", None + ) == "correction": + return self._var_correction(node) args = self.retrieve_args(node) x = args[0] dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.variance(x, dim, keepdims=keepdim)) + def _var_correction(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = node.kwargs.get("keepdim", False) + correction = node.kwargs.get("correction", None) + if correction is None: + correction = 1 + var = self.block_builder.emit(relax.op.variance(x, dim, keepdims=keepdim)) + if correction == 0: + return var + n = self._reduction_size(x, dim) + if n is None: + raise NotImplementedError( + "var/std with non-zero correction requires statically known " + "reduction-axis sizes." + ) + # PyTorch returns NaN (with a warning) when `n - correction <= 0`; + # mirror that semantics rather than failing the import. + if n - correction <= 0: + scale = float("nan") + else: + scale = float(n) / float(n - correction) + return self.block_builder.emit( + relax.op.multiply(var, relax.const(scale, x.struct_info.dtype)) + ) + + @staticmethod + def _reduction_size(x: relax.Expr, dim) -> int | None: + """Static product of reduced-axis sizes; None if any axis is dynamic.""" + shape = x.struct_info.shape + if shape is None: + return None + rank = len(shape) + if dim is None: + axes = list(range(rank)) + elif isinstance(dim, int): + axes = [dim] + else: + axes = list(dim) + n = 1 + for ax in axes: + ax = ax + rank if ax < 0 else ax + s = shape[ax] + if not isinstance(s, tvm.tirx.IntImm): + return None + n *= int(s.value) + return n + def _any(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 602949937247..cb0e4a80a8fe 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7492,6 +7492,7 @@ def main( def test_std(): + # torch.std(x) defaults to correction=1 (Bessel); decomposes to var.correction + sqrt. class Std(Module): def forward(self, x): return torch.std(x) @@ -7504,8 +7505,9 @@ def main( ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False) - lv1: R.Tensor((), dtype="float32") = R.sqrt(lv) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) + lv1: R.Tensor((), dtype="float32") = R.multiply(lv, R.const(15.0 / 14.0, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv2,) R.output(gv) return gv @@ -7514,6 +7516,7 @@ def main( def test_var(): + # torch.var(x) defaults to correction=1 (Bessel). class Var(Module): def forward(self, x): return torch.var(x) @@ -7526,7 +7529,8 @@ def main( ) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((), dtype="float32") = R.variance(x, axis=None, keepdims=False) - gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + lv1: R.Tensor((), dtype="float32") = R.multiply(lv, R.const(15.0 / 14.0, "float32")) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv1,) R.output(gv) return gv @@ -7534,6 +7538,45 @@ def main( verify_model(Var(), example_args, {}, Expected) +def test_var_correction(): + class VarCorrection2(Module): + def forward(self, x): + return torch.var(x, dim=-1, correction=2) + + class VarCorrection0(Module): + def forward(self, x): + return torch.var(x, dim=1, correction=0) + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + x: R.Tensor((2, 5), dtype="float32"), + ) -> R.Tuple(R.Tensor((2,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[-1], keepdims=False) + lv1: R.Tensor((2,), dtype="float32") = R.multiply(lv, R.const(5.0 / 3.0, "float32")) + gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected0: + @R.function + def main( + x: R.Tensor((2, 5), dtype="float32"), + ) -> R.Tuple(R.Tensor((2,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((2,), dtype="float32") = R.variance(x, axis=[1], keepdims=False) + gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 5, dtype=torch.float32),) + verify_model(VarCorrection2(), example_args, {}, Expected2) + verify_model(VarCorrection0(), example_args, {}, Expected0) + + def test_prod(): class Prod(Module): def forward(self, x): From 0cf3be6534a0b0231586b1e3493a371d77411e43 Mon Sep 17 00:00:00 2001 From: Soowon Jeong Date: Wed, 6 May 2026 18:09:48 +0900 Subject: [PATCH 2/2] [Relax][Torch] Tighten dim validation in var/std reduction-size helper Previously `axes = list(dim)` would raise TypeError if dim was not an iterable of Python ints. Validate the type explicitly and return None for the unsupported case (matching the dynamic-shape fallback already in this helper). Also drop the redundant `tvm.` prefix on `tirx.IntImm` to match the file's existing import. --- python/tvm/relax/frontend/torch/base_fx_graph_translator.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 9c5bbb78816c..7296f73e9a0f 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1698,13 +1698,15 @@ def _reduction_size(x: relax.Expr, dim) -> int | None: axes = list(range(rank)) elif isinstance(dim, int): axes = [dim] - else: + elif isinstance(dim, (list, tuple)) and all(isinstance(a, int) for a in dim): axes = list(dim) + else: + return None n = 1 for ax in axes: ax = ax + rank if ax < 0 else ax s = shape[ax] - if not isinstance(s, tvm.tirx.IntImm): + if not isinstance(s, tirx.IntImm): return None n *= int(s.value) return n