Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions transformer_lens/config/TransformerBridgeConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
d_vocab: int = -1,
architecture: Optional[str] = None,
tokenizer_prepends_bos: bool = True,
tokenizer_appends_eos: bool = False,
default_padding_side: Optional[str] = None,
# HookedTransformerConfig compatibility fields
model_name: str = "custom",
Expand Down Expand Up @@ -103,6 +104,7 @@ def __init__(

# Tokenizer configuration
self.tokenizer_prepends_bos = tokenizer_prepends_bos
self.tokenizer_appends_eos = tokenizer_appends_eos
self.default_padding_side = default_padding_side

# Attention weight processing configuration
Expand Down
8 changes: 8 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
NeelSoluOldArchitectureAdapter,
NeoArchitectureAdapter,
NeoxArchitectureAdapter,
Olmo2ArchitectureAdapter,
Olmo3ArchitectureAdapter,
OlmoArchitectureAdapter,
OlmoeArchitectureAdapter,
OptArchitectureAdapter,
Phi3ArchitectureAdapter,
PhiArchitectureAdapter,
Expand Down Expand Up @@ -51,6 +55,10 @@
"NeoForCausalLM": NeoArchitectureAdapter,
"NeoXForCausalLM": NeoxArchitectureAdapter,
"NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter,
"OlmoForCausalLM": OlmoArchitectureAdapter,
"Olmo2ForCausalLM": Olmo2ArchitectureAdapter,
"Olmo3ForCausalLM": Olmo3ArchitectureAdapter,
"OlmoeForCausalLM": OlmoeArchitectureAdapter,
"OPTForCausalLM": OptArchitectureAdapter,
"PhiForCausalLM": PhiArchitectureAdapter,
"Phi3ForCausalLM": Phi3ArchitectureAdapter,
Expand Down
19 changes: 12 additions & 7 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,15 @@ def to_tokens(
truncation=truncate,
max_length=self.cfg.n_ctx if truncate else None,
)["input_ids"]
# Strip trailing EOS tokens that some tokenizers auto-append
# (e.g., OLMo's GPTNeoXTokenizer appends <|endoftext|> to all inputs)
if (
getattr(self.cfg, "tokenizer_appends_eos", False)
and self.tokenizer.eos_token_id is not None
):
# Remove trailing EOS from each sequence, but keep at least 1 token
while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all():
tokens = tokens[:, :-1]
if not prepend_bos and tokenizer_prepends_bos:
tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens)
if move_to_device:
Expand Down Expand Up @@ -1712,16 +1721,12 @@ def generate(
Generated sequence as string, list of strings, or tensor depending on input type and return_type.
If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes.
"""
# Convert input to tokens
# Convert input to tokens using to_tokens() for consistent special token handling
if isinstance(input, str):
input_tokens = self.tokenizer(
input, return_tensors="pt", padding=False, truncation=False
)["input_ids"].to(self.cfg.device)
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_type = "str"
elif isinstance(input, list):
input_tokens = self.tokenizer(
input, return_tensors="pt", padding=True, truncation=False
)["input_ids"].to(self.cfg.device)
input_tokens = self.to_tokens(input, move_to_device=True, truncate=False)
input_type = "list"
else:
input_tokens = input.to(self.cfg.device)
Expand Down
16 changes: 15 additions & 1 deletion transformer_lens/model_bridge/sources/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,21 @@ def boot(
default_padding_side=default_padding_side,
)
if tokenizer is not None:
adapter.cfg.tokenizer_prepends_bos = len(tokenizer.encode("")) > 0
# Detect whether the tokenizer auto-prepends BOS or auto-appends EOS.
# We encode a non-empty test string and check the first/last tokens.
# Using encode("") is unreliable because setup_tokenizer may set
# bos_token = eos_token, making them indistinguishable.
encoded_test = tokenizer.encode("a")
adapter.cfg.tokenizer_prepends_bos = (
len(encoded_test) > 1
and tokenizer.bos_token_id is not None
and encoded_test[0] == tokenizer.bos_token_id
)
adapter.cfg.tokenizer_appends_eos = (
len(encoded_test) > 1
and tokenizer.eos_token_id is not None
and encoded_test[-1] == tokenizer.eos_token_id
)
bridge = TransformerBridge(hf_model, adapter, tokenizer)
return bridge

Expand Down
16 changes: 16 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@
from transformer_lens.model_bridge.supported_architectures.nanogpt import (
NanogptArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.olmo import (
OlmoArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.olmo2 import (
Olmo2ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.olmo3 import (
Olmo3ArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.olmoe import (
OlmoeArchitectureAdapter,
)
from transformer_lens.model_bridge.supported_architectures.neel_solu_old import (
NeelSoluOldArchitectureAdapter,
)
Expand Down Expand Up @@ -97,6 +109,10 @@
"NeelSoluOldArchitectureAdapter",
"NeoArchitectureAdapter",
"NeoxArchitectureAdapter",
"OlmoArchitectureAdapter",
"Olmo2ArchitectureAdapter",
"Olmo3ArchitectureAdapter",
"OlmoeArchitectureAdapter",
"OptArchitectureAdapter",
"PhiArchitectureAdapter",
"Phi3ArchitectureAdapter",
Expand Down
174 changes: 174 additions & 0 deletions transformer_lens/model_bridge/supported_architectures/olmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""OLMo architecture adapter."""

from typing import Any

from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion
from transformer_lens.conversion_utils.param_processing_conversion import (
ParamProcessingConversion,
)
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
from transformer_lens.model_bridge.generalized_components import (
BlockBridge,
EmbeddingBridge,
GatedMLPBridge,
LinearBridge,
NormalizationBridge,
PositionEmbeddingsAttentionBridge,
RotaryEmbeddingBridge,
UnembeddingBridge,
)


class OlmoArchitectureAdapter(ArchitectureAdapter):
"""Architecture adapter for OLMo (v1) models.

OLMo v1 uses a pre-norm architecture with a custom non-learnable LayerNorm
(fixed weight=1, bias=0), rotary position embeddings (RoPE), and gated MLP
(SwiGLU). Key differences from later OLMo variants:

- Pre-norm: LayerNorm is applied BEFORE attention and BEFORE MLP.
- Non-learnable LayerNorm: Weight and bias are not trainable parameters.
Delegating to HF's native forward via NormalizationBridge handles this correctly.
- No Q/K normalization in attention.
- Optional QKV clipping (handled by HF's native attention forward).

Optional Parameters (may not exist in state_dict):
-------------------------------------------------
- blocks.{i}.attn.b_Q - No bias on query projection
- blocks.{i}.attn.b_K - No bias on key projection
- blocks.{i}.attn.b_V - No bias on value projection
- blocks.{i}.attn.b_O - No bias on output projection
- blocks.{i}.mlp.b_in - No bias on MLP up_proj
- blocks.{i}.mlp.b_gate - No bias on MLP gate_proj
- blocks.{i}.mlp.b_out - No bias on MLP down_proj
"""

def __init__(self, cfg: Any) -> None:
"""Initialize the OLMo architecture adapter."""
super().__init__(cfg)

# Set config variables for weight processing
self.cfg.normalization_type = "LN"
self.cfg.positional_embedding_type = "rotary"
self.cfg.final_rms = False
self.cfg.gated_mlp = True
self.cfg.attn_only = False
self.cfg.uses_rms_norm = False
# Force eager attention for numerical consistency with benchmark reference
self.cfg.attn_implementation = "eager"

self.default_config = {
"d_model": cfg.d_model,
"d_head": cfg.d_model // cfg.n_heads,
"n_heads": cfg.n_heads,
"n_layers": cfg.n_layers,
"d_vocab": cfg.d_vocab,
}

# GQA support
if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None:
self.default_config["n_key_value_heads"] = cfg.n_key_value_heads
self.cfg.n_key_value_heads = cfg.n_key_value_heads

n_kv_heads = (
self.cfg.n_key_value_heads
if self.cfg.n_key_value_heads is not None
else self.cfg.n_heads
)

self.weight_processing_conversions = {
"blocks.{i}.attn.q.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads),
),
"blocks.{i}.attn.k.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
),
"blocks.{i}.attn.v.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads),
),
"blocks.{i}.attn.o.weight": ParamProcessingConversion(
tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads),
),
}

# Component mapping — PRE-NORM architecture:
# ln1 = input_layernorm (applied BEFORE attention)
# ln2 = post_attention_layernorm (applied BEFORE MLP)
self.component_mapping = {
"embed": EmbeddingBridge(name="model.embed_tokens"),
"rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg),
"blocks": BlockBridge(
name="model.layers",
submodules={
"ln1": NormalizationBridge(
name="input_layernorm",
config=self.cfg,
use_native_layernorm_autograd=True,
),
"ln2": NormalizationBridge(
name="post_attention_layernorm",
config=self.cfg,
use_native_layernorm_autograd=True,
),
"attn": PositionEmbeddingsAttentionBridge(
name="self_attn",
config=self.cfg,
submodules={
"q": LinearBridge(name="q_proj"),
"k": LinearBridge(name="k_proj"),
"v": LinearBridge(name="v_proj"),
"o": LinearBridge(name="o_proj"),
},
requires_attention_mask=True,
requires_position_embeddings=True,
),
"mlp": GatedMLPBridge(
name="mlp",
config=self.cfg,
submodules={
"gate": LinearBridge(name="gate_proj"),
"in": LinearBridge(name="up_proj"),
"out": LinearBridge(name="down_proj"),
},
),
},
),
"ln_final": NormalizationBridge(
name="model.norm",
config=self.cfg,
use_native_layernorm_autograd=True,
),
"unembed": UnembeddingBridge(name="lm_head", config=self.cfg),
}

def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None:
"""Set up rotary embedding references for OLMo component testing.

OLMo uses RoPE (Rotary Position Embeddings). We set the rotary_emb
reference on all attention bridge instances for component testing.

Args:
hf_model: The HuggingFace OLMo model instance
bridge_model: The TransformerBridge model (if available)
"""
# Get rotary embedding instance from the model
rotary_emb = hf_model.model.rotary_emb

# Force HF model to use "eager" attention to match bridge implementation
if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"):
hf_model.config._attn_implementation = "eager"

if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"):
for layer in hf_model.model.layers:
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"):
layer.self_attn.config._attn_implementation = "eager"

# Set rotary_emb on actual bridge instances in bridge_model if available
if bridge_model is not None and hasattr(bridge_model, "blocks"):
for block in bridge_model.blocks:
if hasattr(block, "attn"):
block.attn.set_rotary_emb(rotary_emb)

# Also set on the template for get_generalized_component() calls
attn_bridge = self.get_generalized_component("blocks.0.attn")
attn_bridge.set_rotary_emb(rotary_emb)
Loading