Skip to content

[WIP] exp with SDPA instead of einsum#63

Open
ramithuh wants to merge 3 commits into
all_pair_embedfrom
sdpa_all_pair_embed
Open

[WIP] exp with SDPA instead of einsum#63
ramithuh wants to merge 3 commits into
all_pair_embedfrom
sdpa_all_pair_embed

Conversation

@ramithuh
Copy link
Copy Markdown
Collaborator

@ramithuh ramithuh commented Feb 27, 2026

TL;DR: Looks like we can rely on SDPA even with the pair bias track. This will only benefit on longer sequences / more pair-biased attention layers.

Todos:

  • see if performance improvement is tied to a specific pytorch version
  • see which sequence lengths might benefit from this
  • is self.proj_z(z) anyway dominating the runtime so thsi does not matter?

More details:

  • When a dense pair bias is given to scaled_dot_product_attention it will fall to Efficient Attention, and it should be competitive with flashattention.

More info about the backends that exist in F.scaled_dot_product_attention can be found in the these slides


I ran two quick experiments to check if the implementations are providing the same answer wandb report comparing the two runs. I did not see a difference in training speed at this scale though :|

@ramithuh
Copy link
Copy Markdown
Collaborator Author

ramithuh commented Feb 27, 2026

image We need to optimize `self.proj_z(z)` because that's the bottleneck. SDPA / einsum debate takes only small fraction in the total time since `self.proj_z(z)` dominates:

some directions:

  1. don't have different bias values for each head.
  2. use flex attention to maybe fuse this computation

in general: training with +trainer.precision=bf16-mixed helps to speed up transformers

improvement is minor, but helpful:
- 1.33 it/s to 1.41 it/s on GB10 with the following command

`python routines/train.py num_workers=2 name=pair_embed_test_SDPA_minimal_proj_z task_group=fixed_protein_cond_a_plinder_only edges_per_batch=125000 max_steps=800000 prot_pos_std=0.5 +trainer.log_every_n_steps=1 wandb_conf.project=omtra_per +trainer.num_sanity_val_steps=0` (testing without pharmit for now)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant