Skip to content

Conversation

@pggPL
Copy link
Collaborator

@pggPL pggPL commented Dec 30, 2025

Description

MCore's fused wgrad accumulation feature requires setting the grad_added_to_main_grad attribute on the weight's Python object. This means the original Python object must be accessible and modifiable during the backward pass.

Currently, weights are saved via save_for_backward, with the assumption that no hooks substitute them with different tensors (e.g., during CPU offload/reload). For CPU offloading, we work around this by saving weights directly on ctx. However, this approach is incompatible with non-TE CPU offloading scenarios and potentially conflicts with FSDP, which also manages weight tensors.

This PR addresses these issues by saving weak references to weights for the backward pass instead. When modifications to the original Python object are needed (e.g., setting grad_added_to_main_grad), the weakref is dereferenced and the modification is applied. This is done conditionally, only when MCore FSDP or MCore fused wgrad accumulation is enabled.

Changes:

  • Replace direct weight references with weakref in forward pass
  • Dereference weakrefs in backward pass only when fuse_wgrad_accumulation is enabled
  • Remove CPU offloading workarounds that saved weights directly on ctx
  • Apply consistent pattern across linear.py, layernorm_linear.py, grouped_linear.py, and layernorm_mlp.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 5 commits December 30, 2025 16:14
@pggPL pggPL marked this pull request as ready for review January 13, 2026 12:50
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile Overview

Greptile Summary

This PR successfully refactors weight tensor handling across four autograd Function classes (_Linear, _LayerNormLinear, _GroupedLinear, _LayerNormMLP) to use weak references instead of direct tensor saving via save_for_backward.

Motivation:
MCore's fused wgrad accumulation feature requires modifying attributes (like grad_added_to_main_grad) on the original Python weight objects during backward pass. Previously, weights were saved via save_for_backward, which could be substituted by hooks (e.g., during CPU offload), breaking the connection to the original Python objects. This PR solves this by saving weak references to weights instead.

Implementation Pattern:

Forward Pass:

  • When fuse_wgrad_accumulation=True and weight requires grad, save weakref.ref(weight) on context
  • Also save the overwrite_main_grad flag value from the weight object on context
  • Remove weights from save_for_backward() call (only save weight tensor data via weightmat)

Backward Pass:

  • When fuse_wgrad_accumulation=True and weight requires grad, dereference the weakref to recover the original Python object
  • Assert that the weakref returns non-None (guards against weight being garbage collected)
  • Set main_grad attribute and other flags on the recovered Python object
  • Use the saved overwrite_main_grad flag value (not the tensor data) for logic decisions

Key Changes:

  • Removed CPU offloading workarounds that manually saved weight objects on ctx
  • Applied consistent weakref pattern across all four module files
  • Added test validation to ensure grad_added_to_main_grad is properly set after backward

Correctness:
The implementation correctly distinguishes between:

  • weight/weightmat: Tensor data (potentially QuantizedTensorStorage) used for computation
  • origin_weight/origin_weight_python_object: Original Python parameter object accessed via weakref for attribute modification

All attribute access uses the correct object type.

Confidence Score: 5/5

  • Safe to merge - refactoring is correctly implemented with proper weakref handling and comprehensive test coverage
  • The PR implements a clean refactoring with a consistent pattern across all affected files. The weakref approach is correctly implemented: weakrefs are created in forward when needed, properly dereferenced in backward with assertions, and all attribute access uses the correct object (Python object vs tensor data). The code removes unnecessary CPU offloading workarounds and adds test validation. No bugs or issues were found.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/linear.py 5/5 Refactored weight handling to use weakref pattern. Removes CPU offloading workarounds. Correctly saves weakref in forward, dereferences in backward, and uses origin_weight_python_object for all attribute access.
transformer_engine/pytorch/module/layernorm_linear.py 5/5 Refactored weight handling to use weakref pattern. Correctly implements the same pattern as linear.py with proper attribute access via origin_weight.
transformer_engine/pytorch/module/grouped_linear.py 5/5 Refactored to use weakrefs for multiple weights. Correctly saves list of weakrefs in forward and dereferences in backward. Properly uses origin_weights for attribute access.
transformer_engine/pytorch/module/layernorm_mlp.py 5/5 Refactored to use weakrefs for fc1_weight and fc2_weight. Correctly implements the weakref pattern for both weights with proper attribute access via python_object variables.
tests/pytorch/test_sanity.py 5/5 Added test validation to ensure grad_added_to_main_grad flag is properly set on weight parameters after backward pass, verifying the weakref refactoring works correctly.

Sequence Diagram

sequenceDiagram
    participant FWD as Forward Pass
    participant CTX as AutogradContext
    participant WEAKREF as WeakReference
    participant WEIGHT as Weight Parameter
    participant BWD as Backward Pass
    participant MCORE as MCore DDP

    FWD->>WEIGHT: Check fuse_wgrad_accumulation & requires_grad
    FWD->>WEAKREF: Create weakref.ref(weight)
    FWD->>CTX: Save origin_weight_ref
    FWD->>CTX: Save origin_weight_overwrites_main_grad flag
    FWD->>CTX: save_for_backward(weightmat) [tensor data only]
    
    Note over FWD,BWD: Forward/Backward boundary
    
    BWD->>CTX: Retrieve origin_weight_ref
    BWD->>WEAKREF: Dereference weakref()
    WEAKREF-->>BWD: Return original Python object
    BWD->>BWD: Assert origin_weight is not None
    BWD->>WEIGHT: Set origin_weight.main_grad
    BWD->>MCORE: Check hasattr(origin_weight, "grad_added_to_main_grad")
    MCORE-->>BWD: Attribute exists
    BWD->>WEIGHT: Set origin_weight.grad_added_to_main_grad = True
    BWD->>BWD: Return dummy wgrad (gradient handled via main_grad)
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

pggPL and others added 2 commits January 13, 2026 14:12
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 1226 to 1230
"out_dtype": (
origin_fc2_weight.main_grad.dtype
fc2_weight_main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When ctx.fuse_wgrad_accumulation is True but the weakref returns None (line 996-999), fc2_weight_main_grad remains None (check fails at line 1004). Accessing .dtype on None will cause an AttributeError. Need to add a safety check or ensure fc2_weight_main_grad has a fallback value.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Additional Comments (2)

transformer_engine/pytorch/module/layernorm_linear.py
Using weight (the tensor restored from save_for_backward) instead of origin_weight (the weakref-restored original Python object) defeats the purpose of this PR. The overwrite_main_grad attribute should be checked on the original Python object to ensure correct behavior with MCore's fused wgrad accumulation.

Should use:

                    "accumulate": (
                        accumulate_wgrad_into_param_main_grad
                        if not (
                            origin_weight is not None
                            and getattr(origin_weight, "overwrite_main_grad", False)
                        )
                        else False
                    ),

transformer_engine/pytorch/module/grouped_linear.py
Using weights[0] (tensor restored from save_for_backward) instead of origin_weights[0] (the weakref-restored original Python object) defeats the purpose of this PR. The overwrite_main_grad attribute should be checked on the original Python object to ensure correct behavior with MCore's fused wgrad accumulation.

Should use:

                    accumulate=(
                        accumulate_wgrad_into_param_main_grad
                        if not (
                            origin_weights[0] is not None
                            and getattr(origin_weights[0], "overwrite_main_grad", False)
                        )
                        else False
                    ),

pggPL and others added 2 commits January 13, 2026 15:15
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 1238 to +1251
fc2_wgrad_gemm_kwargs = {
"out_dtype": (
origin_fc2_weight.main_grad.dtype
fc2_weight_main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
if not getattr(ctx, "fc2_weight_overwrites_main_grad", False)
else False
),
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"out": fc2_weight_main_grad if ctx.fuse_wgrad_accumulation else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] fc2_weight_main_grad is None when fuse_wgrad_accumulation=False, causing AttributeError on .dtype access. This block executes when ctx.fc2_weight_requires_grad=True (line 1178), but fc2_weight_main_grad is only initialized when both ctx.fuse_wgrad_accumulation=True AND ctx.fc2_weight_requires_grad=True (lines 993-1018). Wrap this kwargs dict in if ctx.fuse_wgrad_accumulation: or use ternary to handle None case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but .dtype is invoked only if "if ctx.fuse_wgrad_accumulation" returns True ....

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Collaborator Author

pggPL commented Jan 13, 2026

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant