From 8bc29ace73f7a6d888b8a5303141ff224cb317ce Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Tue, 6 Jan 2026 11:06:24 -0800 Subject: [PATCH 1/8] First version of TopK (doesn't save memory) Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 195 +++++++++++++++------ 1 file changed, 143 insertions(+), 52 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 500921ce3..907783973 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -297,54 +297,37 @@ def __init__( self._temperature = temperature self._reverse = reverse - def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: - """Forward function. - - Args: - predictions: Student model tensors (size [s, b, h]) - targets: Teacher model tensors (size [s, b, h]) - - Returns: - KLD loss of tensors (size [b, s]) - """ - predictions, targets = self.pre_forward(predictions, targets) - - # Division by temp should happen prior to finding max for both student and teacher. - # Currently we don't use temperature in any of ours runs (temp=1.0) - output_teacher = targets.float() / self._temperature - output_student = predictions.float() / self._temperature - + def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tensor: + """Calculate KLD loss between two tensors.""" # Compute local softmax, and the reweight to compute global softmax. if self._config.tensor_model_parallel_size > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + # Maximum value along vocab dimension across all GPUs. teacher_logits_max, _ = torch.max(output_teacher, dim=-1) torch.distributed.all_reduce( teacher_logits_max, op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), + group=tp_group, ) - output_teacher = output_teacher - teacher_logits_max.unsqueeze(dim=-1) + output_teacher -= teacher_logits_max.unsqueeze(dim=-1) denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) # We can't use standard reduction function here since the computation # that follows it isn't identical across TP ranks. - denom_teacher = all_reduce_autograd( - denom_teacher, group=parallel_state.get_tensor_model_parallel_group() - ) + denom_teacher = all_reduce_autograd(denom_teacher, group=tp_group) # Maximum value along vocab dimension across all GPUs. student_logits_max, _ = torch.max(output_student, dim=-1) torch.distributed.all_reduce( student_logits_max, op=torch.distributed.ReduceOp.MAX, - group=parallel_state.get_tensor_model_parallel_group(), + group=tp_group, ) - output_student = output_student - student_logits_max.unsqueeze(dim=-1).detach() + output_student -= student_logits_max.unsqueeze(dim=-1).detach() denom_student = torch.sum(torch.exp(output_student), dim=-1) - denom_student = all_reduce_autograd( - denom_student, group=parallel_state.get_tensor_model_parallel_group() - ) + denom_student = all_reduce_autograd(denom_student, group=tp_group) slen, bsz, sharded_vocab_size = output_student.shape student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( @@ -355,34 +338,142 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: ) if self._reverse: - loss = torch.sum( - F.kl_div(teacher_log_prob, student_log_prob, reduction="none", log_target=True), - dim=-1, - ) + p, q = teacher_log_prob, student_log_prob else: - loss = torch.sum( - F.kl_div(student_log_prob, teacher_log_prob, reduction="none", log_target=True), - dim=-1, - ) + p, q = student_log_prob, teacher_log_prob + + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) - elif self._reverse: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_teacher, dim=-1), - F.softmax(output_student, dim=-1), - reduction="none", - ), - dim=-1, - ) else: - loss = torch.sum( - F.kl_div( - F.log_softmax(output_student, dim=-1), - F.softmax(output_teacher, dim=-1), - reduction="none", - ), - dim=-1, - ) + if self._reverse: + p, q = F.log_softmax(output_teacher, dim=-1), F.softmax(output_student, dim=-1) + else: + p, q = F.log_softmax(output_student, dim=-1), F.softmax(output_teacher, dim=-1) + + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) + + return loss + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """Forward function. + + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + # Division by temp should happen prior to finding max for both student and teacher. + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + loss = self._calculate_kld(output_teacher, output_student) + + return self.post_forward(loss, tp_reduce=True) + + +class TopKLogitsKLLoss(LogitsKLLoss): + """Calculates KL-Divergence loss restricted to the Teacher's Top-K vocabulary entries. + + Respects Tensor Parallelism by finding the global Top-K threshold without + gathering full logits. + """ + + def __init__( + self, + model_config: "TransformerConfig", + temperature: float = 1.0, + reverse: bool = False, + top_k: int = 100, + ): + """Constructor. + + Args: + model_config: MCore transformer config. + temperature: Divide tensors by this value prior to calculating loss. + reverse: Whether to reverse the loss as KLD(teacher, student) instead of KLD(student, teacher) + top_k: The number of top vocabulary entries to keep from the teacher's distribution. + """ + super().__init__(model_config, temperature, reverse) + self.top_k = top_k + + def _get_global_min_threshold(self, local_teacher_logits: Tensor) -> Tensor: + """Determines the cutoff value (threshold) for the global Top-K elements. + + Args: + local_teacher_logits: Tensor of shape [s, b, h] + + Returns: + threshold: Tensor of shape [s, b, 1] containing the K-th largest value globally. + """ + # 1. Get Local Top-K values (we don't need indices, just values) + # We clamp k to the local vocab size to avoid errors if vocab < k + local_k = min(self.top_k, local_teacher_logits.size(-1)) + local_top_vals, _ = torch.topk(local_teacher_logits, local_k, dim=-1) # [s, b, k] + + # If TP is 1, local is global + if self._config.tensor_model_parallel_size == 1: + return local_top_vals[..., -1:] + + # 2. Gather these candidates from all TP ranks + # Resulting shape will be effectively [s, b, k * tp_size] + gathered_list = [ + torch.zeros_like(local_top_vals) for _ in range(self._config.tensor_model_parallel_size) + ] + torch.distributed.all_gather( + gathered_list, + local_top_vals.contiguous(), + group=parallel_state.get_tensor_model_parallel_group(), + ) + + # Concatenate along the top-k dimension + global_candidates = torch.cat(gathered_list, dim=-1) + + # 3. Find the global Top-K from the candidates + # The K-th value here is the global threshold. + # Note: We must ensure we don't ask for more than available if k*tp is small (unlikely) + global_k = min(self.top_k, global_candidates.size(-1)) + global_top_vals, _ = torch.topk(global_candidates, global_k, dim=-1) + + # The last element is the smallest of the top K, i.e., the threshold. + threshold = global_top_vals[..., -1:] # [s, b, 1] + + return threshold + + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """Forward function. + + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + # Apply Temperature + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + + # We determine the mask based on teacher's confidence. + with torch.no_grad(): + threshold = self._get_global_min_threshold(output_teacher) + + # Create mask: True if value is NOT in top-k + # We use strict inequality (<) for the threshold cut. + # (Ties might include slightly more than K, which is acceptable and numerically safer) + mask = output_teacher < threshold + + # Apply mask to both Teacher and Student + # Setting to -inf ensures they contribute 0 to the sum(exp(x)) later + output_teacher = output_teacher.masked_fill(mask, float("-inf")) + output_student = output_student.masked_fill(mask, float("-inf")) + + loss = self._calculate_kld(output_teacher, output_student) return self.post_forward(loss, tp_reduce=True) From 4fb7a569ede5757c525fe3a90147485e4746a0ad Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 7 Jan 2026 09:09:31 -0800 Subject: [PATCH 2/8] Streamline Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 36 +++++++++------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 907783973..f8d94c678 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -303,7 +303,7 @@ def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tens if self._config.tensor_model_parallel_size > 1: tp_group = parallel_state.get_tensor_model_parallel_group() - # Maximum value along vocab dimension across all GPUs. + # Subtract maximum value along vocab dimension across all GPUs (for stability) teacher_logits_max, _ = torch.max(output_teacher, dim=-1) torch.distributed.all_reduce( teacher_logits_max, @@ -317,7 +317,7 @@ def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tens # that follows it isn't identical across TP ranks. denom_teacher = all_reduce_autograd(denom_teacher, group=tp_group) - # Maximum value along vocab dimension across all GPUs. + # Subtract maximum value along vocab dimension across all GPUs (for stability) student_logits_max, _ = torch.max(output_student, dim=-1) torch.distributed.all_reduce( student_logits_max, @@ -400,7 +400,7 @@ def __init__( super().__init__(model_config, temperature, reverse) self.top_k = top_k - def _get_global_min_threshold(self, local_teacher_logits: Tensor) -> Tensor: + def _get_global_min_threshold(self, output_teacher: Tensor) -> Tensor: """Determines the cutoff value (threshold) for the global Top-K elements. Args: @@ -409,37 +409,29 @@ def _get_global_min_threshold(self, local_teacher_logits: Tensor) -> Tensor: Returns: threshold: Tensor of shape [s, b, 1] containing the K-th largest value globally. """ - # 1. Get Local Top-K values (we don't need indices, just values) - # We clamp k to the local vocab size to avoid errors if vocab < k - local_k = min(self.top_k, local_teacher_logits.size(-1)) - local_top_vals, _ = torch.topk(local_teacher_logits, local_k, dim=-1) # [s, b, k] + # Get Local Top-K values + assert self.top_k <= output_teacher.size(-1), f"{self.top_k=}, {output_teacher.size(-1)=}" + local_top_vals, _ = torch.topk(output_teacher, self.top_k, dim=-1) # [s, b, k] # If TP is 1, local is global if self._config.tensor_model_parallel_size == 1: return local_top_vals[..., -1:] - # 2. Gather these candidates from all TP ranks - # Resulting shape will be effectively [s, b, k * tp_size] - gathered_list = [ + # Gather these candidates from all TP ranks + global_candidates = [ torch.zeros_like(local_top_vals) for _ in range(self._config.tensor_model_parallel_size) ] torch.distributed.all_gather( - gathered_list, + global_candidates, local_top_vals.contiguous(), group=parallel_state.get_tensor_model_parallel_group(), ) + global_candidates = torch.cat(global_candidates, dim=-1) # [s, b, k * tp_size] - # Concatenate along the top-k dimension - global_candidates = torch.cat(gathered_list, dim=-1) - - # 3. Find the global Top-K from the candidates - # The K-th value here is the global threshold. - # Note: We must ensure we don't ask for more than available if k*tp is small (unlikely) - global_k = min(self.top_k, global_candidates.size(-1)) - global_top_vals, _ = torch.topk(global_candidates, global_k, dim=-1) - - # The last element is the smallest of the top K, i.e., the threshold. - threshold = global_top_vals[..., -1:] # [s, b, 1] + # Find the global Top-K from the candidates + # The smallest of the top K (last element) is the global threshold. + global_top_vals, _ = torch.topk(global_candidates, self.top_k, dim=-1) + threshold = global_top_vals[..., -1:] return threshold From 1531c9ddf4e63004ddcdb19e7e9084abe97c9d73 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 7 Jan 2026 10:52:53 -0800 Subject: [PATCH 3/8] Version 2 with scatter_add Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 93 ++++++++++++++-------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index f8d94c678..14e680812 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -337,19 +337,15 @@ def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tens slen, bsz, sharded_vocab_size ) + p, q = student_log_prob, teacher_log_prob if self._reverse: p, q = teacher_log_prob, student_log_prob - else: - p, q = student_log_prob, teacher_log_prob - loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) else: + p, q = F.log_softmax(output_student, dim=-1), F.softmax(output_teacher, dim=-1) if self._reverse: p, q = F.log_softmax(output_teacher, dim=-1), F.softmax(output_student, dim=-1) - else: - p, q = F.log_softmax(output_student, dim=-1), F.softmax(output_teacher, dim=-1) - loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) return loss @@ -435,39 +431,70 @@ def _get_global_min_threshold(self, output_teacher: Tensor) -> Tensor: return threshold - def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: - """Forward function. - - Args: - predictions: Student model tensors (size [s, b, h]) - targets: Teacher model tensors (size [s, b, h]) - - Returns: - KLD loss of tensors (size [b, s]) - """ - predictions, targets = self.pre_forward(predictions, targets) - - # Apply Temperature - output_teacher = targets.float() / self._temperature - output_student = predictions.float() / self._temperature - - # We determine the mask based on teacher's confidence. + def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tensor: + """Calculate KLD loss between two tensors.""" + # 1. Determine cutoff threshold based on teacher's confidence with torch.no_grad(): threshold = self._get_global_min_threshold(output_teacher) + mask = output_teacher >= threshold # elements to keep + + # Flatten tensors for simplicity + s, b = output_student.size(0), output_student.size(1) + output_teacher = output_teacher.view(s * b, -1) + output_student = output_student.view(s * b, -1) + mask = mask.view(s * b, -1) + + # 2. Extract values above threshold (Sparse Selection) + sel_teacher = torch.masked_select(output_teacher, mask) + sel_student = torch.masked_select(output_student, mask) + + # 3. Handle Indices + indices = torch.nonzero(mask) + # indices[:, 0] is exactly the row_index (0 to s*b-1) + # indices[:, 1] is the vocab_index (which we don't need for summation) + row_ids = indices[:, 0] + + # 4. Softmax Normalization + exp_teacher = torch.exp(sel_teacher) + exp_student = torch.exp(sel_student) + + # Prepare containers for the sums of shape [s * b] + denom_teacher = output_student.new_zeros(s * b) + denom_student = output_student.new_zeros(s * b) + + # We must use scatter_add because 'exp_teacher' is a 1D list of variable length + # segments. We need to sum "all values belonging to row 0", then "all for row 1", etc. + denom_teacher.scatter_add_(0, row_ids, exp_teacher) + denom_student.scatter_add_(0, row_ids, exp_student) + + # Global Reduction (Tensor Parallelism) + if self._config.tensor_model_parallel_size > 1: + tp_group = parallel_state.get_tensor_model_parallel_group() + all_reduce_autograd(denom_teacher, group=tp_group) + all_reduce_autograd(denom_student, group=tp_group) - # Create mask: True if value is NOT in top-k - # We use strict inequality (<) for the threshold cut. - # (Ties might include slightly more than K, which is acceptable and numerically safer) - mask = output_teacher < threshold + # 5. KL Divergence + # Gather the calculated denominators back to the sparse elements + # If sel_teacher[i] belongs to row J, we divide by denom_teacher[J] + sel_denom_teacher = denom_teacher[row_ids] + sel_denom_student = denom_student[row_ids] - # Apply mask to both Teacher and Student - # Setting to -inf ensures they contribute 0 to the sum(exp(x)) later - output_teacher = output_teacher.masked_fill(mask, float("-inf")) - output_student = output_student.masked_fill(mask, float("-inf")) + log_prob_teacher = sel_teacher - torch.log(sel_denom_teacher) + log_prob_student = sel_student - torch.log(sel_denom_student) - loss = self._calculate_kld(output_teacher, output_student) + p, q = log_prob_student, log_prob_teacher + if self._reverse: + p, q = log_prob_teacher, log_prob_student + kl_elements = F.kl_div(p, q, reduction="none", log_target=True) - return self.post_forward(loss, tp_reduce=True) + # 6. Accumulate Loss + loss_flat = output_student.new_zeros(s * b) + loss_flat.scatter_add_(0, row_ids, kl_elements) + + # Reshape back to [s, b] for the post_forward step + loss = loss_flat.view(s, b) + + return loss class LogitsAndIntermediatesLossBalancer(mtd.DistillationLossBalancer): From 3ebf4be1779700e82e7cb7153e5175b96e089fd1 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 7 Jan 2026 11:12:07 -0800 Subject: [PATCH 4/8] Simplify old KL loss Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 84 ++++++++++------------ 1 file changed, 37 insertions(+), 47 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 14e680812..0c7743897 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -297,76 +297,66 @@ def __init__( self._temperature = temperature self._reverse = reverse - def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tensor: - """Calculate KLD loss between two tensors.""" + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """Forward function. + + Args: + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) + + Returns: + KLD loss of tensors (size [b, s]) + """ + predictions, targets = self.pre_forward(predictions, targets) + + # Division by temp should happen prior to finding max for both student and teacher. + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature + # Compute local softmax, and the reweight to compute global softmax. if self._config.tensor_model_parallel_size > 1: tp_group = parallel_state.get_tensor_model_parallel_group() # Subtract maximum value along vocab dimension across all GPUs (for stability) - teacher_logits_max, _ = torch.max(output_teacher, dim=-1) + teacher_logits_max, _ = torch.max(output_teacher, dim=-1, keepdim=True) torch.distributed.all_reduce( teacher_logits_max, op=torch.distributed.ReduceOp.MAX, group=tp_group, ) - output_teacher -= teacher_logits_max.unsqueeze(dim=-1) - - denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1) - # We can't use standard reduction function here since the computation - # that follows it isn't identical across TP ranks. - denom_teacher = all_reduce_autograd(denom_teacher, group=tp_group) + output_teacher -= teacher_logits_max - # Subtract maximum value along vocab dimension across all GPUs (for stability) - student_logits_max, _ = torch.max(output_student, dim=-1) + student_logits_max, _ = torch.max(output_student, dim=-1, keepdim=True) torch.distributed.all_reduce( student_logits_max, op=torch.distributed.ReduceOp.MAX, group=tp_group, ) - output_student -= student_logits_max.unsqueeze(dim=-1).detach() + output_student -= student_logits_max.detach() - denom_student = torch.sum(torch.exp(output_student), dim=-1) + # Compute global softmax denominators + # We can't use standard reduction function here since the computation + # that follows it isn't identical across TP ranks. + denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True) + denom_teacher = all_reduce_autograd(denom_teacher, group=tp_group) + + denom_student = torch.sum(torch.exp(output_student), dim=-1, keepdim=True) denom_student = all_reduce_autograd(denom_student, group=tp_group) - slen, bsz, sharded_vocab_size = output_student.shape - student_log_prob = output_student - torch.log(denom_student).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size - ) - teacher_log_prob = output_teacher - torch.log(denom_teacher).view(slen, bsz, 1).expand( - slen, bsz, sharded_vocab_size - ) + # Compute log probabilities (log softmax) + teacher_log_prob = output_teacher - torch.log(denom_teacher) + student_log_prob = output_student - torch.log(denom_student) + # KL divergence p, q = student_log_prob, teacher_log_prob - if self._reverse: - p, q = teacher_log_prob, student_log_prob - loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) - else: - p, q = F.log_softmax(output_student, dim=-1), F.softmax(output_teacher, dim=-1) - if self._reverse: - p, q = F.log_softmax(output_teacher, dim=-1), F.softmax(output_student, dim=-1) - loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) - - return loss - - def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: - """Forward function. - - Args: - predictions: Student model tensors (size [s, b, h]) - targets: Teacher model tensors (size [s, b, h]) - - Returns: - KLD loss of tensors (size [b, s]) - """ - predictions, targets = self.pre_forward(predictions, targets) - - # Division by temp should happen prior to finding max for both student and teacher. - output_teacher = targets.float() / self._temperature - output_student = predictions.float() / self._temperature + # Compute log probabilities + p, q = F.log_softmax(output_student, dim=-1), F.log_softmax(output_teacher, dim=-1) - loss = self._calculate_kld(output_teacher, output_student) + # KL divergence + if self._reverse: + p, q = q, p + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) return self.post_forward(loss, tp_reduce=True) From e7d33a7804a58769621e61a98c451ec3b9acbdf4 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 8 Jan 2026 08:45:17 -0800 Subject: [PATCH 5/8] Simplest top-k version with gather Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 131 +++++++-------------- 1 file changed, 45 insertions(+), 86 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 0c7743897..bbfb7e29a 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -364,8 +364,8 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: class TopKLogitsKLLoss(LogitsKLLoss): """Calculates KL-Divergence loss restricted to the Teacher's Top-K vocabulary entries. - Respects Tensor Parallelism by finding the global Top-K threshold without - gathering full logits. + Calculates using the global Top-K entries without gathering full logits. + NOTE: Will gather Top-K logits per rank, so mind the value of K for memory and communication. """ def __init__( @@ -373,7 +373,7 @@ def __init__( model_config: "TransformerConfig", temperature: float = 1.0, reverse: bool = False, - top_k: int = 100, + top_k: int = 1000, ): """Constructor. @@ -386,105 +386,64 @@ def __init__( super().__init__(model_config, temperature, reverse) self.top_k = top_k - def _get_global_min_threshold(self, output_teacher: Tensor) -> Tensor: - """Determines the cutoff value (threshold) for the global Top-K elements. + def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: + """Forward function. Args: - local_teacher_logits: Tensor of shape [s, b, h] + predictions: Student model tensors (size [s, b, h]) + targets: Teacher model tensors (size [s, b, h]) Returns: - threshold: Tensor of shape [s, b, 1] containing the K-th largest value globally. + Top-K KLD loss of tensors (size [b, s]) """ - # Get Local Top-K values - assert self.top_k <= output_teacher.size(-1), f"{self.top_k=}, {output_teacher.size(-1)=}" - local_top_vals, _ = torch.topk(output_teacher, self.top_k, dim=-1) # [s, b, k] - - # If TP is 1, local is global - if self._config.tensor_model_parallel_size == 1: - return local_top_vals[..., -1:] + predictions, targets = self.pre_forward(predictions, targets) - # Gather these candidates from all TP ranks - global_candidates = [ - torch.zeros_like(local_top_vals) for _ in range(self._config.tensor_model_parallel_size) - ] - torch.distributed.all_gather( - global_candidates, - local_top_vals.contiguous(), - group=parallel_state.get_tensor_model_parallel_group(), - ) - global_candidates = torch.cat(global_candidates, dim=-1) # [s, b, k * tp_size] + assert self.top_k <= targets.size(-1), f"{self.top_k=}, {targets.size(-1)=}" - # Find the global Top-K from the candidates - # The smallest of the top K (last element) is the global threshold. - global_top_vals, _ = torch.topk(global_candidates, self.top_k, dim=-1) - threshold = global_top_vals[..., -1:] + # Divide by temperature first + output_teacher = targets.float() / self._temperature + output_student = predictions.float() / self._temperature - return threshold + # Extract local Top-K + # We take K from each rank and then find the global Top-K of all those. + top_teacher_vals, top_idx = torch.topk(output_teacher, self.top_k, dim=-1) + top_student_vals = torch.gather(output_student, dim=-1, index=top_idx) - def _calculate_kld(self, output_teacher: Tensor, output_student: Tensor) -> Tensor: - """Calculate KLD loss between two tensors.""" - # 1. Determine cutoff threshold based on teacher's confidence - with torch.no_grad(): - threshold = self._get_global_min_threshold(output_teacher) - mask = output_teacher >= threshold # elements to keep - - # Flatten tensors for simplicity - s, b = output_student.size(0), output_student.size(1) - output_teacher = output_teacher.view(s * b, -1) - output_student = output_student.view(s * b, -1) - mask = mask.view(s * b, -1) - - # 2. Extract values above threshold (Sparse Selection) - sel_teacher = torch.masked_select(output_teacher, mask) - sel_student = torch.masked_select(output_student, mask) - - # 3. Handle Indices - indices = torch.nonzero(mask) - # indices[:, 0] is exactly the row_index (0 to s*b-1) - # indices[:, 1] is the vocab_index (which we don't need for summation) - row_ids = indices[:, 0] - - # 4. Softmax Normalization - exp_teacher = torch.exp(sel_teacher) - exp_student = torch.exp(sel_student) - - # Prepare containers for the sums of shape [s * b] - denom_teacher = output_student.new_zeros(s * b) - denom_student = output_student.new_zeros(s * b) - - # We must use scatter_add because 'exp_teacher' is a 1D list of variable length - # segments. We need to sum "all values belonging to row 0", then "all for row 1", etc. - denom_teacher.scatter_add_(0, row_ids, exp_teacher) - denom_student.scatter_add_(0, row_ids, exp_student) - - # Global Reduction (Tensor Parallelism) - if self._config.tensor_model_parallel_size > 1: + if (tp_size := self._config.tensor_model_parallel_size) > 1: tp_group = parallel_state.get_tensor_model_parallel_group() - all_reduce_autograd(denom_teacher, group=tp_group) - all_reduce_autograd(denom_student, group=tp_group) - # 5. KL Divergence - # Gather the calculated denominators back to the sparse elements - # If sel_teacher[i] belongs to row J, we divide by denom_teacher[J] - sel_denom_teacher = denom_teacher[row_ids] - sel_denom_student = denom_student[row_ids] + # Gather all candidates into shape [N_rows, local_k * tp_size] + all_teacher_vals = [torch.zeros_like(top_teacher_vals) for _ in range(tp_size)] + all_student_vals = [torch.zeros_like(top_student_vals) for _ in range(tp_size)] + torch.distributed.all_gather( + all_teacher_vals, top_teacher_vals.contiguous(), group=tp_group + ) + torch.distributed.all_gather( + all_student_vals, top_student_vals.contiguous(), group=tp_group + ) + all_teacher_vals = torch.cat(all_teacher_vals, dim=-1) + all_student_vals = torch.cat(all_student_vals, dim=-1) - log_prob_teacher = sel_teacher - torch.log(sel_denom_teacher) - log_prob_student = sel_student - torch.log(sel_denom_student) + # Pick the true Top-K based on Teacher values + global_top_vals, global_top_idx = torch.topk(all_teacher_vals, self.top_k, dim=-1) - p, q = log_prob_student, log_prob_teacher - if self._reverse: - p, q = log_prob_teacher, log_prob_student - kl_elements = F.kl_div(p, q, reduction="none", log_target=True) + final_teacher_logits = global_top_vals + final_student_logits = torch.gather(all_student_vals, dim=-1, index=global_top_idx) + else: + final_teacher_logits = top_teacher_vals + final_student_logits = top_student_vals - # 6. Accumulate Loss - loss_flat = output_student.new_zeros(s * b) - loss_flat.scatter_add_(0, row_ids, kl_elements) + # Standard (dense) Softmax + KL + p = F.log_softmax(final_student_logits, dim=-1) + q = F.log_softmax(final_teacher_logits, dim=-1) - # Reshape back to [s, b] for the post_forward step - loss = loss_flat.view(s, b) + # KL divergence + if self._reverse: + p, q = q, p + loss = torch.sum(F.kl_div(p, q, reduction="none", log_target=True), dim=-1) - return loss + # No need to reduce since all ranks compute same global Top-K + return self.post_forward(loss, tp_reduce=False) class LogitsAndIntermediatesLossBalancer(mtd.DistillationLossBalancer): From 1284f72f4193ece90a66b1a99b58909b2d85213c Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Fri, 9 Jan 2026 11:39:15 -0800 Subject: [PATCH 6/8] Add tests and fix bugs Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 62 ++--- .../torch/distill/plugins/test_megatron.py | 237 ++++++++++++++++++ 2 files changed, 259 insertions(+), 40 deletions(-) create mode 100644 tests/gpu/torch/distill/plugins/test_megatron.py diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index bbfb7e29a..80d3f61ca 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -26,6 +26,7 @@ from typing import TYPE_CHECKING import torch +import torch.distributed.nn as dist_nn import torch.nn as nn import torch.nn.functional as F import yaml @@ -57,6 +58,7 @@ class DistillationConfig: skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``). kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``. logit_kl_temperature: Temperature for the logit KL-divergence loss. + logits_kl_topk: If not None, use TopKLogitsKLLoss instead of LogitsKLLoss with this top-k value. """ intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) @@ -64,6 +66,7 @@ class DistillationConfig: skip_lm_loss: bool = True kd_loss_scale: float = 1.0 logit_kl_temperature: float = 1.0 + logits_kl_topk: int | None = None criterion: Criterion | None = None loss_balancer: mtd.DistillationLossBalancer | None = None @@ -123,9 +126,15 @@ def setup_distillation_config( if cfg.criterion is None: criterion = {} if parallel_state.is_pipeline_last_stage(): - criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( - student_cfg, temperature=cfg.logit_kl_temperature - ) + # Use TopKLogitsKLLoss if logits_kl_topk is specified, otherwise use LogitsKLLoss + if cfg.logits_kl_topk is not None: + criterion[tuple(cfg.logit_layers)] = TopKLogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature, top_k=cfg.logits_kl_topk + ) + else: + criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( + student_cfg, temperature=cfg.logit_kl_temperature + ) # NOTE: Projection layer shared among intermediate layer pairs. projection_layer = ProjectionLayer(student_cfg, teacher_cfg) @@ -335,13 +344,13 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: output_student -= student_logits_max.detach() # Compute global softmax denominators - # We can't use standard reduction function here since the computation + # We can't use standard all_reduce function here since the computation # that follows it isn't identical across TP ranks. denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True) - denom_teacher = all_reduce_autograd(denom_teacher, group=tp_group) + denom_teacher = dist_nn.functional.all_reduce(denom_teacher, group=tp_group) denom_student = torch.sum(torch.exp(output_student), dim=-1, keepdim=True) - denom_student = all_reduce_autograd(denom_student, group=tp_group) + denom_student = dist_nn.functional.all_reduce(denom_student, group=tp_group) # Compute log probabilities (log softmax) teacher_log_prob = output_teacher - torch.log(denom_teacher) @@ -409,17 +418,16 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: top_teacher_vals, top_idx = torch.topk(output_teacher, self.top_k, dim=-1) top_student_vals = torch.gather(output_student, dim=-1, index=top_idx) - if (tp_size := self._config.tensor_model_parallel_size) > 1: + if self._config.tensor_model_parallel_size > 1: tp_group = parallel_state.get_tensor_model_parallel_group() # Gather all candidates into shape [N_rows, local_k * tp_size] - all_teacher_vals = [torch.zeros_like(top_teacher_vals) for _ in range(tp_size)] - all_student_vals = [torch.zeros_like(top_student_vals) for _ in range(tp_size)] - torch.distributed.all_gather( - all_teacher_vals, top_teacher_vals.contiguous(), group=tp_group + # Use all_gather from torch.distributed.nn.functional to preserve gradients + all_teacher_vals = dist_nn.functional.all_gather( + top_teacher_vals.contiguous(), group=tp_group ) - torch.distributed.all_gather( - all_student_vals, top_student_vals.contiguous(), group=tp_group + all_student_vals = dist_nn.functional.all_gather( + top_student_vals.contiguous(), group=tp_group ) all_teacher_vals = torch.cat(all_teacher_vals, dim=-1) all_student_vals = torch.cat(all_student_vals, dim=-1) @@ -476,7 +484,7 @@ def forward(self, loss_dict: dict[str, Tensor]) -> Tensor: """ original_loss = loss_dict.pop(mtd.loss_balancers.STUDENT_LOSS_KEY) for _key in loss_dict: - if _key.startswith(LogitsKLLoss.__name__): + if "Logits" in _key: # class name logits_key = _key # should only be one logits_loss = loss_dict.pop(logits_key) intermediate_loss = sum(loss_dict.values()) / max(len(loss_dict), 1) @@ -540,32 +548,6 @@ def _init_weights(self, module): module.bias.data.zero_() -class _AllReduce(torch.autograd.Function): - """Implementation from old PyTorch `torch.distributed.nn.parallel`.""" - - @staticmethod - def forward(ctx, op, group, tensor): - ctx.group, ctx.op = group, op - tensor = tensor.clone() - torch.distributed.all_reduce(tensor, op=op, group=group) - return tensor - - @staticmethod - def backward(ctx, grad_output): - return (None, None, _AllReduce.apply(ctx.op, ctx.group, grad_output)) - - -def all_reduce_autograd( - tensor, op=torch.distributed.ReduceOp.SUM, group=torch.distributed.group.WORLD -): - """Custom all-reduce function. - - Needed instead of other all-reduce functions available when the computation following - the all-reduce call differs per rank. In KL loss, this corresponds to the different numerators. - """ - return _AllReduce.apply(op, group, tensor) - - ######################################################## diff --git a/tests/gpu/torch/distill/plugins/test_megatron.py b/tests/gpu/torch/distill/plugins/test_megatron.py new file mode 100644 index 000000000..a04659ce3 --- /dev/null +++ b/tests/gpu/torch/distill/plugins/test_megatron.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import torch +from _test_utils.import_helper import skip_if_no_megatron + +skip_if_no_megatron(apex_or_te_required=True) + +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.megatron.models import get_mcore_gpt_model +from _test_utils.torch.megatron.utils import run_mcore_inference_with_dummy_input +from _test_utils.torch.misc import set_seed + +import modelopt.torch.distill as mtd +from modelopt.torch.distill.plugins.megatron import ( + DistillationConfig, + adjust_distillation_model_for_mcore, + setup_distillation_config, +) + +SEED = 1234 + + +def _test_logits_kl_loss(rank, size): + """Test basic LogitsKLLoss with simple forward/backward pass.""" + channel_divisor = 4 + + num_layers = 2 + hidden_size = channel_divisor * 2 + num_attention_heads = 4 + num_query_groups = 2 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 32 + batch_size = 2 + + # Create teacher model (slightly larger) + teacher_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Create student model (same size for simplicity) + student_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=False, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Setup distillation config + distill_cfg = setup_distillation_config( + config_or_path=None, + student_cfg=student_model.config, + teacher_cfg=teacher_model.config, + ) + + # Convert to distillation model + kd_config = { + "teacher_model": teacher_model, + "criterion": distill_cfg.criterion, + "loss_balancer": distill_cfg.loss_balancer, + } + distillation_model = mtd.convert(student_model, mode=[("kd_loss", kd_config)]) + + # Apply Megatron-specific adjustments + adjust_distillation_model_for_mcore(distillation_model, distill_cfg) + + # Forward pass with dummy input + distillation_model.train() + run_mcore_inference_with_dummy_input(distillation_model, batch_size, hidden_size) + + # Forward and backward pass to verify gradients + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + position_ids = ( + torch.arange(max_sequence_length, dtype=torch.long) + .unsqueeze(0) + .repeat(batch_size, 1) + .cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, max_sequence_length, max_sequence_length), dtype=torch.bool) + ).cuda() + + student_loss = distillation_model(prompt_tokens, position_ids, attention_mask, labels=labels) + + # Compute distillation loss + loss = distillation_model.compute_kd_loss( + student_loss=student_loss, loss_reduction_fn=lambda x: x[0].mean() + ) + assert isinstance(loss, dict), "Loss should be a dictionary" + assert "kd_loss" in loss, "Should contain kd_loss key" + + # Backward pass + loss["kd_loss"].backward() + + +def _test_topk_logits_kl_loss(top_k, rank, size): + """Test TopKLogitsKLLoss with simple forward/backward pass.""" + channel_divisor = 4 + + num_layers = 2 + hidden_size = channel_divisor * 2 + num_attention_heads = 4 + num_query_groups = 2 + ffn_hidden_size = channel_divisor * 2 + max_sequence_length = 8 + vocab_size = 128 + batch_size = 2 + + # Create teacher model + teacher_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=True, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Create student model + student_model = get_mcore_gpt_model( + tensor_model_parallel_size=size, + pipeline_model_parallel_size=1, + initialize_megatron=False, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + max_sequence_length=max_sequence_length, + vocab_size=vocab_size, + activation_func="squared_relu", + ).cuda() + + # Setup distillation config with TopKLogitsKLLoss via logits_kl_topk argument + distill_cfg = setup_distillation_config( + config_or_path=DistillationConfig(logits_kl_topk=top_k), + student_cfg=student_model.config, + teacher_cfg=teacher_model.config, + ) + + # Convert to distillation model + kd_config = { + "teacher_model": teacher_model, + "criterion": distill_cfg.criterion, + "loss_balancer": distill_cfg.loss_balancer, + } + distillation_model = mtd.convert(student_model, mode=[("kd_loss", kd_config)]) + + # Apply Megatron-specific adjustments + adjust_distillation_model_for_mcore(distillation_model, distill_cfg) + + # Forward pass with dummy input + distillation_model.train() + run_mcore_inference_with_dummy_input(distillation_model, batch_size, hidden_size) + + # Forward and backward pass to verify gradients + prompt_tokens = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + labels = torch.randint(0, vocab_size, (batch_size, max_sequence_length)).cuda() + position_ids = ( + torch.arange(max_sequence_length, dtype=torch.long) + .unsqueeze(0) + .repeat(batch_size, 1) + .cuda() + ) + attention_mask = torch.tril( + torch.ones((batch_size, 1, max_sequence_length, max_sequence_length), dtype=torch.bool) + ).cuda() + + student_loss = distillation_model(prompt_tokens, position_ids, attention_mask, labels=labels) + + # Compute distillation loss + loss = distillation_model.compute_kd_loss( + student_loss=student_loss, loss_reduction_fn=lambda x: x[0].mean() + ) + assert isinstance(loss, dict), "Loss should be a dictionary" + assert "kd_loss" in loss, "Should contain kd_loss key" + + # Backward pass + loss["kd_loss"].backward() + + +def test_logits_kl_loss(): + """Test LogitsKLLoss with TP parallelism.""" + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=_test_logits_kl_loss, + backend="nccl", + ) + + +def test_topk_logits_kl_loss(top_k: int = 5): + """Test TopKLogitsKLLoss with TP parallelism.""" + set_seed(SEED) + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_test_topk_logits_kl_loss, top_k), + backend="nccl", + ) From 3093b8a57f42bc84e1fd266e9f337137cc82cd6b Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Fri, 9 Jan 2026 15:51:56 -0800 Subject: [PATCH 7/8] Prevent top-k size error Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 80d3f61ca..76af1497b 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -407,7 +407,10 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: """ predictions, targets = self.pre_forward(predictions, targets) - assert self.top_k <= targets.size(-1), f"{self.top_k=}, {targets.size(-1)=}" + tp_size = self._config.tensor_model_parallel_size + assert self.top_k <= targets.size(-1) * tp_size, ( + f"top_k ({self.top_k}) is larger than total vocab size ({targets.size(-1) * tp_size})" + ) # Divide by temperature first output_teacher = targets.float() / self._temperature @@ -415,10 +418,11 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: # Extract local Top-K # We take K from each rank and then find the global Top-K of all those. - top_teacher_vals, top_idx = torch.topk(output_teacher, self.top_k, dim=-1) + local_top_k = min(self.top_k, targets.size(-1)) + top_teacher_vals, top_idx = torch.topk(output_teacher, local_top_k, dim=-1) top_student_vals = torch.gather(output_student, dim=-1, index=top_idx) - if self._config.tensor_model_parallel_size > 1: + if tp_size > 1: tp_group = parallel_state.get_tensor_model_parallel_group() # Gather all candidates into shape [N_rows, local_k * tp_size] From 8c03f0f1077ef5f6e875291751a61c2a0240ce30 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Sat, 10 Jan 2026 13:34:15 -0800 Subject: [PATCH 8/8] Rename config arg Signed-off-by: Asha Anoosheh --- modelopt/torch/distill/plugins/megatron.py | 10 +++++----- tests/gpu/torch/distill/plugins/test_megatron.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 76af1497b..b0eeeab74 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -58,7 +58,7 @@ class DistillationConfig: skip_lm_loss: Whether to skip computing the standard language model loss (default: ``True``). kd_loss_scale: Relative scaling factor for the distillation loss if ``skip_lm_loss`` is ``False``. logit_kl_temperature: Temperature for the logit KL-divergence loss. - logits_kl_topk: If not None, use TopKLogitsKLLoss instead of LogitsKLLoss with this top-k value. + logit_kl_topk: If not None, use TopKLogitsKLLoss instead of LogitsKLLoss with this top-k value. """ intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) @@ -66,7 +66,7 @@ class DistillationConfig: skip_lm_loss: bool = True kd_loss_scale: float = 1.0 logit_kl_temperature: float = 1.0 - logits_kl_topk: int | None = None + logit_kl_topk: int | None = None criterion: Criterion | None = None loss_balancer: mtd.DistillationLossBalancer | None = None @@ -126,10 +126,10 @@ def setup_distillation_config( if cfg.criterion is None: criterion = {} if parallel_state.is_pipeline_last_stage(): - # Use TopKLogitsKLLoss if logits_kl_topk is specified, otherwise use LogitsKLLoss - if cfg.logits_kl_topk is not None: + # Use TopKLogitsKLLoss if logit_kl_topk is specified, otherwise use LogitsKLLoss + if cfg.logit_kl_topk is not None: criterion[tuple(cfg.logit_layers)] = TopKLogitsKLLoss( - student_cfg, temperature=cfg.logit_kl_temperature, top_k=cfg.logits_kl_topk + student_cfg, temperature=cfg.logit_kl_temperature, top_k=cfg.logit_kl_topk ) else: criterion[tuple(cfg.logit_layers)] = LogitsKLLoss( diff --git a/tests/gpu/torch/distill/plugins/test_megatron.py b/tests/gpu/torch/distill/plugins/test_megatron.py index a04659ce3..6e1833dd6 100644 --- a/tests/gpu/torch/distill/plugins/test_megatron.py +++ b/tests/gpu/torch/distill/plugins/test_megatron.py @@ -169,9 +169,9 @@ def _test_topk_logits_kl_loss(top_k, rank, size): activation_func="squared_relu", ).cuda() - # Setup distillation config with TopKLogitsKLLoss via logits_kl_topk argument + # Setup distillation config with TopKLogitsKLLoss via logit_kl_topk argument distill_cfg = setup_distillation_config( - config_or_path=DistillationConfig(logits_kl_topk=top_k), + config_or_path=DistillationConfig(logit_kl_topk=top_k), student_cfg=student_model.config, teacher_cfg=teacher_model.config, )