-
-
Notifications
You must be signed in to change notification settings - Fork 982
Description
Hi! I think there is an out-of-bounds shared memory read in the block-reduction stage of the online safe softmax function in the softmax kernels using the MD struct.
Where
In kernels/softmax/softmax.cu, in both:
online_safe_softmax_f32_per_token_kernelonline_safe_softmax_f32x4_per_token_kernel
Shared memory is sized by the number of warps:
const int WARP_NUM = NUM_THREADS / WARP_SIZE;
__shared__ MD shared[WARP_NUM];Then after writing one partial per warp (shared[warp_id]), the block-reduce stage does:
if (local_tid < WARP_SIZE) {
MD block_res = shared[local_tid]; // <-- OOB if local_tid > WARP_NUM
block_res = warp_reduce_md_op<WARP_NUM>(block_res);
if (local_tid == 0) shared[0] = block_res;
}If NUM_THREADS = 256, then WARP_NUM = NUM_THREADS / WARP_SIZE = 256 / 32 = 8, so threads with local_tid = 8..31 will read shared[8..31], but shared was declared as shared[WARP_NUM] = shared[8], which leads to an OOB shared memory read.
I think a simple fix is to guard the load from shared and initialize the extra values in the final warp with identity values before the warp reduction.
Concretely, for lanes local_tid >= WARP_NUM we should not read shared[local_tid] . Instead, set:
block_res.m = -FLT_MAX;
block_res.d = 0.0f;