Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jan 9, 2026

Summary

Enable warp-specialized kernels for fusions with symbolic (dynamic) input shapes by passing launch parameters to the lowering pass. This PR enables inner persistent scheduler, inner-outer persistent scheduler will be addressed in a following PR.

Problem

Warp-specialized kernels require static CTA shapes, bdim.x/y/z are fixed at compile time and set in the scheduler heuristics. However, this information was not passed to the lowering pass, causing it failed to derive and validate CTA shapes when fusion inputs have symbolic shapes.

For example, with symbolic size i3, the loop domain contains rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p
The lowering pass cannot derive a static bdim.x from this symbolic expression.

Solution

Pass lparams to the lowering pass.
Primary usage in ParallelDimensionMap::getThreadCountInDim
When bdim.x/y/z are set in the launch parameters, use them to derive the extent of the corresponding loop domain instead of relying on the symbolic extent:

int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) {
  if (!dim_map_.contains(pt)) {
    return 1;
  }
  if (dim_map_.at(pt)->isConstScalar()) {
    return dim_map_.at(pt)->value().as<int64_t>();
  }
  // If dimension is dynamic but we have launch parameters available,
  // use the actual launch parameter value
  if (GpuLower::hasCurrent() &&
      GpuLower::current()->launchParams().hasDim(pt)) {
    return GpuLower::current()->launchParams().getDim(pt);
  }
  // Return -1 for dynamic dimensions when launch parameters are not known,
  // this disables register sharing on dynamic dimensions since we can't
  // guarantee the number of threads is divisible by 128.
  return -1;
}

Note: The generated kernel still uses symbolic shapes to support kernel-reuse when input shapes are changed. See added test TmaPersistentTestF.KernelReuse

@github-actions
Copy link

github-actions bot commented Jan 9, 2026

Review updated until commit fa855e0

Description

  • Enable warp-specialized kernels for dynamic input shapes by passing CTA dimensions to lowering pass

  • Add compile-time CTA shape parameters (bdimx/y/z) to CompileParams struct

  • Modify ParallelDimensionMap to use static CTA dimensions when available for dynamic shapes

  • Update inner persistent scheduler to compute and pass static CTA dimensions

  • Add kernel reuse test to validate symbolic shape support

Changes walkthrough

Relevant files
Enhancement
parallel_dimension_map.cpp
Use compile-time CTA dimensions for dynamic shape handling

csrc/parallel_dimension_map.cpp

  • Modified getThreadCountInDim to use compile-time CTA shape parameters
  • Check bdimx/y/z from GpuLower::current()->compileParams() for dynamic
    dimensions
  • Return actual compile parameter values instead of -1 when available
  • +14/-4   
    normalization_inner_tma.cpp
    Compute and pass static CTA dimensions for warp specialization

    csrc/scheduler/normalization_inner_tma.cpp

  • Add recomputation of bdimx based on persistent batch size
  • Pad bdimx to multiple of warp size for warp specialized version
  • Set compile parameters cparams.bdimx/y/z with computed values
  • +12/-0   
    lower2device.h
    Expose compile parameters for CTA dimension access             

    csrc/device_lower/lower2device.h

  • Add compileParams() method to expose compile parameters
  • Provide access to CTA dimensions from GpuLower
  • +4/-0     
    executor_params.h
    Add CTA dimension parameters to CompileParams                       

    csrc/runtime/executor_params.h

  • Add optional bdimx, bdimy, bdimz parameters to CompileParams struct
  • Update equality operator to include new CTA dimension parameters
  • Support compile-time CTA shape specification
  • +7/-1     
    Tests
    test_persistent_buffer.cpp
    Add dynamic shape tests and kernel reuse validation           

    tests/cpp/test_persistent_buffer.cpp

  • Change makeContigConcreteTensor to makeContigTensor for dynamic shapes
  • Add KernelReuse test validating kernel reuse with different input
    dimensions
  • Test symbolic shape support and compile-time CTA dimension handling
  • +78/-2   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Warp size padding logic

    The padding logic bdimx = bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size; should be verified to correctly handle edge cases, particularly when bdimx is already a multiple of warp_size or when bdimx is very small.

    bdimx =
        bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size;
    Error handling robustness

    The code assumes GpuLower::hasCurrent() is true and calls GpuLower::current() without additional null checks. Consider adding defensive programming to handle cases where GpuLower might not be properly initialized.

    NVF_ERROR(GpuLower::hasCurrent());
    const auto& cparams = GpuLower::current()->compileParams();

    @liqiangxl liqiangxl force-pushed the llu/ws_dynamic_inner_persistent branch from bb9139c to fec714b Compare January 9, 2026 15:46
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl marked this pull request as ready for review January 9, 2026 15:46
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    Overview

    This PR enables warp-specialized kernels for fusions with symbolic (dynamic) input shapes by passing launch parameters to the lowering pass. The change flows launch parameters through the compilation stack to allow the lowering pass to use static CTA shapes when available, rather than relying solely on symbolic expressions.

    Key Changes

    • GpuLower: Now accepts LaunchParams in constructor and stores them as member variable with accessor method
    • Launch Parameter Propagation: Parameters flow from KernelExecutor::compile()CompiledKernel constructor → GpuLower constructor
    • Analysis Improvements:
      • fused_reduction.cpp: Uses launch params for bdimx when available, falling back to const extent
      • parallel_dimension_map.cpp: Checks launch params before returning -1 for dynamic dimensions
    • Test Coverage: New KernelReuse test validates that kernels with dynamic shapes correctly support kernel reuse across different input dimensions

    Architecture

    The implementation properly maintains separation of concerns:

    • Deserialization path correctly passes empty LaunchParams() since it only loads pre-compiled binaries
    • Compilation path passes launch_constraints to enable dynamic shape support
    • All parameter passing is backward compatible with default arguments

    Issue Found

    CRITICAL: Contains debug print statements (std::cout and printTransforms()) in normalization_inner_tma.cpp that must be removed before merging. These are development artifacts left in production code.

    Confidence Score: 2/5

    • This PR has critical issues preventing safe merging - debug print statements are present in scheduler code that will output during kernel compilation.
    • The PR implements a solid architectural approach to enable dynamic shapes in warp-specialized kernels. The core logic is sound: launch parameters flow correctly through the stack (KernelExecutor → CompiledKernel → GpuLower), and the lowering analysis correctly uses these parameters as fallbacks for dynamic dimensions. However, there is a critical blocking issue: debug print statements (std::cout and printTransforms()) are left in normalization_inner_tma.cpp lines 470-471. These are development artifacts that will pollute stdout during normal kernel scheduling. Additionally, the deserialization code path could benefit from clarifying comments. The test coverage is good (new KernelReuse test validates the feature), and most implementation details are correct, but the debug prints must be removed before this can be safely merged to main.
    • csrc/scheduler/normalization_inner_tma.cpp - CRITICAL: Remove debug print statements before merge. All other files are well-implemented.

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/normalization_inner_tma.cpp 1/5 CRITICAL: Contains debug print statements (std::cout and printTransforms) that should be removed before merging. These are development artifacts that will add unnecessary output during production kernel scheduling.
    csrc/device_lower/lower2device.h 5/5 Header changes look correct. Added lparams parameter to GpuLower constructor, added launchParams() accessor method, and added lparams_ member variable. Changes are properly scoped and backward compatible.
    csrc/device_lower/lower2device.cpp 5/5 Constructor implementation correctly initializes lparams_ member from the parameter passed in. Changes are minimal and correctly integrated into the initialization list.
    csrc/device_lower/analysis/fused_reduction.cpp 5/5 Correctly implements the logic to use launch parameters for bdimx when available, falling back to const extent. Uses proper UNINITIALIZED_VAL check and optional value handling.
    csrc/parallel_dimension_map.cpp 5/5 Correctly checks for launch params before returning -1 for dynamic dimensions. Uses hasDim() check before accessing getDim(), following proper defensive programming patterns.
    csrc/runtime/compiled_kernel.h 5/5 Both CompiledKernel constructors correctly accept LaunchParams as new parameter. Changes are properly integrated with appropriate default values.
    csrc/runtime/compiled_kernel.cpp 5/5 Both constructor implementations correctly pass lparams to GpuLower constructor. The second constructor properly delegates to the first constructor with all new parameters.
    csrc/runtime/executor.cpp 4/5 Updated to pass launch_constraints to CompiledKernel constructor in compile path, and LaunchParams() in deserialization path. However, deserialization path passes empty LaunchParams which may lose launch constraint information during deserialization.
    tests/cpp/test_persistent_buffer.cpp 5/5 Changed concrete tensors to makeContigTensor for dynamic shape support. Added comprehensive KernelReuse test that validates kernel reuse with different input shapes - this effectively validates the PR's main feature works correctly.

    Sequence Diagram

    sequenceDiagram
        participant KE as KernelExecutor
        participant CK as CompiledKernel
        participant GL as GpuLower
        participant FR as FusedReduction
        
        rect rgb(200, 220, 255)
        Note over KE,FR: Compilation Path
        KE->>CK: "compile(fusion, launch_constraints)"
        CK->>GL: "GpuLower(fusion, cparams, launch_params)"
        GL->>GL: "Stores lparams_"
        GL->>GL: "run() executes passes"
        GL->>FR: "Analysis phase uses GpuLower::current()"
        FR->>GL: "launchParams().getDim(ParallelType::TIDx)"
        GL-->>FR: "Returns static bdimx value"
        end
        
        rect rgb(200, 200, 220)
        Note over KE,GL: "Deserialization Path"
        KE->>CK: "deserialize() - loads pre-compiled binary"
        CK->>GL: "GpuLower(fusion, cparams, LaunchParams())"
        Note right of GL: "Empty params OK - binary already compiled"
        end
    
    Loading

    Comment on lines 470 to 471
    std::cout << "reduction_tv: " << reduction_tv->toString() << std::endl;
    reduction_tv->printTransforms();
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Debug print statements should be removed before merging. These lines add unnecessary stdout output during kernel scheduling in production code.

    Suggested change
    std::cout << "reduction_tv: " << reduction_tv->toString() << std::endl;
    reduction_tv->printTransforms();
    reduction_tv->axis(reduction_pos + 2)->parallelize(ParallelType::TIDx);
    reduction_tv->axis(reduction_pos + 2)->padToMultipleOfWarp();
    // Create rfactor tensor to separate thread-local reduction from block

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 9, 2026

    Additional Comments (1)

    csrc/runtime/executor.cpp
    Consider adding a comment explaining why empty LaunchParams() is passed during deserialization. Since the kernel is already compiled and only the binary is being loaded, the launch params are not needed for lowering. This clarifies intent for future maintainers.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    @liqiangxl liqiangxl force-pushed the llu/ws_dynamic_inner_persistent branch from 2b617d9 to 0d8fc6e Compare January 9, 2026 19:05
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    This PR enables warp-specialized kernels for fusions with symbolic (dynamic) input shapes by threading launch parameters through the lowering pipeline.

    Key Changes:

    • Modified GpuLower to accept and store LaunchParams alongside CompileParams
    • Enhanced ParallelDimensionMap::getThreadCountInDim() to use static bdim.x/y/z from launch params when dimensions are dynamic
    • Updated CompiledKernel and KernelExecutor to pass launch params to lowering
    • Changed test to use symbolic tensors and added kernel reuse validation

    How It Works:
    The scheduler sets static CTA dimensions (e.g., bdim.x=160) in params->lparams during heuristic computation. These are now passed to GpuLower, allowing the lowering pass to use concrete thread counts even when input shapes are symbolic. The generated kernel still uses symbolic shapes for runtime flexibility and kernel reuse.

    Issue Found:
    Debug print statements on lines 470-471 of csrc/scheduler/normalization_inner_tma.cpp must be removed before merge.

    Confidence Score: 4/5

    • This PR is safe to merge after removing debug print statements
    • The implementation is clean and follows existing patterns. Launch params are properly threaded through the compilation pipeline with appropriate fallbacks. The only issue is leftover debug code that must be removed.
    • csrc/scheduler/normalization_inner_tma.cpp - contains debug print statements that must be removed

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/normalization_inner_tma.cpp 2/5 Added debug print statements (lines 470-471) that must be removed before merging
    csrc/device_lower/lower2device.h 5/5 Added LaunchParams parameter to GpuLower constructor and member field with accessor method - clean implementation
    csrc/parallel_dimension_map.cpp 5/5 Enhanced getThreadCountInDim to use launch params when available for dynamic dimensions - properly handles the case when launch params are not set
    csrc/runtime/executor.cpp 5/5 Passes launch_constraints during compilation and empty LaunchParams() during deserialization - correctly handles both cases
    tests/cpp/test_persistent_buffer.cpp 5/5 Changed test from concrete tensors to symbolic tensors and added kernel reuse test - validates the feature works with dynamic shapes

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Scheduler Heuristics
        participant Executor as KernelExecutor
        participant CK as CompiledKernel
        participant GL as GpuLower
        participant PDM as ParallelDimensionMap
    
        Note over Scheduler: Scheduler determines<br/>launch params (bdim.x/y/z)<br/>for warp-specialized kernel
    
        Scheduler->>Executor: compile(fusion, args, launch_constraints)
        
        Executor->>CK: new CompiledKernel(fusion, compile_params, launch_params)
        
        CK->>GL: new GpuLower(fusion, cparams, lparams)
        Note over GL: Stores lparams as<br/>member variable
        
        GL->>GL: analysis(fusion)
        Note over GL: Builds ParallelDimensionMap<br/>and other lowering info
        
        GL->>PDM: getThreadCountInDim(ParallelType::TIDx)
        
        alt Dimension is const scalar
            PDM-->>GL: Return const value
        else Dimension is dynamic AND lparams available
            PDM->>GL: launchParams().getDim(pt)
            Note over PDM: NEW: Use static bdim.x<br/>from launch params
            PDM-->>GL: Return static thread count
        else Dimension is dynamic AND no lparams
            PDM-->>GL: Return -1 (disables register sharing)
        end
        
        Note over GL: Continues lowering with<br/>correct thread dimensions
    
    Loading

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 9, 2026

    Additional Comments (1)

    csrc/scheduler/normalization_inner_tma.cpp
    debug print statements must be removed before merge

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    Overview

    This PR enables warp-specialized kernels to work with dynamic (symbolic) input shapes by passing LaunchParams to the lowering pass. Previously, the lowering pass couldn't derive static block dimensions (bdim.x/y/z) when fusion inputs had symbolic shapes, because this information wasn't available during lowering.

    Key Changes

    1. Modified GpuLower interface: Added optional LaunchParams parameter to constructor, allowing the lowering pass to access static block dimensions set by the scheduler
    2. Enhanced ParallelDimensionMap: Updated getThreadCountInDim() to use actual launch parameters for dynamic dimensions instead of always returning -1
    3. Updated CompiledKernel: Threads LaunchParams through both constructor variants to GpuLower
    4. KernelExecutor integration: Passes launch_constraints during compilation; uses empty LaunchParams() during deserialization (justified since binary is pre-compiled)
    5. Test coverage: Added comprehensive test validating kernel reuse with dynamic shapes

    Design Rationale

    The solution maintains kernel reusability with symbolic shapes while enabling the scheduler to provide static CTA dimensions. When launch parameters are available (during compilation), the lowering pass can use them to derive static extent for thread-indexed loop domains. When unavailable (during deserialization), the code gracefully falls back to existing behavior.

    Backward Compatibility

    All changes maintain full backward compatibility through default parameter values. Existing code paths continue to work as before.

    Testing

    The new TmaPersistentTestF::KernelReuse test validates that:

    • Kernels are properly reused when input shapes change but launch parameters remain compatible
    • New kernels are created when launch parameters must change
    • The kernel maintains functionality across multiple input shapes

    Confidence Score: 5/5

    • This PR is safe to merge with minimal risk. Changes are well-designed, maintain backward compatibility, and include comprehensive testing.
    • Score reflects: (1) Clean API design with default parameters maintaining backward compatibility, (2) Localized changes with no risky cross-module dependencies, (3) Proper thread-through of parameters across all code paths, (4) Comprehensive test coverage validating the kernel reuse mechanism, (5) Graceful fallback handling in deserialization path, (6) Clear separation of concerns between scheduler constraints and lowering logic.
    • No files require special attention. All changes are well-integrated and properly tested.

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/device_lower/lower2device.h 5/5 Added LaunchParams parameter to GpuLower constructor and added launchParams() getter. Changes are clean and preserve API compatibility with default parameter.
    csrc/device_lower/lower2device.cpp 5/5 Updated GpuLower constructor to accept and store LaunchParams. Implementation correctly passes lparams through initialization list and stores for later access.
    csrc/parallel_dimension_map.cpp 5/5 Enhanced getThreadCountInDim() to use launch parameters when available for dynamic dimensions. Falls back to returning -1 when params unavailable. Improves support for warp-specialized kernels with dynamic shapes.
    csrc/runtime/compiled_kernel.cpp 5/5 Updated CompiledKernel constructors to receive and pass LaunchParams to GpuLower. Implementation correctly threads parameters through all constructor variants.
    csrc/runtime/executor.cpp 4/5 Passes launch_constraints to CompiledKernel in compile path and empty LaunchParams() in deserialization path (justified since kernel is already compiled). Should include comment explaining empty params in deserialization.
    tests/cpp/test_persistent_buffer.cpp 5/5 Added test_TmaPersistentTestF::KernelReuse test validating kernel reuse with dynamic shapes. Changed symbolic tensor creation in existing tests to support dynamic shapes properly.

    Sequence Diagram

    sequenceDiagram
        participant KernelExecutor
        participant CompiledKernel
        participant GpuLower
        participant ParallelDimensionMap
    
        KernelExecutor->>CompiledKernel: construct(launch_params)
        CompiledKernel->>GpuLower: construct(lparams)
        GpuLower->>GpuLower: analysis() - creates ParallelDimensionMap
        ParallelDimensionMap->>GpuLower: reads launch_params via GpuLower::current()
        ParallelDimensionMap->>ParallelDimensionMap: getThreadCountInDim() uses launch_params
        
        Note over GpuLower,ParallelDimensionMap: Dynamic dimensions use launch_params<br/>when available instead of symbolic extent
        
        Note over KernelExecutor,CompiledKernel: Deserialization passes LaunchParams()<br/>because kernel is already compiled
    
    Loading

    Comment on lines 1611 to 1614
    _fusion,
    compile_params,
    LaunchParams(),
    device,
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Consider adding a clarifying comment explaining why an empty LaunchParams() is passed during deserialization. Since the kernel is already compiled and only the binary is being loaded, the launch parameters are not needed for lowering. This would help future maintainers understand the intentional design.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This comment seems to make sense. What do you think? @liqiangxl

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Greptile Overview

    Greptile Summary

    Overview

    This PR enables warp-specialized kernels to work with dynamic (symbolic) input shapes by passing launch parameters (LaunchParams) to the GPU lowering pass. Previously, when fusion inputs had symbolic shapes, the lowering pass couldn't derive static block dimensions (bdim.x/y/z) needed for warp-specialized kernels.

    Key Changes

    Core Mechanism:

    • Added LaunchParams parameter to GpuLower constructor and stored it as a member (lparams_)
    • Exposed launchParams() getter method in GpuLower
    • Updated ParallelDimensionMap::getThreadCountInDim() to consult launch parameters when available: if a ParallelType dimension is set in launch parameters, use that value instead of relying on symbolic extent

    Data Flow:

    • KernelExecutor::compile() now passes launch_constraints to CompiledKernel
    • CompiledKernel constructors accept launch_params and forward to GpuLower
    • During deserialization, empty LaunchParams() is passed (appropriate since kernel is already compiled)

    Testing:

    • Added TmaPersistentTestF.KernelReuse test that validates kernel reuse with different input shapes
    • Test verifies: same kernel binary can be reused when different input dimensions produce the same launch configuration
    • Test also validates that different launch configurations create different kernel runtimes

    Technical Correctness

    • Thread Safety: The code correctly uses GpuLower::hasCurrent() guard before accessing GpuLower::current(). The pattern is used in ParallelDimensionMap constructor which runs during GpuLower::analysis() when LowerGuard is active
    • Backward Compatibility: Fallback behavior is preserved: when launch parameters are unavailable (e.g., during scheduler phase in greedy.cpp), getThreadCountInDim() returns -1 as before
    • Initialization: LaunchParams has proper default constructor and methods (hasDim(), getDim()) used correctly

    Code Quality

    • Changes follow existing patterns in the codebase
    • Parameters are properly threaded through constructor chains
    • Getter methods are const-correct
    • File analyses provided comprehensive coverage of all changes

    The implementation correctly solves the stated problem: warp-specialized kernels can now use launch parameters to determine static CTA shapes even when input shapes are dynamic, while still supporting kernel reuse via symbolic shapes in the generated kernel code.

    Confidence Score: 5/5

    • This PR is safe to merge with high confidence. The changes are well-designed, properly guarded against unsafe access, and thoroughly tested.
    • Score reflects: (1) Proper thread safety with GpuLower::hasCurrent() checks before accessing GpuLower::current(), (2) Correct preservation of fallback behavior for cases where GpuLower is inactive, (3) Simple, straightforward parameter threading through constructors with no complex logic, (4) Comprehensive test coverage validating both kernel reuse and correct behavior with dynamic shapes, (5) No breaking changes to existing APIs (LaunchParams defaults are provided), (6) Clear intent and well-documented through PR description and test cases.
    • No files require special attention. All changes are well-implemented and tested.

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/device_lower/lower2device.h 5/5 Added LaunchParams parameter to GpuLower constructor and added launchParams() getter method. Changes are straightforward and follow existing patterns in the codebase.
    csrc/device_lower/lower2device.cpp 5/5 Updated GpuLower constructor to accept and store LaunchParams. Correctly initializes lparams_ member. Logic is sound and follows existing patterns.
    csrc/parallel_dimension_map.cpp 5/5 Enhanced getThreadCountInDim() to use launch parameters when available. Correctly guards access to GpuLower::current() with hasCurrent() check. Fallback behavior (returning -1) preserves existing behavior when launch params unavailable.
    csrc/runtime/compiled_kernel.h 5/5 Updated CompiledKernel constructor signatures to accept LaunchParams. Changes are purely API modifications that pass parameters through to GpuLower.
    csrc/runtime/compiled_kernel.cpp 5/5 Updated CompiledKernel constructors to accept and forward launch_params to GpuLower. Correctly passes parameters in both primary constructor and delegating constructor.
    csrc/runtime/executor.cpp 4/5 Updated CompiledKernel creation calls to pass launch_constraints during normal compilation and empty LaunchParams() during deserialization. The deserialization pattern is reasonable since the kernel binary is already compiled, but could benefit from a clarifying comment as noted in PR thread.
    tests/cpp/test_persistent_buffer.cpp 5/5 Added new test TmaPersistentTestF.KernelReuse that validates kernel reuse with dynamic input shapes. Test correctly verifies that the same compiled kernel binary can be reused across different input dimensions that produce the same launch configuration.

    Sequence Diagram

    sequenceDiagram
        participant KE as KernelExecutor
        participant CK as CompiledKernel
        participant GL as GpuLower
        participant PDM as ParallelDimensionMap
        participant LP as LaunchParams
    
        KE->>CK: new CompiledKernel(fusion, compile_params,<br/>launch_constraints)
        activate CK
        CK->>GL: make_unique(fusion, compile_params,<br/>launch_params)
        activate GL
        GL->>GL: analysis(fusion) [LowerGuard active]
        activate GL
        GL->>PDM: make_unique(fusion)
        activate PDM
        PDM->>GL: hasCurrent() ✓
        PDM->>LP: hasDim(ParallelType::TIDx) ?
        LP-->>PDM: true/false
        alt Launch Params Available
            LP->>PDM: getDim() → static value
        else Not Available
            PDM->>PDM: return -1
        end
        deactivate PDM
        deactivate GL
        GL->>GL: run() [applies lowering passes]
        deactivate GL
        deactivate CK
    
        Note over GL,PDM: During lowering phase with<br/>launch parameters available,<br/>static CTA shapes can be derived<br/>even with symbolic input shapes
    
    
    Loading

    @liqiangxl liqiangxl requested review from naoyam and rdspring1 January 12, 2026 13:54
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 12, 2026

    Greptile Summary

    This PR enables warp-specialized kernels for fusions with symbolic (dynamic) input shapes by passing compile-time CTA shape information to the lowering pass through CompileParams. Previously, the lowering pass couldn't derive static CTA dimensions from symbolic expressions like rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p, causing warp specialization to fail.

    Key changes:

    • Added bdimx/y/z fields to CompileParams to store compile-time CTA shape
    • Modified ParallelDimensionMap::getThreadCountInDim to use CompileParams values for dynamic dimensions instead of returning -1
    • Scheduler now sets static CTA dimensions in CompileParams when warp specialization conditions are met
    • Test updated to use dynamic tensor shapes and validate kernel reuse behavior
    • Generated kernels still use symbolic shapes to support kernel reuse across different input sizes

    The approach correctly distinguishes between compile-time decisions (using CompileParams) and runtime behavior (using symbolic shapes in generated code).

    Confidence Score: 4/5

    • Safe to merge with minor concerns about edge case handling
    • The core logic is sound and addresses a real limitation with dynamic shapes. The changes are well-structured and include appropriate tests. Score is 4 (not 5) due to: (1) the NVF_ERROR assertion in parallel_dimension_map.cpp:164 being less forgiving than the previous implementation, though this is likely acceptable since GpuLower should always be available in lowering contexts; (2) the bdimx recomputation logic has some complexity that previous reviewers questioned, though it appears intentional for the heuristic.
    • Pay close attention to csrc/parallel_dimension_map.cpp for the assertion behavior and csrc/scheduler/normalization_inner_tma.cpp for the bdimx recomputation logic

    Important Files Changed

    Filename Overview
    csrc/runtime/executor_params.h Added CTA shape fields (bdimx/y/z) to CompileParams and updated equality operator
    csrc/parallel_dimension_map.cpp Modified getThreadCountInDim to use CompileParams for dynamic dimensions instead of returning -1
    csrc/scheduler/normalization_inner_tma.cpp Recomputed bdimx and set CTA shape in CompileParams for warp specialization

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Scheduler (normalization_inner_tma)
        participant Heuristics as InnerNormTmaParams
        participant GpuLower as GpuLower (lowering pass)
        participant ParallelDim as ParallelDimensionMap
        participant CompileParams as CompileParams
    
        Scheduler->>Scheduler: Calculate bdimx, bdimy, bdimz<br/>for warp specialization
        Scheduler->>Heuristics: Set cparams.bdimx/y/z<br/>(static CTA shape)
        Scheduler->>Heuristics: Set lparams (launch params)
        
        Note over Scheduler,Heuristics: For dynamic shapes, scheduler<br/>sets static CTA dimensions<br/>in CompileParams
        
        Heuristics->>GpuLower: Pass CompileParams<br/>during lowering
        
        GpuLower->>ParallelDim: getThreadCountInDim(TIDx/y/z)
        
        alt Dimension is const scalar
            ParallelDim-->>GpuLower: Return constant value
        else Dimension is dynamic
            ParallelDim->>CompileParams: Check cparams.bdimx/y/z.has_value()
            alt CompileParams has static CTA shape
                CompileParams-->>ParallelDim: Return bdimx/y/z value
                ParallelDim-->>GpuLower: Return static CTA dimension
            else CompileParams not set
                ParallelDim-->>GpuLower: Return -1 (disable optimizations)
            end
        end
        
        Note over GpuLower,ParallelDim: Enables register sharing<br/>and warp specialization<br/>for dynamic shapes
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 12, 2026

    Additional Comments (2)

    csrc/scheduler/normalization_inner_tma.cpp
    Debug print statements must be removed before merge. These std::cout statements are for debugging purposes and should not be in production code.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


    csrc/scheduler/normalization_inner_tma.cpp
    Debug print statement must be removed before merge. This std::cout statement is for debugging purposes and should not be in production code.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines +84 to +86
    bdimx = ceilDiv(after_vect, params->persistent_batch_size);
    bdimx =
    bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The bdimx recomputation can break warp specialization. When the while loop exits with bdimx=128, recomputing bdimx = ceilDiv(after_vect, pbs) may produce a different value that fails the bdimx == 128 check at line 103. For example: if after_vect=160 and the loop sets bdimx=128, pbs=2, the recomputation yields ceilDiv(160,2)=80, which rounds to 96 after warp padding, preventing warp specialization despite meeting the original conditions.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines 162 to 167
    // If dimension is dynamic but we have launch parameters available,
    // use the actual launch parameter value
    if (GpuLower::hasCurrent() &&
    GpuLower::current()->launchParams().hasDim(pt)) {
    return GpuLower::current()->launchParams().getDim(pt);
    }
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Consider adding a validation check here to ensure the launch parameter value is valid and matches the expected parallel type dimension. While hasDim() checks if the dimension is set, it doesn't validate that the value is positive and reasonable (e.g., not accidentally UNINITIALIZED_VAL or negative).

    Suggested change
    // If dimension is dynamic but we have launch parameters available,
    // use the actual launch parameter value
    if (GpuLower::hasCurrent() &&
    GpuLower::current()->launchParams().hasDim(pt)) {
    return GpuLower::current()->launchParams().getDim(pt);
    }
    // If dimension is dynamic but we have launch parameters available,
    // use the actual launch parameter value
    if (GpuLower::hasCurrent() &&
    GpuLower::current()->launchParams().hasDim(pt)) {
    int64_t dim = GpuLower::current()->launchParams().getDim(pt);
    NVF_ERROR(dim > 0, "Invalid launch parameter value for ", pt, ": ", dim);
    return dim;
    }

    This ensures we catch any configuration issues early rather than returning invalid values that could cause subtle bugs downstream.

    Comment on lines +81 to +86
    // recompute bimdx since it is used to parallelize the dynamic dimension
    // pad to multiple of warp size to use warp reduction.
    // to use warp specialized version, it should pad to 128 threads.
    bdimx = ceilDiv(after_vect, params->persistent_batch_size);
    bdimx =
    bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The bdimx recalculation logic looks correct, but the relationship between this recalculated value and the earlier loop (lines 75-78) could be clearer. Consider adding a comment explaining why we can't just use the bdimx value from after the loop and why we need to recalculate from scratch.

    The recalculation is necessary because:

    1. The loop may have increased bdimx beyond what's actually needed for the given persistent_batch_size
    2. We need the minimum bdimx that satisfies ceilDiv(after_vect, persistent_batch_size) with proper warp alignment

    Example: if after_vect=160, the loop might set bdimx=128, pbs=2, but recalculation gives bdimx=ceilDiv(160,2)=80→96 (padded to warp). This is more efficient.

    Consider adding this clarifying comment above line 81.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The generated kernel still uses symbolic shapes

    What does this mean? Doesn't GpuLower::current()->launchParams().getDim(pt) return a static size?

    Comment on lines 1611 to 1614
    _fusion,
    compile_params,
    LaunchParams(),
    device,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This comment seems to make sense. What do you think? @liqiangxl

    Fusion* fusion,
    const CompileParams& cparams = CompileParams());
    const CompileParams& cparams = CompileParams(),
    const LaunchParams& lparams = LaunchParams());
    Copy link
    Collaborator

    @naoyam naoyam Jan 14, 2026

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Just a general comment. CompileParams and LaunchParams are meant to be about compilations and kernel launches, respectively. If something is needed for compilation, it should be included in CompileParams rather than passing the parameter struct meant for kernel launches.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Sounds a good point. I revised to add static CTA shape as optional vars to CompileParams, then we don't need to pass LaunchParams to lowering pass.

      // CTA shape known at compile time
      std::optional<int64_t> bdimx = std::nullopt;
      std::optional<int64_t> bdimy = std::nullopt;
      std::optional<int64_t> bdimz = std::nullopt;
    

    @liqiangxl
    Copy link
    Collaborator Author

    The generated kernel still uses symbolic shapes

    What does this mean?

    It means the generated cuda code still uses symbolic/dynamic input size, e.g.

      Tensor<__bfloat, 2, 2> s1;
      s1.data = T0.data;
      s1.logical_size = T0.logical_size;
      s1.alloc_stride = T0.alloc_stride;
      Array<nvfuser_index_t, 2, 1> a2;
      a2 = s1.logical_size;
    

    Doesn't GpuLower::current()->launchParams().getDim(pt) return a static size?

    Yes, it's a static const size. We didn't use it to derive a static input shape. For example rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p, we know bdimx = 128 from launch params. But we didn't further derive i3 using i3/8/5 == 128

    @naoyam
    Copy link
    Collaborator

    naoyam commented Jan 15, 2026

    The generated kernel still uses symbolic shapes

    What does this mean?

    It means the generated cuda code still uses symbolic/dynamic input size, e.g.

      Tensor<__bfloat, 2, 2> s1;
      s1.data = T0.data;
      s1.logical_size = T0.logical_size;
      s1.alloc_stride = T0.alloc_stride;
      Array<nvfuser_index_t, 2, 1> a2;
      a2 = s1.logical_size;
    

    Doesn't GpuLower::current()->launchParams().getDim(pt) return a static size?

    Yes, it's a static const size. We didn't use it to derive a static input shape. For example rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p, we know bdimx = 128 from launch params. But we didn't further derive i3 using i3/8/5 == 128

    So, with this PR, 128 is used instead of the symbolic Val of ceilDiv(( ceilDiv(i3, 8) ), 5) ). Am I understanding correctly? If so, isn't the generated kernel using the static shape of 128?

    @liqiangxl
    Copy link
    Collaborator Author

    liqiangxl commented Jan 15, 2026

    Yes, it's a static const size. We didn't use it to derive a static input shape. For example rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p, we know bdimx = 128 from launch params. But we didn't further derive i3 using i3/8/5 == 128

    So, with this PR, 128 is used instead of the symbolic Val of ceilDiv(( ceilDiv(i3, 8) ), 5) ). Am I understanding correctly? If so, isn't the generated kernel using the static shape of 128?

    Yes, ceilDiv(( ceilDiv(i3, 8) ), 5) ) = 128. when you say static shape of 128, it means the CTA shape is static, not the input tensor shape, e.g. i3 can be 5120, 5120 - 8, 5120 - 16, ...4096+8 (minmum value to reuse kernel depends on heuristics). When i3 = 4096, the heuristics will select a different persistent batch size of 4 and regenerate a new kernel

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0);

    // Helper to get the number of compiled kernel runtimes
    auto numRuntimes = [&executor_cache]() -> size_t {
    Copy link
    Collaborator

    @rdspring1 rdspring1 Jan 15, 2026

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Or create regular function in anonymous namespace?

    Suggested change
    auto numRuntimes = [&executor_cache]() -> size_t {
    auto num_runtimes = [&executor_cache]() -> size_t {

    };

    // Helper to run fusion with given dimensions and return the runtime
    auto runAndValidate = [&](int64_t outer_dim,
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    auto runAndValidate = [&](int64_t outer_dim,
    auto run_and_validate = [&](int64_t outer_dim,

    EXPECT_NE(first_runtime, fourth_runtime);
    EXPECT_EQ(numRuntimes(), 2);

    // Fifth run with different inner dimension - should not reuse the kernel
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

      // Fifth run with different inner dimension should not reuse the kernel.
      // After vectorization of 8, the inner dimension (1280) is 160. After persistent
      // batch size of 2, it is 80, so 80 threads are required. The current heuristic pads
      // to 96 threads and won't use warp specialized version. The warp specialized
      // version requires padding to 128 threads.

    @naoyam
    Copy link
    Collaborator

    naoyam commented Jan 15, 2026

    Yes, it's a static const size. We didn't use it to derive a static input shape. For example rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p, we know bdimx = 128 from launch params. But we didn't further derive i3 using i3/8/5 == 128

    So, with this PR, 128 is used instead of the symbolic Val of ceilDiv(( ceilDiv(i3, 8) ), 5) ). Am I understanding correctly? If so, isn't the generated kernel using the static shape of 128?

    Yes, ceilDiv(( ceilDiv(i3, 8) ), 5) ) = 128. when you say static shape of 128, it means the CTA shape is static, not the input tensor shape, e.g. i3 can be 5120, 5120 - 8, 5120 - 16, ...4096+8 (minmum value to reuse kernel depends on heuristics). When i3 = 4096, the heuristics will select a different persistent batch size of 4 and regenerate a new kernel

    Lowering now uses the launch parameters whenever they are available. Is it guaranteed to be safe without changing how heuristic parameters are matched? https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/reduction_heuristic.h#L187

    @naoyam
    Copy link
    Collaborator

    naoyam commented Jan 15, 2026

    Yes, it's a static const size. We didn't use it to derive a static input shape. For example rthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}p, we know bdimx = 128 from launch params. But we didn't further derive i3 using i3/8/5 == 128

    So, with this PR, 128 is used instead of the symbolic Val of ceilDiv(( ceilDiv(i3, 8) ), 5) ). Am I understanding correctly? If so, isn't the generated kernel using the static shape of 128?

    Yes, ceilDiv(( ceilDiv(i3, 8) ), 5) ) = 128. when you say static shape of 128, it means the CTA shape is static, not the input tensor shape, e.g. i3 can be 5120, 5120 - 8, 5120 - 16, ...4096+8 (minmum value to reuse kernel depends on heuristics). When i3 = 4096, the heuristics will select a different persistent batch size of 4 and regenerate a new kernel

    Lowering now uses the launch parameters whenever they are available. Is it guaranteed to be safe without changing how heuristic parameters are matched? https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/reduction_heuristic.h#L187

    Ah, I remember that's a responsibility of each scheduler. Since the launch parameter is indeed used anyway, it must be safe to use it in the generated kernel. Am I missing anything?

    Copy link
    Collaborator

    @naoyam naoyam left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM

    // check to a point where the launch parameters are known.
    // If dimension is dynamic but we have compile-time CTA shape available,
    // use the actual compile parameter value
    if (GpuLower::hasCurrent()) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    short-circuit

    Suggested change
    if (GpuLower::hasCurrent()) {
    if (!GpuLower::hasCurrent()) {
    return -1;
    }

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    getThreadCountInDim isn't called outside of ParallelDimMap, so can it be an asset like NVF_ERROR(GpuLower::hasCurrent());

    @liqiangxl
    Copy link
    Collaborator Author

    Lowering now uses the launch parameters whenever they are available. Is it guaranteed to be safe without changing how heuristic parameters are matched? https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/reduction_heuristic.h#L187

    Ah, I remember that's a responsibility of each scheduler. Since the launch parameter is indeed used anyway, it must be safe to use it in the generated kernel. Am I missing anything?

    It's safe. The actual launch params are computed using heuristic launch params as constraints. In other words, the kernel must be launched with the same threads/blocks if it was set by the heuristics.

      LaunchParams computeLaunchParams(
          const LaunchParams& launch_constraints,
          ExpressionEvaluator& expr_eval,
          const int64_t warp_size,
          DataType index_dtype);
    

    Following your previous suggestion, I added to optional static CTA shape to compile params, compile params are matched when comparing heuristic parameters.

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl merged commit 4dedf26 into main Jan 15, 2026
    61 checks passed
    @liqiangxl liqiangxl deleted the llu/ws_dynamic_inner_persistent branch January 15, 2026 17:16
    @liqiangxl liqiangxl changed the title pass lparams to lower pass to allow dynamic shapes in inner persistent warp specialized scheduler add optional static CTA shape to cparams to allow dynamic shapes in inner persistent warp specialized scheduler Jan 15, 2026
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants