-
Notifications
You must be signed in to change notification settings - Fork 95
Description
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request?
High
Please provide a clear description of problem this feature solves
I am writing a GEMM program with cutile. Before using cutile, I was already very used to structuring CUDA programs in three levels: device, block, and warp. However, when implementing GEMM with cutile, I ran into an issue I hadn’t anticipated.
In a standard GEMM (D = A * B + C), we launch the kernel with a 2D grid. Each block is responsible for computing a [tile_m, tile_n] tile of the output matrix. Therefore, the block needs to read an [M, K] tile from matrix A and an [N, K] tile from matrix B (assuming B is column-major).
I normally structure my program like this:
- device_gemm: responsible for splitting the work at the block level, partitioning A and B into tiles that each block will handle.
- block_gemm: responsible for the actual GEMM computation.
But when I tried to write device_gemm, I encountered a problem:
@ct.kernel
def device_gemm(
x: ct.Array, y: ct.Array, out: ct.Array,
tm: ct.Constant, tn: ct.Constant, tk: ct.Constant
):
block_idx_x, block_idx_y = ct.bid(0), ct.bid(1)
# can not create local_x here
# cuda.tile._exception.TileTypeError: Non-constant slices are not supported
local_x = x[block_idx_x * tm: block_idx_x * tm + tm, :]
local_y = y[block_idx_y * tn: block_idx_y * tn + tn, :]
acc = tile_gemm(local_x, local_y, tm, tn, tk)
ct.store(out, (block_idx_x, block_idx_y), acc)With the current cutile design, I’m unable to “relocate” an array. In fact, what I need is to obtain a new array via indexing and slicing, which is just a view of the original array.
While I could pass the entire array directly into tile_gemm and then rely on ct.load inside it to perform all data accesses at once, believe me, that’s not a good approach. Creating new arrays (as views) through indexing and slicing makes the program much more readable.
Feature Description
Support indexing & slicing in ct.Array
@ct.kernel
def kernel(
x: ct.Array, tile_size: ct.Constant
):
block_idx = ct.bid(0)
# just like this:
local_x: ct.ArrayView = x[block_idx * tile_size: block_idx * tile_size + tile_size]Describe your ideal solution
Introduce a view mechanism to ct.Array, so that we can relocate ct.Array without change anything else in cutile.
Describe any alternatives you have considered
No response
Additional context
No response
Contributing Guidelines
- I agree to follow cuTile Python's contributing guidelines
- I have searched the open feature requests and have found no duplicates for this feature request