Skip to content

Commit d7af46f

Browse files
Update docs
1 parent a8c92e4 commit d7af46f

250 files changed

Lines changed: 9991 additions & 1350 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
tilelang.contrib.cutedsl.atomic
2+
===============================
3+
4+
.. py:module:: tilelang.contrib.cutedsl.atomic
5+
6+
.. autoapi-nested-parse::
7+
8+
Atomic operations for CuTeDSL backend.
9+
10+
This module provides implementations of atomic operations using NVVM and LLVM dialects.
11+
12+
13+
14+
Functions
15+
---------
16+
17+
.. autoapisummary::
18+
19+
tilelang.contrib.cutedsl.atomic.AtomicAdd
20+
tilelang.contrib.cutedsl.atomic.AtomicAddRet
21+
tilelang.contrib.cutedsl.atomic.AtomicAddx2
22+
tilelang.contrib.cutedsl.atomic.AtomicAddx4
23+
tilelang.contrib.cutedsl.atomic.AtomicMax
24+
tilelang.contrib.cutedsl.atomic.AtomicMaxRet
25+
tilelang.contrib.cutedsl.atomic.AtomicMin
26+
tilelang.contrib.cutedsl.atomic.AtomicMinRet
27+
tilelang.contrib.cutedsl.atomic.AtomicLoad
28+
tilelang.contrib.cutedsl.atomic.AtomicStore
29+
30+
31+
Module Contents
32+
---------------
33+
34+
.. py:function:: AtomicAdd(ptr, value, *, loc=None, ip=None)
35+
36+
Perform atomic addition on a pointer.
37+
38+
Supports float16, float32, int32, and int64 types.
39+
Returns the old value before addition (atomicrmw semantics).
40+
41+
42+
.. py:function:: AtomicAddRet(ptr, value, *, loc=None, ip=None)
43+
44+
Perform atomic addition and return the previous value.
45+
46+
This is the same as AtomicAdd since nvvm.atomicrmw always returns old value.
47+
48+
49+
.. py:function:: AtomicAddx2(dst_ptr, src_values, *, loc=None, ip=None)
50+
51+
Vectorized atomic add for 2 consecutive elements.
52+
53+
Uses PTX atom.add.v2.f32 for float32 or atom.add.noftz.v2.f16 for float16.
54+
55+
:param dst_ptr: Pointer to destination (2 consecutive elements)
56+
:param src_values: Source values - can be TensorSSA (loaded tensor) or Pointer
57+
58+
59+
.. py:function:: AtomicAddx4(dst_ptr, src_values, *, loc=None, ip=None)
60+
61+
Vectorized atomic add for 4 consecutive float32 elements.
62+
63+
Uses PTX atom.global.add.v4.f32 for true vectorized atomic operation on SM90+.
64+
65+
:param dst_ptr: Pointer to destination (4 consecutive float32 elements)
66+
:param src_values: Source values - can be TensorSSA (loaded tensor) or Pointer
67+
68+
69+
.. py:function:: AtomicMax(ptr, value, *, loc=None, ip=None)
70+
71+
Perform atomic maximum operation.
72+
73+
For integers, uses nvvm.atomicrmw with MAX.
74+
For floats, uses CAS loop since PTX doesn't have atomic max for float32.
75+
76+
77+
.. py:function:: AtomicMaxRet(ptr, value, *, loc=None, ip=None)
78+
79+
Perform atomic maximum and return the previous value.
80+
81+
82+
.. py:function:: AtomicMin(ptr, value, *, loc=None, ip=None)
83+
84+
Perform atomic minimum operation.
85+
86+
For integers, uses nvvm.atomicrmw with MIN.
87+
For floats, uses CAS loop since PTX doesn't have atomic min for float32.
88+
89+
90+
.. py:function:: AtomicMinRet(ptr, value, *, loc=None, ip=None)
91+
92+
Perform atomic minimum and return the previous value.
93+
94+
95+
.. py:function:: AtomicLoad(ptr, memory_order, *, loc=None, ip=None)
96+
97+
Perform atomic load with specified memory ordering.
98+
99+
:param ptr: Pointer to load from
100+
:param memory_order: TileLang memory order ID (0=relaxed, 2=acquire, 5=seq_cst, etc.)
101+
102+
:returns: The loaded value
103+
104+
PTX mapping (per NVIDIA ABI):
105+
relaxed: ld.relaxed.<scope>
106+
acquire: ld.acquire.<scope>
107+
seq_cst: fence.sc.<scope>; ld.relaxed.<scope>
108+
109+
110+
.. py:function:: AtomicStore(ptr, value, memory_order, *, loc=None, ip=None)
111+
112+
Perform atomic store with specified memory ordering.
113+
114+
:param ptr: Pointer to store to
115+
:param value: Value to store
116+
:param memory_order: TileLang memory order ID (0=relaxed, 3=release, 5=seq_cst, etc.)
117+
118+
PTX mapping (per NVIDIA ABI):
119+
relaxed: st.relaxed.<scope>
120+
release: st.release.<scope>
121+
seq_cst: fence.sc.<scope>; st.relaxed.<scope>
122+
123+

