Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/tirx/transform/common_subexpr_elim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
* - It is not a leaf (Var, IntImm, FloatImm, StringImm).
* - It does not contain Call or BufferLoad (side-effects / memory dependence).
* - It is not Ramp or Broadcast (hardware-specific vector ops).
* - It is not bool-typed. Boolean predicates are kept inline because the
* consumer (if / Select / assert) reads more clearly with the condition
* spelled out, and downstream simplification benefits from seeing the
* predicate directly.
*
* Scope tree
* ----------
Expand Down Expand Up @@ -263,6 +267,8 @@ class CSEPlanner : public StmtExprVisitor {
* - Not a Call or BufferLoad (side effects / memory dependence).
* - Not Ramp or Broadcast (hardware-specific vector construction).
* - Does not transitively contain any forbidden node.
* - Is not bool-typed (predicates are kept inline for readability and
* downstream simplification).
*
* \param expr The expression to check.
* \return true if the expression can participate in CSE.
Expand All @@ -274,6 +280,14 @@ class CSEPlanner : public StmtExprVisitor {
}
if (IsForbiddenNode(expr)) return false;
if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
// Reject bool-typed expressions. Boolean predicates almost always feed an
// if / Select / assert, where reading the condition inline is clearer than
// going through a `cse_v: bool = (a < b)` temporary, and where downstream
// simplification (ProveCondition, branch elimination) benefits from seeing
// the predicate directly. BoolImm is already filtered above as an IntImm
// leaf, so this rule only affects compound bool expressions
// (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool).
if (expr.dtype().is_bool()) return false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current check expr.dtype().is_bool() only identifies scalar boolean expressions (typically uint1 with 1 lane). In TIR, boolean predicates are frequently vectorized (e.g., uint1x4 or uint1x8), especially when feeding into vectorized Select nodes.

Since the stated goal is to keep predicates inline to facilitate downstream simplification and readability, this logic should also apply to vectorized boolean expressions. Downstream passes like Simplify or ProveCondition often benefit from seeing the vectorized comparison directly within the Select condition.

Consider checking for a 1-bit width to cover both scalar and vector boolean types.

Suggested change
if (expr.dtype().is_bool()) return false;
if (expr.dtype().bits() == 1) return false;

if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,47 @@ def test_let_floordiv_pattern():
assert "cse_v" not in script, f"CSE incorrectly extracted from Let body:\n{script}"


# =====================================================================
# T22: No lifting of bool predicate (comparison expression)
# A duplicated `i < n` feeds two if-statements. CSE must leave it
# inline rather than hoisting a `cse_v: bool = (i < n)` binding.
# =====================================================================
def test_no_lift_bool_predicate():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((50,), "int32"), n: T.int32, x: T.int32):
for i in range(50):
if i < n:
B[i] = x
if i < n:
B[i] = x + 1

after = tvm.tirx.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Before)
assert "cse_v" not in after["main"].script()


# =====================================================================
# T23: No lifting of bool logical expression (And)
# A duplicated `a && b` feeds two if-statements. CSE must leave it
# inline rather than hoisting a `cse_v: bool = T.And(a, b)` binding.
# =====================================================================
def test_no_lift_bool_logical():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(B: T.Buffer((50,), "int32"), a: T.bool, b: T.bool, x: T.int32):
if T.And(a, b):
B[0] = x
if T.And(a, b):
B[1] = x + 1

after = tvm.tirx.transform.CommonSubexprElim()(Before)
tvm.ir.assert_structural_equal(after, Before)
assert "cse_v" not in after["main"].script()


if __name__ == "__main__":
test_basic()
test_if_single_branch()
Expand All @@ -735,3 +776,5 @@ def test_let_floordiv_pattern():
test_let_value_cse()
test_nested_let_no_extraction()
test_let_floordiv_pattern()
test_no_lift_bool_predicate()
test_no_lift_bool_logical()
Loading