GGEMM+srelu kernels for MxFP8 Nemotron#2981
Conversation
|
/te-ci pytorch |
|
Please sign-off your commits @sraman-rgb |
Greptile SummaryThis PR refactors the fused GroupedMLP kernel hierarchy into a shared base class and adds
Confidence Score: 5/5The refactor is well-structured and the SReLU kernel wiring follows the established GLU pattern closely; the two flagged items are clarifying questions rather than confirmed failures. The class hierarchy generalisation is clean, dscales_tensor is always an allocated tensor, the recompute-FC2-input path is guarded by multiple independent checks, and test coverage spans both unit-level ScaledSReLU and the full grouped-MLP integration. forward_grouped_mlp.py (prob_tensor dtype) and _common.py (_nvidia_cudnn_frontend_supports_wgrad guard) Important Files Changed
Sequence DiagramsequenceDiagram
participant Fuser
participant GLUFwd as ForwardGroupedMLP_CuTeGEMMGLU_MXFP8
participant SReLUFwd as ForwardGroupedMLP_CuTeGEMMUnary_MXFP8
participant SReLUBwd as BackwardGroupedMLP_CuTeGEMMDUnary_MXFP8
participant cuDNN as cuDNN FE Kernels
Fuser->>GLUFwd: fuse_forward_ops GLU pattern
GLUFwd->>cuDNN: grouped_gemm_glu_wrapper_sm100
cuDNN-->>GLUFwd: fc2_in scales and activation_in
Fuser->>SReLUFwd: fuse_forward_srelu_ops SReLU pattern
SReLUFwd->>cuDNN: grouped_gemm_srelu_wrapper_sm100
cuDNN-->>SReLUFwd: fc2_in scales and activation_in
Note over SReLUFwd: Save activation_in and scales
Note over SReLUFwd: optionally skip saving fc2_x
Fuser->>SReLUBwd: fuse_backward_srelu_ops
SReLUBwd->>cuDNN: grouped_gemm_dsrelu_wrapper_sm100
cuDNN-->>SReLUBwd: FC1 dy tensors and grad_scales
cuDNN-->>SReLUBwd: optional recomputed FC2 input
SReLUBwd->>cuDNN: grouped_gemm_wgrad for FC1 and FC2
Reviews (8): Last reviewed commit: "Address grouped MLP ScaledSReLU review c..." | Re-trigger Greptile |
8373402 to
765d2e9
Compare
Signed-off-by: sraman-rgb <sraman@nvidia.com>
765d2e9 to
43093cc
Compare
timmoon10
left a comment
There was a problem hiding this comment.
Overall looks good, but we've gotten to the point where we need to start thinking about how to gracefully handle adding new activations. It seems that every model has a different activation function.
| swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None, | ||
| srelu: Optional[ScaledSReLU] = None, |
There was a problem hiding this comment.
Why not have a single arg?
| swiglu: Optional[ScaledSwiGLU | ScaledClampedQGeGLU] = None, | |
| srelu: Optional[ScaledSReLU] = None, | |
| activation: Optional[FusibleOperation] = None, |
It seems like we're adding one activation function after another, so we want interfaces that scale gracefully. Also, fused ops are basically internal to TE and these ops in particular are experimental, so backward compatibility is not a major concern.
The forward fused op should have a similar design. Changing to a consistent arg name would also let us get rid of the kwarg name messiness in the op fusion function.
| return fc2_out, [(), (), ()] | ||
|
|
||
|
|
||
| class ForwardGroupedMLP_CuTeGEMMSReLU_MXFP8(ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8): |
There was a problem hiding this comment.
This is an awkward class hierarchy. It would be better to have a virtual base class that both the GLU and non-GLU functions inherit from. The backward fused ops should have a similar design.
While we're messing with the existing classes, we should reconsider the names. The "SwiGLU" op is actually used for both SwiGLU and ClampedQGeGLU, so a name like "GLU" would be better. And there's no reason to expect "SReLU" won't be applied to other activations later, so maybe "Unary" would be more general.
| pytest.skip("Quantized group GEMM is only supported with BF16/FP16") | ||
| if activation == "scaled_srelu" and quantization != "mxfp8": | ||
| pytest.skip("ScaledSReLU grouped MLP fusion is only supported with MXFP8") | ||
| if activation == "scaled_srelu" and glu_interleave_size is not None: |
There was a problem hiding this comment.
Nit: This is assuming that activations are GLUs by default, and SReLU is weird. Isn't that kind of backward? In any case, it would be more logical to have a single point where we check is_glu_activation, and then use that everywhere.
|
Want your agent to iterate on Greptile's feedback? Try greploops. |
| *, | ||
| fc1: GroupedLinear, | ||
| swiglu: ScaledSwiGLU | ScaledClampedQGeGLU, | ||
| activation: Optional[FusibleOperation] = None, |
There was a problem hiding this comment.
Nit: Python supports kwargs without defaults.
| activation: Optional[FusibleOperation] = None, | |
| activation: Optional[FusibleOperation], |
| fc2_ctx.dtype = dtype | ||
| fc2_ctx.input_requires_grad = input_requires_grad | ||
| fc2_ctx.weight_requires_grad = weight_requires_grad | ||
| fc2_ctx.recompute_input_from_dsrelu = recompute_srelu_fc2_x |
There was a problem hiding this comment.
This option isn't supported in the unfused GroupedLinear op. This is a problem because the forward and backward fusions are performed indendently, so everything needs to be compatible with the unfused op interfaces in case there are different forward and backward fusions. However, I also don't want to include this in the unfused op because this is so hyper-specific to this particular fusion.
The requirement that the fused and unfused ops are interchangeable has causing some trouble with the grouped MLP block. It may be worth relaxing, but we would need to have some guarantee that the forward and backward fusions match exactly. I propose we change the op fuser to operate in three stages: fuse forward-backward ops together, fuse forward ops, fuse backward ops. For fused ops with matching forward and backward, we can tolerate tighter forward-backward integration.
Signed-off-by: Siddhartha Raman S <sraman@login-lyris01.lyris.clusters.nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
912b1d9 to
46b3169
Compare
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM. We might want to wait on the cudnn release and apt cudnn guards are added.
| else: | ||
| try: | ||
| validate_grouped_mlp_dims(window[0], window[1], window[2]) | ||
| except (TypeError, ValueError): | ||
| matches_pattern = False |
There was a problem hiding this comment.
We would want to disable srelu fusion based on cudnn version here eventually before the merge
| scales.detach().to(dtype=dtype).reshape(-1, 1, 1) | ||
| if scales is not None | ||
| else torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device) |
There was a problem hiding this comment.
This might be a hold over from before right? And we do expect scales passed to be never None. So we can revert the change?
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: