Skip to content

[WIP] Enable MultiCastTranspose for expert weights#628

Draft
sudhu2k wants to merge 3 commits into
devfrom
sudhu/grouped_linear_mct
Draft

[WIP] Enable MultiCastTranspose for expert weights#628
sudhu2k wants to merge 3 commits into
devfrom
sudhu/grouped_linear_mct

Conversation

@sudhu2k

@sudhu2k sudhu2k commented Jun 16, 2026

Copy link
Copy Markdown
Contributor

Description

This PR optimizes FP8 weight quantization for GroupedLinear by replacing the per-weight quantize loop with a single fused multi_cast_transpose kernel for delayed-scaling FP8. Previously each weight in a group was cast and transposed with its own quantize call; now the whole group is processed in one fused call, reducing kernel-launch overhead and unnecessary tensor reallocations during FP8 operations.

To make this work efficiently with the existing weight-caching mechanism, multi_tensor_quantize now accepts an optional outputs argument so the fused kernel can write directly into cached workspace buffers in place rather than allocating fresh tensors on every update.

Fixes #16929

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added TransformerEngineBaseModule.get_multi_weight_workspace, a group-aware analog of get_weight_workspace that casts and transposes a whole group of weights with a single fused multi_tensor_quantize kernel for delayed-scaling FP8, with caching support. Falls back to the per-tensor get_weight_workspace path when fusion is not applicable (other recipes, rowwise-only usage, already-quantized weights, or CUDA-graph weight caching).
  • Updated _GroupedLinear to use get_multi_weight_workspace instead of looping over get_weight_workspace per GEMM.
  • Extended the multi_tensor_quantize C++ extension (extensions.h, cast.cpp, pybind.cpp) with an optional outputs parameter, enabling in-place quantization into pre-allocated/cached workspace buffers and avoiding redundant tensor allocations.
  • Refactored the workspace allocation/caching logic to handle cache hits and misses through a unified fused path: a cache miss allocates fresh buffers, while a cache hit with update_workspace reuses cached buffers via the new outputs argument. Quantizers are temporarily forced to internal=False so cached workspaces survive prepare_for_saving.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

sudhu2k added 3 commits June 10, 2026 23:01
This new method allows for efficient workspace management of multiple weights, enabling fused operations for delayed-scaling FP8. It enhances performance by reducing the number of quantization calls and supports caching of workspaces. The grouped_linear module has been updated to utilize this method.
…ed workspace management

This update enhances the quantization process by optimizing workspace allocation and handling cache misses more effectively. It introduces a streamlined approach for both cache hits and misses, ensuring efficient in-place quantization and reducing unnecessary memory reallocations. The changes aim to improve performance during FP8 operations while maintaining compatibility with existing functionality.
This update modifies the multi_tensor_quantize function to accept an optional outputs parameter, allowing for in-place quantization when cached workspaces are provided. The changes improve memory efficiency and performance during FP8 operations by reducing unnecessary tensor reallocations.
@sudhu2k sudhu2k changed the title Enable MultiCastTranspose for expert weights [WIP] Enable MultiCastTranspose for expert weights Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant