Skip to content

Count leading parallel dimensions#5817

Merged
wujingyue merged 7 commits intomainfrom
wjy/parallel
Jan 14, 2026
Merged

Count leading parallel dimensions#5817
wujingyue merged 7 commits intomainfrom
wjy/parallel

Conversation

@wujingyue
Copy link
Copy Markdown
Collaborator

@wujingyue wujingyue commented Jan 13, 2026

A follow-up to #5806

@wujingyue
Copy link
Copy Markdown
Collaborator Author

!test

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Jan 13, 2026

Review updated until commit b031eb7

Description

  • Move numDeviceDims function from multidevice/utils to scheduler/utils

  • Rename to countLeadingParallelDimensions with improved semantics

  • Update all schedulers (matmul, pointwise, reduction) to use new function

  • Rename member variable from num_device_dims_ to num_parallel_dims_

Changes walkthrough

Relevant files
Refactoring
2 files
utils.cpp
Remove numDeviceDims function, move to scheduler_utils     
+0/-7     
utils.h
Remove numDeviceDims declaration                                                 
+0/-3     
Enhancement
8 files
matmul.cpp
Replace numDeviceDims with countLeadingParallelDimensions
+3/-2     
matmul_ampere-.cpp
Update variable name from num_device_dims_ to num_parallel_dims_
+1/-1     
matmul_hopper+.cpp
Update all references to use num_parallel_dims_                   
+19/-19 
pointwise_utils.cpp
Replace numDeviceDims with countLeadingParallelDimensions
+7/-6     
reduction_utils.cpp
Replace numDeviceDims with countLeadingParallelDimensions
+8/-6     
utils.cpp
Add countLeadingParallelDimensions function to scheduler_utils
+23/-4   
matmul.h
Rename member variable to num_parallel_dims_                         
+1/-1     
utils.h
Add countLeadingParallelDimensions declaration                     
+8/-0     
Tests
1 files
test_sharding.cpp
Update test to use new function and expectations                 
+6/-8     

PR Reviewer Guide

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review
Function Implementation

The new countLeadingParallelDimensions function correctly implements the logic to count leading parallel non-reduction dimensions with proper validation. The implementation looks sound and includes error checking for unexpected parallel dimensions.

Variable Renaming

The PR consistently replaces num_device_dims_ with num_parallel_dims_ and updates the calculation logic. The change from numDeviceDims(mma_result) to scheduler_utils::countLeadingParallelDimensions(mma_result) appears correct.

Test Updates

Tests have been updated to use the new API. The test changes from checking numDeviceDims(tv2) == 1 to using EXPECT_THAT with Contains(IsParallelized(ParallelType::DIDx)).Times(1) which is a more appropriate test for the new functionality.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jan 13, 2026

Greptile Summary

This PR refactors dimension counting logic across schedulers by replacing the numDeviceDims function with a new countLeadingParallelDimensions function.

Key changes:

  • Replaces numDeviceDims() (which counted all device dimensions anywhere in loop domain) with countLeadingParallelDimensions() (which counts leading parallel non-reduction dimensions and enforces they appear only at the beginning)
  • Renames num_device_dims_ to num_parallel_dims_ throughout matmul schedulers to reflect broader scope
  • Migrates function from multidevice/utils to scheduler_utils for better organization
  • Updates all scheduler call sites (matmul, pointwise, reduction) to use the new function
  • Simplifies test using gmock matchers instead of manual numDeviceDims checks

Behavioral change:
The new function counts any parallel type (DIDx/DIDy/DIDz, Stream, BIDx/y/z, TIDx/y/z) that appears at the beginning, not just device dimensions. It also enforces that parallel dimensions must be contiguous and leading, which is a stricter requirement than the old implementation.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The refactoring is well-structured with comprehensive updates across all affected schedulers. The new function includes proper validation logic (asserts that parallel dimensions are only leading), has clear documentation, and the test simplification demonstrates the improvement. The semantic broadening from device-only to all parallel types aligns with the goal of supporting 2D mesh reductions as mentioned in the related PR Support reduction with 2D mesh #5806
  • No files require special attention

Important Files Changed

Filename Overview
csrc/scheduler/utils.cpp Added new countLeadingParallelDimensions function that counts leading parallel non-reduction dimensions and asserts proper ordering
csrc/scheduler/utils.h Added declaration for countLeadingParallelDimensions with comprehensive documentation
csrc/scheduler/matmul.cpp Updated countDims() to use new countLeadingParallelDimensions instead of numDeviceDims
csrc/scheduler/matmul_hopper+.cpp Replaced all references to num_device_dims_ with num_parallel_dims_ throughout the scheduler
csrc/scheduler/pointwise_utils.cpp Replaced numDeviceDims with countLeadingParallelDimensions in pointwise scheduling
csrc/scheduler/reduction_utils.cpp Updated reduction scheduler to use countLeadingParallelDimensions instead of numDeviceDims

Sequence Diagram

sequenceDiagram
    participant Scheduler
    participant Utils
    participant TV
    
    Scheduler->>Utils: Call countLeadingParallelDimensions
    Utils->>TV: Get nDims
    TV-->>Utils: Return dimension count
    
    loop Count leading parallel dimensions
        Utils->>TV: Get axis at index
        TV-->>Utils: Return IterDomain
        alt IterDomain is parallel and not reduction
            Utils->>Utils: Increment counter
        else IterDomain is not parallel
            Utils->>Utils: Stop counting
        end
    end
    
    loop Validate remaining dimensions
        Utils->>TV: Get axis at index
        TV-->>Utils: Return IterDomain
        alt IterDomain is parallel
            Utils->>Utils: Throw error
        end
    end
    
    Utils-->>Scheduler: Return parallel dimension count

Loading

@wujingyue wujingyue changed the base branch from main to wjy/reduction January 13, 2026 23:26
Comment thread csrc/scheduler/utils.h Outdated
Copy link
Copy Markdown
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making the change!

Comment thread csrc/scheduler/matmul.cpp Outdated
Base automatically changed from wjy/reduction to main January 14, 2026 21:39
@wujingyue
Copy link
Copy Markdown
Collaborator Author

!test

@wujingyue wujingyue merged commit 99f41a7 into main Jan 14, 2026
36 of 37 checks passed
@wujingyue wujingyue deleted the wjy/parallel branch January 14, 2026 23:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants