Skip to content

Introduce a fused padding + cast transpose kernel grouped linear#632

Open
alextmagro wants to merge 2 commits into
devfrom
fused_pad_MCT
Open

Introduce a fused padding + cast transpose kernel grouped linear#632
alextmagro wants to merge 2 commits into
devfrom
fused_pad_MCT

Conversation

@alextmagro

@alextmagro alextmagro commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

Fuses the following 2 kernels:

Pad tensors in BF16
Cast from BF16 to FP8, transpose, and store

into a single kernel that:

Cast from BF16 to FP8, transpose, store with padding

2x speedup over unfused kernel, applicable to grouped linear.

Comment thread tests/cpp/operator/test_multi_cast_transpose.cu Outdated
Comment thread tests/cpp/operator/test_multi_cast_transpose.cu
Comment thread transformer_engine/common/transpose/multi_cast_transpose.cu
Comment thread transformer_engine/pytorch/module/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread transformer_engine/pytorch/module/grouped_linear.py
Comment thread transformer_engine/pytorch/module/grouped_linear.py
@github-actions

Copy link
Copy Markdown

Claude code review summary

Reviewed PR #632 (fused padding + cast-transpose for grouped linear) against dev (HEAD^1). Scope: 11 files, ~640 LoC added across the C API, ROCm kernel, PyTorch extension, Python GroupedLinear, C++ test, Python test, and benchmark.

Overall: Approach is sound — a single padded MCT kernel replacing pad-then-MCT, exposed end-to-end through split_quantize(valid_split_sections=...) and consumed by GroupedLinear(actual_m_splits=, unpad_output=). The ROCm kernel padding semantics (input zero-fill, amax/store gated by valid_num_rows, columnwise store gated by num_rows) are internally consistent. Upstream isolation is good — almost all new logic sits behind #ifdef __HIP_PLATFORM_AMD__ or IS_HIP_EXTENSION.

Findings (see inline comments):

  • Blocker — likely compile failure in the new performTestWithPadding C++ test: references valid_h that was never declared (the loop binds in_h).
  • One unguarded change to shared CUDA-reachable code in multi_cast_transpose.cu (NVTE_CHECK operand swap) — semantically equivalent on CUDA but creates IFU noise. Suggest guarding or commenting.
  • GroupedLinear.forward signature gains actual_m_splits / unpad_output without docstring updates; the dim-0 view at return assumes the unpadded case.
  • A few small Python nits (stray \ in a comment, mutation of grad_output_quantizers without explanation).

Copyright headers: OK — all 11 modified files have correctly updated AMD headers with the 2026 end-year and preserved NVIDIA lines.

Not duplicated: @ipanfilo's nit on test_multi_cast_transpose.cu:108 about marking the make_nvte_vector lambda removal with ROCm: for IFU clarity still stands.

@alextmagro alextmagro left a comment

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed copilot comments

Comment thread tests/cpp/operator/test_multi_cast_transpose.cu Outdated
Comment thread transformer_engine/pytorch/module/grouped_linear.py
@alextmagro alextmagro added ci-level 3 CI test level 3 and removed ci-level 2 CI test level 2 labels Jun 17, 2026

@ipanfilo ipanfilo left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pending CI results

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants