Skip to content

Potential OOB shared memory read in online safe softmax (MD block reduction) #409

@teosssss

Description

@teosssss

Hi! I think there is an out-of-bounds shared memory read in the block-reduction stage of the online safe softmax function in the softmax kernels using the MD struct.

Where

In kernels/softmax/softmax.cu, in both:

  • online_safe_softmax_f32_per_token_kernel
  • online_safe_softmax_f32x4_per_token_kernel

Shared memory is sized by the number of warps:

const int WARP_NUM = NUM_THREADS / WARP_SIZE;
__shared__ MD shared[WARP_NUM];

Then after writing one partial per warp (shared[warp_id]), the block-reduce stage does:

if (local_tid < WARP_SIZE) {
  MD block_res = shared[local_tid];   // <-- OOB if local_tid > WARP_NUM
  block_res = warp_reduce_md_op<WARP_NUM>(block_res);
  if (local_tid == 0) shared[0] = block_res;
}

If NUM_THREADS = 256, then WARP_NUM = NUM_THREADS / WARP_SIZE = 256 / 32 = 8, so threads with local_tid = 8..31 will read shared[8..31], but shared was declared as shared[WARP_NUM] = shared[8], which leads to an OOB shared memory read.

I think a simple fix is to guard the load from shared and initialize the extra values in the final warp with identity values before the warp reduction.

Concretely, for lanes local_tid >= WARP_NUM we should not read shared[local_tid] . Instead, set:

block_res.m = -FLT_MAX;  
block_res.d = 0.0f;    

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions