-
Notifications
You must be signed in to change notification settings - Fork 75
add optional static CTA shape to cparams to allow dynamic shapes in inner persistent warp specialized scheduler #5785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit fa855e0 Description
|
| Relevant files | |||||||||
|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Warp size padding logic
bdimx = bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size; should be verified to correctly handle edge cases, particularly when bdimx is already a multiple of warp_size or when bdimx is very small. |
bb9139c to
fec714b
Compare
|
!test |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Overview
This PR enables warp-specialized kernels for fusions with symbolic (dynamic) input shapes by passing launch parameters to the lowering pass. The change flows launch parameters through the compilation stack to allow the lowering pass to use static CTA shapes when available, rather than relying solely on symbolic expressions.
Key Changes
- GpuLower: Now accepts
LaunchParamsin constructor and stores them as member variable with accessor method - Launch Parameter Propagation: Parameters flow from
KernelExecutor::compile()→CompiledKernelconstructor →GpuLowerconstructor - Analysis Improvements:
fused_reduction.cpp: Uses launch params for bdimx when available, falling back to const extentparallel_dimension_map.cpp: Checks launch params before returning -1 for dynamic dimensions
- Test Coverage: New
KernelReusetest validates that kernels with dynamic shapes correctly support kernel reuse across different input dimensions
Architecture
The implementation properly maintains separation of concerns:
- Deserialization path correctly passes empty
LaunchParams()since it only loads pre-compiled binaries - Compilation path passes
launch_constraintsto enable dynamic shape support - All parameter passing is backward compatible with default arguments
Issue Found
CRITICAL: Contains debug print statements (std::cout and printTransforms()) in normalization_inner_tma.cpp that must be removed before merging. These are development artifacts left in production code.
Confidence Score: 2/5
- This PR has critical issues preventing safe merging - debug print statements are present in scheduler code that will output during kernel compilation.
- The PR implements a solid architectural approach to enable dynamic shapes in warp-specialized kernels. The core logic is sound: launch parameters flow correctly through the stack (KernelExecutor → CompiledKernel → GpuLower), and the lowering analysis correctly uses these parameters as fallbacks for dynamic dimensions. However, there is a critical blocking issue: debug print statements (std::cout and printTransforms()) are left in normalization_inner_tma.cpp lines 470-471. These are development artifacts that will pollute stdout during normal kernel scheduling. Additionally, the deserialization code path could benefit from clarifying comments. The test coverage is good (new KernelReuse test validates the feature), and most implementation details are correct, but the debug prints must be removed before this can be safely merged to main.
- csrc/scheduler/normalization_inner_tma.cpp - CRITICAL: Remove debug print statements before merge. All other files are well-implemented.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/scheduler/normalization_inner_tma.cpp | 1/5 | CRITICAL: Contains debug print statements (std::cout and printTransforms) that should be removed before merging. These are development artifacts that will add unnecessary output during production kernel scheduling. |
| csrc/device_lower/lower2device.h | 5/5 | Header changes look correct. Added lparams parameter to GpuLower constructor, added launchParams() accessor method, and added lparams_ member variable. Changes are properly scoped and backward compatible. |
| csrc/device_lower/lower2device.cpp | 5/5 | Constructor implementation correctly initializes lparams_ member from the parameter passed in. Changes are minimal and correctly integrated into the initialization list. |
| csrc/device_lower/analysis/fused_reduction.cpp | 5/5 | Correctly implements the logic to use launch parameters for bdimx when available, falling back to const extent. Uses proper UNINITIALIZED_VAL check and optional value handling. |
| csrc/parallel_dimension_map.cpp | 5/5 | Correctly checks for launch params before returning -1 for dynamic dimensions. Uses hasDim() check before accessing getDim(), following proper defensive programming patterns. |
| csrc/runtime/compiled_kernel.h | 5/5 | Both CompiledKernel constructors correctly accept LaunchParams as new parameter. Changes are properly integrated with appropriate default values. |
| csrc/runtime/compiled_kernel.cpp | 5/5 | Both constructor implementations correctly pass lparams to GpuLower constructor. The second constructor properly delegates to the first constructor with all new parameters. |
| csrc/runtime/executor.cpp | 4/5 | Updated to pass launch_constraints to CompiledKernel constructor in compile path, and LaunchParams() in deserialization path. However, deserialization path passes empty LaunchParams which may lose launch constraint information during deserialization. |
| tests/cpp/test_persistent_buffer.cpp | 5/5 | Changed concrete tensors to makeContigTensor for dynamic shape support. Added comprehensive KernelReuse test that validates kernel reuse with different input shapes - this effectively validates the PR's main feature works correctly. |
Sequence Diagram
sequenceDiagram
participant KE as KernelExecutor
participant CK as CompiledKernel
participant GL as GpuLower
participant FR as FusedReduction
rect rgb(200, 220, 255)
Note over KE,FR: Compilation Path
KE->>CK: "compile(fusion, launch_constraints)"
CK->>GL: "GpuLower(fusion, cparams, launch_params)"
GL->>GL: "Stores lparams_"
GL->>GL: "run() executes passes"
GL->>FR: "Analysis phase uses GpuLower::current()"
FR->>GL: "launchParams().getDim(ParallelType::TIDx)"
GL-->>FR: "Returns static bdimx value"
end
rect rgb(200, 200, 220)
Note over KE,GL: "Deserialization Path"
KE->>CK: "deserialize() - loads pre-compiled binary"
CK->>GL: "GpuLower(fusion, cparams, LaunchParams())"
Note right of GL: "Empty params OK - binary already compiled"
end
| std::cout << "reduction_tv: " << reduction_tv->toString() << std::endl; | ||
| reduction_tv->printTransforms(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug print statements should be removed before merging. These lines add unnecessary stdout output during kernel scheduling in production code.
| std::cout << "reduction_tv: " << reduction_tv->toString() << std::endl; | |
| reduction_tv->printTransforms(); | |
| reduction_tv->axis(reduction_pos + 2)->parallelize(ParallelType::TIDx); | |
| reduction_tv->axis(reduction_pos + 2)->padToMultipleOfWarp(); | |
| // Create rfactor tensor to separate thread-local reduction from block |
Additional Comments (1)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
2b617d9 to
0d8fc6e
Compare
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR enables warp-specialized kernels for fusions with symbolic (dynamic) input shapes by threading launch parameters through the lowering pipeline.
Key Changes:
- Modified
GpuLowerto accept and storeLaunchParamsalongsideCompileParams - Enhanced
ParallelDimensionMap::getThreadCountInDim()to use staticbdim.x/y/zfrom launch params when dimensions are dynamic - Updated
CompiledKernelandKernelExecutorto pass launch params to lowering - Changed test to use symbolic tensors and added kernel reuse validation
How It Works:
The scheduler sets static CTA dimensions (e.g., bdim.x=160) in params->lparams during heuristic computation. These are now passed to GpuLower, allowing the lowering pass to use concrete thread counts even when input shapes are symbolic. The generated kernel still uses symbolic shapes for runtime flexibility and kernel reuse.
Issue Found:
Debug print statements on lines 470-471 of csrc/scheduler/normalization_inner_tma.cpp must be removed before merge.
Confidence Score: 4/5
- This PR is safe to merge after removing debug print statements
- The implementation is clean and follows existing patterns. Launch params are properly threaded through the compilation pipeline with appropriate fallbacks. The only issue is leftover debug code that must be removed.
- csrc/scheduler/normalization_inner_tma.cpp - contains debug print statements that must be removed
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/scheduler/normalization_inner_tma.cpp | 2/5 | Added debug print statements (lines 470-471) that must be removed before merging |
| csrc/device_lower/lower2device.h | 5/5 | Added LaunchParams parameter to GpuLower constructor and member field with accessor method - clean implementation |
| csrc/parallel_dimension_map.cpp | 5/5 | Enhanced getThreadCountInDim to use launch params when available for dynamic dimensions - properly handles the case when launch params are not set |
| csrc/runtime/executor.cpp | 5/5 | Passes launch_constraints during compilation and empty LaunchParams() during deserialization - correctly handles both cases |
| tests/cpp/test_persistent_buffer.cpp | 5/5 | Changed test from concrete tensors to symbolic tensors and added kernel reuse test - validates the feature works with dynamic shapes |
Sequence Diagram
sequenceDiagram
participant Scheduler as Scheduler Heuristics
participant Executor as KernelExecutor
participant CK as CompiledKernel
participant GL as GpuLower
participant PDM as ParallelDimensionMap
Note over Scheduler: Scheduler determines<br/>launch params (bdim.x/y/z)<br/>for warp-specialized kernel
Scheduler->>Executor: compile(fusion, args, launch_constraints)
Executor->>CK: new CompiledKernel(fusion, compile_params, launch_params)
CK->>GL: new GpuLower(fusion, cparams, lparams)
Note over GL: Stores lparams as<br/>member variable
GL->>GL: analysis(fusion)
Note over GL: Builds ParallelDimensionMap<br/>and other lowering info
GL->>PDM: getThreadCountInDim(ParallelType::TIDx)
alt Dimension is const scalar
PDM-->>GL: Return const value
else Dimension is dynamic AND lparams available
PDM->>GL: launchParams().getDim(pt)
Note over PDM: NEW: Use static bdim.x<br/>from launch params
PDM-->>GL: Return static thread count
else Dimension is dynamic AND no lparams
PDM-->>GL: Return -1 (disables register sharing)
end
Note over GL: Continues lowering with<br/>correct thread dimensions
Additional Comments (1)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Overview
This PR enables warp-specialized kernels to work with dynamic (symbolic) input shapes by passing LaunchParams to the lowering pass. Previously, the lowering pass couldn't derive static block dimensions (bdim.x/y/z) when fusion inputs had symbolic shapes, because this information wasn't available during lowering.
Key Changes
- Modified GpuLower interface: Added optional LaunchParams parameter to constructor, allowing the lowering pass to access static block dimensions set by the scheduler
- Enhanced ParallelDimensionMap: Updated
getThreadCountInDim()to use actual launch parameters for dynamic dimensions instead of always returning -1 - Updated CompiledKernel: Threads LaunchParams through both constructor variants to GpuLower
- KernelExecutor integration: Passes launch_constraints during compilation; uses empty LaunchParams() during deserialization (justified since binary is pre-compiled)
- Test coverage: Added comprehensive test validating kernel reuse with dynamic shapes
Design Rationale
The solution maintains kernel reusability with symbolic shapes while enabling the scheduler to provide static CTA dimensions. When launch parameters are available (during compilation), the lowering pass can use them to derive static extent for thread-indexed loop domains. When unavailable (during deserialization), the code gracefully falls back to existing behavior.
Backward Compatibility
All changes maintain full backward compatibility through default parameter values. Existing code paths continue to work as before.
Testing
The new TmaPersistentTestF::KernelReuse test validates that:
- Kernels are properly reused when input shapes change but launch parameters remain compatible
- New kernels are created when launch parameters must change
- The kernel maintains functionality across multiple input shapes
Confidence Score: 5/5
- This PR is safe to merge with minimal risk. Changes are well-designed, maintain backward compatibility, and include comprehensive testing.
- Score reflects: (1) Clean API design with default parameters maintaining backward compatibility, (2) Localized changes with no risky cross-module dependencies, (3) Proper thread-through of parameters across all code paths, (4) Comprehensive test coverage validating the kernel reuse mechanism, (5) Graceful fallback handling in deserialization path, (6) Clear separation of concerns between scheduler constraints and lowering logic.
- No files require special attention. All changes are well-integrated and properly tested.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/device_lower/lower2device.h | 5/5 | Added LaunchParams parameter to GpuLower constructor and added launchParams() getter. Changes are clean and preserve API compatibility with default parameter. |
| csrc/device_lower/lower2device.cpp | 5/5 | Updated GpuLower constructor to accept and store LaunchParams. Implementation correctly passes lparams through initialization list and stores for later access. |
| csrc/parallel_dimension_map.cpp | 5/5 | Enhanced getThreadCountInDim() to use launch parameters when available for dynamic dimensions. Falls back to returning -1 when params unavailable. Improves support for warp-specialized kernels with dynamic shapes. |
| csrc/runtime/compiled_kernel.cpp | 5/5 | Updated CompiledKernel constructors to receive and pass LaunchParams to GpuLower. Implementation correctly threads parameters through all constructor variants. |
| csrc/runtime/executor.cpp | 4/5 | Passes launch_constraints to CompiledKernel in compile path and empty LaunchParams() in deserialization path (justified since kernel is already compiled). Should include comment explaining empty params in deserialization. |
| tests/cpp/test_persistent_buffer.cpp | 5/5 | Added test_TmaPersistentTestF::KernelReuse test validating kernel reuse with dynamic shapes. Changed symbolic tensor creation in existing tests to support dynamic shapes properly. |
Sequence Diagram
sequenceDiagram
participant KernelExecutor
participant CompiledKernel
participant GpuLower
participant ParallelDimensionMap
KernelExecutor->>CompiledKernel: construct(launch_params)
CompiledKernel->>GpuLower: construct(lparams)
GpuLower->>GpuLower: analysis() - creates ParallelDimensionMap
ParallelDimensionMap->>GpuLower: reads launch_params via GpuLower::current()
ParallelDimensionMap->>ParallelDimensionMap: getThreadCountInDim() uses launch_params
Note over GpuLower,ParallelDimensionMap: Dynamic dimensions use launch_params<br/>when available instead of symbolic extent
Note over KernelExecutor,CompiledKernel: Deserialization passes LaunchParams()<br/>because kernel is already compiled
| _fusion, | ||
| compile_params, | ||
| LaunchParams(), | ||
| device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a clarifying comment explaining why an empty LaunchParams() is passed during deserialization. Since the kernel is already compiled and only the binary is being loaded, the launch parameters are not needed for lowering. This would help future maintainers understand the intentional design.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment seems to make sense. What do you think? @liqiangxl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Overview
This PR enables warp-specialized kernels to work with dynamic (symbolic) input shapes by passing launch parameters (LaunchParams) to the GPU lowering pass. Previously, when fusion inputs had symbolic shapes, the lowering pass couldn't derive static block dimensions (bdim.x/y/z) needed for warp-specialized kernels.
Key Changes
Core Mechanism:
- Added
LaunchParamsparameter toGpuLowerconstructor and stored it as a member (lparams_) - Exposed
launchParams()getter method in GpuLower - Updated
ParallelDimensionMap::getThreadCountInDim()to consult launch parameters when available: if a ParallelType dimension is set in launch parameters, use that value instead of relying on symbolic extent
Data Flow:
KernelExecutor::compile()now passeslaunch_constraintstoCompiledKernelCompiledKernelconstructors acceptlaunch_paramsand forward toGpuLower- During deserialization, empty
LaunchParams()is passed (appropriate since kernel is already compiled)
Testing:
- Added
TmaPersistentTestF.KernelReusetest that validates kernel reuse with different input shapes - Test verifies: same kernel binary can be reused when different input dimensions produce the same launch configuration
- Test also validates that different launch configurations create different kernel runtimes
Technical Correctness
- Thread Safety: The code correctly uses
GpuLower::hasCurrent()guard before accessingGpuLower::current(). The pattern is used in ParallelDimensionMap constructor which runs during GpuLower::analysis() when LowerGuard is active - Backward Compatibility: Fallback behavior is preserved: when launch parameters are unavailable (e.g., during scheduler phase in greedy.cpp),
getThreadCountInDim()returns -1 as before - Initialization: LaunchParams has proper default constructor and methods (
hasDim(),getDim()) used correctly
Code Quality
- Changes follow existing patterns in the codebase
- Parameters are properly threaded through constructor chains
- Getter methods are const-correct
- File analyses provided comprehensive coverage of all changes
The implementation correctly solves the stated problem: warp-specialized kernels can now use launch parameters to determine static CTA shapes even when input shapes are dynamic, while still supporting kernel reuse via symbolic shapes in the generated kernel code.
Confidence Score: 5/5
- This PR is safe to merge with high confidence. The changes are well-designed, properly guarded against unsafe access, and thoroughly tested.
- Score reflects: (1) Proper thread safety with GpuLower::hasCurrent() checks before accessing GpuLower::current(), (2) Correct preservation of fallback behavior for cases where GpuLower is inactive, (3) Simple, straightforward parameter threading through constructors with no complex logic, (4) Comprehensive test coverage validating both kernel reuse and correct behavior with dynamic shapes, (5) No breaking changes to existing APIs (LaunchParams defaults are provided), (6) Clear intent and well-documented through PR description and test cases.
- No files require special attention. All changes are well-implemented and tested.
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| csrc/device_lower/lower2device.h | 5/5 | Added LaunchParams parameter to GpuLower constructor and added launchParams() getter method. Changes are straightforward and follow existing patterns in the codebase. |
| csrc/device_lower/lower2device.cpp | 5/5 | Updated GpuLower constructor to accept and store LaunchParams. Correctly initializes lparams_ member. Logic is sound and follows existing patterns. |
| csrc/parallel_dimension_map.cpp | 5/5 | Enhanced getThreadCountInDim() to use launch parameters when available. Correctly guards access to GpuLower::current() with hasCurrent() check. Fallback behavior (returning -1) preserves existing behavior when launch params unavailable. |
| csrc/runtime/compiled_kernel.h | 5/5 | Updated CompiledKernel constructor signatures to accept LaunchParams. Changes are purely API modifications that pass parameters through to GpuLower. |
| csrc/runtime/compiled_kernel.cpp | 5/5 | Updated CompiledKernel constructors to accept and forward launch_params to GpuLower. Correctly passes parameters in both primary constructor and delegating constructor. |
| csrc/runtime/executor.cpp | 4/5 | Updated CompiledKernel creation calls to pass launch_constraints during normal compilation and empty LaunchParams() during deserialization. The deserialization pattern is reasonable since the kernel binary is already compiled, but could benefit from a clarifying comment as noted in PR thread. |
| tests/cpp/test_persistent_buffer.cpp | 5/5 | Added new test TmaPersistentTestF.KernelReuse that validates kernel reuse with dynamic input shapes. Test correctly verifies that the same compiled kernel binary can be reused across different input dimensions that produce the same launch configuration. |
Sequence Diagram
sequenceDiagram
participant KE as KernelExecutor
participant CK as CompiledKernel
participant GL as GpuLower
participant PDM as ParallelDimensionMap
participant LP as LaunchParams
KE->>CK: new CompiledKernel(fusion, compile_params,<br/>launch_constraints)
activate CK
CK->>GL: make_unique(fusion, compile_params,<br/>launch_params)
activate GL
GL->>GL: analysis(fusion) [LowerGuard active]
activate GL
GL->>PDM: make_unique(fusion)
activate PDM
PDM->>GL: hasCurrent() ✓
PDM->>LP: hasDim(ParallelType::TIDx) ?
LP-->>PDM: true/false
alt Launch Params Available
LP->>PDM: getDim() → static value
else Not Available
PDM->>PDM: return -1
end
deactivate PDM
deactivate GL
GL->>GL: run() [applies lowering passes]
deactivate GL
deactivate CK
Note over GL,PDM: During lowering phase with<br/>launch parameters available,<br/>static CTA shapes can be derived<br/>even with symbolic input shapes
|
!test |
Greptile SummaryThis PR enables warp-specialized kernels for fusions with symbolic (dynamic) input shapes by passing compile-time CTA shape information to the lowering pass through Key changes:
The approach correctly distinguishes between compile-time decisions (using Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Scheduler as Scheduler (normalization_inner_tma)
participant Heuristics as InnerNormTmaParams
participant GpuLower as GpuLower (lowering pass)
participant ParallelDim as ParallelDimensionMap
participant CompileParams as CompileParams
Scheduler->>Scheduler: Calculate bdimx, bdimy, bdimz<br/>for warp specialization
Scheduler->>Heuristics: Set cparams.bdimx/y/z<br/>(static CTA shape)
Scheduler->>Heuristics: Set lparams (launch params)
Note over Scheduler,Heuristics: For dynamic shapes, scheduler<br/>sets static CTA dimensions<br/>in CompileParams
Heuristics->>GpuLower: Pass CompileParams<br/>during lowering
GpuLower->>ParallelDim: getThreadCountInDim(TIDx/y/z)
alt Dimension is const scalar
ParallelDim-->>GpuLower: Return constant value
else Dimension is dynamic
ParallelDim->>CompileParams: Check cparams.bdimx/y/z.has_value()
alt CompileParams has static CTA shape
CompileParams-->>ParallelDim: Return bdimx/y/z value
ParallelDim-->>GpuLower: Return static CTA dimension
else CompileParams not set
ParallelDim-->>GpuLower: Return -1 (disable optimizations)
end
end
Note over GpuLower,ParallelDim: Enables register sharing<br/>and warp specialization<br/>for dynamic shapes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
Additional Comments (2)
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
| bdimx = ceilDiv(after_vect, params->persistent_batch_size); | ||
| bdimx = | ||
| bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bdimx recomputation can break warp specialization. When the while loop exits with bdimx=128, recomputing bdimx = ceilDiv(after_vect, pbs) may produce a different value that fails the bdimx == 128 check at line 103. For example: if after_vect=160 and the loop sets bdimx=128, pbs=2, the recomputation yields ceilDiv(160,2)=80, which rounds to 96 after warp padding, preventing warp specialization despite meeting the original conditions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
csrc/parallel_dimension_map.cpp
Outdated
| // If dimension is dynamic but we have launch parameters available, | ||
| // use the actual launch parameter value | ||
| if (GpuLower::hasCurrent() && | ||
| GpuLower::current()->launchParams().hasDim(pt)) { | ||
| return GpuLower::current()->launchParams().getDim(pt); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a validation check here to ensure the launch parameter value is valid and matches the expected parallel type dimension. While hasDim() checks if the dimension is set, it doesn't validate that the value is positive and reasonable (e.g., not accidentally UNINITIALIZED_VAL or negative).
| // If dimension is dynamic but we have launch parameters available, | |
| // use the actual launch parameter value | |
| if (GpuLower::hasCurrent() && | |
| GpuLower::current()->launchParams().hasDim(pt)) { | |
| return GpuLower::current()->launchParams().getDim(pt); | |
| } | |
| // If dimension is dynamic but we have launch parameters available, | |
| // use the actual launch parameter value | |
| if (GpuLower::hasCurrent() && | |
| GpuLower::current()->launchParams().hasDim(pt)) { | |
| int64_t dim = GpuLower::current()->launchParams().getDim(pt); | |
| NVF_ERROR(dim > 0, "Invalid launch parameter value for ", pt, ": ", dim); | |
| return dim; | |
| } |
This ensures we catch any configuration issues early rather than returning invalid values that could cause subtle bugs downstream.
| // recompute bimdx since it is used to parallelize the dynamic dimension | ||
| // pad to multiple of warp size to use warp reduction. | ||
| // to use warp specialized version, it should pad to 128 threads. | ||
| bdimx = ceilDiv(after_vect, params->persistent_batch_size); | ||
| bdimx = | ||
| bdimx % warp_size == 0 ? bdimx : bdimx + warp_size - bdimx % warp_size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bdimx recalculation logic looks correct, but the relationship between this recalculated value and the earlier loop (lines 75-78) could be clearer. Consider adding a comment explaining why we can't just use the bdimx value from after the loop and why we need to recalculate from scratch.
The recalculation is necessary because:
- The loop may have increased
bdimxbeyond what's actually needed for the givenpersistent_batch_size - We need the minimum
bdimxthat satisfiesceilDiv(after_vect, persistent_batch_size)with proper warp alignment
Example: if after_vect=160, the loop might set bdimx=128, pbs=2, but recalculation gives bdimx=ceilDiv(160,2)=80→96 (padded to warp). This is more efficient.
Consider adding this clarifying comment above line 81.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generated kernel still uses symbolic shapes
What does this mean? Doesn't GpuLower::current()->launchParams().getDim(pt) return a static size?
| _fusion, | ||
| compile_params, | ||
| LaunchParams(), | ||
| device, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment seems to make sense. What do you think? @liqiangxl
csrc/device_lower/lower2device.h
Outdated
| Fusion* fusion, | ||
| const CompileParams& cparams = CompileParams()); | ||
| const CompileParams& cparams = CompileParams(), | ||
| const LaunchParams& lparams = LaunchParams()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a general comment. CompileParams and LaunchParams are meant to be about compilations and kernel launches, respectively. If something is needed for compilation, it should be included in CompileParams rather than passing the parameter struct meant for kernel launches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds a good point. I revised to add static CTA shape as optional vars to CompileParams, then we don't need to pass LaunchParams to lowering pass.
// CTA shape known at compile time
std::optional<int64_t> bdimx = std::nullopt;
std::optional<int64_t> bdimy = std::nullopt;
std::optional<int64_t> bdimz = std::nullopt;
It means the generated cuda code still uses symbolic/dynamic input size, e.g.
Yes, it's a static const size. We didn't use it to derive a static input shape. For example |
So, with this PR, |
Yes, |
|
!test |
| auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); | ||
|
|
||
| // Helper to get the number of compiled kernel runtimes | ||
| auto numRuntimes = [&executor_cache]() -> size_t { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or create regular function in anonymous namespace?
| auto numRuntimes = [&executor_cache]() -> size_t { | |
| auto num_runtimes = [&executor_cache]() -> size_t { |
| }; | ||
|
|
||
| // Helper to run fusion with given dimensions and return the runtime | ||
| auto runAndValidate = [&](int64_t outer_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| auto runAndValidate = [&](int64_t outer_dim, | |
| auto run_and_validate = [&](int64_t outer_dim, |
| EXPECT_NE(first_runtime, fourth_runtime); | ||
| EXPECT_EQ(numRuntimes(), 2); | ||
|
|
||
| // Fifth run with different inner dimension - should not reuse the kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Fifth run with different inner dimension should not reuse the kernel.
// After vectorization of 8, the inner dimension (1280) is 160. After persistent
// batch size of 2, it is 80, so 80 threads are required. The current heuristic pads
// to 96 threads and won't use warp specialized version. The warp specialized
// version requires padding to 128 threads.
Lowering now uses the launch parameters whenever they are available. Is it guaranteed to be safe without changing how heuristic parameters are matched? https://github.com/NVIDIA/Fuser/blob/main/csrc/scheduler/reduction_heuristic.h#L187 |
Ah, I remember that's a responsibility of each scheduler. Since the launch parameter is indeed used anyway, it must be safe to use it in the generated kernel. Am I missing anything? |
naoyam
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
csrc/parallel_dimension_map.cpp
Outdated
| // check to a point where the launch parameters are known. | ||
| // If dimension is dynamic but we have compile-time CTA shape available, | ||
| // use the actual compile parameter value | ||
| if (GpuLower::hasCurrent()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
short-circuit
| if (GpuLower::hasCurrent()) { | |
| if (!GpuLower::hasCurrent()) { | |
| return -1; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getThreadCountInDim isn't called outside of ParallelDimMap, so can it be an asset like NVF_ERROR(GpuLower::hasCurrent());
It's safe. The actual launch params are computed using heuristic launch params as constraints. In other words, the kernel must be launched with the same threads/blocks if it was set by the heuristics. Following your previous suggestion, I added to optional static CTA shape to compile params, compile params are matched when comparing heuristic parameters. |
|
!test |
Summary
Enable warp-specialized kernels for fusions with symbolic (dynamic) input shapes by passing launch parameters to the lowering pass. This PR enables inner persistent scheduler, inner-outer persistent scheduler will be addressed in a following PR.
Problem
Warp-specialized kernels require static CTA shapes,
bdim.x/y/zare fixed at compile time and set in the scheduler heuristics. However, this information was not passed to the lowering pass, causing it failed to derive and validate CTA shapes when fusion inputs have symbolic shapes.For example, with symbolic size
i3, the loop domain containsrthreadIdx.x164{( ceilDiv(( ceilDiv(i3, 8) ), 5) )}pThe lowering pass cannot derive a static
bdim.xfrom this symbolic expression.Solution
Pass
lparamsto the lowering pass.Primary usage in
ParallelDimensionMap::getThreadCountInDimWhen
bdim.x/y/zare set in the launch parameters, use them to derive the extent of the corresponding loop domain instead of relying on the symbolic extent:Note: The generated kernel still uses symbolic shapes to support kernel-reuse when input shapes are changed. See added test
TmaPersistentTestF.KernelReuse