Skip to content

Integrate DeepSeek Sparse Attention with Tokamax Flash Attention#3087

Open
RissyRan wants to merge 1 commit intomainfrom
v32_flash_integration
Open

Integrate DeepSeek Sparse Attention with Tokamax Flash Attention#3087
RissyRan wants to merge 1 commit intomainfrom
v32_flash_integration

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 4, 2026

Description

Integrate DSA with Tokamax Flash Attention

  • Leverages make_dynamic_splash_mha to support 3D dynamic index masking within the Flash Attention
  • Moves index_mask reshaping to occur locally within the corresponding dot_product attention call

Tests

  • Added dedicated unit tests for the Flash Attention path. Note that because it requires seq must be a multiple of 128 (TPU num of lanes), new test sequence length has been standardized to 128. Observed that both dot_product and flash attention require normalization to pass at longer sequence length like seq=128.
    • Add unit tests and run with python3 -m unittest tests.unit.deepseek32_vs_reference_test - logs: link
    • Performed a manual diff of the final three entries of the last batch. I do observed 1-2 outliers with index_topk=4 between the dot_product and flash_attention implementations - link
    • Verified that the batch mask indices are also identical across both attention paths - link
    • Also tested with a different topK value on the reference side resulted in unit test failures for both dot_product and flash_attention paths, even when normalization was applied. Given that the failures are consistent across both implementations, utilizing normalization within the unit tests appears to be an acceptable baseline for stability.
  • Observed a ~10% throughput boost (from 74 tflops/s/device to 82 tflops/s/device) compared to the baseline implementation, tested on a small model version of DS v3.2 (2B model size). More benchmarks will be covered in b/469549024.
V3.2 version (2B) - batch=12, seq=4096, FSDP

Per train step:
 Total TFLOPs: 145.89 

# dot_product
I0204 22:40:21.182972 140619320139328 metric_logger.py:179] completed step: 9, seconds: 1.951, TFLOP/s/device: 74.797, Tokens/s/device: 12599.730, total_weights: 98304, loss: 11.982

# flash 
I0204 22:45:41.304875 139758854504000 metric_logger.py:179] completed step: 9, seconds: 1.765, TFLOP/s/device: 82.661, Tokens/s/device: 13924.521, total_weights: 98304, loss: 11.980

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@RissyRan RissyRan changed the title [WIP] Integrate DS V32 with flash integration [WIP] Integrate DS V32 with flash attention Feb 4, 2026
@codecov
Copy link

codecov bot commented Feb 4, 2026

Codecov Report

❌ Patch coverage is 30.76923% with 9 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/attention_op.py 30.76% 8 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@RissyRan RissyRan force-pushed the v32_flash_integration branch 3 times, most recently from 4b7e14c to 908d7bb Compare February 5, 2026 01:01
@RissyRan RissyRan changed the title [WIP] Integrate DS V32 with flash attention [WIP] Integrate DeepSeek Sparse Attention with Tokamax Flash Attention Feb 5, 2026
@RissyRan RissyRan force-pushed the v32_flash_integration branch 2 times, most recently from 202751f to 7b2fe81 Compare February 5, 2026 02:07
@RissyRan RissyRan force-pushed the v32_flash_integration branch from 7b2fe81 to ffdc797 Compare February 5, 2026 02:21
@RissyRan RissyRan changed the title [WIP] Integrate DeepSeek Sparse Attention with Tokamax Flash Attention Integrate DeepSeek Sparse Attention with Tokamax Flash Attention Feb 5, 2026
@github-actions
Copy link

github-actions bot commented Feb 5, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request successfully integrates DeepSeek Sparse Attention with Tokamax Flash Attention, which provides a notable performance improvement. The changes are well-structured, and the logic for handling the dynamic sparse mask within the flash attention kernel appears correct.

🔍 General Feedback

  • The unit tests have been significantly improved to cover both the dot_product and flash attention paths, ensuring parity between the two implementations. The parameterization of the tests is a good addition.
  • The movement of the index_mask reshaping logic into the apply_attention_dot function is a clean refactoring.
  • The updates to the configuration validation and compile tests are thorough and ensure the new attention mechanism is properly supported.

Overall, this is a high-quality contribution that enhances the performance of sparse attention. The suggestions provided are minor and aimed at improving code clarity and test readability.

Copy link
Collaborator Author

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-cli /review

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions bot Feb 6, 2026
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request successfully integrates DeepSeek Sparse Attention with Tokamax Flash Attention, which is a valuable performance enhancement. The implementation is solid, and the added tests are comprehensive, ensuring the correctness of the new attention path.

🔍 General Feedback

  • The changes are well-structured and easy to follow.
  • The extension of the test suite to cover both dot_product and flash attention with various configurations is a great addition and significantly improves the robustness of the implementation.
  • The PR description is clear and provides good context, including performance numbers.

Comment on lines 912 to +947

@parameterized.named_parameters(
{"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2},
{"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8},
{
"testcase_name": "seq_len=2 (index_topk=4) & dot_product",
"attention": "dot_product",
"seq_len": 2,
"index_topk": 4,
},
{
"testcase_name": "seq_len=8 (index_topk=4) & dot_product",
"attention": "dot_product",
"seq_len": 8,
"index_topk": 4,
},
{
"testcase_name": "seq_len=128 (index_topk=4) & dot_product",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
"testcase_name": "seq_len=128 (index_topk=128) & dot_product",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 128,
"check_norm": True,
},
{
"testcase_name": "seq_len=128 (index_topk=4) & flash",
"attention": "flash",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The test case names are a bit long. Consider a more concise format for better readability in test output, like attention_s<seq_len>_k<top_k>.

Suggested change
@parameterized.named_parameters(
{"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2},
{"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8},
{
"testcase_name": "seq_len=2 (index_topk=4) & dot_product",
"attention": "dot_product",
"seq_len": 2,
"index_topk": 4,
},
{
"testcase_name": "seq_len=8 (index_topk=4) & dot_product",
"attention": "dot_product",
"seq_len": 8,
"index_topk": 4,
},
{
"testcase_name": "seq_len=128 (index_topk=4) & dot_product",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
"testcase_name": "seq_len=128 (index_topk=128) & dot_product",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 128,
"check_norm": True,
},
{
"testcase_name": "seq_len=128 (index_topk=4) & flash",
"attention": "flash",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
{
"testcase_name": "dot_product_s2_k4",
"attention": "dot_product",
"seq_len": 2,
"index_topk": 4,
},
{
"testcase_name": "dot_product_s8_k4",
"attention": "dot_product",
"seq_len": 8,
"index_topk": 4,
},
{
"testcase_name": "dot_product_s128_k4",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
"testcase_name": "dot_product_s128_k128",
"attention": "dot_product",
"seq_len": 128,
"index_topk": 128,
"check_norm": True,
},
{
"testcase_name": "flash_s128_k4",
"attention": "flash",
"seq_len": 128,
"index_topk": 4,
"check_norm": True,
},
{
"testcase_name": "flash_s128_k128",
"attention": "flash",
"seq_len": 128,
"index_topk": 128,
"check_norm": True,
},

{
"testcase_name": "seq_len=128 (index_topk=128) & flash",
"attention": "flash",
"seq_len": 128,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 The docstring for this test function could be more accurate. It's the test itself, not a helper function.

Suggested change
"seq_len": 128,
"""Verifies JAX MLA output against the PyTorch reference implementation."""

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants