diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 500921ce3..b0eeeab74 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING import torch +import torch.distributed.nn as dist_nn import torch.nn as nn import torch.nn.functional as F import yaml @@ -57,6 +58,7 @@ class DistillationConfig: skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``). kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``. logit_kl_temperature: Temperature for the logit KL-divergence loss. + logit_kl_topk: If not None, use TopKLogitsKLLoss instead of LogitsKLLoss with this top-k value. """ intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) @@ -64,6 +66,7 @@ class DistillationConfig: skip_lm_loss: bool = True kd_loss_scale: float = 1.0 logit_kl_temperature: float = 1.0 + logit_kl_topk: int | None = None criterion: Criterion | None = None loss_balancer: mtd.DistillationLossBalancer | None = None @@ -123,9 +126,15 @@ def setup_distillation_config( if cfg.criterion is None: criterion = {} if parallel_state.is_pipeline_last_stage(): - criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( - student_cfg, temperature=cfg.logit_kl_temperature - ) + # Use TopKLogitsKLLoss if logit_kl_topk is specified, otherwise use LogitsKLLoss + if cfg.logit_kl_topk is not None: + criterion[tuple(cfg.logit_layers)] = TopKLogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature, top_k=cfg.logit_kl_topk + ) + else: + criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature + ) # NOTE: Projection layer shared among intermediate layer pairs. projection_layer = ProjectionLayer(student_cfg, teacher_cfg) @@ -310,81 +319,143 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: predictions, targets = self.pre_forward(predictions, targets) # Division by temp should happen prior to finding max for both student and teacher. - # Currently we don't use temperature in any of ours runs (temp=1.0) output_teacher = targets.float() / self._temperature output_student = predictions.float() / self._temperature # Compute local softmax, and the reweight to compute global softmax. if self._config.tensor_model_parallel_size > 1: - # Maximum value along vocab dimension across all GPUs. - teacher_logits_max, _ = torch.max(output_teacher, dim=-1) + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Subtract maximum value along vocab dimension across all GPUs (for stability) + teacher_logits_max, _ = torch.max(output_teacher, dim=-1, keepdim=True) torch.distributed.all_reduce( teacher_logits_max, op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), + group=tp_group, ) - output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) + output_teacher -= teacher_logits_max - denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) - # We can't use standard reduction function here since the computation - # that follows it isn't identical across TP ranks. - denom_teacher = all_reduce_autograd( - denom_teacher, group=parallel_state.get_tensor_model_parallel_group() - ) - - # Maximum value along vocab dimension across all GPUs. - student_logits_max, _ = torch.max(output_student, dim=-1) + student_logits_max, _ = torch.max(output_student, dim=-1, keepdim=True) torch.distributed.all_reduce( student_logits_max, op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), + group=tp_group, ) - output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() + output_student -= student_logits_max.detach() - denom_student = torch.sum(torch.exp(output_student), dim=-1) - denom_student = all_reduce_autograd( - denom_student, group=parallel_state.get_tensor_model_parallel_group() - ) + # Compute global softmax denominators + # We can't use standard all_reduce function here since the computation + # that follows it isn't identical across TP ranks. + denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True) + denom_teacher = dist_nn.functional.all_reduce(denom_teacher, group=tp_group) + + denom_student = torch.sum(torch.exp(output_student), dim=-1, keepdim=True) + denom_student = dist_nn.functional.all_reduce(denom_student, group=tp_group) + + # Compute log probabilities (log softmax) + teacher_log_prob = output_teacher - torch.log(denom_teacher) + student_log_prob = output_student - torch.log(denom_student) + + # KL divergence + p, q = student_log_prob, teacher_log_prob + else: + # Compute log probabilities + p, q = F.log_softmax(output_student, dim=-1), F.log_softmax(output_teacher, dim=-1) + + # KL divergence + if self._reverse: + p, q = q, p + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) + + return self.post_forward(loss, tp_reduce=True) + + +class TopKLogitsKLLoss(LogitsKLLoss): + """Calculates KL-Divergence loss restricted to the Teacher's Top-K vocabulary entries. + + Calculates using the global Top-K entries without gathering full logits. + NOTE: Will gather Top-K logits per rank, so mind the value of K for memory and communication. + """ + + def __init__( + self, + model_config: "TransformerConfig", + temperature: float = 1.0, + reverse: bool = False, + top_k: int = 1000, + ): + """Constructor. + + Args: + model_config: MCore transformer config. + temperature: Divide tensors by this value prior to calculating loss. + reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) + top_k: The number of top vocabulary entries to keep from the teacher's distribution. + """ + super().__init__(model_config, temperature, reverse) + self.top_k = top_k + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """Forward function. - slen, bsz, sharded_vocab_size = output_student.shape - student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + Top-K KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + tp_size = self._config.tensor_model_parallel_size + assert self.top_k <= targets.size(-1) * tp_size, ( + f"top_k ({self.top_k}) is larger than total vocab size ({targets.size(-1) * tp_size})" + ) + + # Divide by temperature first + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + # Extract local Top-K + # We take K from each rank and then find the global Top-K of all those. + local_top_k = min(self.top_k, targets.size(-1)) + top_teacher_vals, top_idx = torch.topk(output_teacher, local_top_k, dim=-1) + top_student_vals = torch.gather(output_student, dim=-1, index=top_idx) + + if tp_size > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + + # Gather all candidates into shape [N_rows, local_k * tp_size] + # Use all_gather from torch.distributed.nn.functional to preserve gradients + all_teacher_vals = dist_nn.functional.all_gather( + top_teacher_vals.contiguous(), group=tp_group ) - teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size + all_student_vals = dist_nn.functional.all_gather( + top_student_vals.contiguous(), group=tp_group ) + all_teacher_vals = torch.cat(all_teacher_vals, dim=-1) + all_student_vals = torch.cat(all_student_vals, dim=-1) - if self._reverse: - loss = torch.sum( - F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), - dim=-1, - ) - else: - loss = torch.sum( - F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), - dim=-1, - ) + # Pick the true Top-K based on Teacher values + global_top_vals, global_top_idx = torch.topk(all_teacher_vals, self.top_k, dim=-1) - elif self._reverse: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_teacher, dim=-1), - F.softmax(output_student, dim=-1), - reduction="none", - ), - dim=-1, - ) + final_teacher_logits = global_top_vals + final_student_logits = torch.gather(all_student_vals, dim=-1, index=global_top_idx) else: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_student, dim=-1), - F.softmax(output_teacher, dim=-1), - reduction="none", - ), - dim=-1, - ) + final_teacher_logits = top_teacher_vals + final_student_logits = top_student_vals - return self.post_forward(loss, tp_reduce=True) + # Standard (dense) Softmax + KL + p = F.log_softmax(final_student_logits, dim=-1) + q = F.log_softmax(final_teacher_logits, dim=-1) + + # KL divergence + if self._reverse: + p, q = q, p + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) + + # No need to reduce since all ranks compute same global Top-K + return self.post_forward(loss, tp_reduce=False) class LogitsAndIntermediatesLossBalancer(mtd.DistillationLossBalancer): @@ -417,7 +488,7 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor: """ original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY) for _key in loss_dict: - if _key.startswith(LogitsKLLoss.__name__): + if "Logits" in _key: # class name logits_key = _key # should only be one logits_loss = loss_dict.pop(logits_key) intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1) @@ -481,32 +552,6 @@ def _init_weights(self, module): module.bias.data.zero_() -class _AllReduce(torch.autograd.Function): - """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" - - @staticmethod - def forward(ctx, op, group, tensor): - ctx.group, ctx.op = group, op - tensor = tensor.clone() - torch.distributed.all_reduce(tensor, op=op, group=group) - return tensor - - @staticmethod - def backward(ctx, grad_output): - return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) - - -def all_reduce_autograd( - tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD -): - """Custom all-reduce function. - - Needed instead of other all-reduce functions available when the computation following - the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. - """ - return _AllReduce.apply(op, group, tensor) - - ######################################################## diff --git a/tests/gpu/torch/distill/plugins/test_megatron.py b/tests/gpu/torch/distill/plugins/test_megatron.py new file mode 100644 index 000000000..6e1833dd6 --- /dev/null +++ b/tests/gpu/torch/distill/plugins/test_megatron.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from _test_utils.import_helper import skip_if_no_megatron + +skip_if_no_megatron(apex_or_te_required=True) + +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.megatron.utils import run_mcore_inference_with_dummy_input +from _test_utils.torch.misc import set_seed + +import modelopt.torch.distill as mtd +from modelopt.torch.distill.plugins.megatron import ( + DistillationConfig, + adjust_distillation_model_for_mcore, + setup_distillation_config, +) + +SEED = 1234 + + +def _test_logits_kl_loss(rank, size): + """Test basic LogitsKLLoss with simple forward/backward pass.""" + channel_divisor = 4 + + num_layers = 2 + hidden_size = channel_divisor * 2 + num_attention_heads = 4 + num_query_groups = 2 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 32 + batch_size = 2 + + # Create teacher model (slightly larger) + teacher_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Create student model (same size for simplicity) + student_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=False, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Setup distillation config + distill_cfg = setup_distillation_config( + config_or_path=None, + student_cfg=student_model.config, + teacher_cfg=teacher_model.config, + ) + + # Convert to distillation model + kd_config = { + "teacher_model": teacher_model, + "criterion": distill_cfg.criterion, + "loss_balancer": distill_cfg.loss_balancer, + } + distillation_model = mtd.convert(student_model, mode=[("kd_loss", kd_config)]) + + # Apply Megatron-specific adjustments + adjust_distillation_model_for_mcore(distillation_model, distill_cfg) + + # Forward pass with dummy input + distillation_model.train() + run_mcore_inference_with_dummy_input(distillation_model, batch_size, hidden_size) + + # Forward and backward pass to verify gradients + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + position_ids = ( + torch.arange(max_sequence_length, dtype=torch.long) + .unsqueeze(0) + .repeat(batch_size, 1) + .cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, max_sequence_length, max_sequence_length), dtype=torch.bool) + ).cuda() + + student_loss = distillation_model(prompt_tokens, position_ids, attention_mask, labels=labels) + + # Compute distillation loss + loss = distillation_model.compute_kd_loss( + student_loss=student_loss, loss_reduction_fn=lambda x: x[0].mean() + ) + assert isinstance(loss, dict), "Loss should be a dictionary" + assert "kd_loss" in loss, "Should contain kd_loss key" + + # Backward pass + loss["kd_loss"].backward() + + +def _test_topk_logits_kl_loss(top_k, rank, size): + """Test TopKLogitsKLLoss with simple forward/backward pass.""" + channel_divisor = 4 + + num_layers = 2 + hidden_size = channel_divisor * 2 + num_attention_heads = 4 + num_query_groups = 2 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 128 + batch_size = 2 + + # Create teacher model + teacher_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Create student model + student_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=False, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Setup distillation config with TopKLogitsKLLoss via logit_kl_topk argument + distill_cfg = setup_distillation_config( + config_or_path=DistillationConfig(logit_kl_topk=top_k), + student_cfg=student_model.config, + teacher_cfg=teacher_model.config, + ) + + # Convert to distillation model + kd_config = { + "teacher_model": teacher_model, + "criterion": distill_cfg.criterion, + "loss_balancer": distill_cfg.loss_balancer, + } + distillation_model = mtd.convert(student_model, mode=[("kd_loss", kd_config)]) + + # Apply Megatron-specific adjustments + adjust_distillation_model_for_mcore(distillation_model, distill_cfg) + + # Forward pass with dummy input + distillation_model.train() + run_mcore_inference_with_dummy_input(distillation_model, batch_size, hidden_size) + + # Forward and backward pass to verify gradients + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + position_ids = ( + torch.arange(max_sequence_length, dtype=torch.long) + .unsqueeze(0) + .repeat(batch_size, 1) + .cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, max_sequence_length, max_sequence_length), dtype=torch.bool) + ).cuda() + + student_loss = distillation_model(prompt_tokens, position_ids, attention_mask, labels=labels) + + # Compute distillation loss + loss = distillation_model.compute_kd_loss( + student_loss=student_loss, loss_reduction_fn=lambda x: x[0].mean() + ) + assert isinstance(loss, dict), "Loss should be a dictionary" + assert "kd_loss" in loss, "Should contain kd_loss key" + + # Backward pass + loss["kd_loss"].backward() + + +def test_logits_kl_loss(): + """Test LogitsKLLoss with TP parallelism.""" + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=_test_logits_kl_loss, + backend="nccl", + ) + + +def test_topk_logits_kl_loss(top_k: int = 5): + """Test TopKLogitsKLLoss with TP parallelism.""" + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_topk_logits_kl_loss, top_k), + backend="nccl", + )