-
Notifications
You must be signed in to change notification settings - Fork 75
SdpaFwdOp::evaluate computes meta tensors respecting allocation domains. #5848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 5e114fd Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement |
| ||
| Documentation |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Layout handling correctness
relayoutByTensorView function is critical for ensuring output tensors match the expected allocation domain. The function performs layout canonicalization and permutation operations. Verify that the error handling and permutation calculations are robust, especially the assumption that the allocation domain is always a permutation of the logical domain. |
Test failures (partial, pipeline still running)
-
(Medium, 1)
Thunder vs. Torch scalar mismatch in nanogpt autograd test (thunder/tests/test_networks.py) on dlcluster_h100Test Name H100 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌
|
!test |
Greptile SummaryThis PR refactors Key Changes:
Impact: Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Client
participant SdpaFwdOp
participant Meta as scaled_dot_product_attention_meta
participant Math as at::_scaled_dot_product_attention_math
participant Flash as at::_scaled_dot_product_flash_attention
participant Helper as relayoutByTensorView
Client->>SdpaFwdOp: evaluate(inputs)
SdpaFwdOp->>SdpaFwdOp: flattenBatchDims(query, key, value)
alt query.is_meta()
SdpaFwdOp->>Meta: scaled_dot_product_attention_meta(query, value)
Meta->>Meta: validate 4D tensors
Meta->>Meta: compute output shapes
Meta-->>SdpaFwdOp: (output, logsumexp, philox_seed, philox_offset)
else attn_bias.defined()
SdpaFwdOp->>Math: at::_scaled_dot_product_attention_math(...)
Math-->>SdpaFwdOp: (attn_out, logsumexp)
Note over SdpaFwdOp: attn_out is contiguous
else flash attention
SdpaFwdOp->>Flash: at::_scaled_dot_product_flash_attention(...)
Flash-->>SdpaFwdOp: (attn_out, logsumexp, ...)
Note over SdpaFwdOp: attn_out has query's stride order
end
alt batch_dims.size() > 1
SdpaFwdOp->>SdpaFwdOp: unflattenBatchDim(output, logsumexp)
end
SdpaFwdOp->>Helper: relayoutByTensorView(output, attn_out())
Helper->>Helper: canonicalizeLayout(tv)
Helper->>Helper: computePermutation(logical, allocation)
Helper->>Helper: permute → contiguous → inverse permute
Helper-->>SdpaFwdOp: relayouted output
SdpaFwdOp-->>Client: (output, logsumexp, philox_seed, philox_offset)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
csrc/ir/composite_nodes.cpp
Outdated
| "Flash attention requires the last dimension to be a multiple of 8, but " | ||
| "got: ", | ||
| last_dim_size); | ||
| // Conmpute scale using original size of last dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // Conmpute scale using original size of last dimension |
|
Are you suggesting the following?
Would this add extra copy-kernel? Are we so rich that we just pay for extra copy-kernels? |
Yes
I hope not. I want preseg (e.g. MarkAliasesPrepare and AllocationOrderInference) to set the allocation domain and contiguity that matches the behavior of ExpressionEvaluator. If an expr-eval segment returns a different "layout", I consider it a bug that should be fixed by making the two match without extra data copy. Thanks for clarifying -- now I see your motivation to "reset" contiguities based on ExpressionEvaluator. I haven't thought too much about this approach. The first thing that will fail is that ReorderShardedAxis is allocation-sensitive -- at this moment, we require the scattered/gathered axis of a communication I/O to be outermost for NCCL constraints. If the allocation of such a TensorView gets changed after segmentation, we may fail to lower a communication segment to a NCCL call or may have to do extra data copy around the NCCL call that could be fused with upstream or downstream kernel segments. cc @zasdfgbnm |
|
!test |
To unblock #5772