Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 128 additions & 83 deletions modelopt/torch/distill/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import TYPE_CHECKING

import torch
import torch.distributed.nn as dist_nn
import torch.nn as nn
import torch.nn.functional as F
import yaml
Expand Down Expand Up @@ -57,13 +58,15 @@ class DistillationConfig:
skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``).
kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``.
logit_kl_temperature: Temperature for the logit KL-divergence loss.
logit_kl_topk: If not None, use TopKLogitsKLLoss instead of LogitsKLLoss with this top-k value.
"""

intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list)
logit_layers: tuple[str, str] = ("output_layer", "output_layer")
skip_lm_loss: bool = True
kd_loss_scale: float = 1.0
logit_kl_temperature: float = 1.0
logit_kl_topk: int | None = None
criterion: Criterion | None = None
loss_balancer: mtd.DistillationLossBalancer | None = None

Expand Down Expand Up @@ -123,9 +126,15 @@ def setup_distillation_config(
if cfg.criterion is None:
criterion = {}
if parallel_state.is_pipeline_last_stage():
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(
student_cfg, temperature=cfg.logit_kl_temperature
)
# Use TopKLogitsKLLoss if logit_kl_topk is specified, otherwise use LogitsKLLoss
if cfg.logit_kl_topk is not None:
criterion[tuple(cfg.logit_layers)] = TopKLogitsKLLoss(
student_cfg, temperature=cfg.logit_kl_temperature, top_k=cfg.logit_kl_topk
)
else:
criterion[tuple(cfg.logit_layers)] = LogitsKLLoss(
student_cfg, temperature=cfg.logit_kl_temperature
)
# NOTE: Projection layer shared among intermediate layer pairs.
projection_layer = ProjectionLayer(student_cfg, teacher_cfg)

Expand Down Expand Up @@ -310,81 +319,143 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
predictions, targets = self.pre_forward(predictions, targets)

# Division by temp should happen prior to finding max for both student and teacher.
# Currently we don't use temperature in any of ours runs (temp=1.0)
output_teacher = targets.float() / self._temperature
output_student = predictions.float() / self._temperature

# Compute local softmax, and the reweight to compute global softmax.
if self._config.tensor_model_parallel_size > 1:
# Maximum value along vocab dimension across all GPUs.
teacher_logits_max, _ = torch.max(output_teacher, dim=-1)
tp_group = parallel_state.get_tensor_model_parallel_group()

# Subtract maximum value along vocab dimension across all GPUs (for stability)
teacher_logits_max, _ = torch.max(output_teacher, dim=-1, keepdim=True)
torch.distributed.all_reduce(
teacher_logits_max,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_tensor_model_parallel_group(),
group=tp_group,
)
output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1)
output_teacher -= teacher_logits_max

denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1)
# We can't use standard reduction function here since the computation
# that follows it isn't identical across TP ranks.
denom_teacher = all_reduce_autograd(
denom_teacher, group=parallel_state.get_tensor_model_parallel_group()
)

# Maximum value along vocab dimension across all GPUs.
student_logits_max, _ = torch.max(output_student, dim=-1)
student_logits_max, _ = torch.max(output_student, dim=-1, keepdim=True)
torch.distributed.all_reduce(
student_logits_max,
op=torch.distributed.ReduceOp.MAX,
group=parallel_state.get_tensor_model_parallel_group(),
group=tp_group,
)
output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach()
output_student -= student_logits_max.detach()

denom_student = torch.sum(torch.exp(output_student), dim=-1)
denom_student = all_reduce_autograd(
denom_student, group=parallel_state.get_tensor_model_parallel_group()
)
# Compute global softmax denominators
# We can't use standard all_reduce function here since the computation
# that follows it isn't identical across TP ranks.
denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True)
denom_teacher = dist_nn.functional.all_reduce(denom_teacher, group=tp_group)

denom_student = torch.sum(torch.exp(output_student), dim=-1, keepdim=True)
denom_student = dist_nn.functional.all_reduce(denom_student, group=tp_group)

# Compute log probabilities (log softmax)
teacher_log_prob = output_teacher - torch.log(denom_teacher)
student_log_prob = output_student - torch.log(denom_student)

# KL divergence
p, q = student_log_prob, teacher_log_prob
else:
# Compute log probabilities
p, q = F.log_softmax(output_student, dim=-1), F.log_softmax(output_teacher, dim=-1)

# KL divergence
if self._reverse:
p, q = q, p
loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1)

return self.post_forward(loss, tp_reduce=True)


class TopKLogitsKLLoss(LogitsKLLoss):
"""Calculates KL-Divergence loss restricted to the Teacher's Top-K vocabulary entries.

Calculates using the global Top-K entries without gathering full logits.
NOTE: Will gather Top-K logits per rank, so mind the value of K for memory and communication.
"""

def __init__(
self,
model_config: "TransformerConfig",
temperature: float = 1.0,
reverse: bool = False,
top_k: int = 1000,
):
"""Constructor.

Args:
model_config: MCore transformer config.
temperature: Divide tensors by this value prior to calculating loss.
reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher)
top_k: The number of top vocabulary entries to keep from the teacher's distribution.
"""
super().__init__(model_config, temperature, reverse)
self.top_k = top_k

def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
"""Forward function.

slen, bsz, sharded_vocab_size = output_student.shape
student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand(
slen, bsz, sharded_vocab_size
Args:
predictions: Student model tensors (size [s, b, h])
targets: Teacher model tensors (size [s, b, h])

Returns:
Top-K KLD loss of tensors (size [b, s])
"""
predictions, targets = self.pre_forward(predictions, targets)

tp_size = self._config.tensor_model_parallel_size
assert self.top_k <= targets.size(-1) * tp_size, (
f"top_k ({self.top_k}) is larger than total vocab size ({targets.size(-1) * tp_size})"
)

# Divide by temperature first
output_teacher = targets.float() / self._temperature
output_student = predictions.float() / self._temperature

# Extract local Top-K
# We take K from each rank and then find the global Top-K of all those.
local_top_k = min(self.top_k, targets.size(-1))
top_teacher_vals, top_idx = torch.topk(output_teacher, local_top_k, dim=-1)
top_student_vals = torch.gather(output_student, dim=-1, index=top_idx)

if tp_size > 1:
tp_group = parallel_state.get_tensor_model_parallel_group()

# Gather all candidates into shape [N_rows, local_k * tp_size]
# Use all_gather from torch.distributed.nn.functional to preserve gradients
all_teacher_vals = dist_nn.functional.all_gather(
top_teacher_vals.contiguous(), group=tp_group
)
teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand(
slen, bsz, sharded_vocab_size
all_student_vals = dist_nn.functional.all_gather(
top_student_vals.contiguous(), group=tp_group
)
all_teacher_vals = torch.cat(all_teacher_vals, dim=-1)
all_student_vals = torch.cat(all_student_vals, dim=-1)

if self._reverse:
loss = torch.sum(
F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True),
dim=-1,
)
else:
loss = torch.sum(
F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True),
dim=-1,
)
# Pick the true Top-K based on Teacher values
global_top_vals, global_top_idx = torch.topk(all_teacher_vals, self.top_k, dim=-1)

elif self._reverse:
loss = torch.sum(
F.kl_div(
F.log_softmax(output_teacher, dim=-1),
F.softmax(output_student, dim=-1),
reduction="none",
),
dim=-1,
)
final_teacher_logits = global_top_vals
final_student_logits = torch.gather(all_student_vals, dim=-1, index=global_top_idx)
else:
loss = torch.sum(
F.kl_div(
F.log_softmax(output_student, dim=-1),
F.softmax(output_teacher, dim=-1),
reduction="none",
),
dim=-1,
)
final_teacher_logits = top_teacher_vals
final_student_logits = top_student_vals

return self.post_forward(loss, tp_reduce=True)
# Standard (dense) Softmax + KL
p = F.log_softmax(final_student_logits, dim=-1)
q = F.log_softmax(final_teacher_logits, dim=-1)

# KL divergence
if self._reverse:
p, q = q, p
loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1)

# No need to reduce since all ranks compute same global Top-K
return self.post_forward(loss, tp_reduce=False)


class LogitsAndIntermediatesLossBalancer(mtd.DistillationLossBalancer):
Expand Down Expand Up @@ -417,7 +488,7 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor:
"""
original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY)
for _key in loss_dict:
if _key.startswith(LogitsKLLoss.__name__):
if "Logits" in _key: # class name
logits_key = _key # should only be one
logits_loss = loss_dict.pop(logits_key)
intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1)
Expand Down Expand Up @@ -481,32 +552,6 @@ def _init_weights(self, module):
module.bias.data.zero_()


class _AllReduce(torch.autograd.Function):
"""Implementation from old PyTorch `torch.distributed.nn.parallel`."""

@staticmethod
def forward(ctx, op, group, tensor):
ctx.group, ctx.op = group, op
tensor = tensor.clone()
torch.distributed.all_reduce(tensor, op=op, group=group)
return tensor

@staticmethod
def backward(ctx, grad_output):
return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output))


def all_reduce_autograd(
tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD
):
"""Custom all-reduce function.

Needed instead of other all-reduce functions available when the computation following
the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators.
"""
return _AllReduce.apply(op, group, tensor)


########################################################


Expand Down
Loading
Loading