Skip to content

Conversation

@huanghua1994
Copy link

The original implementation does not support qk_head_dim != v_head_dim, which is needed in Multi-head Latent Attention. Also fix some test code logic.

Description

The original implementation does not support qk_head_dim != v_head_dim, which is needed in Multi-head Latent Attention. Problem sizes in samples/AttentionFMHA.py are updated s.t. qk_head_dim != v_head_dim and q_num_head != kv_num_head to test a generic GQA case. Parameters and the way calling PyTorch scale_dot_product_attention are also updated to avoid being unable to find a working backend.

All tests have passed locally on a B200.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

The original implementation does not support qk_head_dim != v_head_dim,
which is needed in Multi-head Latent Attention.

Also fix some test code logic.

Signed-off-by: Hua Huang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant