diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e9d24c1a8e..2db95bffe6 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -194,6 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci continue elif "weight" in name and p.requires_grad: p.main_grad = torch.zeros_like(p) + p.grad_added_to_main_grad = False # Should be set to True after backward use_fp8 = fp8_recipe is not None with autocast(enabled=use_fp8, recipe=fp8_recipe): @@ -203,13 +204,19 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci torch.cuda.synchronize() failed_grads = [] + failed_grad_added_flags = [] for name, p in block.named_parameters(): if "layer_norm_weight" in name: continue elif "weight" in name and p.requires_grad: if not torch.count_nonzero(p.main_grad) > 0: failed_grads.append(name) + if not getattr(p, "grad_added_to_main_grad", False): + failed_grad_added_flags.append(name) assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." + assert ( + len(failed_grad_added_flags) == 0 + ), f"grad_added_to_main_grad not set to True for {failed_grad_added_flags}." def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1e6f0b00ab..263d2e9ff0 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,6 +6,7 @@ from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings +import weakref import functools import torch @@ -239,23 +240,9 @@ def forward( else: inputmats = [None] * num_gemms - if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") - - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_objects = [] - for weight in weights: - ctx.weight_objects.append(weight) - tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, - *weights, *biases, ) ctx.save_for_backward(*tensors_to_save) @@ -267,6 +254,12 @@ def forward( ctx.weights_requires_grad = weights[0].requires_grad if fuse_wgrad_accumulation and ctx.weights_requires_grad: + # Keep weakrefs to weights to preserve attributes like main_grad + # when we need to modify the weight python objects + ctx.origin_weight_refs = [weakref.ref(w) for w in weights] + ctx.origin_weights_overwrite_main_grad = getattr( + weights[0], "overwrite_main_grad", False + ) # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -277,8 +270,6 @@ def forward( ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_gemms) ] - else: - ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] ctx.device = device ctx.output_quantizers = output_quantizers ctx.m_splits = m_splits @@ -315,19 +306,24 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], N = ctx.num_gemms inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] - origin_weights = saved_tensors[2 * N : 3 * N] - biases = saved_tensors[3 * N : 4 * N] - main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - for i, weight in enumerate(ctx.weight_objects): - origin_weights[i] = ctx.weight_objects[i] - ctx.weight_objects[i] = None - - if ctx.fuse_wgrad_accumulation: - for i in range(N): - origin_weights[i].main_grad = main_grads[i] + biases = saved_tensors[2 * N : 3 * N] + + # Restore from weakrefs to get original weight python objects + # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) + # Only needed when fuse_wgrad_accumulation is enabled. + origin_weights = [None] * N + main_grads = [None] * N + if ctx.fuse_wgrad_accumulation and ctx.weights_requires_grad: + origin_weight_refs = ctx.origin_weight_refs + ctx.origin_weight_refs = None + origin_weights = [ref() if ref is not None else None for ref in origin_weight_refs] + assert all( + w is not None for w in origin_weights + ), "weight was removed while fuse_wgrad_accumulation=True" + main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + for origin_weight, main_grad in zip(origin_weights, main_grads): + if main_grad is not None: + origin_weight.main_grad = main_grad # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -464,7 +460,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator=wgrad_gemm_use_split_accumulator, accumulate=( accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) + if not getattr(ctx, "origin_weights_overwrite_main_grad", False) else False ), ) @@ -482,7 +478,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Deallocate input tensor clear_tensor_data(*inputmats) - def handle_custom_ddp_from_mcore(weight, wgrad): + def handle_custom_ddp_from_mcore(weight, main_grad, wgrad): if ctx.weights_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr( @@ -491,14 +487,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, + list(main_grad.shape), + main_grad.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, + list(main_grad.shape), + main_grad.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -507,8 +503,8 @@ def handle_custom_ddp_from_mcore(weight, wgrad): return wgrad wgrad_list = [ - handle_custom_ddp_from_mcore(weight, wgrad) - for weight, wgrad in zip(origin_weights, wgrad_list) + handle_custom_ddp_from_mcore(weight, main_grad, wgrad) + for weight, main_grad, wgrad in zip(origin_weights, main_grads, wgrad_list) ] else: wgrad_list = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 13b94f2327..e9d8320526 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -5,6 +5,7 @@ """LayerNormLinear API""" import os import warnings +import weakref from typing import Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op @@ -454,19 +455,10 @@ def forward( ln_weight, ln_bias, ) - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_object = weight tensors_to_save, tensor_objects = prepare_for_saving( inputmat, weightmat, - weight, bias, ln_weight, ln_out, @@ -479,6 +471,13 @@ def forward( ctx.requires_wgrad = weight.requires_grad ctx.is_weight_param_quantized = is_weight_param_quantized if fuse_wgrad_accumulation and weight.requires_grad: + # Keep weakref to weight to preserve attributes like main_grad + # when we need to modify the weight python object + ctx.origin_weight_ref = weakref.ref(weight) + # Save overwrite_main_grad flag now while we have access to weight objec + ctx.origin_weight_overwrites_main_grad = getattr( + weight, "overwrite_main_grad", False + ) # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -554,7 +553,6 @@ def backward( ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight, - origin_weight, bias, ln_weight, ln_out, @@ -566,12 +564,25 @@ def backward( # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None - # Since main_grad can be modified inplace, it should not be a part of saved_tensors - main_grad = ( - ctx.main_grad_func() - if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad - else None + # Restore from weakref to get original weight python object + # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) + # Only needed when fuse_wgrad_accumulation is enabled. + origin_weight = None + origin_weight_overwrites_main_grad = getattr( + ctx, "origin_weight_overwrites_main_grad", False ) + main_grad = None + if ctx.fuse_wgrad_accumulation and ctx.requires_wgrad: + origin_weight_ref = ctx.origin_weight_ref + ctx.origin_weight_ref = None + origin_weight = origin_weight_ref() if origin_weight_ref is not None else None + assert ( + origin_weight is not None + ), "weight was removed while fuse_wgrad_accumulation=True" + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ctx.main_grad_func() if weight is not None else None + if main_grad is not None: + origin_weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -587,14 +598,6 @@ def backward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, - # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - origin_weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - origin_weight.main_grad = main_grad - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -868,7 +871,7 @@ def backward( "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(weight, "overwrite_main_grad", False) + if not origin_weight_overwrites_main_grad else False ), "layout": "NT", @@ -1000,14 +1003,14 @@ def wgrad_gemm( origin_weight.grad_added_to_main_grad = True if getattr(origin_weight, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(origin_weight.main_grad.shape), - origin_weight.dtype, + list(main_grad.shape), + main_grad.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(origin_weight.main_grad.shape), - origin_weight.dtype, + list(main_grad.shape), + main_grad.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ddb33f303c..3ddaf44844 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -5,6 +5,7 @@ """LayerNormMLP API""" import os import warnings +import weakref from typing import Callable, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op @@ -748,13 +749,11 @@ def _forward( ln_weight, ln_out, fc1_weight_final, - fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight_final, - fc2_weight, fc2_bias, mu, rsigma, @@ -764,6 +763,20 @@ def _forward( ctx.tensor_objects = tensor_objects if fuse_wgrad_accumulation: + # Keep weakrefs to weights to preserve attributes like main_grad + # when we need to modify the weight python objects + ctx.fc1_weight_python_object_ref = ( + weakref.ref(fc1_weight) if fc1_weight.requires_grad else None + ) + ctx.fc2_weight_python_object_ref = ( + weakref.ref(fc2_weight) if fc2_weight.requires_grad else None + ) + ctx.fc1_weight_overwrites_main_grad = getattr( + fc1_weight, "overwrite_main_grad", False + ) + ctx.fc2_weight_overwrites_main_grad = getattr( + fc2_weight, "overwrite_main_grad", False + ) # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -791,8 +804,6 @@ def _forward( ctx.fc1_weight_requires_grad = fc1_weight.requires_grad ctx.fc2_weight_requires_grad = fc2_weight.requires_grad - ctx.fc1_weight = fc1_weight - ctx.fc2_weight = fc2_weight ctx.device = device ctx.activation_dtype = activation_dtype @@ -844,13 +855,11 @@ def _forward( ln_weight, ln_out, fc1_weight_final, - fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight_final, - fc2_weight, fc2_bias, mu, rsigma, @@ -964,39 +973,49 @@ def backward( ln_weight, ln_out, fc1_weight, - origin_fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, act_out, fc2_weight, - origin_fc2_weight, fc2_bias, mu, rsigma, ) = _LayerNormMLP._recompute(ctx) - # Since main_grad can be modified inplace, it should not be a part of saved_tensors - fc1_weight_main_grad = ( - ctx.fc1_main_grad_func() - if fc1_weight is not None - and ctx.fuse_wgrad_accumulation - and ctx.fc1_weight_requires_grad - else None - ) - fc2_weight_main_grad = ( - ctx.fc2_main_grad_func() - if origin_fc2_weight is not None - and ctx.fuse_wgrad_accumulation - and ctx.fc2_weight_requires_grad - else None - ) - - # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, - # we need to connect them into one. + # Restore origin weights from weakrefs + # Only needed when fuse_wgrad_accumulation is enabled. + fc1_weight_python_object = None + fc2_weight_python_object = None + fc1_weight_main_grad = None + fc2_weight_main_grad = None if ctx.fuse_wgrad_accumulation: - origin_fc1_weight.main_grad = fc1_weight_main_grad - origin_fc2_weight.main_grad = fc2_weight_main_grad + fc1_weight_python_object_ref = getattr(ctx, "fc1_weight_python_object_ref", None) + fc2_weight_python_object_ref = getattr(ctx, "fc2_weight_python_object_ref", None) + ctx.fc1_weight_python_object_ref = None + ctx.fc2_weight_python_object_ref = None + fc1_weight_python_object = ( + fc1_weight_python_object_ref() + if fc1_weight_python_object_ref is not None + else None + ) + fc2_weight_python_object = ( + fc2_weight_python_object_ref() + if fc2_weight_python_object_ref is not None + else None + ) + if ctx.fc1_weight_requires_grad: + assert ( + fc1_weight_python_object is not None + ), "fc1_weight was removed while fuse_wgrad_accumulation=True" + fc1_weight_main_grad = ctx.fc1_main_grad_func() + fc1_weight_python_object.main_grad = fc1_weight_main_grad + if ctx.fc2_weight_requires_grad: + assert ( + fc2_weight_python_object is not None + ), "fc2_weight was removed while fuse_wgrad_accumulation=True" + fc2_weight_main_grad = ctx.fc2_main_grad_func() + fc2_weight_python_object.main_grad = fc2_weight_main_grad # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP @@ -1115,9 +1134,9 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + fc2_weight, QuantizedTensorStorage ): - ctx.fc2_weight.update_usage(columnwise_usage=True) + fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM gemm_output, *_ = general_gemm( @@ -1217,18 +1236,18 @@ def backward( # Arguments to include in wgrad GEMM closure fc2_wgrad_gemm_kwargs = { "out_dtype": ( - origin_fc2_weight.main_grad.dtype + fc2_weight_main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(fc1_weight, "overwrite_main_grad", False) + if not getattr(ctx, "fc2_weight_overwrites_main_grad", False) else False ), "layout": "NT", - "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + "out": fc2_weight_main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, "use_split_accumulator": wgrad_use_split_accumulator, "grad": grad_arg, @@ -1367,9 +1386,9 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + fc1_weight, QuantizedTensorStorage ): - ctx.fc1_weight.update_usage(columnwise_usage=True) + fc1_weight.update_usage(columnwise_usage=True) # Output buffers for Userbuffers reduce-scatter gemm_out = None @@ -1464,18 +1483,18 @@ def fc2_wgrad_gemm( # Arguments to include in wgrad GEMM closure fc1_wgrad_gemm_kwargs = { "out_dtype": ( - origin_fc1_weight.main_grad.dtype + fc1_weight_main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(fc2_weight, "overwrite_main_grad", False) + if not getattr(ctx, "fc1_weight_overwrites_main_grad", False) else False ), "layout": "NT", - "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, + "out": fc1_weight_main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, "use_split_accumulator": wgrad_use_split_accumulator, "grad": fuse_gemm_and_bias_fc1_wgrad, @@ -1579,19 +1598,21 @@ def fc1_wgrad_gemm( if ctx.fc1_weight_requires_grad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"): - origin_fc1_weight.grad_added_to_main_grad = True - if getattr(origin_fc1_weight, "zero_out_wgrad", False): + if ctx.fuse_wgrad_accumulation and hasattr( + fc1_weight_python_object, "grad_added_to_main_grad" + ): + fc1_weight_python_object.grad_added_to_main_grad = True + if getattr(fc1_weight_python_object, "zero_out_wgrad", False): fc1_wgrad = torch.zeros( - origin_fc1_weight.main_grad.shape, - dtype=origin_fc1_weight.dtype, + fc1_weight_main_grad.shape, + dtype=fc1_weight_main_grad.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: fc1_wgrad = torch.empty( - origin_fc1_weight.main_grad.shape, - dtype=origin_fc1_weight.dtype, + fc1_weight_main_grad.shape, + dtype=fc1_weight_main_grad.dtype, device=torch.cuda.current_device(), requires_grad=False, ) @@ -1603,20 +1624,20 @@ def fc1_wgrad_gemm( if ctx.fc2_weight_requires_grad: # Handle custom DDP from mcore. if ctx.fuse_wgrad_accumulation and hasattr( - origin_fc2_weight, "grad_added_to_main_grad" + fc2_weight_python_object, "grad_added_to_main_grad" ): - origin_fc2_weight.grad_added_to_main_grad = True - if getattr(origin_fc2_weight, "zero_out_wgrad", False): + fc2_weight_python_object.grad_added_to_main_grad = True + if getattr(fc2_weight_python_object, "zero_out_wgrad", False): fc2_wgrad = torch.zeros( - origin_fc2_weight.main_grad.shape, - dtype=origin_fc2_weight.dtype, + fc2_weight_main_grad.shape, + dtype=fc2_weight_main_grad.dtype, device=torch.cuda.current_device(), requires_grad=False, ) else: fc2_wgrad = torch.empty( - origin_fc2_weight.main_grad.shape, - dtype=origin_fc2_weight.dtype, + fc2_weight_main_grad.shape, + dtype=fc2_weight_main_grad.dtype, device=torch.cuda.current_device(), requires_grad=False, ) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b8349f84a0..47c3397b56 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -7,6 +7,7 @@ from functools import reduce from operator import mul as multiply_op import warnings +import weakref import torch @@ -417,24 +418,13 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - - if ctx.grad_added_to_main_grad: - # If you are passing torch.nn.Parameter through the Torch hooks, you will - # get back torch.Tensor. Torch rips off the Parameter wrapper. - # You need to preserve the weight object to have all the attributes user - # sets for the weights. Because of this, it is not recommended to offload - # weights if weights are externally touched outside this module - ctx.weight_object = weight - if cpu_offloading: mark_not_offload(weight, weightmat, bias) + # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, weightmat, - weight, bias, ) ctx.save_for_backward(*tensors_to_save) @@ -447,8 +437,15 @@ def forward( ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer + ctx.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation if fuse_wgrad_accumulation and weight.requires_grad: + # Keep weakref to weight to preserve attributes like main_grad + # when we need to modify the weight python object + ctx.origin_weight_ref = weakref.ref(weight) + ctx.origin_weight_overwrites_main_grad = getattr( + weight, "overwrite_main_grad", False + ) # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -503,7 +500,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], with get_nvtx_range_context("_Linear_backward"): saved_tensors = ctx.saved_tensors - inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking + inputmat, weight_fp8, bias = ( # pylint: disable=unbalanced-tuple-unpacking restore_from_saved(ctx.tensor_objects, saved_tensors) ) @@ -511,18 +508,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None - # Since main_grad can be modified inplace, it should not be a part of saved_tensors - main_grad = ( - ctx.main_grad_func() - if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad - else None + # Restore from weakref to get original weight python object + # (preserves attributes like main_grad, grad_added_to_main_grad, etc.) + # Only needed when fuse_wgrad_accumulation is enabled. + origin_weight_python_object = None + origin_weight_overwrites_main_grad = getattr( + ctx, "origin_weight_overwrites_main_grad", False ) - - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: - weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - weight.main_grad = main_grad + main_grad = None + if ctx.fuse_wgrad_accumulation and ctx.requires_wgrad: + origin_weight_ref = ctx.origin_weight_ref + ctx.origin_weight_ref = None + origin_weight_python_object = ( + origin_weight_ref() if origin_weight_ref is not None else None + ) + assert ( + origin_weight_python_object is not None + ), "weight was removed while fuse_wgrad_accumulation=True" + # Since main_grad can be modified inplace, it should not be a part of saved_tensors + main_grad = ctx.main_grad_func() + origin_weight_python_object.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -854,7 +859,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad - if not getattr(weight, "overwrite_main_grad", False) + if not origin_weight_overwrites_main_grad else False ), "layout": "NT", @@ -944,22 +949,20 @@ def wgrad_gemm( if ctx.requires_wgrad: # Handle custom DDP from mcore. - if ( - ctx.fuse_wgrad_accumulation - and weight is not None - and hasattr(weight, "grad_added_to_main_grad") + if ctx.fuse_wgrad_accumulation and hasattr( + origin_weight_python_object, "grad_added_to_main_grad" ): - weight.grad_added_to_main_grad = True - if getattr(weight, "zero_out_wgrad", False): + origin_weight_python_object.grad_added_to_main_grad = True + if getattr(origin_weight_python_object, "zero_out_wgrad", False): wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, + list(main_grad.shape), + main_grad.dtype, zero=True, ) else: wgrad = get_dummy_wgrad( - list(weight.main_grad.shape), - weight.dtype, + list(main_grad.shape), + main_grad.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -973,7 +976,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): + if ctx.fp8 and not ctx.is_weight_param_quantized: _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( wgrad,