-
Notifications
You must be signed in to change notification settings - Fork 286
add OoTPatchedModuleFusedSDPA #2361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v3.6.post.oot
Are you sure you want to change the base?
Conversation
|
@yiliu30 @Wei-Lin-Intel @czhu15 |
| causal_res = self.fp8_fsdpa_fwd(q, causal_k, causal_v, causal_mask, dropout_p, scale, False, softmax_mode) | ||
| causal_out, causal_m, causal_linv = (gqa_output_reshape(x) if gqa else x for x in causal_res[:3]) | ||
| causal_m = causal_m.to(torch.float32) | ||
| causal_linv = causal_linv.to(torch.float32) * (128.0 if softmax_mode != "fp32" else 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only fast mode requires * 128.0, for fp32 and None modes, scale is 1.0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx!
yiliu30
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Please add usage in PR desc :)
Done. |
|
So this solution only truncate the prefix, not padding? |
Motivation
Current implementation in vLLM will call FusedSDPA with a huge attention mask, which leads to bad performance. This PR introduced three implementations to get better performance.
Usage
Three environment variables are introduced to control the implementation:
PT_HPU_QKV_SLICE_SEQ_LEN_THLD:int, the threshold forkv_len(=q_len+prefix_len) to apply the implementations, defaults to4096.PT_HPU_QKV_SLICE_CHUNK_SIZE:int, chunk size for the slicing in the implementation, defaults toPT_HPU_QKV_SLICE_SEQ_LEN_THLD.PT_HPU_QKV_SLICE_IMPL:strwith choices in['split_kv', 'slice_causal', 'slice_qkv'], used to select the implementations, defaults toslice_qkv.Implementations
For a FusedSDPA with

q_len=11525andprefix_len=10752. The lengths will be padded and be truncated respectively toq_len=16384andprefix_len=8192before calling the FusedSDPA. The full attention mask is shown bellow, in which the Gray parts steads for the values to be masked out.Notation
The following images include rectangles with three colors:
rgb(255,0,0):is_causal=Falseandattn_mask is not Nonergb(255,255,0):is_causal=Trueandattn_mask=Nonergb(255,0,255):is_causal=Falseandattn_mask=NoneThe original implementation
The original implementation pass the full attention mask and set

is_causal=Falseandvalid_seq_len=None, which results in bad TPC/MME pipeline. The implementation could be visualized as the following image.The
SplitKVimplementationThis implementation call the FusedSPDA twice for the prefix part and causal part respectively. And do not pass

attn_maskfor the prefix part thus gives better performance.The
SliceCausalimplementationThis implementation further slice the causal part into smaller chunks as illustrated in the following image.

The
SliceQKVimplementationThis implementation further slice the prefix part as shown below.