_sources/autoapi/tilelang/contrib/cutedsl/cpasync/index.rst.txt

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,6 @@ tilelang.contrib.cutedsl.cpasync
44
.. py:module:: tilelang.contrib.cutedsl.cpasync
55
66
7-
Attributes
8-
----------
9-
10-
.. autoapisummary::
11-
12-
tilelang.contrib.cutedsl.cpasync.BYTES_PER_TENSORMAP
13-
tilelang.contrib.cutedsl.cpasync.BYTES_PER_POINTER
14-
15-
167
Functions
178
---------
189

@@ -23,23 +14,20 @@ Functions
2314
tilelang.contrib.cutedsl.cpasync.extract_tensormap_ptr
2415
tilelang.contrib.cutedsl.cpasync.tma_load
2516
tilelang.contrib.cutedsl.cpasync.tma_store
17+
tilelang.contrib.cutedsl.cpasync.tma_reduce
2618
tilelang.contrib.cutedsl.cpasync.tma_store_arrive
2719
tilelang.contrib.cutedsl.cpasync.tma_store_wait
2820
tilelang.contrib.cutedsl.cpasync.cp_async_shared_global
2921
tilelang.contrib.cutedsl.cpasync.prefetch_tma_descriptor
22+
tilelang.contrib.cutedsl.cpasync.mbarrier_wait
23+
tilelang.contrib.cutedsl.cpasync.mbarrier_cp_async_arrive
24+
tilelang.contrib.cutedsl.cpasync.fence_proxy_async
25+
tilelang.contrib.cutedsl.cpasync.fence_barrier_init
3026

3127

3228
Module Contents
3329
---------------
3430

35-
.. py:data:: BYTES_PER_TENSORMAP
36-
:value: 128
37-
38-
39-
.. py:data:: BYTES_PER_POINTER
40-
:value: 8
41-
42-
4331
.. py:function:: cp_async_gs(size, dst, src)
4432
4533
.. py:function:: cp_async_gs_conditional(size, dst, src, cond)
@@ -77,6 +65,21 @@ Module Contents
7765
:type crd: tuple[Int, ...]
7866

7967

68+
.. py:function:: tma_reduce(tma_desc, smem_ptr, crd, *, loc=None, ip=None)
69+
70+
Reduce data from shared memory to global memory using TMA with atomic ADD reduction.
71+
72+
This performs an atomic add of shared memory data to global memory using
73+
the TMA unit's reduce capability.
74+
75+
:param tma_desc: TMA descriptor for the tensor
76+
:type tma_desc: TMA descriptor
77+
:param smem_ptr: Source pointer in shared memory
78+
:type smem_ptr: Pointer
79+
:param crd: Coordinates tuple for the tensor access
80+
:type crd: tuple[Int, ...]
81+
82+
8083
.. py:function:: tma_store_arrive(*, loc=None, ip=None)
8184
8285
Indicate arrival of warp issuing TMA_STORE.
@@ -114,3 +117,17 @@ Module Contents
114117
Corresponds to PTX instruction: prefetch.tensormap;
115118

