diff --git a/modelopt/torch/speculative/eagle/utils.py b/modelopt/torch/speculative/eagle/utils.py index 281528788..d77ed298a 100644 --- a/modelopt/torch/speculative/eagle/utils.py +++ b/modelopt/torch/speculative/eagle/utils.py @@ -36,7 +36,6 @@ """Eagle model utils.""" import torch -from torch import nn # Copied from transformers.models.bart.modeling_bart._make_causal_mask @@ -71,21 +70,3 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No inverted_mask = 1.0 - expanded_mask return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) - - -class RMSNorm(nn.Module): - """Borrowed from LlamaRMSNorm class.""" - - def __init__(self, hidden_size, eps=1e-6): - """LlamaRMSNorm is equivalent to T5LayerNorm.""" - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - """Forward function for RMSNorm.""" - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 39df7b9b7..517ddd9b4 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -49,7 +49,7 @@ from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel -from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask +from ..eagle.utils import expand_mask, make_causal_mask from ..medusa.conversion import MedusaDMRegistry from ..medusa.medusa_model import MedusaModel from ..utils import ( @@ -219,7 +219,7 @@ def __init__(self, config, decoder_layer_cls, bias=False): [decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) if config.use_last_layernorm: - self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) + self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps) # Optionally, we use a smaller vocab table for eagle module if config.draft_vocab_size != config.vocab_size or config.has_lm_head: