Skip to content

Conversation

@yangulei
Copy link

@yangulei yangulei commented Dec 11, 2025

Motivation

Current implementation in vLLM will call FusedSDPA with a huge attention mask, which leads to bad performance. This PR introduced three implementations to get better performance.

Usage

Three environment variables are introduced to control the implementation:

  • PT_HPU_QKV_SLICE_SEQ_LEN_THLD: int, the threshold for kv_len (=q_len+prefix_len) to apply the implementations, defaults to 4096.
  • PT_HPU_QKV_SLICE_CHUNK_SIZE: int, chunk size for the slicing in the implementation, defaults to PT_HPU_QKV_SLICE_SEQ_LEN_THLD.
  • PT_HPU_QKV_SLICE_IMPL: str with choices in ['split_kv', 'slice_causal', 'slice_qkv'], used to select the implementations, defaults to slice_qkv.

Implementations

For a FusedSDPA with q_len=11525 and prefix_len=10752. The lengths will be padded and be truncated respectively to q_len=16384 and prefix_len=8192 before calling the FusedSDPA. The full attention mask is shown bellow, in which the Gray parts steads for the values to be masked out.
attention_mask

Notation

The following images include rectangles with three colors:

  • rgb(255,0,0): is_causal=False and attn_mask is not None
  • rgb(255,255,0): is_causal=True and attn_mask=None
  • rgb(255,0,255): is_causal=False and attn_mask=None

The original implementation

The original implementation pass the full attention mask and set is_causal=False and valid_seq_len=None, which results in bad TPC/MME pipeline. The implementation could be visualized as the following image.
slicing_schedule_fp8_origin

The SplitKV implementation

This implementation call the FusedSPDA twice for the prefix part and causal part respectively. And do not pass attn_mask for the prefix part thus gives better performance.
slicing_schedule_fp8_SplitKV

The SliceCausal implementation

This implementation further slice the causal part into smaller chunks as illustrated in the following image.
slicing_schedule_fp8_SliceCausal

The SliceQKV implementation

This implementation further slice the prefix part as shown below.
slicing_schedule_fp8_SliceQKV

@yiliu30 yiliu30 added the OoT label Dec 15, 2025
@yangulei
Copy link
Author

@yiliu30 @Wei-Lin-Intel @czhu15
Please help to review, thanks!

causal_res = self.fp8_fsdpa_fwd(q, causal_k, causal_v, causal_mask, dropout_p, scale, False, softmax_mode)
causal_out, causal_m, causal_linv = (gqa_output_reshape(x) if gqa else x for x in causal_res[:3])
causal_m = causal_m.to(torch.float32)
causal_linv = causal_linv.to(torch.float32) * (128.0 if softmax_mode != "fp32" else 1.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only fast mode requires * 128.0, for fp32 and None modes, scale is 1.0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx!

Copy link
Contributor

@yiliu30 yiliu30 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Please add usage in PR desc :)

@yangulei
Copy link
Author

LGTM Please add usage in PR desc :)

Done.

@yangulei yangulei changed the title [Draft] add OoTPatchedModuleFusedSDPA add OoTPatchedModuleFusedSDPA Dec 17, 2025
@czhu15
Copy link

czhu15 commented Dec 17, 2025

So this solution only truncate the prefix, not padding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants