Integrate DeepSeek Sparse Attention with Tokamax Flash Attention#3087
Integrate DeepSeek Sparse Attention with Tokamax Flash Attention#3087
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
4b7e14c to
908d7bb
Compare
202751f to
7b2fe81
Compare
7b2fe81 to
ffdc797
Compare
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request successfully integrates DeepSeek Sparse Attention with Tokamax Flash Attention, which provides a notable performance improvement. The changes are well-structured, and the logic for handling the dynamic sparse mask within the flash attention kernel appears correct.
🔍 General Feedback
- The unit tests have been significantly improved to cover both the
dot_productandflashattention paths, ensuring parity between the two implementations. The parameterization of the tests is a good addition. - The movement of the
index_maskreshaping logic into theapply_attention_dotfunction is a clean refactoring. - The updates to the configuration validation and compile tests are thorough and ensure the new attention mechanism is properly supported.
Overall, this is a high-quality contribution that enhances the performance of sparse attention. The suggestions provided are minor and aimed at improving code clarity and test readability.
RissyRan
left a comment
There was a problem hiding this comment.
@gemini-cli /review
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request successfully integrates DeepSeek Sparse Attention with Tokamax Flash Attention, which is a valuable performance enhancement. The implementation is solid, and the added tests are comprehensive, ensuring the correctness of the new attention path.
🔍 General Feedback
- The changes are well-structured and easy to follow.
- The extension of the test suite to cover both
dot_productandflashattention with various configurations is a great addition and significantly improves the robustness of the implementation. - The PR description is clear and provides good context, including performance numbers.
|
|
||
| @parameterized.named_parameters( | ||
| {"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2}, | ||
| {"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8}, | ||
| { | ||
| "testcase_name": "seq_len=2 (index_topk=4) & dot_product", | ||
| "attention": "dot_product", | ||
| "seq_len": 2, | ||
| "index_topk": 4, | ||
| }, | ||
| { | ||
| "testcase_name": "seq_len=8 (index_topk=4) & dot_product", | ||
| "attention": "dot_product", | ||
| "seq_len": 8, | ||
| "index_topk": 4, | ||
| }, | ||
| { | ||
| "testcase_name": "seq_len=128 (index_topk=4) & dot_product", | ||
| "attention": "dot_product", | ||
| "seq_len": 128, | ||
| "index_topk": 4, | ||
| "check_norm": True, | ||
| }, | ||
| { | ||
| "testcase_name": "seq_len=128 (index_topk=128) & dot_product", | ||
| "attention": "dot_product", | ||
| "seq_len": 128, | ||
| "index_topk": 128, | ||
| "check_norm": True, | ||
| }, | ||
| { | ||
| "testcase_name": "seq_len=128 (index_topk=4) & flash", | ||
| "attention": "flash", | ||
| "seq_len": 128, | ||
| "index_topk": 4, | ||
| "check_norm": True, | ||
| }, | ||
| { |
There was a problem hiding this comment.
🟢 The test case names are a bit long. Consider a more concise format for better readability in test output, like attention_s<seq_len>_k<top_k>.
| @parameterized.named_parameters( | |
| {"testcase_name": "seq_len=2 (index_topk=4)", "seq_len": 2}, | |
| {"testcase_name": "seq_len=8 (index_topk=4)", "seq_len": 8}, | |
| { | |
| "testcase_name": "seq_len=2 (index_topk=4) & dot_product", | |
| "attention": "dot_product", | |
| "seq_len": 2, | |
| "index_topk": 4, | |
| }, | |
| { | |
| "testcase_name": "seq_len=8 (index_topk=4) & dot_product", | |
| "attention": "dot_product", | |
| "seq_len": 8, | |
| "index_topk": 4, | |
| }, | |
| { | |
| "testcase_name": "seq_len=128 (index_topk=4) & dot_product", | |
| "attention": "dot_product", | |
| "seq_len": 128, | |
| "index_topk": 4, | |
| "check_norm": True, | |
| }, | |
| { | |
| "testcase_name": "seq_len=128 (index_topk=128) & dot_product", | |
| "attention": "dot_product", | |
| "seq_len": 128, | |
| "index_topk": 128, | |
| "check_norm": True, | |
| }, | |
| { | |
| "testcase_name": "seq_len=128 (index_topk=4) & flash", | |
| "attention": "flash", | |
| "seq_len": 128, | |
| "index_topk": 4, | |
| "check_norm": True, | |
| }, | |
| { | |
| { | |
| "testcase_name": "dot_product_s2_k4", | |
| "attention": "dot_product", | |
| "seq_len": 2, | |
| "index_topk": 4, | |
| }, | |
| { | |
| "testcase_name": "dot_product_s8_k4", | |
| "attention": "dot_product", | |
| "seq_len": 8, | |
| "index_topk": 4, | |
| }, | |
| { | |
| "testcase_name": "dot_product_s128_k4", | |
| "attention": "dot_product", | |
| "seq_len": 128, | |
| "index_topk": 4, | |
| "check_norm": True, | |
| }, | |
| { | |
| "testcase_name": "dot_product_s128_k128", | |
| "attention": "dot_product", | |
| "seq_len": 128, | |
| "index_topk": 128, | |
| "check_norm": True, | |
| }, | |
| { | |
| "testcase_name": "flash_s128_k4", | |
| "attention": "flash", | |
| "seq_len": 128, | |
| "index_topk": 4, | |
| "check_norm": True, | |
| }, | |
| { | |
| "testcase_name": "flash_s128_k128", | |
| "attention": "flash", | |
| "seq_len": 128, | |
| "index_topk": 128, | |
| "check_norm": True, | |
| }, |
| { | ||
| "testcase_name": "seq_len=128 (index_topk=128) & flash", | ||
| "attention": "flash", | ||
| "seq_len": 128, |
There was a problem hiding this comment.
🟢 The docstring for this test function could be more accurate. It's the test itself, not a helper function.
| "seq_len": 128, | |
| """Verifies JAX MLA output against the PyTorch reference implementation.""" |
Description
Integrate DSA with Tokamax Flash Attention
make_dynamic_splash_mhato support 3D dynamic index masking within the Flash AttentionTests
python3 -m unittest tests.unit.deepseek32_vs_reference_test- logs: linkChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.