[Draft] Support Flash Attention#501
Draft
cmathw wants to merge 5 commits intoTransformerLensOrg:devfrom
Draft
Conversation
Collaborator
|
Is there any update on this, and any idea on what can be done to get the tolerance to an acceptable level? |
Contributor
Author
I haven't taken a further look yet, is this something currently blocking other features? |
Collaborator
|
Nope, I have just been going through PRs and closing out anything that can be closed, and helping get anything else closed out. If you need some help with this, let me know. If I have time to help you out, I would be happy to. |
|
fails with llama-3 likely a group query attention issue, since llama-2 doesn't have the same error. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Resolves #378 by adding support for
torch.nn.functional.scaled_dot_product_attentionas found here. This implementation includes FlashAttention-2, as well as, two other alternative (potentially faster) attention implementations. PyTorch attempts to automatically select the most optimal implementation based on inputs. Thank you @alan-cooney for recommending this implementation!Currently still in draft because the tolerances between model that uses a fast attention implementation and one that does not are a bit high. This is likely due to the fact
scaled_dot_product_attentionrequires casting to float16 (which we then cast back to float32 after doing the fused attention). I will look into this further though and see if there is an improvement to be made here.Type of change
Please delete options that are not relevant.
Checklist: