Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
continue
elif "weight" in name and p.requires_grad:
p.main_grad = torch.zeros_like(p)
p.grad_added_to_main_grad = False # Should be set to True after backward

use_fp8 = fp8_recipe is not None
with autocast(enabled=use_fp8, recipe=fp8_recipe):
Expand All @@ -203,13 +204,19 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci
torch.cuda.synchronize()

failed_grads = []
failed_grad_added_flags = []
for name, p in block.named_parameters():
if "layer_norm_weight" in name:
continue
elif "weight" in name and p.requires_grad:
if not torch.count_nonzero(p.main_grad) > 0:
failed_grads.append(name)
if not getattr(p, "grad_added_to_main_grad", False):
failed_grad_added_flags.append(name)
assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}."
assert (
len(failed_grad_added_flags) == 0
), f"grad_added_to_main_grad not set to True for {failed_grad_added_flags}."


def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad):
Expand Down
70 changes: 33 additions & 37 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Union, Optional, Callable, Tuple, List
from itertools import chain
import warnings
import weakref

import functools
import torch
Expand Down Expand Up @@ -239,23 +240,9 @@ def forward(
else:
inputmats = [None] * num_gemms

if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")

if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_objects = []
for weight in weights:
ctx.weight_objects.append(weight)

tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
*weights_fp8,
*weights,
*biases,
)
ctx.save_for_backward(*tensors_to_save)
Expand All @@ -267,6 +254,12 @@ def forward(

ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
# Keep weakrefs to weights to preserve attributes like main_grad
# when we need to modify the weight python objects
ctx.origin_weight_refs = [weakref.ref(w) for w in weights]
ctx.origin_weights_overwrite_main_grad = getattr(
weights[0], "overwrite_main_grad", False
)
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
Expand All @@ -277,8 +270,6 @@ def forward(
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
]
else:
ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device
ctx.output_quantizers = output_quantizers
ctx.m_splits = m_splits
Expand Down Expand Up @@ -315,19 +306,24 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
N = ctx.num_gemms
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]

if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects):
origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None

if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
biases = saved_tensors[2 * N : 3 * N]

# Restore from weakrefs to get original weight python objects
# (preserves attributes like main_grad, grad_added_to_main_grad, etc.)
# Only needed when fuse_wgrad_accumulation is enabled.
origin_weights = [None] * N
main_grads = [None] * N
if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad:
origin_weight_refs = ctx.origin_weight_refs
ctx.origin_weight_refs = None
origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs]
assert all(
w is not None for w in origin_weights
), "weight was removed while fuse_wgrad_accumulation=True"
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
for origin_weight, main_grad in zip(origin_weights, main_grads):
if main_grad is not None:
origin_weight.main_grad = main_grad

# Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
Expand Down Expand Up @@ -464,7 +460,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
use_split_accumulator=wgrad_gemm_use_split_accumulator,
accumulate=(
accumulate_wgrad_into_param_main_grad
if not getattr(weights[0], "overwrite_main_grad", False)
if not getattr(ctx, "origin_weights_overwrite_main_grad", False)
else False
),
)
Expand All @@ -482,7 +478,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
# Deallocate input tensor
clear_tensor_data(*inputmats)

def handle_custom_ddp_from_mcore(weight, wgrad):
def handle_custom_ddp_from_mcore(weight, main_grad, wgrad):
if ctx.weights_requires_grad:
# Handle custom DDP from mcore.
if ctx.fuse_wgrad_accumulation and hasattr(
Expand All @@ -491,14 +487,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False):
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
list(main_grad.shape),
main_grad.dtype,
zero=True,
)
else:
wgrad = get_dummy_wgrad(
list(weight.main_grad.shape),
weight.dtype,
list(main_grad.shape),
main_grad.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
Expand All @@ -507,8 +503,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad):
return wgrad

wgrad_list = [
handle_custom_ddp_from_mcore(weight, wgrad)
for weight, wgrad in zip(origin_weights, wgrad_list)
handle_custom_ddp_from_mcore(weight, main_grad, wgrad)
for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list)
]
else:
wgrad_list = [None] * ctx.num_gemms
Expand Down
59 changes: 31 additions & 28 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""LayerNormLinear API"""
import os
import warnings
import weakref
from typing import Callable, Dict, Optional, Tuple, Union, List
from functools import reduce
from operator import mul as multiply_op
Expand Down Expand Up @@ -454,19 +455,10 @@ def forward(
ln_weight,
ln_bias,
)
ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user
# sets for the weights. Because of this, it is not recommended to offload
# weights if weights are externally touched outside this module
ctx.weight_object = weight

tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
weightmat,
weight,
bias,
ln_weight,
ln_out,
Expand All @@ -479,6 +471,13 @@ def forward(
ctx.requires_wgrad = weight.requires_grad
ctx.is_weight_param_quantized = is_weight_param_quantized
if fuse_wgrad_accumulation and weight.requires_grad:
# Keep weakref to weight to preserve attributes like main_grad
# when we need to modify the weight python object
ctx.origin_weight_ref = weakref.ref(weight)
# Save overwrite_main_grad flag now while we have access to weight objec
ctx.origin_weight_overwrites_main_grad = getattr(
weight, "overwrite_main_grad", False
)
# This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
Expand Down Expand Up @@ -554,7 +553,6 @@ def backward(
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
weight,
origin_weight,
bias,
ln_weight,
ln_out,
Expand All @@ -566,12 +564,25 @@ def backward(
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
ctx.main_grad_func()
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None
# Restore from weakref to get original weight python object
# (preserves attributes like main_grad, grad_added_to_main_grad, etc.)
# Only needed when fuse_wgrad_accumulation is enabled.
origin_weight = None
origin_weight_overwrites_main_grad = getattr(
ctx, "origin_weight_overwrites_main_grad", False
)
main_grad = None
if ctx.fuse_wgrad_accumulation and ctx.requires_wgrad:
origin_weight_ref = ctx.origin_weight_ref
ctx.origin_weight_ref = None
origin_weight = origin_weight_ref() if origin_weight_ref is not None else None
assert (
origin_weight is not None
), "weight was removed while fuse_wgrad_accumulation=True"
# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = ctx.main_grad_func() if weight is not None else None
if main_grad is not None:
origin_weight.main_grad = main_grad

# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
Expand All @@ -587,14 +598,6 @@ def backward(
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")

# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad

# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
Expand Down Expand Up @@ -868,7 +871,7 @@ def backward(
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(weight, "overwrite_main_grad", False)
if not origin_weight_overwrites_main_grad
else False
),
"layout": "NT",
Expand Down Expand Up @@ -1000,14 +1003,14 @@ def wgrad_gemm(
origin_weight.grad_added_to_main_grad = True
if getattr(origin_weight, "zero_out_wgrad", False):
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
list(main_grad.shape),
main_grad.dtype,
zero=True,
)
else:
wgrad = get_dummy_wgrad(
list(origin_weight.main_grad.shape),
origin_weight.dtype,
list(main_grad.shape),
main_grad.dtype,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
Expand Down
Loading
Loading