116119

120+
.. py:function:: mbarrier_wait(mbar_ptr, phase, timeout_ns = 10000000, *, loc=None, ip=None)
121+
122+
Waits on a mbarrier with a specified phase (blocking loop).
123+
124+
Uses inline PTX to loop until the try_wait succeeds.
125+
The CUDA backend does: while (!mbar.try_wait(parity)) {}
126+
127+
128+
.. py:function:: mbarrier_cp_async_arrive(mbar_ptr, *, loc=None, ip=None)
129+
130+
.. py:function:: fence_proxy_async()
131+
132+
.. py:function:: fence_barrier_init()
133+
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
tilelang.contrib.cutedsl.gemm_tcgen05
2+
=====================================
3+
4+
.. py:module:: tilelang.contrib.cutedsl.gemm_tcgen05
5+
6+
.. autoapi-nested-parse::
7+
8+
tcgen05 (SM100/Blackwell) MMA support for CuTeDSL backend.
9+
10+
Provides:
11+
- Tcgen05SmemDescriptor: 64-bit SMEM descriptor for tcgen05 MMA
12+
- initialize_tcgen05_descriptor: bitfield packing matching common.h layout
13+
- tcgen05mma_ss / tcgen05mma_ws_ss / tcgen05mma_ts: MMA PTX inline asm
14+
- tcgen05_mma_arrive: mbarrier arrive for MMA commit
15+
- tmem_allocate / tmem_deallocate: TMEM allocation/deallocation
16+
17+
18+
19+
Classes
20+
-------
21+
22+
.. autoapisummary::
23+
24+
tilelang.contrib.cutedsl.gemm_tcgen05.Tcgen05SmemDescriptor
25+
26+
27+
Functions
28+
---------
29+
30+
.. autoapisummary::
31+
32+
tilelang.contrib.cutedsl.gemm_tcgen05.initialize_tcgen05_descriptor
33+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ss
34+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ws_ss
35+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05mma_ts
36+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_mma_arrive
37+
tilelang.contrib.cutedsl.gemm_tcgen05.tmem_allocate
38+
tilelang.contrib.cutedsl.gemm_tcgen05.tmem_deallocate
39+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp32bNx
40+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp64bNx
41+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp128bNx
42+
tilelang.contrib.cutedsl.gemm_tcgen05.tcgen05_ld_32dp256bNx
43+
44+
45+
Module Contents
46+
---------------
47+
48+
.. py:class:: Tcgen05SmemDescriptor(desc_64 = None)
49+
50+
64-bit shared-memory descriptor for tcgen05 MMA (Blackwell).
51+
52+
Mirrors tl::Tcgen05SMemDescriptor from common.h.
53+
Stored as two Int32 registers; recast to Int64 for the PTX operand.
54+
55+
56+
.. py:attribute:: desc
57+
58+
59+
.. py:attribute:: desc_i64
60+
61+
62+
.. py:method:: __add__(offset)
63+
64+
Add byte offset. Like C++ operator+, shifts offset >> 4.
65+
66+
67+
68+
.. py:function:: initialize_tcgen05_descriptor(desc, start_address, leading_byte_offset, stride_byte_offset, base_offset, leading_abs, swizzle_mode)
69+
70+
Pack the tcgen05 SMEM descriptor bitfields.
71+
72+
Matches the C++ ``initialize_tcgen05_descriptor`` in common.h:
73+
Low 32 bits (reg32_[0]):
74+
[0:14) start_address >> 4
75+
[16:30) leading_byte_offset (already >>4 from TIR)
76+
High 32 bits (reg32_[1]):
77+
[0:14) stride_byte_offset (already >>4 from TIR)
78+
[14:16) version = 1
79+
[17:20) base_offset & 0x7
80+
[20:21) lbo_mode (leading_is_absolute ? 1 : 0)
81+
[29:32) layout_type (swizzle_mode & 0x7)
82+
83+
84+
.. py:function:: tcgen05mma_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3)
85+
86+
tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, {masks}, p;
87+
88+
Guarded by elect_one_sync — only one thread in the warp issues the MMA.
89+
The TIR codegen also wraps calls in ``if (threadIdx.x >> 5) == 0``
90+
which selects warp 0.
91+
92+
93+
.. py:function:: tcgen05mma_ws_ss(kind_dtype, desc_a, desc_b, tmem_c, desc_val, scale_out)
94+
95+
tcgen05.mma.ws.cta_group::1.kind::{kind} [tmem_c], desc_a, desc_b, desc_val, p, 0;
96+
97+
98+
.. py:function:: tcgen05mma_ts(kind_dtype, tmem_a, desc_b, tmem_c, desc_val, scale_out, mask0, mask1, mask2, mask3)
99+
100+
tcgen05.mma.cta_group::1.kind::{kind} [tmem_c], [tmem_a], desc_b, desc_val, {masks}, p;
101+
102+
103+
.. py:function:: tcgen05_mma_arrive(mbar_ptr)
104+
105+
tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [mbar];
106+
107+
Guarded by elect_one_sync — only one thread in the warp issues the commit.
108+
109+
110+
.. py:function:: tmem_allocate(tmem_buffer_ptr, num_cols)
111+
112+
tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [dst], num_cols;
113+
114+
tmem_buffer_ptr: SMEM pointer that receives the allocated TMEM address.
115+
num_cols: number of columns to allocate.
116+
117+
118+
.. py:function:: tmem_deallocate(tmem_ptr, num_cols)
119+
120+
tcgen05.dealloc.cta_group::1.sync.aligned.b32 tmem_addr, num_cols;
121+
122+
tmem_ptr: SMEM pointer to the uint32 holding the TMEM address.
123+
num_cols: number of columns to deallocate.
124+
125+
126+
.. py:function:: tcgen05_ld_32dp32bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)
127+
128+
Load N uint32 values from TMEM using tcgen05.ld.sync.aligned.32x32b.
129+
130+
Matches tl::tcgen05_ld_32dp32bNx from copy_sm100.h.
131+
N: number of 32-bit elements to load (x-count, compile-time constant).
132+
pack16: if True, use 16-bit packing (not implemented yet).
133+
tmem_start_col: TMEM base column address.
134+
tmem_col_offset: additional column offset.
135+
dst_ptr: destination pointer (register memory).
136+
137+
138+
.. py:function:: tcgen05_ld_32dp64bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)
139+
140+
Load from TMEM using 32dp64b pattern (2x 16x64b for lower/upper 16 rows).
141+
142+
Matches tl::tmem_ld_32dp64bNx from tcgen_05_ld.h.
143+
N: x-count for 16x64b instructions. Total output: 2*N i32 regs.
144+
145+
146+
.. py:function:: tcgen05_ld_32dp128bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)
147+
148+
Load from TMEM using 32dp128b pattern (2x 16x128b for lower/upper 16 rows).
149+
150+
Matches tl::tmem_ld_32dp128bNx from tcgen_05_ld.h.
151+
N: x-count for 16x128b instructions. Total output: 4*N i32 regs.
152+
16x128b.xN produces 2*N i32 regs per half.
153+
154+
155+
.. py:function:: tcgen05_ld_32dp256bNx(N, pack16, tmem_start_col, tmem_col_offset, dst_ptr)
156+
157+
Load from TMEM using 32dp256b pattern (2x 16x256b for lower/upper 16 rows).
158+
159+
Matches tl::tmem_ld_32dp256bNx from tcgen_05_ld.h.
160+
N: x-count for 16x256b instructions. Total output: 8*N i32 regs.
161+
16x256b.xN produces 4*N i32 regs per half.
162+
163+

0 commit comments

Comments
 (0)