diff --git a/transformer_lens/config/TransformerBridgeConfig.py b/transformer_lens/config/TransformerBridgeConfig.py index aaf1b16e6..c1b9f2744 100644 --- a/transformer_lens/config/TransformerBridgeConfig.py +++ b/transformer_lens/config/TransformerBridgeConfig.py @@ -26,6 +26,7 @@ def __init__( d_vocab: int = -1, architecture: Optional[str] = None, tokenizer_prepends_bos: bool = True, + tokenizer_appends_eos: bool = False, default_padding_side: Optional[str] = None, # HookedTransformerConfig compatibility fields model_name: str = "custom", @@ -103,6 +104,7 @@ def __init__( # Tokenizer configuration self.tokenizer_prepends_bos = tokenizer_prepends_bos + self.tokenizer_appends_eos = tokenizer_appends_eos self.default_padding_side = default_padding_side # Attention weight processing configuration diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index aa83dd402..f3fba33f4 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -23,6 +23,10 @@ NeelSoluOldArchitectureAdapter, NeoArchitectureAdapter, NeoxArchitectureAdapter, + Olmo2ArchitectureAdapter, + Olmo3ArchitectureAdapter, + OlmoArchitectureAdapter, + OlmoeArchitectureAdapter, OptArchitectureAdapter, Phi3ArchitectureAdapter, PhiArchitectureAdapter, @@ -51,6 +55,10 @@ "NeoForCausalLM": NeoArchitectureAdapter, "NeoXForCausalLM": NeoxArchitectureAdapter, "NeelSoluOldForCausalLM": NeelSoluOldArchitectureAdapter, + "OlmoForCausalLM": OlmoArchitectureAdapter, + "Olmo2ForCausalLM": Olmo2ArchitectureAdapter, + "Olmo3ForCausalLM": Olmo3ArchitectureAdapter, + "OlmoeForCausalLM": OlmoeArchitectureAdapter, "OPTForCausalLM": OptArchitectureAdapter, "PhiForCausalLM": PhiArchitectureAdapter, "Phi3ForCausalLM": Phi3ArchitectureAdapter, diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 9aa5855a2..3214cbe55 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -806,6 +806,15 @@ def to_tokens( truncation=truncate, max_length=self.cfg.n_ctx if truncate else None, )["input_ids"] + # Strip trailing EOS tokens that some tokenizers auto-append + # (e.g., OLMo's GPTNeoXTokenizer appends <|endoftext|> to all inputs) + if ( + getattr(self.cfg, "tokenizer_appends_eos", False) + and self.tokenizer.eos_token_id is not None + ): + # Remove trailing EOS from each sequence, but keep at least 1 token + while tokens.shape[-1] > 1 and (tokens[:, -1] == self.tokenizer.eos_token_id).all(): + tokens = tokens[:, :-1] if not prepend_bos and tokenizer_prepends_bos: tokens = utils.get_tokens_with_bos_removed(self.tokenizer, tokens) if move_to_device: @@ -1712,16 +1721,12 @@ def generate( Generated sequence as string, list of strings, or tensor depending on input type and return_type. If output_logits=True, returns a ModelOutput-like object with 'sequences' and 'logits' attributes. """ - # Convert input to tokens + # Convert input to tokens using to_tokens() for consistent special token handling if isinstance(input, str): - input_tokens = self.tokenizer( - input, return_tensors="pt", padding=False, truncation=False - )["input_ids"].to(self.cfg.device) + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) input_type = "str" elif isinstance(input, list): - input_tokens = self.tokenizer( - input, return_tensors="pt", padding=True, truncation=False - )["input_ids"].to(self.cfg.device) + input_tokens = self.to_tokens(input, move_to_device=True, truncate=False) input_type = "list" else: input_tokens = input.to(self.cfg.device) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index b46a4d67c..d297bd1f8 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -290,7 +290,21 @@ def boot( default_padding_side=default_padding_side, ) if tokenizer is not None: - adapter.cfg.tokenizer_prepends_bos = len(tokenizer.encode("")) > 0 + # Detect whether the tokenizer auto-prepends BOS or auto-appends EOS. + # We encode a non-empty test string and check the first/last tokens. + # Using encode("") is unreliable because setup_tokenizer may set + # bos_token = eos_token, making them indistinguishable. + encoded_test = tokenizer.encode("a") + adapter.cfg.tokenizer_prepends_bos = ( + len(encoded_test) > 1 + and tokenizer.bos_token_id is not None + and encoded_test[0] == tokenizer.bos_token_id + ) + adapter.cfg.tokenizer_appends_eos = ( + len(encoded_test) > 1 + and tokenizer.eos_token_id is not None + and encoded_test[-1] == tokenizer.eos_token_id + ) bridge = TransformerBridge(hf_model, adapter, tokenizer) return bridge diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index a07cb3c03..0afe524a0 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -43,6 +43,18 @@ from transformer_lens.model_bridge.supported_architectures.nanogpt import ( NanogptArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.olmo import ( + OlmoArchitectureAdapter, +) +from transformer_lens.model_bridge.supported_architectures.olmo2 import ( + Olmo2ArchitectureAdapter, +) +from transformer_lens.model_bridge.supported_architectures.olmo3 import ( + Olmo3ArchitectureAdapter, +) +from transformer_lens.model_bridge.supported_architectures.olmoe import ( + OlmoeArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.neel_solu_old import ( NeelSoluOldArchitectureAdapter, ) @@ -97,6 +109,10 @@ "NeelSoluOldArchitectureAdapter", "NeoArchitectureAdapter", "NeoxArchitectureAdapter", + "OlmoArchitectureAdapter", + "Olmo2ArchitectureAdapter", + "Olmo3ArchitectureAdapter", + "OlmoeArchitectureAdapter", "OptArchitectureAdapter", "PhiArchitectureAdapter", "Phi3ArchitectureAdapter", diff --git a/transformer_lens/model_bridge/supported_architectures/olmo.py b/transformer_lens/model_bridge/supported_architectures/olmo.py new file mode 100644 index 000000000..bec03e64d --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/olmo.py @@ -0,0 +1,174 @@ +"""OLMo architecture adapter.""" + +from typing import Any + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + NormalizationBridge, + PositionEmbeddingsAttentionBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + + +class OlmoArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for OLMo (v1) models. + + OLMo v1 uses a pre-norm architecture with a custom non-learnable LayerNorm + (fixed weight=1, bias=0), rotary position embeddings (RoPE), and gated MLP + (SwiGLU). Key differences from later OLMo variants: + + - Pre-norm: LayerNorm is applied BEFORE attention and BEFORE MLP. + - Non-learnable LayerNorm: Weight and bias are not trainable parameters. + Delegating to HF's native forward via NormalizationBridge handles this correctly. + - No Q/K normalization in attention. + - Optional QKV clipping (handled by HF's native attention forward). + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + - blocks.{i}.attn.b_Q - No bias on query projection + - blocks.{i}.attn.b_K - No bias on key projection + - blocks.{i}.attn.b_V - No bias on value projection + - blocks.{i}.attn.b_O - No bias on output projection + - blocks.{i}.mlp.b_in - No bias on MLP up_proj + - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj + - blocks.{i}.mlp.b_out - No bias on MLP down_proj + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the OLMo architecture adapter.""" + super().__init__(cfg) + + # Set config variables for weight processing + self.cfg.normalization_type = "LN" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = False + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = False + # Force eager attention for numerical consistency with benchmark reference + self.cfg.attn_implementation = "eager" + + self.default_config = { + "d_model": cfg.d_model, + "d_head": cfg.d_model // cfg.n_heads, + "n_heads": cfg.n_heads, + "n_layers": cfg.n_layers, + "d_vocab": cfg.d_vocab, + } + + # GQA support + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.default_config["n_key_value_heads"] = cfg.n_key_value_heads + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + n_kv_heads = ( + self.cfg.n_key_value_heads + if self.cfg.n_key_value_heads is not None + else self.cfg.n_heads + ) + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), + ), + } + + # Component mapping — PRE-NORM architecture: + # ln1 = input_layernorm (applied BEFORE attention) + # ln2 = post_attention_layernorm (applied BEFORE MLP) + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": NormalizationBridge( + name="input_layernorm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + "ln2": NormalizationBridge( + name="post_attention_layernorm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": NormalizationBridge( + name="model.norm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up rotary embedding references for OLMo component testing. + + OLMo uses RoPE (Rotary Position Embeddings). We set the rotary_emb + reference on all attention bridge instances for component testing. + + Args: + hf_model: The HuggingFace OLMo model instance + bridge_model: The TransformerBridge model (if available) + """ + # Get rotary embedding instance from the model + rotary_emb = hf_model.model.rotary_emb + + # Force HF model to use "eager" attention to match bridge implementation + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + + if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): + for layer in hf_model.model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" + + # Set rotary_emb on actual bridge instances in bridge_model if available + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + # Also set on the template for get_generalized_component() calls + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) diff --git a/transformer_lens/model_bridge/supported_architectures/olmo2.py b/transformer_lens/model_bridge/supported_architectures/olmo2.py new file mode 100644 index 000000000..719607199 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/olmo2.py @@ -0,0 +1,174 @@ +"""OLMo 2 architecture adapter.""" + +from typing import Any + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + PositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + + +class Olmo2ArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for OLMo 2 models. + + OLMo 2 uses a post-norm architecture with RMSNorm, Q/K normalization in attention, + rotary position embeddings (RoPE), and gated MLP (SwiGLU). Key differences from + pre-norm models like Llama: + + - Post-norm: RMSNorm is applied AFTER attention and AFTER MLP, not before. + ln1 maps to post_attention_layernorm, ln2 maps to post_feedforward_layernorm. + - Q/K normalization: Per-head RMSNorm applied to queries and keys after projection. + - No biases on any projections. + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + - blocks.{i}.attn.b_Q - No bias on query projection + - blocks.{i}.attn.b_K - No bias on key projection + - blocks.{i}.attn.b_V - No bias on value projection + - blocks.{i}.attn.b_O - No bias on output projection + - blocks.{i}.mlp.b_in - No bias on MLP up_proj + - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj + - blocks.{i}.mlp.b_out - No bias on MLP down_proj + - blocks.{i}.ln1.b - RMSNorm has no bias + - blocks.{i}.ln2.b - RMSNorm has no bias + - ln_final.b - RMSNorm has no bias + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the OLMo 2 architecture adapter.""" + super().__init__(cfg) + + # Set config variables for weight processing + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = True + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + # Force eager attention for numerical consistency with benchmark reference. + # PositionEmbeddingsAttentionBridge delegates to native HF attention, so + # both bridge and reference must use the same implementation. + self.cfg.attn_implementation = "eager" + + self.default_config = { + "d_model": cfg.d_model, + "d_head": cfg.d_model // cfg.n_heads, + "n_heads": cfg.n_heads, + "n_layers": cfg.n_layers, + "d_vocab": cfg.d_vocab, + } + + # GQA support + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.default_config["n_key_value_heads"] = cfg.n_key_value_heads + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + n_kv_heads = ( + self.cfg.n_key_value_heads + if self.cfg.n_key_value_heads is not None + else self.cfg.n_heads + ) + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), + ), + } + + # Component mapping — POST-NORM architecture: + # ln1 = post_attention_layernorm (applied AFTER attention) + # ln2 = post_feedforward_layernorm (applied AFTER MLP) + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge( + name="post_feedforward_layernorm", config=self.cfg + ), + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), + "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up rotary embedding references for OLMo 2 component testing. + + OLMo 2 uses RoPE (Rotary Position Embeddings). We set the rotary_emb + reference on all attention bridge instances for component testing. + + We also force the HF model to use "eager" attention to match the bridge's + implementation. The bridge uses "eager" to support output_attentions for hooks. + + Args: + hf_model: The HuggingFace OLMo 2 model instance + bridge_model: The TransformerBridge model (if available) + """ + # Get rotary embedding instance from the model + rotary_emb = hf_model.model.rotary_emb + + # Force HF model to use "eager" attention to match bridge implementation + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + + # Also set on all attention layers + if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): + for layer in hf_model.model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" + + # Set rotary_emb on actual bridge instances in bridge_model if available + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + # Also set on the template for get_generalized_component() calls + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) diff --git a/transformer_lens/model_bridge/supported_architectures/olmo3.py b/transformer_lens/model_bridge/supported_architectures/olmo3.py new file mode 100644 index 000000000..c7c74c9b6 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/olmo3.py @@ -0,0 +1,17 @@ +"""OLMo 3 architecture adapter.""" + +from transformer_lens.model_bridge.supported_architectures.olmo2 import ( + Olmo2ArchitectureAdapter, +) + + +class Olmo3ArchitectureAdapter(Olmo2ArchitectureAdapter): + """Architecture adapter for OLMo 3 / OLMo 3.1 models. + + OLMo 3 is architecturally identical to OLMo 2 at the weight and component level. + The only difference is sliding window attention on some layers (configurable via + layer_types), which is handled by the HF model's forward pass (mask creation) + and does not affect weight structure or component mapping. + """ + + pass diff --git a/transformer_lens/model_bridge/supported_architectures/olmoe.py b/transformer_lens/model_bridge/supported_architectures/olmoe.py new file mode 100644 index 000000000..ae7611b0f --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/olmoe.py @@ -0,0 +1,168 @@ +"""OLMoE (Mixture of Experts) architecture adapter.""" + +from typing import Any + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + LinearBridge, + MoEBridge, + PositionEmbeddingsAttentionBridge, + RMSNormalizationBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + + +class OlmoeArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for OLMoE (Mixture of Experts) models. + + OLMoE uses a pre-norm architecture with RMSNorm, Q/K normalization in attention, + rotary position embeddings (RoPE), and sparse Mixture of Experts MLP. Key features: + + - Pre-norm: RMSNorm applied BEFORE attention and BEFORE MLP. + - Q/K normalization: RMSNorm applied to queries and keys after projection. + - Sparse MoE: 64 experts with top-8 routing (configurable). + - Batched expert parameters: gate_up_proj [num_experts, 2*d_mlp, d_model] and + down_proj [num_experts, d_model, d_mlp] as single tensors, not a ModuleList. + - Optional QKV clipping (handled by HF's native attention forward). + - No biases on any projections. + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + - blocks.{i}.attn.b_Q - No bias on query projection + - blocks.{i}.attn.b_K - No bias on key projection + - blocks.{i}.attn.b_V - No bias on value projection + - blocks.{i}.attn.b_O - No bias on output projection + - blocks.{i}.ln1.b - RMSNorm has no bias + - blocks.{i}.ln2.b - RMSNorm has no bias + - ln_final.b - RMSNorm has no bias + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the OLMoE architecture adapter.""" + super().__init__(cfg) + + # Set config variables for weight processing + self.cfg.normalization_type = "RMS" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = False + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = True + # Force eager attention for numerical consistency with benchmark reference + self.cfg.attn_implementation = "eager" + + self.default_config = { + "d_model": cfg.d_model, + "d_head": cfg.d_model // cfg.n_heads, + "n_heads": cfg.n_heads, + "n_layers": cfg.n_layers, + "d_vocab": cfg.d_vocab, + } + + # GQA support + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.default_config["n_key_value_heads"] = cfg.n_key_value_heads + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + n_kv_heads = ( + self.cfg.n_key_value_heads + if self.cfg.n_key_value_heads is not None + else self.cfg.n_heads + ) + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), + ), + } + + # Component mapping — PRE-NORM architecture: + # ln1 = input_layernorm (applied BEFORE attention) + # ln2 = post_attention_layernorm (applied BEFORE MLP) + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb", config=self.cfg), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg), + "ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg), + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + "q_norm": RMSNormalizationBridge(name="q_norm", config=self.cfg), + "k_norm": RMSNormalizationBridge(name="k_norm", config=self.cfg), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + # OLMoE uses batched expert parameters (gate_up_proj, down_proj + # as 3D tensors) rather than a ModuleList of individual experts. + # MoEBridge wraps the entire MLP module and delegates to HF's + # native forward pass. The gate (router) is mapped as a submodule + # for hook access. + "mlp": MoEBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate"), + }, + ), + }, + ), + "ln_final": RMSNormalizationBridge(name="model.norm", config=self.cfg), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up rotary embedding references for OLMoE component testing. + + OLMoE uses RoPE (Rotary Position Embeddings). We set the rotary_emb + reference on all attention bridge instances for component testing. + + Args: + hf_model: The HuggingFace OLMoE model instance + bridge_model: The TransformerBridge model (if available) + """ + # Get rotary embedding instance from the model + rotary_emb = hf_model.model.rotary_emb + + # Force HF model to use "eager" attention to match bridge implementation + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + + if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): + for layer in hf_model.model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" + + # Set rotary_emb on actual bridge instances in bridge_model if available + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + # Also set on the template for get_generalized_component() calls + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb)