Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions modelopt/torch/speculative/eagle/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
"""Eagle model utils."""

import torch
from torch import nn


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
Expand Down Expand Up @@ -71,21 +70,3 @@ def expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: int | None = No
inverted_mask = 1.0 - expanded_mask

return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


class RMSNorm(nn.Module):
"""Borrowed from LlamaRMSNorm class."""

def __init__(self, hidden_size, eps=1e-6):
"""LlamaRMSNorm is equivalent to T5LayerNorm."""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
"""Forward function for RMSNorm."""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
4 changes: 2 additions & 2 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

from ..eagle.conversion import EagleDMRegistry
from ..eagle.eagle_model import EagleModel
from ..eagle.utils import RMSNorm, expand_mask, make_causal_mask
from ..eagle.utils import expand_mask, make_causal_mask
from ..medusa.conversion import MedusaDMRegistry
from ..medusa.medusa_model import MedusaModel
from ..utils import (
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self, config, decoder_layer_cls, bias=False):
[decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
if config.use_last_layernorm:
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps)

# Optionally, we use a smaller vocab table for eagle module
if config.draft_vocab_size != config.vocab_size or config.has_lm_head:
Expand Down