[Feature] 2-SM support for TMA, TMEM and TCGEN5MMA on Blackwell#1882
[Feature] 2-SM support for TMA, TMEM and TCGEN5MMA on Blackwell#1882LeiWang1999 merged 39 commits intotile-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds optional 2-CTA (cluster) support for Blackwell SM100 TCGEN5MMA: new kernel variants, pass config to disable 2‑CTA, metadata propagation, lowering and transform passes to annotate/use cluster mode, CUDA codegen and device templates updated for cluster-aware TMA/MMA and swizzle, plus runtime selection in examples. Changes
Sequence DiagramsequenceDiagram
participant User as Python Caller
participant TileLang as TileLang Compiler/Passes
participant Meta as TCGEN5 Meta
participant CodeGen as CUDA CodeGen
participant GPU as Device (2 CTAs)
User->>TileLang: JIT compile kernel (pass config)
TileLang->>Meta: get_tcgen5_mma_meta(M,N,K, disable_2cta=...)
Meta-->>TileLang: meta (enable_2cta)
TileLang->>TileLang: LowerBlackwell2SM -> annotate block use_2cta
TileLang->>TileLang: lower_shared_tmem / inject_tma_barrier (propagate use_2cta)
TileLang->>CodeGen: emit IR with cluster_dims / use_2cta annotations
CodeGen->>CodeGen: ClusterInfoExtractor reads cluster_dims
CodeGen->>GPU: launch kernel with cluster config
GPU->>GPU: CTA0/CTA1 coordinate (arrive multicast, tmem alloc/dealloc)
GPU-->>User: results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
1f2c229 to
ad2cdf4
Compare
There was a problem hiding this comment.
Actionable comments posted: 12
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/tl_templates/cuda/instruction/tcgen05mma.h (1)
28-38:⚠️ Potential issue | 🟠 Major
tcgen05mma_tsstill has no 2-CTA implementation.The primary template now exposes
use_2cta, but this file still only defines the default TS specializations. The firsttcgen05mma_ts<..., true>instantiation will fall back to the genericstatic_assertinstead of emitting a 2-CTA instruction.Also applies to: 40-143
src/target/codegen_cuda.cc (1)
2462-2502:⚠️ Potential issue | 🟠 Major
enable_2ctastops at the SS emitter.This change plumbs the new flag only through
ptx_tcgen05_mma_ss. Line 2542 still emitstl::tcgen05mma_ts<(ABType)>unconditionally, so TMEM-backed tcgen05 kernels keep lowering to the 1-CTA path even when the IR requests 2-CTA.Also applies to: 2519-2564
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/target/codegen_cuda.cc` around lines 2462 - 2502, The SS emitter only applies enable_2cta via use_2cta_suffix when building tcgen05_call, but the TS/TMEM-backed path still emits "tl::tcgen05mma_ts<(ABType)>" unconditionally; update the TS emitter (the code that emits the tcgen05mma_ts instantiation) to reuse the same use_2cta_suffix logic so the template becomes "tl::tcgen05mma_ts<(ABType)(USE_2CTA_SUFFIX)>" when enable_2cta is true; locate the other emitter that constructs the tcgen05 call (search for "tcgen05mma_ts" or "tcgen05_name") and apply the same Replacer pattern (register_rule("(USE_2CTA_SUFFIX)", use_2cta_suffix)) so both ptx_tcgen05_mma_ss and the TS/MEM-backed lowering honor the enable_2cta flag.
🧹 Nitpick comments (7)
src/transform/inject_tma_barrier.cc (2)
320-329: Consider adding a clarifying comment about the cluster barrier detection heuristic.The check
barrier_id.as<CallNode>() == nullptrrelies on cluster barriers having aBufferLoadas barrier_id rather than aget_mbarrier()call. A brief comment documenting this assumption would help future maintainers.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 320 - 329, Add a short clarifying comment above the cluster detection around tma_op_to_barrier_id_/barrier_id explaining the heuristic: that cluster barriers use a BufferLoad for barrier_id instead of a get_mbarrier() CallNode, so the code checks barrier_id.as<CallNode>() == nullptr to detect cluster barriers (hence skipping int_set computation and restore_barrier_ids_ handling); reference CallNode, BufferLoad, get_mbarrier, is_cluster, and restore_barrier_ids_ in the comment so future readers understand the assumption and why the branch is skipped.
462-476: Minor: Consider combining the twoPostOrderVisittraversals.There are two separate
PostOrderVisitcalls overf->body(lines 453-461 and 464-476). These could be merged into a single traversal for efficiency, though the impact is minimal since this is compile-time code.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 462 - 476, Combine the two PostOrderVisit traversals over f->body into a single traversal: inside one PostOrderVisit callback detect BlockNode annotations for both "tma_barrier" and "cluster_dims", compute and set cluster_size (previously assigned in the second traversal) and perform the existing tma_barrier handling (previously in the first traversal); update references to the variables like cluster_size and any flag used for the barrier so both pieces of logic run in the same pass and remove the now-redundant second PostOrderVisit call.src/tl_templates/cuda/threadblock_swizzle.h (1)
52-53: Consider usingcutlass::ceil_divfor consistency.The existing non-cluster templates use
cutlass::ceil_div, but the new cluster-aware templates define a local lambda. Consider using the existing utility for consistency.♻️ Suggested change
template <int panel_width, int cluster_dim_x> TL_DEVICE dim3 rasterization2DRowWithCluster() { - auto ceil_div = [](int a, int b) { return (a + b - 1) / b; }; const unsigned int num_cluster_x = gridDim.x / cluster_dim_x; // ... rest unchanged, use cutlass::ceil_div below🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/threadblock_swizzle.h` around lines 52 - 53, Replace the local lambda ceil_div with the existing utility cutlass::ceil_div for consistency: remove the lambda definition and compute num_cluster_x using cutlass::ceil_div(gridDim.x, cluster_dim_x) (referencing the local symbol num_cluster_x and cluster_dim_x) so the code uses the shared cutlass helper instead of a duplicated lambda.src/op/tcgen5_meta.h (1)
40-45: LGTM - 2CTA path for M=128.The 2CTA exploration correctly returns
atom_m=256(two CTAs × 128) withuse_2cta=true. The placement before theM % 128 == 0branch ensures M=128 prefers 2CTA mode when not disabled.The TODO on line 41 suggests expanding 2CTA-preferred shapes. Would you like me to open an issue to track this enhancement?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/tcgen5_meta.h` around lines 40 - 45, Add a follow-up issue to track the TODO about expanding 2CTA-preferred shapes: reference the M == 128 branch and the disable_2cta flag in tcgen5_meta (the loop that calls SUCCESS(256, atom_n, 16, false, true)) and describe the enhancement to evaluate additional atom_m/atom_n combinations and placement heuristics for 2CTA mode; include motivation, acceptance criteria (e.g., new shapes tested, performance thresholds, and unit/regression tests), and link to the exact code block so we can prioritize and assign the work.src/tl_templates/cuda/copy_sm100.h (1)
302-325: Consider extracting common barrier handling to reduce duplication.The barrier address computation logic (lines 308-316) is repeated identically across all five
tma_load_2smoverloads. Consider extracting this to a helper function:♻️ Suggested helper extraction
template <typename BarrierType> TL_DEVICE uint32_t get_peer_masked_barrier_addr(BarrierType &smem_mbar) { uint32_t smem_int_mbar; if constexpr (std::is_pointer_v<BarrierType>) { smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar)); } else { smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar)); } return smem_int_mbar & Sm100MmaPeerBitMask; }This would simplify each overload and ensure consistent behavior.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/copy_sm100.h` around lines 302 - 325, The barrier address calculation in tma_load_2sm is duplicated; extract it into a helper like get_peer_masked_barrier_addr<BarrierType> that takes BarrierType &smem_mbar, calls smem_ptr_to_uint with the pointer or address depending on std::is_pointer_v, masks with Sm100MmaPeerBitMask and returns the uint32_t; then replace the local logic in each tma_load_2sm overload to call get_peer_masked_barrier_addr(smem_mbar) and use that result for smem_int_mbar so behavior is consistent and duplication is removed.tilelang/layout/swizzle.py (1)
111-127: Unusedcontinuityparameter should be removed or deprecated.The
continuityparameter in the function signature is no longer used - the implementation now derivescontinuousdirectly fromshape[-1](line 116). Callers passing a customcontinuityvalue will have it silently ignored.Consider either:
- Removing the parameter if it's no longer needed
- Adding a deprecation warning if backward compatibility is required
♻️ Suggested fix: Remove unused parameter
# for TCGEN05MMA Intrinsics -def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): +def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, k_major: bool = True): buf, shape, _ = _get_buffer_info(buffer)Note: The
k_majorparameter also appears unused in the new implementation. Verify if it's still needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/layout/swizzle.py` around lines 111 - 127, The function make_tcgen05mma_swizzled_layout currently ignores the continuity parameter (it computes continuous = int(shape[-1])) and also doesn't use k_major; remove the unused continuity (and optionally k_major if unused elsewhere) from the signature and all call sites, or if backward compatibility is required keep the parameter but emit a DeprecationWarning and honor a non-None continuity by using it instead of shape[-1]; update any callers, tests, and docstrings to match the chosen approach and ensure references to continuous (the local variable) remain correct.src/transform/lower_shared_tmem.cc (1)
128-139: Share theuse_2ctaattr key instead of duplicating the raw string.This pass hard-codes
"use_2cta"whilelower_blackwell_2sm.ccdefines the same key separately. A future typo or rename in either place will silently drop the TMEM init/dealloc annotation path.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_shared_tmem.cc` around lines 128 - 139, The code hard-codes the "use_2cta" string in lower_shared_tmem.cc for checking and setting annotations (see uses in op->annotations.count(...) and tmem_call_ann.Set(...)); replace those raw string literals with the shared attribute key constant exported by lower_blackwell_2sm.cc (import/include that header/decl) so both files use the same symbol for the use_2cta attr; update the two sites (the count check and the Set call) to reference that shared constant and remove the duplicated raw string.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_2sm.py`:
- Around line 153-159: The test loop masks a correctness bug by using b =
torch.randn(K, 1, device="cuda", dtype=torch.bfloat16).repeat(1,
N).contiguous(); change it to generate full-rank random input with b =
torch.randn(K, N, device="cuda", dtype=torch.bfloat16) (i.e., remove the repeat)
so jit_kernel(a, b) is validated against ref_c = (a @ b).to(torch.bfloat16) on
true random matrices; update or remove the inline comment about the half-correct
case and keep torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) to
catch the bug.
- Around line 133-164: The module currently runs heavy work (building func via
matmul(...), compiling with tilelang.compile, printing kernel source, running
10,000 validation loops and benchmarking via jit_kernel.get_profiler()) at
import time; move all runtime/invocation code into a guarded entrypoint: wrap
the creation of func, the tilelang.compile call, the printing, the for-loop that
generates tensors and asserts, and the profiler/benchmarking calls inside a
main() function and call it only under if __name__ == "__main__": so importing
the file no longer triggers compilation or GPU execution.
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm.py`:
- Around line 79-91: The final writeback/TMA store currently runs when
threadIdx.x == 0 (or similar thread-only guard) but lacks the cluster-leader
check, letting follower CTAs also enqueue the same tma_store and cause duplicate
output and races; update the store guard to require tl::block_rank_in_cluster()
== 0 in addition to the threadIdx.x condition so only the CTA with
tl::block_rank_in_cluster() == 0 performs the tma_store to C_desc (and apply the
same change to the analogous block that starts around the other occurrence
noted, e.g., the block around the tcgen05mma_ss/tmem_full sequence at lines
~108-116), keeping the existing thread/threadIdx.x checks but AND-ing them with
tl::block_rank_in_cluster() == 0.
In `@src/op/gemm_py.cc`:
- Around line 321-324: The C++ binding filters meta using disable_2cta but
GemmPyNode::allowTcgen5Mma() still calls GetTCGEN5MMAMeta(...) without that
flag, causing inconsistent capability checks; update
GemmPyNode::allowTcgen5Mma() to accept and forward the disable_2cta boolean into
its GetTCGEN5MMAMeta(...) invocation (matching the tl.get_tcgen5_mma_meta
binding) so both instruction selection and lowering use the same
disable_2cta-aware meta filtering.
In `@src/tl_templates/cuda/instruction/tcgen05mma.h`:
- Around line 167-184: The specialization tcgen05mma_ss<DataType::kFloat16,
true> drops the tail masks (mask0..mask3) and currently predicates on scalec
only; update the assembly to use the provided mask0..mask3 to predicate partial
tiles so non-full tiles are masked correctly. Specifically, in
tcgen05mma_ss<DataType::kFloat16, true> replace the single predicate setup with
predicate registers derived from mask0..mask3 (e.g., .reg .pred p0..p3 and
setp.ne.b32 p0, mask0, 0; ... for mask1..mask3) and pass the appropriate
predicate(s) into the tcgen05.mma.cta_group::2.kind::f16 instruction operands;
also update the asm input operand list to include mask0..mask3 instead of only
scalec. Ensure the same fix is applied to the analogous block at the later
location (lines 197-203).
In `@src/transform/inject_tma_barrier.cc`:
- Around line 570-573: The assignment to clear_arrive_ from
clear_expect_list_[cur_expect_idx_++] is immediately overwritten by
clear_arrive_ = false, making the conditional dead; decide which behavior you
want and fix accordingly: if clear_arrive_ should always be false (but you still
need to advance cur_expect_idx_), remove the conditional assignment and only
increment cur_expect_idx_ when has_warp_specialization_ is true; if you intended
to use the value from clear_expect_list_, remove the subsequent clear_arrive_ =
false and keep the conditional assignment. Refer to clear_arrive_,
clear_expect_list_, cur_expect_idx_, and has_warp_specialization_ to locate and
update the logic.
- Around line 226-231: The current block handling
tl::ptx_arrive_cluster_barrier() silently ignores cases where call->args[1] is
not an IntImmNode, which causes missed cluster barrier CTA IDs; update the
branch that checks call->args[1].as<IntImmNode>() so that the else path emits a
clear diagnostic (e.g., LOG(WARNING)/ICHECK_MESSAGE or TVM_PLOG) or asserts
depending on desired strictness, referencing the same symbols (call,
op.same_as(tl::ptx_arrive_cluster_barrier()), call->args[1], IntImmNode,
cluster_barrier_cta_ids_.Set, barrier_id) so non-constant cta_id cases are
reported instead of silently skipped.
In `@src/transform/lower_blackwell_2sm.cc`:
- Around line 60-63: GetTCGEN5MMAMeta is being called with node->a_->dtype which
is wrong for TS kernels because A (TMEM) carries the accumulator dtype; update
the GetTCGEN5MMAMeta call to use the same AB dtype rule as allowTcgen5Mma() —
i.e. pass A dtype as (is-TS ? node->c_->dtype : node->a_->dtype) while keeping C
dtype as node->c_->dtype — so the meta lookup can correctly detect enable_2cta
and allow setting use_2cta.
- Around line 86-99: The bug is that VisitStmt_(const BlockRealizeNode *op)
annotates the first post-order BlockRealize visited (inner), because it calls
the base VisitStmt_ first; to fix, annotate the block on the current op before
recursing so the outer/root BlockRealize gets the kUse2Cta annotation: in
VisitStmt_(const BlockRealizeNode *op) check root_block_annotated_ and if false
make a writable copy of op->block, set n->annotations.Set(attr::kUse2Cta,
IntImm(DataType::Int(32),1)), set root_block_annotated_ = true, then continue
with/return the usual traversal (or call the base VisitStmt_ after annotation)
so that lower_shared_tmem.cc will see the annotation on the owning/root block.
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 629-631: The call to get_tcgen5_mma_meta applies the 2CTA policy
globally but tcgen05mma_ts still lowers to a 1CTA sequence and never receives
the _enable_2cta state; update the implementation so the 2CTA flag is
consistently propagated and honored by the TS lowering: either stop applying
2CTA in get_tcgen5_mma_meta and let tcgen05mma_ts query pass_ctx itself, or
change tcgen05mma_ts to accept/obtain the same disable_2cta/_enable_2cta flag
(from pass_ctx.PassConfigKey.TL_DISABLE_2CTA_TCGEN5MMA) and use it when
computing geometry/atom_m/ n_dim and when emitting arrive/intrinsic sequences
(and ensure any gating in gemm_tcgen05.py matches this flag). Ensure symbols to
modify include get_tcgen5_mma_meta, tcgen05mma_ts, and the handling of
_enable_2cta/disable_2cta so TS kernels emit the correct 2CTA-aware lowering.
In `@tilelang/language/tir/op.py`:
- Around line 1165-1167: The docstring for the intrinsic in
tilelang/language/tir/op.py is out of date: it still says "14 or 15 positional
arguments" and lists the argument tuple without the newly added enable_2cta;
update that docstring to state "15 or 16 positional arguments" and adjust the
listed positional arguments to include enable_2cta as the final (16th) optional
parameter (i.e., (kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr,
C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws,
enable_2cta]) so the text and example match the current function signature.
In `@tilelang/tileop/gemm/gemm_tcgen05.py`:
- Around line 142-157: The cluster guard uses the TVM Integer enable_2cta in
boolean logic causing incorrect behavior; coerce enable_2cta to a native Python
int before using it in cluster_cond (e.g., compute an int_enable_2cta =
int(enable_2cta) right after obtaining it from get_tcgen5_mma_meta()) and then
use not int_enable_2cta in the cluster_cond expression used by
_gemm_ss_elect_one_thread and _gemm_ss so the rank check is evaluated correctly
in 1-CTA vs multi-CTA modes.
---
Outside diff comments:
In `@src/target/codegen_cuda.cc`:
- Around line 2462-2502: The SS emitter only applies enable_2cta via
use_2cta_suffix when building tcgen05_call, but the TS/TMEM-backed path still
emits "tl::tcgen05mma_ts<(ABType)>" unconditionally; update the TS emitter (the
code that emits the tcgen05mma_ts instantiation) to reuse the same
use_2cta_suffix logic so the template becomes
"tl::tcgen05mma_ts<(ABType)(USE_2CTA_SUFFIX)>" when enable_2cta is true; locate
the other emitter that constructs the tcgen05 call (search for "tcgen05mma_ts"
or "tcgen05_name") and apply the same Replacer pattern
(register_rule("(USE_2CTA_SUFFIX)", use_2cta_suffix)) so both ptx_tcgen05_mma_ss
and the TS/MEM-backed lowering honor the enable_2cta flag.
---
Nitpick comments:
In `@src/op/tcgen5_meta.h`:
- Around line 40-45: Add a follow-up issue to track the TODO about expanding
2CTA-preferred shapes: reference the M == 128 branch and the disable_2cta flag
in tcgen5_meta (the loop that calls SUCCESS(256, atom_n, 16, false, true)) and
describe the enhancement to evaluate additional atom_m/atom_n combinations and
placement heuristics for 2CTA mode; include motivation, acceptance criteria
(e.g., new shapes tested, performance thresholds, and unit/regression tests),
and link to the exact code block so we can prioritize and assign the work.
In `@src/tl_templates/cuda/copy_sm100.h`:
- Around line 302-325: The barrier address calculation in tma_load_2sm is
duplicated; extract it into a helper like
get_peer_masked_barrier_addr<BarrierType> that takes BarrierType &smem_mbar,
calls smem_ptr_to_uint with the pointer or address depending on
std::is_pointer_v, masks with Sm100MmaPeerBitMask and returns the uint32_t; then
replace the local logic in each tma_load_2sm overload to call
get_peer_masked_barrier_addr(smem_mbar) and use that result for smem_int_mbar so
behavior is consistent and duplication is removed.
In `@src/tl_templates/cuda/threadblock_swizzle.h`:
- Around line 52-53: Replace the local lambda ceil_div with the existing utility
cutlass::ceil_div for consistency: remove the lambda definition and compute
num_cluster_x using cutlass::ceil_div(gridDim.x, cluster_dim_x) (referencing the
local symbol num_cluster_x and cluster_dim_x) so the code uses the shared
cutlass helper instead of a duplicated lambda.
In `@src/transform/inject_tma_barrier.cc`:
- Around line 320-329: Add a short clarifying comment above the cluster
detection around tma_op_to_barrier_id_/barrier_id explaining the heuristic: that
cluster barriers use a BufferLoad for barrier_id instead of a get_mbarrier()
CallNode, so the code checks barrier_id.as<CallNode>() == nullptr to detect
cluster barriers (hence skipping int_set computation and restore_barrier_ids_
handling); reference CallNode, BufferLoad, get_mbarrier, is_cluster, and
restore_barrier_ids_ in the comment so future readers understand the assumption
and why the branch is skipped.
- Around line 462-476: Combine the two PostOrderVisit traversals over f->body
into a single traversal: inside one PostOrderVisit callback detect BlockNode
annotations for both "tma_barrier" and "cluster_dims", compute and set
cluster_size (previously assigned in the second traversal) and perform the
existing tma_barrier handling (previously in the first traversal); update
references to the variables like cluster_size and any flag used for the barrier
so both pieces of logic run in the same pass and remove the now-redundant second
PostOrderVisit call.
In `@src/transform/lower_shared_tmem.cc`:
- Around line 128-139: The code hard-codes the "use_2cta" string in
lower_shared_tmem.cc for checking and setting annotations (see uses in
op->annotations.count(...) and tmem_call_ann.Set(...)); replace those raw string
literals with the shared attribute key constant exported by
lower_blackwell_2sm.cc (import/include that header/decl) so both files use the
same symbol for the use_2cta attr; update the two sites (the count check and the
Set call) to reference that shared constant and remove the duplicated raw
string.
In `@tilelang/layout/swizzle.py`:
- Around line 111-127: The function make_tcgen05mma_swizzled_layout currently
ignores the continuity parameter (it computes continuous = int(shape[-1])) and
also doesn't use k_major; remove the unused continuity (and optionally k_major
if unused elsewhere) from the signature and all call sites, or if backward
compatibility is required keep the parameter but emit a DeprecationWarning and
honor a non-None continuity by using it instead of shape[-1]; update any
callers, tests, and docstrings to match the chosen approach and ensure
references to continuous (the local variable) remain correct.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 76078a38-8a04-4c49-a1c6-a66fa967c0c0
📒 Files selected for processing (30)
examples/gemm_sm100/gemm_tcgen5mma_2sm.pyexamples/gemm_sm100/gemm_tcgen5mma_ws.pyexamples/gemm_sm100/gemm_tcgen5mma_ws_2sm.pyexamples/gemm_sm100/gemm_tcgen5mma_ws_persistent.pysrc/op/builtin.ccsrc/op/builtin.hsrc/op/copy.ccsrc/op/gemm_py.ccsrc/op/tcgen5_meta.hsrc/target/codegen_cuda.ccsrc/target/codegen_cuda.hsrc/tl_templates/cuda/cluster.hsrc/tl_templates/cuda/common.hsrc/tl_templates/cuda/copy_sm100.hsrc/tl_templates/cuda/instruction/tcgen05mma.hsrc/tl_templates/cuda/tcgen_05.hsrc/tl_templates/cuda/threadblock_swizzle.hsrc/transform/inject_tma_barrier.ccsrc/transform/lower_blackwell_2sm.ccsrc/transform/lower_shared_tmem.cctilelang/engine/phase.pytilelang/intrinsics/tcgen05_macro_generator.pytilelang/language/annotations.pytilelang/language/builtin.pytilelang/language/cluster.pytilelang/language/tir/op.pytilelang/layout/swizzle.pytilelang/tileop/gemm/gemm_tcgen05.pytilelang/transform/__init__.pytilelang/transform/pass_config.py
| if (((int)threadIdx.x) < 64 && tl::block_rank_in_cluster() == 0) { | ||
| for (int k_1 = 0; k_1 < 128; ++k_1) { | ||
| loaded[(k_1 & 3)].wait(((k_1 >> 2) & 1)); | ||
| tl::initialize_tcgen05_descriptor(desc_a, (&(((bfloat16_t*)buf_dyn_shmem)[((k_1 & 3) * 8192)])), 1, 64, 0, 0, 2); | ||
| tl::initialize_tcgen05_descriptor(desc_b, (&(((bfloat16_t*)buf_dyn_shmem)[(((k_1 & 3) * 8192) + 32768)])), 512, 64, 0, 0, 2); | ||
| #pragma unroll | ||
| for (int ki = 0; ki < 4; ++ki) { | ||
| tl::tcgen05mma_ss<tl::DataType::kBFloat16, true>(uint64_t(desc_a + (ki * 32)), uint64_t(desc_b + (ki * 2048)), (*reinterpret_cast<uint32_t*>(C_tmem)) + 0, ((0 < ki) ? 1 : ((k_1 == 0) ? 0 : 1)), static_cast<uint32_t>(272696464), 0, 0, 0, 0); | ||
| } | ||
| tl::tcgen05_mma_arrive<true>((&(consumed[(k_1 & 3)]))); | ||
| } | ||
| tl::tcgen05_mma_arrive<true>((&(tmem_full[0]))); | ||
| } |
There was a problem hiding this comment.
Only the leader CTA should own the final writeback.
Lines 79-91 already restrict the tcgen05 MMA path to block_rank_in_cluster() == 0, but this store block is guarded only by threadIdx.x == 0. The follower CTA will still enqueue the same four tma_stores to C_desc, which doubles output traffic and creates a cross-CTA write race.
Suggested fix
- if (((int)threadIdx.x) == 0) {
+ if (((int)threadIdx.x) == 0 && tl::block_rank_in_cluster() == 0) {Also applies to: 108-116
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm.py` around lines 79 - 91, The final
writeback/TMA store currently runs when threadIdx.x == 0 (or similar thread-only
guard) but lacks the cluster-leader check, letting follower CTAs also enqueue
the same tma_store and cause duplicate output and races; update the store guard
to require tl::block_rank_in_cluster() == 0 in addition to the threadIdx.x
condition so only the CTA with tl::block_rank_in_cluster() == 0 performs the
tma_store to C_desc (and apply the same change to the analogous block that
starts around the other occurrence noted, e.g., the block around the
tcgen05mma_ss/tmem_full sequence at lines ~108-116), keeping the existing
thread/threadIdx.x checks but AND-ing them with tl::block_rank_in_cluster() ==
0.
| auto [ok, meta] = | ||
| GetTCGEN5MMAMeta(node->m_, node->n_, node->k_, | ||
| node->a_->dtype, node->c_->dtype); | ||
| if (ok && meta.enable_2cta) { |
There was a problem hiding this comment.
Use the same AB dtype rule as allowTcgen5Mma().
For TS kernels, A is TMEM and carries accumulator dtype, so GetTCGEN5MMAMeta(..., node->a_->dtype, node->c_->dtype) asks the wrong meta table. That can miss valid 2CTA TS GEMMs and skip the use_2cta annotation entirely.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/transform/lower_blackwell_2sm.cc` around lines 60 - 63, GetTCGEN5MMAMeta
is being called with node->a_->dtype which is wrong for TS kernels because A
(TMEM) carries the accumulator dtype; update the GetTCGEN5MMAMeta call to use
the same AB dtype rule as allowTcgen5Mma() — i.e. pass A dtype as (is-TS ?
node->c_->dtype : node->a_->dtype) while keeping C dtype as node->c_->dtype — so
the meta lookup can correctly detect enable_2cta and allow setting use_2cta.
| Expects 14 or 15 positional arguments: | ||
| (kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, | ||
| desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]). |
There was a problem hiding this comment.
Docstring needs update for new argument count.
The docstring says "14 or 15 positional arguments" but with the addition of enable_2cta, the intrinsic now accepts 15 or 16 arguments (the 15th being enable_ws and the 16th being enable_2cta).
📝 Suggested docstring fix
- Expects 14 or 15 positional arguments:
+ Expects 15 or 16 positional arguments:
(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset,
- desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]).
+ desc_val, scale_out, mask0, mask1, mask2, mask3, enable_ws[, enable_2cta]).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Expects 14 or 15 positional arguments: | |
| (kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, | |
| desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws]). | |
| Expects 15 or 16 positional arguments: | |
| (kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, | |
| desc_val, scale_out, mask0, mask1, mask2, mask3, enable_ws[, enable_2cta]). |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/tir/op.py` around lines 1165 - 1167, The docstring for the
intrinsic in tilelang/language/tir/op.py is out of date: it still says "14 or 15
positional arguments" and lists the argument tuple without the newly added
enable_2cta; update that docstring to state "15 or 16 positional arguments" and
adjust the listed positional arguments to include enable_2cta as the final
(16th) optional parameter (i.e., (kind_dtype, desc_a, A_offset, desc_b,
B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3[,
enable_ws, enable_2cta]) so the text and example match the current function
signature.
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/copy.cc (1)
1681-1713:⚠️ Potential issue | 🟠 MajorPropagate
use_2ctathrough every emittedtma_load.This flag is only attached in the unsplit branch, and the current check turns any present key into
1. Split bulk loads, or a storeduse_2cta = 0, will therefore reach codegen with the wrong 2CTA state.🛠️ Suggested fix
+ Map<String, ObjectRef> ann; + if (TargetIsSm100(T.target) && is_load && annotations.count("use_2cta")) { + PrimExpr use_2cta = Downcast<PrimExpr>(annotations["use_2cta"]); + if (const auto *imm = use_2cta.as<IntImmNode>(); imm && imm->value != 0) { + ann.Set("use_2cta", IntImm(DataType::Int(32), 1)); + } + } + if ((*inner_box_dim) != instruction_dim) { Var loop_var("i"); int loop_extent = (*inner_box_dim) / instruction_dim; @@ - tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, - Evaluate(Call(DataType::Handle(), op, args))); + tma_copy = For(loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), op, args, ann))); } else { @@ - Map<String, ObjectRef> ann; - if (TargetIsSm100(T.target) && is_load && - annotations.find("use_2cta") != annotations.end()) { - ann.Set("use_2cta", IntImm(DataType::Int(32), 1)); - } tma_copy = Evaluate(Call(DataType::Handle(), op, args, ann)); }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/copy.cc` around lines 1681 - 1713, The current code only sets the "use_2cta" annotation for the unsplit tma_load branch and coerces any presence to 1, so split (loop) loads or explicit use_2cta=0 are lost; update both branches that construct the tma_call (the For unrolled branch that assigns tma_copy and the unsplit branch where ann is built) to read the existing annotations map for "use_2cta" and propagate its actual IntImm value into the Call's annotation map (for the loop branch create and pass an ann map the same way you do in the unsplit branch), and only set "use_2cta" to 1 when it is absent and TargetIsSm100(T.target) && is_load requires it, preserving any preexisting 0/1 value.
♻️ Duplicate comments (6)
tilelang/language/tir/op.py (1)
1157-1160:⚠️ Potential issue | 🟠 MajorUpdate docstring and consider protecting optional parameters from accidental positional binding.
The new
enable_2ctaparameter insertion shifts positional binding forwsandwarp_specialized. More critically, the docstring at lines 1165-1167 is now outdated—it claims the optional 15th positional argument isenable_ws, but with the new parameter order, passing 15 positional arguments now binds toenable_2ctainstead.The docstring should be updated to document the new parameter and clarify which parameters are intended for positional vs. keyword-only use. Alternatively, use a
*separator before the optional parameters to enforce keyword-only access and prevent silent positional binding shifts in the future:🛠️ Suggested fix
def ptx_tcgen05_mma_ss( kind_dtype, desc_a, @@ mask1, mask2, mask3, + *, enable_ws=False, enable_2cta=False, ws=None, warp_specialized=None, variant=None, ):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/tir/op.py` around lines 1157 - 1160, The docstring and parameter order were changed when adding the new enable_2cta parameter, so update the docstring to reflect that the 15th positional argument is now enable_2cta (not enable_ws) and clarify which parameters are intended to be positional vs keyword-only; to prevent future accidental positional binding, make the optional flags keyword-only by adding a positional-only separator (*) before enable_ws (so enable_ws, enable_2cta, ws, and warp_specialized must be passed by keyword), and adjust the docstring text to list and describe enable_2cta, enable_ws, ws, and warp_specialized in their final order and note that they are keyword-only.tilelang/tileop/gemm/gemm_tcgen05.py (1)
142-157:⚠️ Potential issue | 🟠 MajorConvert
enable_2ctabefore using it in this boolean guard.This guard relies on Python truthiness, but
get_tcgen5_mma_meta()is consumed as boxed TVM scalars elsewhere in this PR. Ifenable_2ctastays boxed here,not enable_2ctacan test object truthiness instead of the flag value and disable the 1CTA path incorrectly.Verify by comparing this call site with other
get_tcgen5_mma_meta()consumers; the expected result is that callers coerce the returned elements withint(...)before boolean logic.#!/bin/bash set -euo pipefail sed -n '100,160p' tilelang/tileop/gemm/gemm_tcgen05.py sed -n '620,635p' tilelang/intrinsics/tcgen05_macro_generator.py rg -n "get_tcgen5_mma_meta\(" --type py -C 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_tcgen05.py` around lines 142 - 157, The boolean guard uses enable_2cta which may be a boxed TVM scalar; coerce it to a Python int before using in cluster_cond (e.g. replace enable_2cta with int(enable_2cta)) so not enable_2cta evaluates the actual flag value, ensuring the 1CTA path is not incorrectly disabled; update the cluster_cond definition referenced by _gemm_ss_elect_one_thread and _gemm_ss and ensure any similar uses (e.g. comparisons with warp_size or thread_bounds.extent) follow the same int(...) coercion pattern.examples/gemm_sm100/gemm_tcgen5mma_ws_2sm.py (1)
108-116:⚠️ Potential issue | 🟠 MajorOnly the leader CTA should own the final writeback.
Both CTAs wait on
tmem_fulland then execute these stores, so the follower duplicates output traffic and races the leader onC. Gate both the raw CUDA path and the TileLang writeback withblock_rank_in_cluster() == 0/cta_id == 0.Possible fix
- if (((int)threadIdx.x) == 0) { + if (((int)threadIdx.x) == 0 && tl::block_rank_in_cluster() == 0) { tl::fence_proxy_async(); tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[0])), (((int)blockIdx.y) * 256), (((int)blockIdx.x) * 128)); tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[8192])), ((((int)blockIdx.y) * 256) + 64), (((int)blockIdx.x) * 128)); tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[16384])), ((((int)blockIdx.y) * 256) + 128), (((int)blockIdx.x) * 128)); tl::tma_store(C_desc, (&(((bfloat16_t*)buf_dyn_shmem)[24576])), ((((int)blockIdx.y) * 256) + 192), (((int)blockIdx.x) * 128)); tl::tma_store_arrive(); tl::tma_store_wait<0>(); }- if use_tma_store: + if cta_id == 0 and use_tma_store: T.copy(C_local, C_shared) T.copy(C_shared, C[bx * block_M, by * block_N]) - else: + elif cta_id == 0: T.copy(C_local, C_local_cast) T.copy(C_local_cast, C[bx * block_M, by * block_N])Also applies to: 178-183
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm.py` around lines 108 - 116, The final writeback to C is executed by every CTA leader causing duplicate writes and races; guard the TileLang writeback sequence (tl::fence_proxy_async(), tl::tma_store(...) calls, tl::tma_store_arrive(), tl::tma_store_wait<0>()) so it only runs when the CTA is the cluster leader (check block_rank_in_cluster() == 0 or cta_id == 0) and apply the same guard to the equivalent raw CUDA writeback path; use the cluster-leader condition around the blocks that reference C_desc and buf_dyn_shmem to ensure only the single leader performs the stores.src/transform/lower_blackwell_2sm.cc (2)
60-63:⚠️ Potential issue | 🟠 MajorUse the same AB-dtype rule here as
allowTcgen5Mma().For TS kernels,
Alives in TMEM and carries the accumulator dtype, soGetTCGEN5MMAMeta(..., node->a_->dtype, node->c_->dtype)queries the wrong table. That can miss valid 2CTA TS GEMMs and skip theuse_2ctaannotation entirely.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_blackwell_2sm.cc` around lines 60 - 63, GetTCGEN5MMAMeta is being called with node->a_->dtype and node->c_->dtype which is wrong for TS kernels—follow the same AB-dtype rule used in allowTcgen5Mma: determine the A/B dtypes using the same logic (i.e., for TS kernels treat A as carrying the accumulator dtype / substitute the accumulator dtype where allowTcgen5Mma does) and pass those resolved A/B dtypes into GetTCGEN5MMAMeta (then preserve the existing meta.enable_2cta -> use_2cta behavior).
86-99:⚠️ Potential issue | 🟠 MajorThis still annotates the first post-order
BlockRealize, not the owning/root block.Because the mutator recurses before setting
root_block_annotated_, the first innerBlockRealizewins. If the TMEM alloc/deallocs belong to the outer kernel block,lower_shared_tmemnever seesuse_2ctawhere it needs it.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_blackwell_2sm.cc` around lines 86 - 99, The bug annotates the first post-order (inner) BlockRealize because VisitStmt_(const BlockRealizeNode *op) calls the base visitor before marking root_block_annotated_; fix by annotating the owning/root block before recursing: inspect op (the incoming BlockRealizeNode) and, if root_block_annotated_ is false, copy-and-set the annotation (use block = op->block; BlockNode *n = block.CopyOnWrite(); n->annotations.Set(attr::kUse2Cta, IntImm(...)); root_block_annotated_ = true) on that block first, then call StmtExprMutator::VisitStmt_(op) (or construct the visited BlockRealize from the visited children) so lower_shared_tmem will see attr::kUse2Cta on the actual kernel/owning block.tilelang/intrinsics/tcgen05_macro_generator.py (1)
629-631:⚠️ Potential issue | 🟠 MajorDon’t surface 2CTA meta to TS callers until TS lowering consumes it.
get_tcgen5_mma_meta()now enables 2CTA globally, buttcgen05mma_ts()still keeps its 1CTA geometry and ends with a 1CTAT.tcgen05_mma_arrive(mbar). A TS kernel that getsenable_2cta=1will therefore be cluster-gated upstream yet emit the old single-CTA sequence.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 629 - 631, The TS-facing generator is returning 2CTA-enabled metadata even though tcgen05mma_ts still emits 1CTA sequences; to avoid exposing 2CTA to TS callers, change the call to _ffi_api.get_tcgen5_mma_meta inside the TS generator so it always requests 1CTA meta. Specifically, in tcgen05_macro_generator.py update the code path used by tcgen05mma_ts to pass disable_2cta=True (rather than reading PassConfigKey.TL_DISABLE_2CTA_TCGEN5MMA from pass_ctx) when calling get_tcgen5_mma_meta so TS callers always receive 1CTA metadata until TS lowering supports 2CTA.
🧹 Nitpick comments (4)
src/tl_templates/cuda/copy_sm100.h (1)
302-430: Consider extracting common barrier setup logic.The barrier address computation (lines 310-316, 335-341, etc.) is duplicated across all five overloads. A small inline helper could reduce repetition.
♻️ Suggested helper
template <typename BarrierType> TL_DEVICE uint32_t get_2sm_mbar_addr(BarrierType &smem_mbar) { uint32_t smem_int_mbar; if constexpr (std::is_pointer_v<BarrierType>) { smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(smem_mbar)); } else { smem_int_mbar = smem_ptr_to_uint(reinterpret_cast<uint64_t *>(&smem_mbar)); } return smem_int_mbar & Sm100MmaPeerBitMask; }🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/copy_sm100.h` around lines 302 - 430, The five tma_load_2sm overloads duplicate barrier address computation; extract that logic into a small inline helper (e.g., get_2sm_mbar_addr<BarrierType>(BarrierType &smem_mbar)) that performs the if constexpr pointer check, calls smem_ptr_to_uint, masks with Sm100MmaPeerBitMask, and returns the uint32_t, then replace the repeated block in each tma_load_2sm overload with a single call to this helper to reduce duplication and improve readability.examples/gemm_sm100/gemm_tcgen5mma_2sm.py (2)
9-11: Unusedcodeparameter is intentional but could use a clearer name.The
codeargument is intentionally ignored since this postproc replaces the generated code entirely with a hardcoded kernel. Consider renaming to_codeto suppress the linter warning and signal intent.`@register_cuda_postproc` -def _(code, _): +def _(_code, _): return r"""🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_2sm.py` around lines 9 - 11, The anonymous postprocessing function registered with register_cuda_postproc currently takes a parameter named code that is intentionally unused; rename that parameter to _code in the function signature (the function defined as def _(code, _):) to convey intent and silence linters, keeping the second parameter (_) unchanged and ensuring no other references to the old name exist in the function body.
133-134: FIXME indicates incomplete implementation.The comment "buggy when size is larger" suggests this 2-SM path has correctness issues at larger problem sizes. This should be tracked and resolved before the feature is considered complete.
Would you like me to open an issue to track this size-dependent bug?
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_2sm.py` around lines 133 - 134, The FIXME comment indicates a size-dependent correctness bug when M, N, K are large (variables M, N, K in the global scope and the 2-SM code path triggered thereafter); remove the FIXME by creating a tracked issue and add a runtime guard and test: add a failing unit/integration test that reproduces the bug for large M/N/K, open an issue referencing the test and the 2-SM implementation (the same function/class that contains the 2-SM path), and add a temporary runtime check that prevents selecting the 2-SM path for sizes known to fail (with a clear TODO linking the issue) so behavior is deterministic until the root cause is fixed.src/transform/inject_tma_barrier.cc (1)
554-568: Consider caching cluster_size in block annotations to avoid repeated traversal.The
PostOrderVisittraverses the entire function body to findcluster_dims. If this pass runs frequently or on large functions, consider extracting this to a shared utility or caching the result.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 554 - 568, Compute and cache the derived cluster size instead of re-traversing f->body: before calling PostOrderVisit look up a cached value on the block annotations (e.g., check block->annotations for a "cluster_size" entry) and only run the traversal to compute cluster_size from the "cluster_dims" annotation if the cache is missing; after computing sz from the Array<Integer> from block->annotations.Get("cluster_dims"), store sz back into block->annotations under "cluster_size" so subsequent uses (and functions that use PostOrderVisit/cluster_dims) can read the cached integer and skip the expensive traversal.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm.py`:
- Around line 13-15: The postproc function named "_" currently ignores its
inputs and returns a hard-coded kernel string, which breaks callers that pass
different block_M, block_N, block_K, num_stages, dtypes, or use_tma_store;
either enforce fixed-shape usage by validating those parameters up front (assert
exact expected values for
block_M/block_N/block_K/num_stages/dtypes/use_tma_store and raise a clear error)
or make the returned source parametric by using the incoming arguments (use the
provided code string or format/templating to inject the actual block_* and dtype
settings into the returned CUDA source). Locate and update the "_" function to
implement one of these fixes so the generated kernel always matches the
requested configuration.
In `@src/tl_templates/cuda/common.h`:
- Around line 62-72: The TILELANG_UNREACHABLE macro currently checks
__CUDA_ARCH__ before __CUDACC_RTC__, so NVRTC builds (which define both) take
the printf/__trap branch and fail because <cstdio> is excluded; reorder the
preprocessor checks in src/tl_templates/cuda/common.h so the __CUDACC_RTC__
check comes first and defines TILELANG_UNREACHABLE to use __builtin_trap(),
leaving the __CUDA_ARCH__ branch after it to retain the printf + __trap()
behavior for real device builds.
In `@src/transform/lower_blackwell_2sm.cc`:
- Around line 77-78: The single field has_2sm_tcgen5mma_ is too coarse and
causes cross-contamination between blocks; replace it with a scope-local flag
(e.g., a bool parameter or a per-block keyed map) that is set and checked per
block/PrimFunc during traversal instead of a single class-wide member
(references: has_2sm_tcgen5mma_, Target target_); propagate this local flag
through the helper methods that perform TMEM alloc/dealloc lowering so 1CTA vs
2CTA decisions are made per-block, and add an explicit detection path that
rejects/makes an error when a single PrimFunc mixes 1CTA and 2CTA TCGEN5MMA
(also update the logic around the occurrences noted at lines 121-124 to use the
new per-block/state flag).
In `@tilelang/layout/swizzle.py`:
- Around line 111-127: The function make_tcgen05mma_swizzled_layout currently
ignores the continuity and k_major parameters and will IndexError on 1D buffers
via shape[-2]; update it to (1) emit a deprecation/warning when continuity or
k_major are passed (use the continuity and k_major parameter names in the
warning so callers can find their usage) if you must keep the API, and (2) guard
the shape access by checking len(shape) >= 2 before using shape[-2] (if not,
either raise a clear ValueError or fall back to make_linear_layout) so 1D
buffers don’t crash; reference make_tcgen05mma_swizzled_layout, the continuity
and k_major parameters, and the shape[-2] access when making these changes.
---
Outside diff comments:
In `@src/op/copy.cc`:
- Around line 1681-1713: The current code only sets the "use_2cta" annotation
for the unsplit tma_load branch and coerces any presence to 1, so split (loop)
loads or explicit use_2cta=0 are lost; update both branches that construct the
tma_call (the For unrolled branch that assigns tma_copy and the unsplit branch
where ann is built) to read the existing annotations map for "use_2cta" and
propagate its actual IntImm value into the Call's annotation map (for the loop
branch create and pass an ann map the same way you do in the unsplit branch),
and only set "use_2cta" to 1 when it is absent and TargetIsSm100(T.target) &&
is_load requires it, preserving any preexisting 0/1 value.
---
Duplicate comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm.py`:
- Around line 108-116: The final writeback to C is executed by every CTA leader
causing duplicate writes and races; guard the TileLang writeback sequence
(tl::fence_proxy_async(), tl::tma_store(...) calls, tl::tma_store_arrive(),
tl::tma_store_wait<0>()) so it only runs when the CTA is the cluster leader
(check block_rank_in_cluster() == 0 or cta_id == 0) and apply the same guard to
the equivalent raw CUDA writeback path; use the cluster-leader condition around
the blocks that reference C_desc and buf_dyn_shmem to ensure only the single
leader performs the stores.
In `@src/transform/lower_blackwell_2sm.cc`:
- Around line 60-63: GetTCGEN5MMAMeta is being called with node->a_->dtype and
node->c_->dtype which is wrong for TS kernels—follow the same AB-dtype rule used
in allowTcgen5Mma: determine the A/B dtypes using the same logic (i.e., for TS
kernels treat A as carrying the accumulator dtype / substitute the accumulator
dtype where allowTcgen5Mma does) and pass those resolved A/B dtypes into
GetTCGEN5MMAMeta (then preserve the existing meta.enable_2cta -> use_2cta
behavior).
- Around line 86-99: The bug annotates the first post-order (inner) BlockRealize
because VisitStmt_(const BlockRealizeNode *op) calls the base visitor before
marking root_block_annotated_; fix by annotating the owning/root block before
recursing: inspect op (the incoming BlockRealizeNode) and, if
root_block_annotated_ is false, copy-and-set the annotation (use block =
op->block; BlockNode *n = block.CopyOnWrite();
n->annotations.Set(attr::kUse2Cta, IntImm(...)); root_block_annotated_ = true)
on that block first, then call StmtExprMutator::VisitStmt_(op) (or construct the
visited BlockRealize from the visited children) so lower_shared_tmem will see
attr::kUse2Cta on the actual kernel/owning block.
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 629-631: The TS-facing generator is returning 2CTA-enabled
metadata even though tcgen05mma_ts still emits 1CTA sequences; to avoid exposing
2CTA to TS callers, change the call to _ffi_api.get_tcgen5_mma_meta inside the
TS generator so it always requests 1CTA meta. Specifically, in
tcgen05_macro_generator.py update the code path used by tcgen05mma_ts to pass
disable_2cta=True (rather than reading PassConfigKey.TL_DISABLE_2CTA_TCGEN5MMA
from pass_ctx) when calling get_tcgen5_mma_meta so TS callers always receive
1CTA metadata until TS lowering supports 2CTA.
In `@tilelang/language/tir/op.py`:
- Around line 1157-1160: The docstring and parameter order were changed when
adding the new enable_2cta parameter, so update the docstring to reflect that
the 15th positional argument is now enable_2cta (not enable_ws) and clarify
which parameters are intended to be positional vs keyword-only; to prevent
future accidental positional binding, make the optional flags keyword-only by
adding a positional-only separator (*) before enable_ws (so enable_ws,
enable_2cta, ws, and warp_specialized must be passed by keyword), and adjust the
docstring text to list and describe enable_2cta, enable_ws, ws, and
warp_specialized in their final order and note that they are keyword-only.
In `@tilelang/tileop/gemm/gemm_tcgen05.py`:
- Around line 142-157: The boolean guard uses enable_2cta which may be a boxed
TVM scalar; coerce it to a Python int before using in cluster_cond (e.g. replace
enable_2cta with int(enable_2cta)) so not enable_2cta evaluates the actual flag
value, ensuring the 1CTA path is not incorrectly disabled; update the
cluster_cond definition referenced by _gemm_ss_elect_one_thread and _gemm_ss and
ensure any similar uses (e.g. comparisons with warp_size or
thread_bounds.extent) follow the same int(...) coercion pattern.
---
Nitpick comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_2sm.py`:
- Around line 9-11: The anonymous postprocessing function registered with
register_cuda_postproc currently takes a parameter named code that is
intentionally unused; rename that parameter to _code in the function signature
(the function defined as def _(code, _):) to convey intent and silence linters,
keeping the second parameter (_) unchanged and ensuring no other references to
the old name exist in the function body.
- Around line 133-134: The FIXME comment indicates a size-dependent correctness
bug when M, N, K are large (variables M, N, K in the global scope and the 2-SM
code path triggered thereafter); remove the FIXME by creating a tracked issue
and add a runtime guard and test: add a failing unit/integration test that
reproduces the bug for large M/N/K, open an issue referencing the test and the
2-SM implementation (the same function/class that contains the 2-SM path), and
add a temporary runtime check that prevents selecting the 2-SM path for sizes
known to fail (with a clear TODO linking the issue) so behavior is deterministic
until the root cause is fixed.
In `@src/tl_templates/cuda/copy_sm100.h`:
- Around line 302-430: The five tma_load_2sm overloads duplicate barrier address
computation; extract that logic into a small inline helper (e.g.,
get_2sm_mbar_addr<BarrierType>(BarrierType &smem_mbar)) that performs the if
constexpr pointer check, calls smem_ptr_to_uint, masks with Sm100MmaPeerBitMask,
and returns the uint32_t, then replace the repeated block in each tma_load_2sm
overload with a single call to this helper to reduce duplication and improve
readability.
In `@src/transform/inject_tma_barrier.cc`:
- Around line 554-568: Compute and cache the derived cluster size instead of
re-traversing f->body: before calling PostOrderVisit look up a cached value on
the block annotations (e.g., check block->annotations for a "cluster_size"
entry) and only run the traversal to compute cluster_size from the
"cluster_dims" annotation if the cache is missing; after computing sz from the
Array<Integer> from block->annotations.Get("cluster_dims"), store sz back into
block->annotations under "cluster_size" so subsequent uses (and functions that
use PostOrderVisit/cluster_dims) can read the cached integer and skip the
expensive traversal.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4dc6f224-fe10-4fe6-8a3e-b319e5e6aa54
📒 Files selected for processing (30)
examples/gemm_sm100/gemm_tcgen5mma_2sm.pyexamples/gemm_sm100/gemm_tcgen5mma_ws.pyexamples/gemm_sm100/gemm_tcgen5mma_ws_2sm.pyexamples/gemm_sm100/gemm_tcgen5mma_ws_persistent.pysrc/op/builtin.ccsrc/op/builtin.hsrc/op/copy.ccsrc/op/gemm_py.ccsrc/op/tcgen5_meta.hsrc/target/codegen_cuda.ccsrc/target/codegen_cuda.hsrc/tl_templates/cuda/cluster.hsrc/tl_templates/cuda/common.hsrc/tl_templates/cuda/copy_sm100.hsrc/tl_templates/cuda/instruction/tcgen05mma.hsrc/tl_templates/cuda/tcgen_05.hsrc/tl_templates/cuda/threadblock_swizzle.hsrc/transform/inject_tma_barrier.ccsrc/transform/lower_blackwell_2sm.ccsrc/transform/lower_shared_tmem.cctilelang/engine/phase.pytilelang/intrinsics/tcgen05_macro_generator.pytilelang/language/annotations.pytilelang/language/builtin.pytilelang/language/cluster.pytilelang/language/tir/op.pytilelang/layout/swizzle.pytilelang/tileop/gemm/gemm_tcgen05.pytilelang/transform/__init__.pytilelang/transform/pass_config.py
🚧 Files skipped from review as they are similar to previous changes (9)
- src/op/builtin.cc
- tilelang/language/annotations.py
- src/op/gemm_py.cc
- tilelang/transform/pass_config.py
- examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py
- src/tl_templates/cuda/threadblock_swizzle.h
- src/target/codegen_cuda.cc
- src/target/codegen_cuda.h
- tilelang/engine/phase.py
| #if defined(__CUDA_ARCH__) | ||
| #define TILELANG_UNREACHABLE(msg) \ | ||
| do { \ | ||
| printf("%s, %s:%d\n", msg, __FILE__, __LINE__); \ | ||
| __trap(); \ | ||
| } while (0) | ||
| #elif defined(__CUDACC_RTC__) | ||
| #define TILELANG_UNREACHABLE(msg) \ | ||
| do { \ | ||
| __builtin_trap(); \ | ||
| } while (0) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, let's look at the file and surrounding context
cat -n src/tl_templates/cuda/common.h | head -100Repository: tile-ai/tilelang
Length of output: 4549
🏁 Script executed:
# Search for other uses of these macros in the codebase to understand the pattern
rg "__CUDACC_RTC__|__CUDA_ARCH__" --type h --type cpp -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 7481
🏁 Script executed:
# Check if there are any comments or documentation about this macro
rg "TILELANG_UNREACHABLE|__builtin_trap|printf.*trap" --type h --type cpp -B 3 -A 3Repository: tile-ai/tilelang
Length of output: 5186
🌐 Web query:
NVRTC __CUDACC_RTC__ __CUDA_ARCH__ macro definitions CUDA runtime compilation
💡 Result:
In NVRTC (CUDA runtime compilation), these macros behave essentially the same as in offline nvcc compilation, with a couple of NVRTC-specific points:
__CUDACC_RTC__: Defined by NVRTC (useful to distinguish NVRTC from offlinenvccin#iflogic). (docs.nvidia.com)__CUDA_ARCH__: Defined during NVRTC device compilation with the same semantics asnvcc. Its value is tied to the compile target you pass via NVRTC’s--gpu-architecture/-archoption. (docs.nvidia.com)- Mapping: if you compile for
compute_<version>, then__CUDA_ARCH__ = <version> * 10(e.g.,compute_80→800). (docs.nvidia.cn)
- Mapping: if you compile for
Practical NVRTC example:
- Compile with
-arch=compute_80(or--gpu-architecture=compute_80) ⇒ device code sees__CUDA_ARCH__ == 800. (docs.nvidia.com)
Citations:
- 1: https://docs.nvidia.com/cuda/archive/12.9.0/nvrtc/index.html
- 2: https://docs.nvidia.com/cuda/archive/12.9.0/nvrtc/index.html
- 3: https://docs.nvidia.cn/cuda/cuda-programming-guide/05-appendices/cpp-language-extensions.html?utm_source=openai
- 4: https://docs.nvidia.com/cuda/archive/11.5.0/pdf/NVRTC_User_Guide.pdf
Check __CUDACC_RTC__ before __CUDA_ARCH__.
NVRTC defines both macros during device compilation. The current ordering causes the __CUDA_ARCH__ branch to match first, making the __CUDACC_RTC__ branch unreachable. This routes NVRTC builds through the printf path, but <cstdio> is explicitly excluded when __CUDACC_RTC__ is defined (see line 3), causing compilation issues. The RTC-specific __builtin_trap() path must be checked first.
🛠️ Suggested fix
-#if defined(__CUDA_ARCH__)
-#define TILELANG_UNREACHABLE(msg) \
- do { \
- printf("%s, %s:%d\n", msg, __FILE__, __LINE__); \
- __trap(); \
- } while (0)
-#elif defined(__CUDACC_RTC__)
+#if defined(__CUDACC_RTC__)
`#define` TILELANG_UNREACHABLE(msg) \
do { \
__builtin_trap(); \
} while (0)
+#elif defined(__CUDA_ARCH__)
+#define TILELANG_UNREACHABLE(msg) \
+ do { \
+ printf("%s, %s:%d\n", msg, __FILE__, __LINE__); \
+ __trap(); \
+ } while (0)
`#else`🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/tl_templates/cuda/common.h` around lines 62 - 72, The
TILELANG_UNREACHABLE macro currently checks __CUDA_ARCH__ before __CUDACC_RTC__,
so NVRTC builds (which define both) take the printf/__trap branch and fail
because <cstdio> is excluded; reorder the preprocessor checks in
src/tl_templates/cuda/common.h so the __CUDACC_RTC__ check comes first and
defines TILELANG_UNREACHABLE to use __builtin_trap(), leaving the __CUDA_ARCH__
branch after it to retain the printf + __trap() behavior for real device builds.
| Target target_; | ||
| bool has_2sm_tcgen5mma_ = false; |
There was a problem hiding this comment.
A single function-wide has_2sm_tcgen5mma_ flag is too coarse.
Once any GEMM qualifies, the pass annotates the root block and pushes every downstream TMEM alloc/dealloc onto the 2CTA path. If a PrimFunc mixes 1CTA and 2CTA TCGEN5MMA, the 1CTA path is silently mis-lowered instead of being rejected.
Also applies to: 121-124
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/transform/lower_blackwell_2sm.cc` around lines 77 - 78, The single field
has_2sm_tcgen5mma_ is too coarse and causes cross-contamination between blocks;
replace it with a scope-local flag (e.g., a bool parameter or a per-block keyed
map) that is set and checked per block/PrimFunc during traversal instead of a
single class-wide member (references: has_2sm_tcgen5mma_, Target target_);
propagate this local flag through the helper methods that perform TMEM
alloc/dealloc lowering so 1CTA vs 2CTA decisions are made per-block, and add an
explicit detection path that rejects/makes an error when a single PrimFunc mixes
1CTA and 2CTA TCGEN5MMA (also update the logic around the occurrences noted at
lines 121-124 to use the new per-block/state flag).
| def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): | ||
| _, shape, _ = _get_buffer_info(buffer) | ||
| stride, continuous = _get_stride_continuous(buffer) | ||
| buf, shape, _ = _get_buffer_info(buffer) | ||
| # Use only the last two dims, matching C++ GetSwizzleShapeInfoChecked (uses shape[-2] as stride). | ||
| # Use buffer-based functions (make_full/half/quarter_bank_swizzled_layout) which call | ||
| # base.Expand(leading_dims) instead of reshape, correctly handling staged (3D+) buffers. | ||
| continuous = int(shape[-1]) | ||
| element_size = _get_element_size(buffer) | ||
| if continuity is None: | ||
| continuity = continuous | ||
| base = _ffi_api.make_tcgen05mma_swizzled_layout( | ||
| stride, | ||
| continuous, | ||
| continuity, | ||
| element_size, | ||
| k_major, | ||
| ) | ||
| return base.reshape(shape) | ||
| vector_size = 128 // element_size | ||
| stride = int(shape[-2]) | ||
| if stride % 8 == 0: | ||
| if continuous % (vector_size * 8) == 0: | ||
| return _ffi_api.make_full_bank_swizzled_layout(buf) | ||
| elif continuous % (vector_size * 4) == 0: | ||
| return _ffi_api.make_half_bank_swizzled_layout(buf) | ||
| elif continuous % (vector_size * 2) == 0: | ||
| return _ffi_api.make_quarter_bank_swizzled_layout(buf) | ||
| return make_linear_layout(buffer) |
There was a problem hiding this comment.
Unused continuity and k_major parameters will silently ignore caller intent.
The function signature retains continuity and k_major parameters, but the new implementation never uses them. Callers passing these arguments (e.g., make_tcgen05mma_swizzled_layout(buf, continuity=32, k_major=False)) will have their intent silently ignored.
Options:
- Remove the unused parameters if they are no longer needed.
- Emit a deprecation warning if the API must remain stable for now.
- Incorporate them into the layout selection logic if they should affect behavior.
Additionally, shape[-2] on line 119 will raise IndexError for 1D buffers. Consider adding a guard or documenting that TCGEN05MMA requires ≥2D buffers.
Suggested fix (option 2 – deprecation warning)
def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True):
buf, shape, _ = _get_buffer_info(buffer)
+ if continuity is not None or k_major is not True:
+ import warnings
+ warnings.warn(
+ "Parameters 'continuity' and 'k_major' are deprecated and ignored in make_tcgen05mma_swizzled_layout",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if len(shape) < 2:
+ return make_linear_layout(buffer)
# Use only the last two dims, matching C++ GetSwizzleShapeInfoChecked (uses shape[-2] as stride).📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): | |
| _, shape, _ = _get_buffer_info(buffer) | |
| stride, continuous = _get_stride_continuous(buffer) | |
| buf, shape, _ = _get_buffer_info(buffer) | |
| # Use only the last two dims, matching C++ GetSwizzleShapeInfoChecked (uses shape[-2] as stride). | |
| # Use buffer-based functions (make_full/half/quarter_bank_swizzled_layout) which call | |
| # base.Expand(leading_dims) instead of reshape, correctly handling staged (3D+) buffers. | |
| continuous = int(shape[-1]) | |
| element_size = _get_element_size(buffer) | |
| if continuity is None: | |
| continuity = continuous | |
| base = _ffi_api.make_tcgen05mma_swizzled_layout( | |
| stride, | |
| continuous, | |
| continuity, | |
| element_size, | |
| k_major, | |
| ) | |
| return base.reshape(shape) | |
| vector_size = 128 // element_size | |
| stride = int(shape[-2]) | |
| if stride % 8 == 0: | |
| if continuous % (vector_size * 8) == 0: | |
| return _ffi_api.make_full_bank_swizzled_layout(buf) | |
| elif continuous % (vector_size * 4) == 0: | |
| return _ffi_api.make_half_bank_swizzled_layout(buf) | |
| elif continuous % (vector_size * 2) == 0: | |
| return _ffi_api.make_quarter_bank_swizzled_layout(buf) | |
| return make_linear_layout(buffer) | |
| def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): | |
| buf, shape, _ = _get_buffer_info(buffer) | |
| if continuity is not None or k_major is not True: | |
| import warnings | |
| warnings.warn( | |
| "Parameters 'continuity' and 'k_major' are deprecated and ignored in make_tcgen05mma_swizzled_layout", | |
| DeprecationWarning, | |
| stacklevel=2, | |
| ) | |
| if len(shape) < 2: | |
| return make_linear_layout(buffer) | |
| # Use only the last two dims, matching C++ GetSwizzleShapeInfoChecked (uses shape[-2] as stride). | |
| # Use buffer-based functions (make_full/half/quarter_bank_swizzled_layout) which call | |
| # base.Expand(leading_dims) instead of reshape, correctly handling staged (3D+) buffers. | |
| continuous = int(shape[-1]) | |
| element_size = _get_element_size(buffer) | |
| vector_size = 128 // element_size | |
| stride = int(shape[-2]) | |
| if stride % 8 == 0: | |
| if continuous % (vector_size * 8) == 0: | |
| return _ffi_api.make_full_bank_swizzled_layout(buf) | |
| elif continuous % (vector_size * 4) == 0: | |
| return _ffi_api.make_half_bank_swizzled_layout(buf) | |
| elif continuous % (vector_size * 2) == 0: | |
| return _ffi_api.make_quarter_bank_swizzled_layout(buf) | |
| return make_linear_layout(buffer) |
🧰 Tools
🪛 Ruff (0.15.4)
[warning] 111-111: Unused function argument: continuity
(ARG001)
[warning] 111-111: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
[warning] 111-111: Unused function argument: k_major
(ARG001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/layout/swizzle.py` around lines 111 - 127, The function
make_tcgen05mma_swizzled_layout currently ignores the continuity and k_major
parameters and will IndexError on 1D buffers via shape[-2]; update it to (1)
emit a deprecation/warning when continuity or k_major are passed (use the
continuity and k_major parameter names in the warning so callers can find their
usage) if you must keep the API, and (2) guard the shape access by checking
len(shape) >= 2 before using shape[-2] (if not, either raise a clear ValueError
or fall back to make_linear_layout) so 1D buffers don’t crash; reference
make_tcgen05mma_swizzled_layout, the continuity and k_major parameters, and the
shape[-2] access when making these changes.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.py (2)
36-36: Consider documenting whygroup_size = 5differs from the sibling file.The persistent 1-SM version in
gemm_tcgen5mma_ws_persistent.pyusesgroup_size = 8, while this 2-SM version usesgroup_size = 5. If this is intentional for optimal cluster scheduling, a brief comment explaining the rationale would help maintainability.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.py` at line 36, Add a brief explanatory comment next to the group_size = 5 assignment explaining why it differs from the sibling gemm_tcgen5mma_ws_persistent.py (which uses group_size = 8); mention whether the value is chosen for 2-SM cluster scheduling, occupancy, or empirical tuning and include the rationale (e.g., optimal scheduling for 2 SMs vs 1 SM) so future maintainers understand the divergence; update the comment in the gemm_tcgen5mma_ws_2sm_persistent.py near the group_size variable to state this decision.
55-75: Duplicated tile ID calculation across all three warp sections.The tile ID and block index calculations (lines 57-64, 79-86, 117-124) are nearly identical across warp 0, warp 1, and epilogue sections. While this is acceptable for an example file, extracting this to a helper or using shared variables could improve maintainability.
Also applies to: 77-86, 115-124
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.py` around lines 55 - 75, The tile and block index math (cluster_id, num_clusters, tile_id, m_clusters, bx_cluster, by, bx) is duplicated across the warp 0 / warp 1 / epilogue sections; refactor by extracting a small helper (e.g., compute_tile_indices(block_id, sm_num, w, group_size, m_blocks, cta_id, block_M, block_N) or compute_tile_indices_for_warp) that returns tile_id, bx, by (and any derived values), and call that helper from each warp branch (or compute once and reuse when w loop is shared) so the logic in the loops using tx, T.copy, and T.mbarrier_* reuses the single implementation instead of repeating the calculations.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.py`:
- Line 1: The top-of-file comment incorrectly says "1-SM" while the
implementation uses cluster_dims=2 (a 2-SM kernel); update the header comment to
reflect "2-SM" (or "Persistent, 2-SM, num_epi_stages = 2") so it matches the
implementation and remove any other stray references to 1-SM; search for
cluster_dims and the persistent kernel setup (e.g., where cluster_dims is set)
to confirm consistency and update the comment accordingly.
In `@src/transform/inject_tma_barrier.cc`:
- Around line 557-571: The current PostOrderVisit loop silently overwrites
cluster_size when multiple BlockNode annotations "cluster_dims" exist; change it
to record the first discovered size (e.g., use an optional<int> found_size or a
bool seen flag), compute sz the same way when encountering a BlockNode with
"cluster_dims", and then either (a) short-circuit further processing (stop
visiting) or (b) assert that subsequent matches produce the same sz (fail on
mismatch). Reference PostOrderVisit, BlockNode, the "cluster_dims" annotation,
cluster_size, and f->body when making this change so the loop either breaks
after the first match or verifies equality before assigning cluster_size.
- Around line 351-358: TmaSequenceCollector currently leaves var_int_set_ empty
so EvalSet may evaluate barrier ID expressions without loop variable bounds;
implement VisitStmt_(const ForNode *op) in TmaSequenceCollector to mirror the
loop-bound population logic used in related collectors (capture loop var ->
arith::IntSet via arith::Analyzer or existing helper methods) so var_int_set_ is
populated before EvalSet runs, and if barrier IDs are guaranteed loop-invariant
then add an explicit assertion/comment in TmaSequenceCollector to document that
assumption instead of leaving var_int_set_ empty.
---
Nitpick comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.py`:
- Line 36: Add a brief explanatory comment next to the group_size = 5 assignment
explaining why it differs from the sibling gemm_tcgen5mma_ws_persistent.py
(which uses group_size = 8); mention whether the value is chosen for 2-SM
cluster scheduling, occupancy, or empirical tuning and include the rationale
(e.g., optimal scheduling for 2 SMs vs 1 SM) so future maintainers understand
the divergence; update the comment in the gemm_tcgen5mma_ws_2sm_persistent.py
near the group_size variable to state this decision.
- Around line 55-75: The tile and block index math (cluster_id, num_clusters,
tile_id, m_clusters, bx_cluster, by, bx) is duplicated across the warp 0 / warp
1 / epilogue sections; refactor by extracting a small helper (e.g.,
compute_tile_indices(block_id, sm_num, w, group_size, m_blocks, cta_id, block_M,
block_N) or compute_tile_indices_for_warp) that returns tile_id, bx, by (and any
derived values), and call that helper from each warp branch (or compute once and
reuse when w loop is shared) so the logic in the loops using tx, T.copy, and
T.mbarrier_* reuses the single implementation instead of repeating the
calculations.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f9e7f29e-3473-435b-90da-66ddc1f86998
📒 Files selected for processing (2)
examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.pysrc/transform/inject_tma_barrier.cc
| @@ -0,0 +1,172 @@ | |||
| # Persistent, 1-SM, num_epi_stages = 2 | |||
There was a problem hiding this comment.
Comment incorrectly describes the implementation as "1-SM".
The comment says "Persistent, 1-SM" but this file implements a 2-SM kernel (cluster_dims=2 at line 38). This should be corrected to avoid confusion.
📝 Suggested fix
-# Persistent, 1-SM, num_epi_stages = 2
+# Persistent, 2-SM, num_epi_stages = 2📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Persistent, 1-SM, num_epi_stages = 2 | |
| # Persistent, 2-SM, num_epi_stages = 2 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_2sm_persistent.py` at line 1, The
top-of-file comment incorrectly says "1-SM" while the implementation uses
cluster_dims=2 (a 2-SM kernel); update the header comment to reflect "2-SM" (or
"Persistent, 2-SM, num_epi_stages = 2") so it matches the implementation and
remove any other stray references to 1-SM; search for cluster_dims and the
persistent kernel setup (e.g., where cluster_dims is set) to confirm consistency
and update the comment accordingly.
| std::vector<bool> is_cluster_; | ||
| std::vector<arith::IntSet> int_sets_; | ||
| std::vector<int> restore_barrier_ids_; | ||
| bool has_simt_copy_{false}; | ||
| int if_depth_{0}; | ||
| Map<ObjectRef, PrimExpr> tma_op_to_barrier_id_; | ||
| Map<Var, arith::IntSet> var_int_set_; | ||
| }; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Check if var_int_set_ is populated anywhere in TmaSequenceCollector
rg -nA10 'class TmaSequenceCollector' src/transform/inject_tma_barrier.cc | head -80
# Also check for any VisitStmt_ for ForNode in the class
ast-grep --pattern $'class TmaSequenceCollector {
$$$
VisitStmt_($_, ForNode $_) {
$$$
}
$$$
}'Repository: tile-ai/tilelang
Length of output: 505
🏁 Script executed:
# Read the TmaSequenceCollector class definition (lines ~267 to ~360)
sed -n '267,360p' src/transform/inject_tma_barrier.cc | cat -nRepository: tile-ai/tilelang
Length of output: 3986
🏁 Script executed:
# Search for all references to var_int_set_ in the file
rg 'var_int_set_' src/transform/inject_tma_barrier.cc -nRepository: tile-ai/tilelang
Length of output: 321
🏁 Script executed:
# Look at line 320 to see the EvalSet call mentioned in the review
sed -n '315,325p' src/transform/inject_tma_barrier.cc | cat -nRepository: tile-ai/tilelang
Length of output: 597
🏁 Script executed:
# Search for VisitStmt_ implementations that might populate var_int_set_
rg 'VisitStmt_' src/transform/inject_tma_barrier.cc -nRepository: tile-ai/tilelang
Length of output: 1378
var_int_set_ is never populated in TmaSequenceCollector, causing EvalSet to evaluate barrier ID expressions without loop variable bounds.
The var_int_set_ member is declared but has no mechanism to be populated in TmaSequenceCollector. Unlike related classes in the same file which implement VisitStmt_(const ForNode *op) to populate var_int_set_ with loop variable bounds, TmaSequenceCollector lacks this visitor. As a result, when EvalSet is called at line 320 with an empty var_int_set_, loop-dependent barrier IDs may not be evaluated correctly.
Verify whether barrier IDs in this context are guaranteed to be loop-invariant. If not, TmaSequenceCollector should implement VisitStmt_(const ForNode *op) to populate loop variable bounds.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/transform/inject_tma_barrier.cc` around lines 351 - 358,
TmaSequenceCollector currently leaves var_int_set_ empty so EvalSet may evaluate
barrier ID expressions without loop variable bounds; implement VisitStmt_(const
ForNode *op) in TmaSequenceCollector to mirror the loop-bound population logic
used in related collectors (capture loop var -> arith::IntSet via
arith::Analyzer or existing helper methods) so var_int_set_ is populated before
EvalSet runs, and if barrier IDs are guaranteed loop-invariant then add an
explicit assertion/comment in TmaSequenceCollector to document that assumption
instead of leaving var_int_set_ empty.
| // Compute total cluster size from the "cluster_dims" block annotation | ||
| int cluster_size = 1; | ||
| PostOrderVisit(f->body, [&](const ObjectRef &node) { | ||
| if (const auto *block = node.as<BlockNode>()) { | ||
| if (block->annotations.count("cluster_dims")) { | ||
| if (auto arr = block->annotations.Get("cluster_dims") | ||
| ->try_cast<Array<Integer>>()) { | ||
| int sz = 1; | ||
| for (auto d : arr.value()) | ||
| sz *= static_cast<int>(d->value); | ||
| cluster_size = sz; | ||
| } | ||
| } | ||
| } | ||
| }); |
There was a problem hiding this comment.
Last-writer-wins for cluster_dims if multiple blocks define it.
If multiple BlockNodes carry different cluster_dims annotations, this loop silently keeps the value from the last block visited (in post-order). Consider breaking out of the visitor after the first match, or asserting that all blocks share the same value.
Proposed fix
int cluster_size = 1;
+ bool cluster_size_set = false;
PostOrderVisit(f->body, [&](const ObjectRef &node) {
if (const auto *block = node.as<BlockNode>()) {
if (block->annotations.count("cluster_dims")) {
if (auto arr = block->annotations.Get("cluster_dims")
->try_cast<Array<Integer>>()) {
int sz = 1;
for (auto d : arr.value())
sz *= static_cast<int>(d->value);
+ if (cluster_size_set) {
+ ICHECK_EQ(cluster_size, sz)
+ << "Conflicting cluster_dims across blocks";
+ }
cluster_size = sz;
+ cluster_size_set = true;
}
}
}
});📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| // Compute total cluster size from the "cluster_dims" block annotation | |
| int cluster_size = 1; | |
| PostOrderVisit(f->body, [&](const ObjectRef &node) { | |
| if (const auto *block = node.as<BlockNode>()) { | |
| if (block->annotations.count("cluster_dims")) { | |
| if (auto arr = block->annotations.Get("cluster_dims") | |
| ->try_cast<Array<Integer>>()) { | |
| int sz = 1; | |
| for (auto d : arr.value()) | |
| sz *= static_cast<int>(d->value); | |
| cluster_size = sz; | |
| } | |
| } | |
| } | |
| }); | |
| // Compute total cluster size from the "cluster_dims" block annotation | |
| int cluster_size = 1; | |
| bool cluster_size_set = false; | |
| PostOrderVisit(f->body, [&](const ObjectRef &node) { | |
| if (const auto *block = node.as<BlockNode>()) { | |
| if (block->annotations.count("cluster_dims")) { | |
| if (auto arr = block->annotations.Get("cluster_dims") | |
| ->try_cast<Array<Integer>>()) { | |
| int sz = 1; | |
| for (auto d : arr.value()) | |
| sz *= static_cast<int>(d->value); | |
| if (cluster_size_set) { | |
| ICHECK_EQ(cluster_size, sz) | |
| << "Conflicting cluster_dims across blocks"; | |
| } | |
| cluster_size = sz; | |
| cluster_size_set = true; | |
| } | |
| } | |
| } | |
| }); |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/transform/inject_tma_barrier.cc` around lines 557 - 571, The current
PostOrderVisit loop silently overwrites cluster_size when multiple BlockNode
annotations "cluster_dims" exist; change it to record the first discovered size
(e.g., use an optional<int> found_size or a bool seen flag), compute sz the same
way when encountering a BlockNode with "cluster_dims", and then either (a)
short-circuit further processing (stop visiting) or (b) assert that subsequent
matches produce the same sz (fail on mismatch). Reference PostOrderVisit,
BlockNode, the "cluster_dims" annotation, cluster_size, and f->body when making
this change so the loop either breaks after the first match or verifies equality
before assigning cluster_size.
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (2)
src/transform/lower_blackwell_2sm.cc (2)
77-78:⚠️ Potential issue | 🟠 MajorA single function-wide flag doesn't handle mixed 1CTA/2CTA scenarios.
The
has_2sm_tcgen5mma_flag is set when any GEMM qualifies for 2CTA. If a kernel mixes 1CTA and 2CTA TCGEN5MMA operations, the downstream TMEM alloc/dealloc could be mis-lowered. The TODO on line 13 acknowledges this limitation.Consider either:
- Rejecting mixed 1CTA/2CTA kernels with an error, or
- Tracking per-GEMM state to ensure all GEMMs in the kernel agree on 2CTA mode
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_blackwell_2sm.cc` around lines 77 - 78, The single function-wide flag has_2sm_tcgen5mma_ in lower_blackwell_2sm.cc is unsafe for kernels that mix 1CTA and 2CTA GEMMs because downstream TMEM alloc/dealloc lowering will be incorrect; change the implementation to either (A) detect mixed-mode kernels and fail fast with a clear error when you encounter conflicting GEMM modes, or (B) replace the global flag with per-GEMM bookkeeping so each GEMM's 2CTA/1CTA choice is recorded and consulted during TMEM allocation/deallocation lowering; locate references to has_2sm_tcgen5mma_, the kernel walk where GEMMs are inspected, and the TMEM alloc/dealloc lowering code and update them to use the per-GEMM state (or add the early-reject logic) so mixed-mode kernels are not mis-lowered.
60-62:⚠️ Potential issue | 🟠 MajorUse the same AB dtype rule as
allowTcgen5Mma().For TS kernels where A is in
shared.tmem, the A buffer carries the accumulator dtype, not the input dtype. The current code passesnode->a_->dtypeunconditionally, butallowTcgen5Mma()ingemm_py.ccuses(a_.scope() == "shared.tmem") ? b_->dtype : a_->dtype. This mismatch will cause the detector to miss valid 2CTA TS GEMMs.🔧 Proposed fix
if (node->allowTcgen5Mma(target_)) { + // For TS variant (A from TMEM), use B's dtype as the input dtype + DataType ab_dtype = (node->a_.scope() == "shared.tmem") + ? node->b_->dtype + : node->a_->dtype; auto [ok, meta] = GetTCGEN5MMAMeta(node->m_, node->n_, node->k_, - node->a_->dtype, node->c_->dtype); + ab_dtype, node->c_->dtype);🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_blackwell_2sm.cc` around lines 60 - 62, The call to GetTCGEN5MMAMeta currently passes node->a_->dtype unconditionally; change it to follow the same AB dtype rule used by allowTcgen5Mma(): compute dtype_for_a = (node->a_->scope() == "shared.tmem") ? node->b_->dtype : node->a_->dtype and pass dtype_for_a instead of node->a_->dtype so A-in-shared TS kernels use the accumulator dtype when querying GetTCGEN5MMAMeta.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py`:
- Around line 151-159: The code assumes M tiles are paired but doesn't enforce
that m_blocks is an even, non-zero number; update gemm_persistent_2cta to
validate this by asserting m_blocks > 1 and m_blocks % 2 == 0 (or otherwise
ensure m_clusters is computed from an even m_blocks) so the pairing logic that
computes m_clusters and the downstream pairing use (the code around where
m_blocks -> m_clusters and the pairing at the locations referenced originally)
cannot produce an invalid modulo or drop a tail tile. Add that assertion
immediately after m_blocks is computed and before m_clusters is used.
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 84-86: The CTA-local barrier tmem_full is used for multi-CTA
arrival (T.tcgen05_mma_arrive(..., arrive_2cta=True)) and must be
cluster-visible: change tmem_full allocation from T.alloc_barrier([1]) to a
cluster barrier (e.g., T.alloc_cluster_barrier([1] * num_stages) or a single
cluster-visible barrier) so multiple CTAs can wait on it; also restore the TMEM
read synchronization by inserting a T.sync_threads() just before the TMEM
read/copy that follows the wait so the 2-CTA path matches the 1-CTA and
persistent 2-CTA sync behavior (references: loaded, consumed, tmem_full, and
T.tcgen05_mma_arrive).
---
Duplicate comments:
In `@src/transform/lower_blackwell_2sm.cc`:
- Around line 77-78: The single function-wide flag has_2sm_tcgen5mma_ in
lower_blackwell_2sm.cc is unsafe for kernels that mix 1CTA and 2CTA GEMMs
because downstream TMEM alloc/dealloc lowering will be incorrect; change the
implementation to either (A) detect mixed-mode kernels and fail fast with a
clear error when you encounter conflicting GEMM modes, or (B) replace the global
flag with per-GEMM bookkeeping so each GEMM's 2CTA/1CTA choice is recorded and
consulted during TMEM allocation/deallocation lowering; locate references to
has_2sm_tcgen5mma_, the kernel walk where GEMMs are inspected, and the TMEM
alloc/dealloc lowering code and update them to use the per-GEMM state (or add
the early-reject logic) so mixed-mode kernels are not mis-lowered.
- Around line 60-62: The call to GetTCGEN5MMAMeta currently passes
node->a_->dtype unconditionally; change it to follow the same AB dtype rule used
by allowTcgen5Mma(): compute dtype_for_a = (node->a_->scope() == "shared.tmem")
? node->b_->dtype : node->a_->dtype and pass dtype_for_a instead of
node->a_->dtype so A-in-shared TS kernels use the accumulator dtype when
querying GetTCGEN5MMAMeta.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 87476320-4ad0-48a0-8182-67f992cc4474
📒 Files selected for processing (3)
examples/gemm_sm100/gemm_tcgen5mma_ws.pyexamples/gemm_sm100/gemm_tcgen5mma_ws_persistent.pysrc/transform/lower_blackwell_2sm.cc
| num_clusters = sm_num // 2 | ||
| m_blocks = T.ceildiv(M, block_M) | ||
| m_clusters = m_blocks // 2 | ||
| n_blocks = T.ceildiv(N, block_N) | ||
| assert K % (2 * block_K) == 0 # for simplicity | ||
| k_blocks = T.ceildiv(K, block_K) | ||
| waves = T.ceildiv(m_blocks * n_blocks, sm_num) | ||
| group_size = 8 # in cluster | ||
| assert N % (2 * group_size) == 0 # Please adjust group_size if not satisfied |
There was a problem hiding this comment.
Require an even, non-zero number of clustered M tiles in gemm_persistent_2cta.
Line 153 collapses m_blocks into pairs, but Lines 205-206 assume every M tile belongs to one of those pairs. If m_blocks == 1, % m_clusters is invalid; if m_blocks is odd, the tail M tile is silently dropped. The new Line 159 assert checks N, so it does not protect this path.
Suggested fix
m_blocks = T.ceildiv(M, block_M)
+ assert m_blocks >= 2 and m_blocks % 2 == 0, \
+ "gemm_persistent_2cta requires an even number of M tiles"
m_clusters = m_blocks // 2Also applies to: 205-206
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py` around lines 151 - 159,
The code assumes M tiles are paired but doesn't enforce that m_blocks is an
even, non-zero number; update gemm_persistent_2cta to validate this by asserting
m_blocks > 1 and m_blocks % 2 == 0 (or otherwise ensure m_clusters is computed
from an even m_blocks) so the pairing logic that computes m_clusters and the
downstream pairing use (the code around where m_blocks -> m_clusters and the
pairing at the locations referenced originally) cannot produce an invalid modulo
or drop a tail tile. Add that assertion immediately after m_blocks is computed
and before m_clusters is used.
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (16)
src/tl_templates/cuda/common.h (1)
62-79:⚠️ Potential issue | 🟠 MajorCheck
__CUDACC_RTC__before__CUDA_ARCH__.Because Line 62 matches
__CUDA_ARCH__first, the#elif defined(__CUDACC_RTC__)on Line 68 is unreachable under NVRTC. NVIDIA documents__CUDACC_RTC__as the runtime-compilation discriminator and says__CUDA_ARCH__is also defined in NVRTC with the same semantics as offline device compilation, so RTC device builds can hit both. Here that routes RTC through theprintfbranch even though<cstdio>is excluded by the guard at Line 3, which can break RTC compilation. (docs.nvidia.com)🛠️ Suggested fix
-#if defined(__CUDA_ARCH__) -#define TILELANG_UNREACHABLE(msg) \ - do { \ - printf("%s, %s:%d\n", msg, __FILE__, __LINE__); \ - __trap(); \ - } while (0) -#elif defined(__CUDACC_RTC__) +#if defined(__CUDACC_RTC__) `#define` TILELANG_UNREACHABLE(msg) \ do { \ __builtin_trap(); \ } while (0) +#elif defined(__CUDA_ARCH__) +#define TILELANG_UNREACHABLE(msg) \ + do { \ + printf("%s, %s:%d\n", msg, __FILE__, __LINE__); \ + __trap(); \ + } while (0)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/common.h` around lines 62 - 79, The preprocessor ordering causes NVRTC builds to match __CUDA_ARCH__ first and use the printf/__trap branch, breaking RTC because <cstdio> is excluded; change the macro checks so __CUDACC_RTC__ is tested before __CUDA_ARCH__ (i.e., reorder the `#if/`#elif so the __CUDACC_RTC__ branch with __builtin_trap() is selected for runtime compilation), keeping the TILELANG_UNREACHABLE macro and behavior otherwise identical.src/op/gemm_py.cc (1)
323-326:⚠️ Potential issue | 🟠 Major
disable_2ctais only wired through the reflection helper.
tl.get_tcgen5_mma_meta()now filters withdisable_2cta, butGemmPyNode::allowTcgen5Mma()still callsGetTCGEN5MMAMeta(...)without that flag. That lets C++ instruction selection pick TCGEN5MMA even when lowering has disabled the 2CTA-only shapes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/op/gemm_py.cc` around lines 323 - 326, GemmPyNode::allowTcgen5Mma currently calls GetTCGEN5MMAMeta without forwarding the disable_2cta flag, so instruction selection can still pick TCGEN5MMA even when 2CTA-only shapes are disabled; update GemmPyNode::allowTcgen5Mma to accept or retrieve the disable_2cta boolean and pass it into the GetTCGEN5MMAMeta(...) call (matching the usage in the reflected helper tl.get_tcgen5_mma_meta) so the metadata check respects the disable_2cta filtering.src/transform/inject_tma_barrier.cc (3)
317-320:⚠️ Potential issue | 🟠 Major
EvalSetstill runs with an empty loop-domain map.
TmaSequenceCollectoraddedvar_int_set_, but it still has noVisitStmt_(const ForNode*)to populate it. Any loop-dependent barrier id here will be evaluated without loop bounds, andGetSequence()immediately assumes the resulting min/max areIntImm.Also applies to: 357-357
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 317 - 320, EvalSet is being called with an empty loop-domain map because TmaSequenceCollector never records loop iteration ranges; implement VisitStmt_(const ForNode*) in TmaSequenceCollector to populate var_int_set_ with the loop variable's IntSet (use arith::IntSet::FromRange or construct from for->min and for->extent via arith::EvalSet/arith::IntervalSet), so that when barrier_id is evaluated (in the code paths using barrier_id.as<CallNode>()->args[0] and calling arith::EvalSet(e, var_int_set_)) the map contains proper bounds; ensure the same fix is applied for the other EvalSet usage referenced (the other occurrence around the barrier sequence logic) and that GetSequence() no longer assumes IntImm without checking the resulting IntSet/Interval.
226-231:⚠️ Potential issue | 🟠 MajorNon-constant cluster leader IDs are silently dropped.
If
call->args[1]is not anIntImmNode, the collector never records a leader CTA for that barrier. The rewriter then skips the cluster-specificexpect_tx/use_2ctapath with no diagnostic.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 226 - 231, When handling tl::ptx_arrive_cluster_barrier() the code only records leader CTAs when call->args[1] is an IntImmNode, silently dropping non-constant leaders; change the branch so that if call->args[1] is not an IntImmNode you still record a marker (e.g., cluster_barrier_cta_ids_.Set(barrier_id, IntImm(DataType::Int(32), -1)) or store the Expr itself) and emit a warning/log so the rewriter can detect a non-constant leader and take the expect_tx/use_2cta path accordingly; reference call->args[1], IntImmNode, cluster_barrier_cta_ids_.Set, barrier_id, and tl::ptx_arrive_cluster_barrier() when making this change.
557-571:⚠️ Potential issue | 🟡 Minor
cluster_dimsextraction is still last-writer-wins.If multiple blocks carry
cluster_dims, each later match overwritescluster_size. That silently changes the byte scaling for cluster barriers instead of enforcing consistency.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/inject_tma_barrier.cc` around lines 557 - 571, The current PostOrderVisit loop overwrites cluster_size on each block with a "cluster_dims" annotation (last-writer-wins); change it to detect and enforce consistency: when you encounter a block->annotations.Get("cluster_dims") inside the PostOrderVisit lambda, parse its Array<Integer> into a computed sz and if cluster_size is still the initial sentinel (e.g., 0 or use an optional) set cluster_size = sz, otherwise compare sz to the already-set cluster_size and signal an error/assertion/logging if they differ (do not silently overwrite). Update the code around cluster_size, the PostOrderVisit lambda, and the handling of block->annotations/Get("cluster_dims") to implement this consistency check.src/transform/lower_blackwell_2sm.cc (3)
77-78:⚠️ Potential issue | 🟠 MajorA single function-wide flag still mis-lowers mixed 1CTA/2CTA kernels.
Once any GEMM enables 2CTA, the pass annotates the root block and every TMEM alloc/dealloc follows the 2CTA path. A PrimFunc that mixes 1CTA and 2CTA TCGEN5MMA needs per-block tracking or an explicit rejection.
Also applies to: 121-124
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_blackwell_2sm.cc` around lines 77 - 78, The pass currently uses the function-wide boolean has_2sm_tcgen5mma_ causing mixed 1CTA/2CTA kernels to be mis-lowered; change the logic to track the 2CTA requirement per-block instead of a single flag (or explicitly reject PrimFunc with mixed 1CTA/2CTA). Update the code paths that annotate the root block and that decide TMEM alloc/dealloc to consult this per-block marker (replace uses of has_2sm_tcgen5mma_ with a block-scoped lookup keyed by the Block/Stmt node) so each block follows its own 1CTA/2CTA lowering, and add a validation step in the pass entry to error out if you choose the explicit-reject approach.
86-99:⚠️ Potential issue | 🟠 MajorThis still annotates the first inner
BlockRealize.Because the mutator recurses before setting
root_block_annotated_, the first post-orderBlockRealizewins. If the TMEM allocs live in the outer block,lower_shared_tmem.ccwill missuse_2ctaon the owning block.
60-63:⚠️ Potential issue | 🟠 MajorUse the same AB dtype rule as
allowTcgen5Mma().For TS kernels,
Ais TMEM and carries accumulator dtype, so queryingGetTCGEN5MMAMeta(...)withnode->a_->dtypecan miss valid 2CTA shapes and skip theuse_2ctaannotation. Reuse the same AB-dtype selection asallowTcgen5Mma()here.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/transform/lower_blackwell_2sm.cc` around lines 60 - 63, The code queries GetTCGEN5MMAMeta using node->a_->dtype which can be wrong for TS kernels; change the dtype selection to match the same A/B-dtype rule used in allowTcgen5Mma(): compute the effective A and B dtypes the same way allowTcgen5Mma() does (for TS kernels use the accumulator/CMem dtype instead of TMEM dtype for A, etc.), then pass those effective dtypes into GetTCGEN5MMAMeta(node->m_, node->n_, node->k_, effective_a_dtype, effective_b_dtype) and use the returned meta to decide meta.enable_2cta. Ensure you reference the same dtype-selection helper or replicate its logic so node->a_->dtype is not used directly for TS kernels.tilelang/language/tir/op.py (1)
1165-1167:⚠️ Potential issue | 🟡 MinorDocstring still doesn't match the signature.
It omits
enable_2cta, and the supported positional range here is 13 to 15 arguments.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/language/tir/op.py` around lines 1165 - 1167, Update the operator docstring that currently lists the positional arguments "(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr, C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws])" to include the missing enable_2cta parameter and correct the supported positional-count range to 13–15 arguments; explicitly list the full ordered args including enable_2cta (e.g., "... mask3[, enable_2cta[, enable_ws]]") so the docstring matches the actual signature and accepted argument permutations.examples/gemm_sm100/gemm_tcgen5mma_ws.py (1)
84-86:⚠️ Potential issue | 🔴 CriticalUse a cluster barrier for
tmem_fulland sync before the TMEM read.
arrive_2cta=Trueturnstmem_fullinto a cross-CTA completion barrier, but it is still allocated CTA-local here. The TMEM→fragment copy also lost theT.sync_threads()that the 1-CTA path keeps before readingC_tmem.Suggested fix
- tmem_full = T.alloc_barrier([1]) + tmem_full = T.alloc_cluster_barrier([1]) @@ # Wait for all tcgen5 to finish T.mbarrier_wait_parity(tmem_full, 0) + T.sync_threads() T.copy(C_tmem, C_local)Also applies to: 114-118
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py` around lines 84 - 86, tmem_full is currently allocated with T.alloc_barrier but must be a cluster-wide barrier when arrive_2cta=True; change its allocation to T.alloc_cluster_barrier like loaded and consumed (use T.alloc_cluster_barrier([1]) for tmem_full) and add a T.sync_threads() (or the appropriate cluster sync) immediately before the TMEM→fragment read to match the 1-CTA path so the TMEM read waits for the cluster barrier completion; update both occurrences (the block around loaded/consumed/tmem_full and the identical block at lines 114-118) to use the cluster barrier and include the sync before reading C_tmem.tilelang/layout/swizzle.py (1)
111-127:⚠️ Potential issue | 🟠 MajorDon’t silently ignore
continuityandk_major, and guard 1D inputs.The public signature still accepts those parameters, but this implementation never uses them. It also does
shape[-2]unconditionally, so a 1D buffer now raisesIndexErrorinstead of giving a clear fallback/error.Suggested fix
def make_tcgen05mma_swizzled_layout(buffer: BufferLikeType, continuity: int = None, k_major: bool = True): buf, shape, _ = _get_buffer_info(buffer) + if continuity is not None or k_major is not True: + import warnings + warnings.warn( + "make_tcgen05mma_swizzled_layout currently ignores 'continuity' and 'k_major'.", + DeprecationWarning, + stacklevel=2, + ) + if len(shape) < 2: + raise ValueError("TCGEN05MMA swizzle layout requires at least a 2D buffer.") # Use only the last two dims, matching C++ GetSwizzleShapeInfoChecked (uses shape[-2] as stride).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/layout/swizzle.py` around lines 111 - 127, make_tcgen05mma_swizzled_layout currently ignores the continuity and k_major parameters and assumes at least 2 dimensions (shape[-2]) which causes an IndexError for 1D buffers; update the function to (1) validate shape from _get_buffer_info and handle 1D inputs by treating stride=1 (or falling back to make_linear_layout) instead of indexing shape[-2], and (2) respect the continuity and k_major parameters (or explicitly validate and raise a clear error) when deciding which swizzled layout to choose; reference the existing helpers _get_buffer_info, _get_element_size, make_full_bank_swizzled_layout / make_half_bank_swizzled_layout / make_quarter_bank_swizzled_layout, and make_linear_layout so the logic uses continuity/k_major and guarded stride access.tilelang/tileop/gemm/gemm_tcgen05.py (1)
123-123:⚠️ Potential issue | 🟠 MajorCast
enable_2ctato a nativeintbefore buildingcluster_cond.
get_tcgen5_mma_meta()is handled as boxed TVM integers elsewhere in this path. Usingnot enable_2ctadirectly here makes the guard depend on Python object truthiness instead of the flag value, so the cluster-rank gate can stay active in 1-CTA mode.Suggested fix
- atom_m, atom_n, atom_k, enable_ws, enable_2cta = mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K) + atom_m, atom_n, atom_k, enable_ws, enable_2cta = [ + int(x) for x in mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K) + ]Run the following read-only check to confirm this caller is the only one still using the raw meta values in boolean logic:
#!/bin/bash set -euo pipefail # Expected: the TCGEN5 macro generator casts meta values to int before using enable_2cta, # while this caller still uses the raw return value. rg -n -C2 'get_tcgen5_mma_meta\(|int\(x\) for x in meta|enable_2cta' \ tilelang/tileop/gemm/gemm_tcgen05.py \ tilelang/intrinsics/tcgen05_macro_generator.pyAlso applies to: 161-176
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/tileop/gemm/gemm_tcgen05.py` at line 123, The guard uses the boxed TVM integer enable_2cta returned by get_tcgen5_mma_meta, so using Python truthiness (e.g. not enable_2cta) is incorrect; cast enable_2cta to a native int before building cluster_cond (e.g., enable_2cta = int(enable_2cta)) so the cluster-rank gate is driven by the actual flag value; make the same change for the other usages around the same block (the caller in gemm_tcgen05.py and any other cluster_cond logic in the 161-176 region) to match how tcgen05_macro_generator handles meta values.tilelang/intrinsics/tcgen05_macro_generator.py (1)
403-409:⚠️ Potential issue | 🟠 Major
tcgen05mma_ts()still ignores_enable_2cta.
_enable_2ctais unpacked and then discarded. The TS path keeps full-CTA geometry and line 544 never forwardsarrive_2cta, so a TS kernel can receive 2CTA meta while still emitting the old 1CTA sequence. Either propagate_enable_2ctathrough the TS geometry/offset/arrive path, or force-disable 2CTA when lowering TS.Also applies to: 445-545
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tilelang/intrinsics/tcgen05_macro_generator.py` around lines 403 - 409, Unpacked `_enable_2cta` from get_tcgen5_mma_meta is never used in tcgen05mma_ts, so a TS kernel can be given 2CTA meta but still generate the 1CTA geometry/offset/arrive sequence; either propagate this flag through the TS path or disable it when lowering. Update tcgen05mma_ts to forward `_enable_2cta` (the value obtained alongside atom_m, atom_n, atom_k, _enable_ws) into the TS geometry/offset/arrive code paths (e.g., pass it into whatever constructs produce the CTA geometry and the arrive_2cta logic) or explicitly set/clear the 2CTA flag before generating offsets/arrive so the emitted sequence matches the meta returned by get_tcgen5_mma_meta. Ensure any functions or code that build CTA geometry/offset/arrive (referenced as arrive_2cta/geometry/offset in this file) accept and act on the `_enable_2cta` boolean.examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py (1)
151-159:⚠️ Potential issue | 🔴 CriticalRequire an even, non-zero number of clustered M tiles.
m_clusters = m_blocks // 2makes the swizzle invalid whenm_blocks == 1, and an oddm_blockssilently drops the tail M tile.Suggested fix
m_blocks = T.ceildiv(M, block_M) + assert m_blocks >= 2 and m_blocks % 2 == 0, \ + "gemm_persistent_2cta requires an even number of M tiles" m_clusters = m_blocks // 2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py` around lines 151 - 159, The computation m_clusters = m_blocks // 2 drops a tail tile and yields 0 when m_blocks == 1; replace it with m_clusters = T.ceildiv(m_blocks, 2) and add an assertion that m_clusters >= 1 (or raise) so you never produce zero clusters and you keep the tail M tile for odd m_blocks; update any downstream logic that expects pairs to handle the final singleton cluster if present (references: m_blocks, m_clusters, T.ceildiv).src/tl_templates/cuda/instruction/tcgen05mma.h (2)
167-185:⚠️ Potential issue | 🔴 CriticalThe 2-CTA SS specialization drops tail masks.
The 1-CTA path at line 158 uses
{%5, %6, %7, %8}to passmask0..mask3for predicating partial tiles. The 2-CTA specialization at line 179 omits these masks entirely, which means non-full tiles will execute as fully enabled MMAs, potentially causing incorrect results.,
Please verify against the PTX ISA whether
tcgen05.mma.cta_group::2supports or requires mask operands. If masks are supported, they should be passed; if not, this limitation should be documented and partial-tile scenarios need alternative handling.tcgen05.mma cta_group::2 PTX ISA mask operands🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/instruction/tcgen05mma.h` around lines 167 - 185, The 2-CTA specialization tcgen05mma_ss<DataType::kFloat16, true> is missing the tail-mask predicates (mask0..mask3) in its inline asm, so partial tiles run fully enabled; update the asm template in tcgen05mma_ss to pass the same predicate operands used by the 1-CTA path (mask0, mask1, mask2, mask3) into the tcgen05.mma.cta_group::2 instruction (add corresponding .reg/.pred setup and include them in the operand list and constraint slots), or if PTX for tcgen05.mma.cta_group::2 does not support mask operands, add a clear comment in tcgen05mma_ss and implement an alternative codepath that explicitly guards/zeroes out partial tiles before issuing the MMA (using the same mask variables mask0..mask3 to gate work).
187-204:⚠️ Potential issue | 🔴 CriticalBFloat16 2-CTA variant inherits the mask issue.
The
tcgen05mma_ss<DataType::kBFloat16, true>specialization delegates totcgen05mma_ss<DataType::kFloat16, true>, which drops the tail masks. Once the FP16 2-CTA path is fixed, this delegation will automatically inherit the fix.,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/tl_templates/cuda/instruction/tcgen05mma.h` around lines 187 - 204, The BFloat16 2-CTA specialization tcgen05mma_ss<DataType::kBFloat16, true> currently delegates to tcgen05mma_ss<DataType::kFloat16, true>, which drops the tail masks; fix this by ensuring the BFloat16 2-CTA path preserves mask handling instead of delegating to the FP16 implementation — either (a) change the call to invoke a mask-preserving implementation for BFloat16, or (b) inline/replicate the corrected mask-handling logic from the FP16 2-CTA fix but specialized for DataType::kBFloat16, keeping the same parameters (desc_a, desc_b, tmem_c, scalec, desc_val, mask0, mask1, mask2, mask3) so tail masks are not lost.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/op/tcgen5_meta.h`:
- Around line 22-23: The function GetTCGEN5MMAMeta currently only checks
disable_2cta for the BF16/F16 branch; update every 2-CTA return path so they are
gated by !disable_2cta as well — specifically, locate all SUCCESS(..., true)
branches (the FP8 and INT8 branches referenced) and add the disable_2cta guard
so they return the 2-CTA success only when !disable_2cta; also ensure the
TL_DISABLE_2CTA_TCGEN5MMA semantics are honored everywhere 2-CTA paths are
produced in GetTCGEN5MMAMeta.
In `@src/target/codegen_cuda.cc`:
- Around line 430-460: ClusterInfoExtractor::VisitStmt reads cluster_dims[0..2]
and calls .as<IntImmNode>()->value without validating array size or null cast
results; update VisitStmt to first check cluster_dims.size() >= 3, then for each
index (0,1,2) ensure cluster_dims[i].as<IntImmNode>() is non-null before
accessing ->value, and handle invalid/malformed attributes by leaving
launch_with_cluster false or logging/ICHECK_FAIL; ensure
cluster_grid_x_ext/cluster_grid_y_ext/cluster_grid_z_ext are only assigned after
successful checks and maintain existing ICHECK for positive values.
In `@src/tl_templates/cuda/copy_sm100.h`:
- Around line 318-324: The PTX modifier order in the tma_load_2sm inline asm
strings is incorrect: move the cta_group::2 token so it follows the completion
mechanism/bytes modifiers (i.e., use
".1d.shared::cluster.global.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
instead of
".1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint");
update this ordering in all tma_load_2sm asm occurrences (the instances noted in
the comment and the other similar asm strings in the same function) so the
modifiers follow the documented
dim.dst.src{.load_mode}.completion_mechanism{.multicast}{.cta_group}{.level::cache_hint}
sequence.
In `@tilelang/language/tir/op.py`:
- Around line 1157-1160: The new parameter enable_2cta was inserted before
existing positional parameters (ws, warp_specialized, variant) and breaks
callers; fix it by making enable_2cta keyword-only or moving it after the
existing alias parameters: either (A) add a keyword-only marker (*) before
enable_2cta in the function signature so enable_2cta must be passed by name, or
(B) move the enable_2cta parameter to come after variant (or after
warp_specialized/ws) so existing positional calls still bind to
ws/warp_specialized/variant; update the signature where enable_ws, enable_2cta,
ws, warp_specialized, variant are declared accordingly.
---
Duplicate comments:
In `@examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py`:
- Around line 151-159: The computation m_clusters = m_blocks // 2 drops a tail
tile and yields 0 when m_blocks == 1; replace it with m_clusters =
T.ceildiv(m_blocks, 2) and add an assertion that m_clusters >= 1 (or raise) so
you never produce zero clusters and you keep the tail M tile for odd m_blocks;
update any downstream logic that expects pairs to handle the final singleton
cluster if present (references: m_blocks, m_clusters, T.ceildiv).
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 84-86: tmem_full is currently allocated with T.alloc_barrier but
must be a cluster-wide barrier when arrive_2cta=True; change its allocation to
T.alloc_cluster_barrier like loaded and consumed (use
T.alloc_cluster_barrier([1]) for tmem_full) and add a T.sync_threads() (or the
appropriate cluster sync) immediately before the TMEM→fragment read to match the
1-CTA path so the TMEM read waits for the cluster barrier completion; update
both occurrences (the block around loaded/consumed/tmem_full and the identical
block at lines 114-118) to use the cluster barrier and include the sync before
reading C_tmem.
In `@src/op/gemm_py.cc`:
- Around line 323-326: GemmPyNode::allowTcgen5Mma currently calls
GetTCGEN5MMAMeta without forwarding the disable_2cta flag, so instruction
selection can still pick TCGEN5MMA even when 2CTA-only shapes are disabled;
update GemmPyNode::allowTcgen5Mma to accept or retrieve the disable_2cta boolean
and pass it into the GetTCGEN5MMAMeta(...) call (matching the usage in the
reflected helper tl.get_tcgen5_mma_meta) so the metadata check respects the
disable_2cta filtering.
In `@src/tl_templates/cuda/common.h`:
- Around line 62-79: The preprocessor ordering causes NVRTC builds to match
__CUDA_ARCH__ first and use the printf/__trap branch, breaking RTC because
<cstdio> is excluded; change the macro checks so __CUDACC_RTC__ is tested before
__CUDA_ARCH__ (i.e., reorder the `#if/`#elif so the __CUDACC_RTC__ branch with
__builtin_trap() is selected for runtime compilation), keeping the
TILELANG_UNREACHABLE macro and behavior otherwise identical.
In `@src/tl_templates/cuda/instruction/tcgen05mma.h`:
- Around line 167-185: The 2-CTA specialization
tcgen05mma_ss<DataType::kFloat16, true> is missing the tail-mask predicates
(mask0..mask3) in its inline asm, so partial tiles run fully enabled; update the
asm template in tcgen05mma_ss to pass the same predicate operands used by the
1-CTA path (mask0, mask1, mask2, mask3) into the tcgen05.mma.cta_group::2
instruction (add corresponding .reg/.pred setup and include them in the operand
list and constraint slots), or if PTX for tcgen05.mma.cta_group::2 does not
support mask operands, add a clear comment in tcgen05mma_ss and implement an
alternative codepath that explicitly guards/zeroes out partial tiles before
issuing the MMA (using the same mask variables mask0..mask3 to gate work).
- Around line 187-204: The BFloat16 2-CTA specialization
tcgen05mma_ss<DataType::kBFloat16, true> currently delegates to
tcgen05mma_ss<DataType::kFloat16, true>, which drops the tail masks; fix this by
ensuring the BFloat16 2-CTA path preserves mask handling instead of delegating
to the FP16 implementation — either (a) change the call to invoke a
mask-preserving implementation for BFloat16, or (b) inline/replicate the
corrected mask-handling logic from the FP16 2-CTA fix but specialized for
DataType::kBFloat16, keeping the same parameters (desc_a, desc_b, tmem_c,
scalec, desc_val, mask0, mask1, mask2, mask3) so tail masks are not lost.
In `@src/transform/inject_tma_barrier.cc`:
- Around line 317-320: EvalSet is being called with an empty loop-domain map
because TmaSequenceCollector never records loop iteration ranges; implement
VisitStmt_(const ForNode*) in TmaSequenceCollector to populate var_int_set_ with
the loop variable's IntSet (use arith::IntSet::FromRange or construct from
for->min and for->extent via arith::EvalSet/arith::IntervalSet), so that when
barrier_id is evaluated (in the code paths using
barrier_id.as<CallNode>()->args[0] and calling arith::EvalSet(e, var_int_set_))
the map contains proper bounds; ensure the same fix is applied for the other
EvalSet usage referenced (the other occurrence around the barrier sequence
logic) and that GetSequence() no longer assumes IntImm without checking the
resulting IntSet/Interval.
- Around line 226-231: When handling tl::ptx_arrive_cluster_barrier() the code
only records leader CTAs when call->args[1] is an IntImmNode, silently dropping
non-constant leaders; change the branch so that if call->args[1] is not an
IntImmNode you still record a marker (e.g.,
cluster_barrier_cta_ids_.Set(barrier_id, IntImm(DataType::Int(32), -1)) or store
the Expr itself) and emit a warning/log so the rewriter can detect a
non-constant leader and take the expect_tx/use_2cta path accordingly; reference
call->args[1], IntImmNode, cluster_barrier_cta_ids_.Set, barrier_id, and
tl::ptx_arrive_cluster_barrier() when making this change.
- Around line 557-571: The current PostOrderVisit loop overwrites cluster_size
on each block with a "cluster_dims" annotation (last-writer-wins); change it to
detect and enforce consistency: when you encounter a
block->annotations.Get("cluster_dims") inside the PostOrderVisit lambda, parse
its Array<Integer> into a computed sz and if cluster_size is still the initial
sentinel (e.g., 0 or use an optional) set cluster_size = sz, otherwise compare
sz to the already-set cluster_size and signal an error/assertion/logging if they
differ (do not silently overwrite). Update the code around cluster_size, the
PostOrderVisit lambda, and the handling of
block->annotations/Get("cluster_dims") to implement this consistency check.
In `@src/transform/lower_blackwell_2sm.cc`:
- Around line 77-78: The pass currently uses the function-wide boolean
has_2sm_tcgen5mma_ causing mixed 1CTA/2CTA kernels to be mis-lowered; change the
logic to track the 2CTA requirement per-block instead of a single flag (or
explicitly reject PrimFunc with mixed 1CTA/2CTA). Update the code paths that
annotate the root block and that decide TMEM alloc/dealloc to consult this
per-block marker (replace uses of has_2sm_tcgen5mma_ with a block-scoped lookup
keyed by the Block/Stmt node) so each block follows its own 1CTA/2CTA lowering,
and add a validation step in the pass entry to error out if you choose the
explicit-reject approach.
- Around line 60-63: The code queries GetTCGEN5MMAMeta using node->a_->dtype
which can be wrong for TS kernels; change the dtype selection to match the same
A/B-dtype rule used in allowTcgen5Mma(): compute the effective A and B dtypes
the same way allowTcgen5Mma() does (for TS kernels use the accumulator/CMem
dtype instead of TMEM dtype for A, etc.), then pass those effective dtypes into
GetTCGEN5MMAMeta(node->m_, node->n_, node->k_, effective_a_dtype,
effective_b_dtype) and use the returned meta to decide meta.enable_2cta. Ensure
you reference the same dtype-selection helper or replicate its logic so
node->a_->dtype is not used directly for TS kernels.
In `@tilelang/intrinsics/tcgen05_macro_generator.py`:
- Around line 403-409: Unpacked `_enable_2cta` from get_tcgen5_mma_meta is never
used in tcgen05mma_ts, so a TS kernel can be given 2CTA meta but still generate
the 1CTA geometry/offset/arrive sequence; either propagate this flag through the
TS path or disable it when lowering. Update tcgen05mma_ts to forward
`_enable_2cta` (the value obtained alongside atom_m, atom_n, atom_k, _enable_ws)
into the TS geometry/offset/arrive code paths (e.g., pass it into whatever
constructs produce the CTA geometry and the arrive_2cta logic) or explicitly
set/clear the 2CTA flag before generating offsets/arrive so the emitted sequence
matches the meta returned by get_tcgen5_mma_meta. Ensure any functions or code
that build CTA geometry/offset/arrive (referenced as arrive_2cta/geometry/offset
in this file) accept and act on the `_enable_2cta` boolean.
In `@tilelang/language/tir/op.py`:
- Around line 1165-1167: Update the operator docstring that currently lists the
positional arguments "(kind_dtype, desc_a, A_offset, desc_b, B_offset, C_ptr,
C_offset, desc_val, scale_out, mask0, mask1, mask2, mask3[, enable_ws])" to
include the missing enable_2cta parameter and correct the supported
positional-count range to 13–15 arguments; explicitly list the full ordered args
including enable_2cta (e.g., "... mask3[, enable_2cta[, enable_ws]]") so the
docstring matches the actual signature and accepted argument permutations.
In `@tilelang/layout/swizzle.py`:
- Around line 111-127: make_tcgen05mma_swizzled_layout currently ignores the
continuity and k_major parameters and assumes at least 2 dimensions (shape[-2])
which causes an IndexError for 1D buffers; update the function to (1) validate
shape from _get_buffer_info and handle 1D inputs by treating stride=1 (or
falling back to make_linear_layout) instead of indexing shape[-2], and (2)
respect the continuity and k_major parameters (or explicitly validate and raise
a clear error) when deciding which swizzled layout to choose; reference the
existing helpers _get_buffer_info, _get_element_size,
make_full_bank_swizzled_layout / make_half_bank_swizzled_layout /
make_quarter_bank_swizzled_layout, and make_linear_layout so the logic uses
continuity/k_major and guarded stride access.
In `@tilelang/tileop/gemm/gemm_tcgen05.py`:
- Line 123: The guard uses the boxed TVM integer enable_2cta returned by
get_tcgen5_mma_meta, so using Python truthiness (e.g. not enable_2cta) is
incorrect; cast enable_2cta to a native int before building cluster_cond (e.g.,
enable_2cta = int(enable_2cta)) so the cluster-rank gate is driven by the actual
flag value; make the same change for the other usages around the same block (the
caller in gemm_tcgen05.py and any other cluster_cond logic in the 161-176
region) to match how tcgen05_macro_generator handles meta values.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4d8fb7b9-1ce9-42ca-ba00-dda0c4e44d61
📒 Files selected for processing (28)
examples/gemm_sm100/gemm_tcgen5mma_ws.pyexamples/gemm_sm100/gemm_tcgen5mma_ws_persistent.pysrc/op/builtin.ccsrc/op/builtin.hsrc/op/copy.ccsrc/op/gemm_py.ccsrc/op/tcgen5_meta.hsrc/target/codegen_cuda.ccsrc/target/codegen_cuda.hsrc/tl_templates/cuda/cluster.hsrc/tl_templates/cuda/common.hsrc/tl_templates/cuda/copy_sm100.hsrc/tl_templates/cuda/instruction/tcgen05mma.hsrc/tl_templates/cuda/tcgen_05.hsrc/tl_templates/cuda/threadblock_swizzle.hsrc/transform/inject_tma_barrier.ccsrc/transform/lower_blackwell_2sm.ccsrc/transform/lower_shared_tmem.cctilelang/engine/phase.pytilelang/intrinsics/tcgen05_macro_generator.pytilelang/language/annotations.pytilelang/language/builtin.pytilelang/language/cluster.pytilelang/language/tir/op.pytilelang/layout/swizzle.pytilelang/tileop/gemm/gemm_tcgen05.pytilelang/transform/__init__.pytilelang/transform/pass_config.py
🚧 Files skipped from review as they are similar to previous changes (8)
- src/op/builtin.cc
- tilelang/engine/phase.py
- tilelang/transform/init.py
- tilelang/language/cluster.py
- src/op/copy.cc
- tilelang/language/annotations.py
- src/target/codegen_cuda.h
- src/op/builtin.h
| class ClusterInfoExtractor : public tir::StmtVisitor { | ||
| private: | ||
| void VisitStmt(const PrimFunc &f) { | ||
| if (f->GetAttr<Array<PrimExpr>>("cluster_dims").has_value()) { | ||
| launch_with_cluster = true; | ||
| auto cluster_dims = f->GetAttr<Array<PrimExpr>>("cluster_dims").value(); | ||
| cluster_grid_x_ext = cluster_dims[0].as<IntImmNode>()->value; | ||
| cluster_grid_y_ext = cluster_dims[1].as<IntImmNode>()->value; | ||
| cluster_grid_z_ext = cluster_dims[2].as<IntImmNode>()->value; | ||
| ICHECK(cluster_grid_x_ext > 0 && cluster_grid_y_ext > 0 && | ||
| cluster_grid_z_ext > 0); | ||
| } | ||
| StmtVisitor::VisitStmt(f->body); | ||
| } | ||
|
|
||
| bool launch_with_cluster = false; | ||
| int64_t cluster_grid_x_ext = 1; | ||
| int64_t cluster_grid_y_ext = 1; | ||
| int64_t cluster_grid_z_ext = 1; | ||
|
|
||
| public: | ||
| std::optional<std::tuple<int64_t, int64_t, int64_t>> | ||
| extract(const PrimFunc &f) { | ||
| this->VisitStmt(f); | ||
| if (launch_with_cluster) { | ||
| return std::make_tuple(cluster_grid_x_ext, cluster_grid_y_ext, | ||
| cluster_grid_z_ext); | ||
| } | ||
| return std::nullopt; | ||
| } | ||
| }; |
There was a problem hiding this comment.
Missing bounds and null checks in ClusterInfoExtractor::VisitStmt.
The code accesses cluster_dims[0], [1], [2] without verifying the array has at least 3 elements, and calls .as<IntImmNode>()->value without null-checking the cast result. If the attribute is malformed, this will crash.
🛡️ Proposed fix to add defensive checks
void VisitStmt(const PrimFunc &f) {
if (f->GetAttr<Array<PrimExpr>>("cluster_dims").has_value()) {
- launch_with_cluster = true;
auto cluster_dims = f->GetAttr<Array<PrimExpr>>("cluster_dims").value();
- cluster_grid_x_ext = cluster_dims[0].as<IntImmNode>()->value;
- cluster_grid_y_ext = cluster_dims[1].as<IntImmNode>()->value;
- cluster_grid_z_ext = cluster_dims[2].as<IntImmNode>()->value;
+ ICHECK_GE(cluster_dims.size(), 3)
+ << "cluster_dims attribute must have at least 3 elements";
+ const IntImmNode *x = cluster_dims[0].as<IntImmNode>();
+ const IntImmNode *y = cluster_dims[1].as<IntImmNode>();
+ const IntImmNode *z = cluster_dims[2].as<IntImmNode>();
+ ICHECK(x && y && z) << "cluster_dims elements must be integer constants";
+ cluster_grid_x_ext = x->value;
+ cluster_grid_y_ext = y->value;
+ cluster_grid_z_ext = z->value;
ICHECK(cluster_grid_x_ext > 0 && cluster_grid_y_ext > 0 &&
cluster_grid_z_ext > 0);
+ launch_with_cluster = true;
}
StmtVisitor::VisitStmt(f->body);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class ClusterInfoExtractor : public tir::StmtVisitor { | |
| private: | |
| void VisitStmt(const PrimFunc &f) { | |
| if (f->GetAttr<Array<PrimExpr>>("cluster_dims").has_value()) { | |
| launch_with_cluster = true; | |
| auto cluster_dims = f->GetAttr<Array<PrimExpr>>("cluster_dims").value(); | |
| cluster_grid_x_ext = cluster_dims[0].as<IntImmNode>()->value; | |
| cluster_grid_y_ext = cluster_dims[1].as<IntImmNode>()->value; | |
| cluster_grid_z_ext = cluster_dims[2].as<IntImmNode>()->value; | |
| ICHECK(cluster_grid_x_ext > 0 && cluster_grid_y_ext > 0 && | |
| cluster_grid_z_ext > 0); | |
| } | |
| StmtVisitor::VisitStmt(f->body); | |
| } | |
| bool launch_with_cluster = false; | |
| int64_t cluster_grid_x_ext = 1; | |
| int64_t cluster_grid_y_ext = 1; | |
| int64_t cluster_grid_z_ext = 1; | |
| public: | |
| std::optional<std::tuple<int64_t, int64_t, int64_t>> | |
| extract(const PrimFunc &f) { | |
| this->VisitStmt(f); | |
| if (launch_with_cluster) { | |
| return std::make_tuple(cluster_grid_x_ext, cluster_grid_y_ext, | |
| cluster_grid_z_ext); | |
| } | |
| return std::nullopt; | |
| } | |
| }; | |
| class ClusterInfoExtractor : public tir::StmtVisitor { | |
| private: | |
| void VisitStmt(const PrimFunc &f) { | |
| if (f->GetAttr<Array<PrimExpr>>("cluster_dims").has_value()) { | |
| auto cluster_dims = f->GetAttr<Array<PrimExpr>>("cluster_dims").value(); | |
| ICHECK_GE(cluster_dims.size(), 3) | |
| << "cluster_dims attribute must have at least 3 elements"; | |
| const IntImmNode *x = cluster_dims[0].as<IntImmNode>(); | |
| const IntImmNode *y = cluster_dims[1].as<IntImmNode>(); | |
| const IntImmNode *z = cluster_dims[2].as<IntImmNode>(); | |
| ICHECK(x && y && z) << "cluster_dims elements must be integer constants"; | |
| cluster_grid_x_ext = x->value; | |
| cluster_grid_y_ext = y->value; | |
| cluster_grid_z_ext = z->value; | |
| ICHECK(cluster_grid_x_ext > 0 && cluster_grid_y_ext > 0 && | |
| cluster_grid_z_ext > 0); | |
| launch_with_cluster = true; | |
| } | |
| StmtVisitor::VisitStmt(f->body); | |
| } | |
| bool launch_with_cluster = false; | |
| int64_t cluster_grid_x_ext = 1; | |
| int64_t cluster_grid_y_ext = 1; | |
| int64_t cluster_grid_z_ext = 1; | |
| public: | |
| std::optional<std::tuple<int64_t, int64_t, int64_t>> | |
| extract(const PrimFunc &f) { | |
| this->VisitStmt(f); | |
| if (launch_with_cluster) { | |
| return std::make_tuple(cluster_grid_x_ext, cluster_grid_y_ext, | |
| cluster_grid_z_ext); | |
| } | |
| return std::nullopt; | |
| } | |
| }; |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/target/codegen_cuda.cc` around lines 430 - 460,
ClusterInfoExtractor::VisitStmt reads cluster_dims[0..2] and calls
.as<IntImmNode>()->value without validating array size or null cast results;
update VisitStmt to first check cluster_dims.size() >= 3, then for each index
(0,1,2) ensure cluster_dims[i].as<IntImmNode>() is non-null before accessing
->value, and handle invalid/malformed attributes by leaving launch_with_cluster
false or logging/ICHECK_FAIL; ensure
cluster_grid_x_ext/cluster_grid_y_ext/cluster_grid_z_ext are only assigned after
successful checks and maintain existing ICHECK for positive values.
| asm volatile("cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global." | ||
| "mbarrier::complete_tx::bytes.L2::cache_hint" | ||
| " [%0], [%1, {%3}], [%2], %4;" | ||
| : | ||
| : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), | ||
| "r"(crd0), "l"(cache_hint) | ||
| : "memory"); |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Check if the file exists and examine the relevant lines
if [ -f "src/tl_templates/cuda/copy_sm100.h" ]; then
echo "=== Lines 318-324 ==="
sed -n '318,324p' src/tl_templates/cuda/copy_sm100.h
echo ""
echo "=== Lines 343-349 ==="
sed -n '343,349p' src/tl_templates/cuda/copy_sm100.h
echo ""
echo "=== Lines 369-375 ==="
sed -n '369,375p' src/tl_templates/cuda/copy_sm100.h
echo ""
echo "=== Lines 395-401 ==="
sed -n '395,401p' src/tl_templates/cuda/copy_sm100.h
echo ""
echo "=== Lines 422-429 ==="
sed -n '422,429p' src/tl_templates/cuda/copy_sm100.h
else
echo "File not found"
fiRepository: tile-ai/tilelang
Length of output: 2090
🌐 Web query:
NVIDIA CUDA PTX cp.async.bulk.tensor instruction modifier order syntax documentation
💡 Result:
Official PTX syntax (including modifier order) for cp.async.bulk.tensor is documented in the PTX ISA under “Data Movement and Conversion Instructions: cp.async.bulk.tensor”. (docs.nvidia.com)
Syntax + modifier order (as written in PTX ISA)
Global → Shared::cluster (TMA load):
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.multicast}{.level::cache_hint}
[dstMem], [tensorMap, tensorCoords], [mbar]{, im2colOffsets} {, ctaMask} {, cache-policy}
So the modifier sequence is:
.dim.dst.src- optional
{.load_mode} .completion_mechanism- optional
{.multicast} - optional
{.level::cache_hint}(docs.nvidia.com)
Shared::cta → Global (TMA store):
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.level::cache_hint}
[tensorMap, tensorCoords], [srcMem] {, cache-policy}
(PTX ISA shows this variant does not include .multicast in the syntax line.) (docs.nvidia.com)
Practical cross-check (header-level docs)
The libcudacxx PTX wrappers mirror the same dot-modifier ordering in their commented PTX forms (e.g., ...tensor.2d.dst.src.tile.mbarrier::complete_tx::bytes..., and multicast variants). (wmaxey.github.io)
Citations:
- 1: https://docs.nvidia.com/cuda/archive/12.0.1/parallel-thread-execution/index.html
- 2: https://docs.nvidia.com/cuda/archive/12.0.1/parallel-thread-execution/index.html
- 3: https://docs.nvidia.com/cuda/archive/12.0.1/parallel-thread-execution/index.html
- 4: https://wmaxey.github.io/cccl/libcudacxx/ptx/instructions/cp_async_bulk_tensor.html
🌐 Web query:
SM100 Blackwell PTX cp.async.bulk.tensor cta_group documentation NVIDIA
💡 Result:
cp.async.bulk.tensor (TMA) and its optional .cta_group modifier are documented in the NVIDIA PTX ISA manual:
-
Instruction + syntax (includes
.cta_group): See “Data Movement and Conversion Instructions:cp.async.bulk.tensor”. It defines the forms for:global -> shared::ctaandglobal -> shared::cluster(completion viambarrier::complete_tx::bytes)shared::cta -> global(completion viabulk_group)- and shows
.cta_group = { .cta_group::1, .cta_group::2 }. [1]
-
When/where
.cta_groupexists (Blackwell SM100 support notes):- “Support for qualifier
.cta_groupintroduced in PTX ISA 8.6.” [2] .cta_groupis supported on Blackwell family targets includingsm_100a(and family-specificsm_100f+ from PTX ISA 8.8), with additional architecture-name notes (e.g., renames in later PTX ISA versions). [2]
- “Support for qualifier
Related (still NVIDIA-official) context that may help when using these instructions from CUDA/C++:
- CUDA Programming Guide (Asynchronous Data Copies): overview table mapping “bulk tensor” async copies to
cuda::ptx::*APIs (e.g.,cp_async_bulk_tensor). [3]
Sources:
[1] NVIDIA PTX ISA (Parallel Thread Execution) — cp.async.bulk.tensor section (syntax/description) (docs.nvidia.com)
[2] NVIDIA PTX ISA — notes on .cta_group introduction + supported SM targets (docs.nvidia.com)
[3] CUDA Programming Guide — Asynchronous Data Copies overview (docs.nvidia.cn)
Citations:
- 1: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
- 2: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
- 3: https://docs.nvidia.cn/cuda/cuda-programming-guide/04-special-topics/async-copies.html
🏁 Script executed:
# Check if there are CUTLASS or other examples in the repo that use cta_group with cp.async.bulk.tensor
rg "cp\.async\.bulk\.tensor.*cta_group" --type h --type cppRepository: tile-ai/tilelang
Length of output: 624
🌐 Web query:
NVIDIA PTX ISA cp.async.bulk.tensor exact modifier order syntax "completion_mechanism" "cta_group" "cache_hint"
💡 Result:
In the PTX ISA, the modifier order is fixed in the instruction “header” (everything before operands):
cp.async.bulk.tensor exact modifier order (PTX ISA 8.7 / CUDA 12.8)
Global → shared::cta:
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.cta_group}{.level::cache_hint}
So the order is:
{.load_mode}(optional).completion_mechanism(required){.cta_group}(optional){.level::cache_hint}(optional) [1]
Global → shared::cluster (with optional multicast):
cp.async.bulk.tensor.dim.dst.src{.load_mode}.completion_mechanism{.multicast}{.cta_group}{.level::cache_hint}
So the order is:
{.load_mode}(optional).completion_mechanism(required){.multicast}(optional){.cta_group}(optional){.level::cache_hint}(optional) [1]
Allowed values shown in the ISA:
.completion_mechanism = { .mbarrier::complete_tx::bytes }.cta_group = { .cta_group::1, .cta_group::2 }.level::cache_hint = { .L2::cache_hint }[1]
Source: NVIDIA PTX ISA docs, “Data Movement and Conversion Instructions: cp.async.bulk.tensor”. [1]
References:
[1] https://docs.nvidia.com/cuda/archive/12.8.0/parallel-thread-execution/index.html
🏁 Script executed:
# Look for any CUDA examples or CUTLASS references that show correct cp.async.bulk.tensor syntax with cta_group
rg "cp\.async\.bulk\.tensor" -A 2 -B 2 | head -50Repository: tile-ai/tilelang
Length of output: 3731
Fix the PTX modifier order in all tma_load_2sm asm strings.
The documented cp.async.bulk.tensor syntax is dim.dst.src{.load_mode}.completion_mechanism{.multicast}{.cta_group}{.level::cache_hint}. All five new asm strings place cta_group::2 immediately after the dimension, but it should come after the completion mechanism. Currently: .1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint → should be: .1d.shared::cluster.global.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint.
Suggested fix
- asm volatile("cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global."
- "mbarrier::complete_tx::bytes.L2::cache_hint"
+ asm volatile("cp.async.bulk.tensor.1d.shared::cluster.global."
+ "mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
@@
- asm volatile("cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global."
- "mbarrier::complete_tx::bytes.L2::cache_hint"
+ asm volatile("cp.async.bulk.tensor.2d.shared::cluster.global."
+ "mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
@@
- asm volatile("cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global."
- "mbarrier::complete_tx::bytes.L2::cache_hint"
+ asm volatile("cp.async.bulk.tensor.3d.shared::cluster.global."
+ "mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
@@
- asm volatile("cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global."
- "mbarrier::complete_tx::bytes.L2::cache_hint"
+ asm volatile("cp.async.bulk.tensor.4d.shared::cluster.global."
+ "mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
@@
- asm volatile("cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global."
- "mbarrier::complete_tx::bytes.L2::cache_hint"
+ asm volatile("cp.async.bulk.tensor.5d.shared::cluster.global."
+ "mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"Also applies to: 343-349, 369-375, 395-401, 422-429
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/tl_templates/cuda/copy_sm100.h` around lines 318 - 324, The PTX modifier
order in the tma_load_2sm inline asm strings is incorrect: move the cta_group::2
token so it follows the completion mechanism/bytes modifiers (i.e., use
".1d.shared::cluster.global.mbarrier::complete_tx::bytes.cta_group::2.L2::cache_hint"
instead of
".1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint");
update this ordering in all tma_load_2sm asm occurrences (the instances noted in
the comment and the other similar asm strings in the same function) so the
modifiers follow the documented
dim.dst.src{.load_mode}.completion_mechanism{.multicast}{.cta_group}{.level::cache_hint}
sequence.
| enable_ws=False, | ||
| enable_2cta=False, | ||
| ws=None, | ||
| warp_specialized=None, |
There was a problem hiding this comment.
Adding enable_2cta here breaks existing positional callers.
This inserts a new parameter before ws, warp_specialized, and variant, so any old positional call that passed one of those aliases now binds to a different argument. Keep enable_2cta after the aliases or make it keyword-only.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tilelang/language/tir/op.py` around lines 1157 - 1160, The new parameter
enable_2cta was inserted before existing positional parameters (ws,
warp_specialized, variant) and breaks callers; fix it by making enable_2cta
keyword-only or moving it after the existing alias parameters: either (A) add a
keyword-only marker (*) before enable_2cta in the function signature so
enable_2cta must be passed by name, or (B) move the enable_2cta parameter to
come after variant (or after warp_specialized/ws) so existing positional calls
still bind to ws/warp_specialized/variant; update the signature where enable_ws,
enable_2cta, ws, warp_specialized, variant are declared accordingly.
20b437a to
605d9dc
Compare
This PR aims to support writing high-performance warp-specialized 2cta gemm on Blackwell, optimizing GEMM on (8192, 8192, 8192) to ~1670TFlops. Thanks @Hamerlate for providing the dev machine.
Summary by CodeRabbit
New Features
Improvements