Skip to content

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 20, 2026

To unblock #5772

@github-actions
Copy link

github-actions bot commented Jan 20, 2026

Review updated until commit 5e114fd

Description

  • Refactor attention output layout handling with new relayoutByTensorView function

  • Add dimension validation checks in meta tensor computation

  • Improve code organization by moving functions to anonymous namespace

  • Add comprehensive test for flash attention stride order validation

  • Fix typos and enhance comments in allocation order inference

Changes walkthrough

Relevant files
Enhancement
composite_nodes.cpp
Refactor attention output layout handling                               

csrc/ir/composite_nodes.cpp

  • Move _scaled_dot_product_attention_meta to anonymous namespace and
    rename to scaled_dot_product_attention_meta
  • Add dimension validation checks (4D tensors required)
  • Implement new relayoutByTensorView function for proper layout
    conversion
  • Refactor attention output handling to use the new relayout function
  • Remove unnecessary include
  • Add explanatory comments about layout requirements
  • +47/-44 
    Documentation
    allocation_order_inference.cpp
    Fix typos and improve documentation                                           

    csrc/preseg_passes/allocation_order_inference.cpp

  • Fix typo: "SPDA" corrected to "SDPA"
  • Enhance comments explaining approach and fragility
  • Add reference to test verifying flash attention API expectations
  • Minor formatting improvements
  • +5/-3     
    Tests
    test_sdpa_node.cpp
    Add flash attention stride order validation test                 

    tests/cpp/test_sdpa_node.cpp

  • Add GoogleTest mocking framework include
  • Implement FlashAttentionStrideOrder test case
  • Verify attention output sizes and contiguous stride requirements
  • Test backward pass gradient sizes and stride order
  • Add using testing::ElementsAre for better assertions
  • +61/-3   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Layout handling correctness

    The new relayoutByTensorView function is critical for ensuring output tensors match the expected allocation domain. The function performs layout canonicalization and permutation operations. Verify that the error handling and permutation calculations are robust, especially the assumption that the allocation domain is always a permutation of the logical domain.

    at::Tensor relayoutByTensorView(at::Tensor t, TensorView* tv) {
      const std::optional<Layout> layout = canonicalizeLayout(tv);
      NVF_CHECK(layout.has_value(), "Failed to canonicalize output layout of ", tv);
      const std::optional<std::vector<int64_t>> permutation =
          ir_utils::computePermutation(
              tv->getLogicalDomain(), layout->allocation_domain());
      NVF_ERROR(
          permutation.has_value(),
          "The allocation domain of a canonicalized layout of ",
          tv,
          " is not a permutation of its logical domain.");
      return t.permute(*permutation)
          .contiguous()
          .permute(ir_utils::inversePermutation(*permutation));
    }
    Test coverage completeness

    The new FlashAttentionStrideOrder test verifies stride order for flash attention but only tests the forward pass. Consider whether additional test cases are needed to cover edge cases like different tensor layouts, non-contiguous inputs, or various attention configurations (causal vs non-causal).

    TEST_F(SDPATest, FlashAttentionStrideOrder) {
      NVFUSER_TEST_CUDA_ARCH_GUARD(8, 0);
    
      at::Tensor qkv =
          at::randn({n, s, h * e * 3}, at::dtype(at::kHalf).device(at::kCUDA));
      std::vector<at::Tensor> splits =
          at::chunk(qkv.view({n, s, h, e * 3}), /*chunks=*/3, /*dim=*/-1);
      ASSERT_EQ(splits.size(), 3);
      at::Tensor q = splits.at(0).permute({0, 2, 1, 3});
      at::Tensor k = splits.at(1).permute({0, 2, 1, 3});
      at::Tensor v = splits.at(2).permute({0, 2, 1, 3});
    
      auto outs = at::_scaled_dot_product_flash_attention(q, k, v);
    
      at::Tensor attn_out = std::get<0>(outs);
      at::Tensor logsumexp = std::get<1>(outs);
      at::Tensor cum_seq_q = std::get<2>(outs);
      at::Tensor cum_seq_k = std::get<3>(outs);
      at::SymInt max_q = std::get<4>(outs);
      at::SymInt max_k = std::get<5>(outs);
      at::Tensor philox_seed = std::get<6>(outs);
      at::Tensor philox_offset = std::get<7>(outs);
    
      EXPECT_THAT(attn_out.sizes(), ElementsAre(n, h, s, e));
      EXPECT_TRUE(attn_out.transpose(1, 2).is_contiguous()) << attn_out.strides();
    
      auto [q_grad, k_grad, v_grad] =
          at::_scaled_dot_product_flash_attention_backward_symint(
              /*grad_output=*/attn_out, // This test merely verifies sizes and
                                        // strides so it's fine to reuse `attn_out`
                                        // as `grad_output`
              q,
              k,
              v,
              attn_out,
              logsumexp,
              cum_seq_q,
              cum_seq_k,
              max_q,
              max_k,
              /*dropout_p=*/0.0,
              /*is_causal=*/false,
              philox_seed,
              philox_offset);
    
      for (at::Tensor grad : {q_grad, k_grad, v_grad}) {
        EXPECT_THAT(grad.sizes(), ElementsAre(n, h, s, e));
        EXPECT_TRUE(grad.transpose(1, 2).is_contiguous()) << grad.strides();
      }
    }

    Test failures (partial, pipeline still running)

    • (Medium, 1) Thunder vs. Torch scalar mismatch in nanogpt autograd test (thunder/tests/test_networks.py) on dlcluster_h100

      Test Name H100 Source
      thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue marked this pull request as ready for review January 20, 2026 16:40
    @wujingyue wujingyue requested a review from zasdfgbnm January 20, 2026 16:40
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 20, 2026

    Greptile Summary

    This PR refactors SdpaFwdOp::evaluate to compute meta tensors that respect allocation domains, which is part of the work to unblock issue #5772.

    Key Changes:

    • Extracted layout conversion logic into relayoutByTensorView helper function that applies tensor relayout based on TensorView's allocation domain
    • Unified relayout handling: previously only at::_scaled_dot_product_attention_math path was relayouted, now both backends (math and flash attention) go through relayoutByTensorView
    • Moved scaled_dot_product_attention_meta (renamed from _scaled_dot_product_attention_meta) from sdpa_meta namespace into anonymous namespace
    • Added dimension validation checks (NVF_ERROR_EQ for 4D tensors) in meta function
    • Improved variable naming consistency (outattn_out)
    • Added test FlashAttentionStrideOrder to verify flash attention API stride behavior matches allocation order inference assumptions

    Impact:
    The refactoring correctly applies relayout logic to both attention backends. For flash attention, the relayout's .contiguous() call will be a no-op since the output already has the expected stride order. This change makes the code more maintainable and ensures consistent handling of allocation domains across both code paths.

    Confidence Score: 4/5

    • This PR is safe to merge with a solid refactoring that improves code organization
    • Score reflects well-structured refactoring with good test coverage. The changes extract common logic, add validation, and unify behavior across code paths. The new test verifies flash attention API assumptions. Minor deduction for complexity of allocation domain logic and the fact that relayout now applies to both paths (though flash attention's case is a no-op).
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/ir/composite_nodes.cpp Refactored SdpaFwdOp::evaluate to extract relayout logic into relayoutByTensorView helper function, moved scaled_dot_product_attention_meta into anonymous namespace, and unified relayout handling for both attention backends (math and flash). Added dimension checks and improved variable naming from out to attn_out.
    csrc/preseg_passes/allocation_order_inference.cpp Minor documentation improvement - updated comments from "Propagate" to "Propagates" (consistent verb form) and added reference to test that verifies flash attention API behavior matches expectations.
    tests/cpp/test_sdpa_node.cpp Added FlashAttentionStrideOrder test to verify flash attention API behavior matches allocation order inference expectations. Includes gmock headers and new ATen operation headers. Test validates output stride order for both forward and backward passes.

    Sequence Diagram

    sequenceDiagram
        participant Client
        participant SdpaFwdOp
        participant Meta as scaled_dot_product_attention_meta
        participant Math as at::_scaled_dot_product_attention_math
        participant Flash as at::_scaled_dot_product_flash_attention
        participant Helper as relayoutByTensorView
    
        Client->>SdpaFwdOp: evaluate(inputs)
        SdpaFwdOp->>SdpaFwdOp: flattenBatchDims(query, key, value)
        
        alt query.is_meta()
            SdpaFwdOp->>Meta: scaled_dot_product_attention_meta(query, value)
            Meta->>Meta: validate 4D tensors
            Meta->>Meta: compute output shapes
            Meta-->>SdpaFwdOp: (output, logsumexp, philox_seed, philox_offset)
        else attn_bias.defined()
            SdpaFwdOp->>Math: at::_scaled_dot_product_attention_math(...)
            Math-->>SdpaFwdOp: (attn_out, logsumexp)
            Note over SdpaFwdOp: attn_out is contiguous
        else flash attention
            SdpaFwdOp->>Flash: at::_scaled_dot_product_flash_attention(...)
            Flash-->>SdpaFwdOp: (attn_out, logsumexp, ...)
            Note over SdpaFwdOp: attn_out has query's stride order
        end
        
        alt batch_dims.size() > 1
            SdpaFwdOp->>SdpaFwdOp: unflattenBatchDim(output, logsumexp)
        end
        
        SdpaFwdOp->>Helper: relayoutByTensorView(output, attn_out())
        Helper->>Helper: canonicalizeLayout(tv)
        Helper->>Helper: computePermutation(logical, allocation)
        Helper->>Helper: permute → contiguous → inverse permute
        Helper-->>SdpaFwdOp: relayouted output
        
        SdpaFwdOp-->>Client: (output, logsumexp, philox_seed, philox_offset)
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    "Flash attention requires the last dimension to be a multiple of 8, but "
    "got: ",
    last_dim_size);
    // Conmpute scale using original size of last dimension
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    // Conmpute scale using original size of last dimension

    @zasdfgbnm
    Copy link
    Collaborator

    zasdfgbnm commented Jan 20, 2026

    Are you suggesting the following?

    After preseg pass, the fusion IR will have some contiguity on each TensorView. We should leave this contiguity as-is, and if the expr-eval segment does not return a tensor with that contiguity, we need to modify that tensor to align with that contiguity.

    Would this add extra copy-kernel? Are we so rich that we just pay for extra copy-kernels?

    @wujingyue
    Copy link
    Collaborator Author

    After preseg pass, the fusion IR will have some contiguity on each TensorView. We should leave this contiguity as-is, and if the expr-eval segment does not return a tensor with that contiguity, we need to modify that tensor to align with that contiguity.

    Yes

    Would this add extra copy-kernel?

    I hope not. I want preseg (e.g. MarkAliasesPrepare and AllocationOrderInference) to set the allocation domain and contiguity that matches the behavior of ExpressionEvaluator. If an expr-eval segment returns a different "layout", I consider it a bug that should be fixed by making the two match without extra data copy.

    Thanks for clarifying -- now I see your motivation to "reset" contiguities based on ExpressionEvaluator. I haven't thought too much about this approach. The first thing that will fail is that ReorderShardedAxis is allocation-sensitive -- at this moment, we require the scattered/gathered axis of a communication I/O to be outermost for NCCL constraints. If the allocation of such a TensorView gets changed after segmentation, we may fail to lower a communication segment to a NCCL call or may have to do extra data copy around the NCCL call that could be fused with upstream or downstream kernel segments.

    cc @zasdfgbnm

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue merged commit 6037ef6 into main Jan 21, 2026
    57 of 59 checks passed
    @wujingyue wujingyue deleted the wjy/sdpa-meta branch January 21, 2026 04:34
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants