From 6c4a213abc241047bd0b2101b75a190d7f70f8c9 Mon Sep 17 00:00:00 2001 From: jlarson Date: Tue, 10 Feb 2026 19:27:07 -0600 Subject: [PATCH 1/6] Testing R1 Distills to confirm functional in TransformerLens --- transformer_lens/benchmarks/main_benchmark.py | 18 +++++++---- .../rotary_embedding.py | 9 +++++- .../model_bridge/sources/transformers.py | 11 +++++++ .../supported_architectures/gemma3.py | 9 +++--- .../supported_architectures/qwen2.py | 4 +-- transformer_lens/supported_models.py | 30 +++++++++++++++++++ 6 files changed, 68 insertions(+), 13 deletions(-) diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 091a87873..24e981f60 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -824,10 +824,16 @@ def cleanup_model(model, model_name_str: str): try: # Load a lightweight version without weights to get config bridge_config_only = TransformerBridge.boot_transformers(model_name, device=device, dtype=bridge_dtype, load_weights=False) # type: ignore[attr-defined] - # Extract attn_implementation for HF model loading + # Extract attn_implementation for HF model loading. + # First check if adapter explicitly sets it (e.g. qwen3, gemma3). if hasattr(bridge_config_only.adapter.cfg, "attn_implementation"): attn_implementation = bridge_config_only.adapter.cfg.attn_implementation - if verbose and attn_implementation: + # TransformerBridge always loads HF models with output_attentions=True + # (see sources/transformers.py), which causes HF to fall back from SDPA + # to eager attention. We must match this in the reference model. + if attn_implementation is None: + attn_implementation = "eager" + if verbose: print(f"✓ Detected attn_implementation={attn_implementation}") # Clean up config-only bridge immediately to free memory del bridge_config_only @@ -841,13 +847,14 @@ def cleanup_model(model, model_name_str: str): try: if verbose: print("Loading HuggingFace reference model...") - # Match attn_implementation from bridge to ensure numerical consistency + # Match loading path to TransformerBridge: no device_map, explicit .to(device) + # Using device_map causes different weight materialization than .to(device), + # which produces numerical divergence for bfloat16 models. hf_kwargs = { - "device_map": device, "low_cpu_mem_usage": True, # Reduce memory spikes during loading } if attn_implementation is not None: - hf_kwargs["attn_implementation"] = attn_implementation + hf_kwargs["attn_implementation"] = attn_implementation # type: ignore[assignment] if verbose: print(f"Using attn_implementation={attn_implementation}") # Use appropriate AutoModel class (e.g., AutoModelForSeq2SeqLM for T5) @@ -855,6 +862,7 @@ def cleanup_model(model, model_name_str: str): if verbose and auto_model_class != AutoModelForCausalLM: print(f"Using {auto_model_class.__name__} for encoder-decoder model") hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type] + hf_model = hf_model.to(device) hf_model.eval() # Detect dtype from HF model try: diff --git a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py index c3bb81378..3af922a04 100644 --- a/transformer_lens/model_bridge/generalized_components/rotary_embedding.py +++ b/transformer_lens/model_bridge/generalized_components/rotary_embedding.py @@ -72,7 +72,14 @@ def get_random_inputs( head_dim = 256 x = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype) position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) - return {"args": (x, position_ids)} + args: tuple = (x, position_ids) + # Gemma3's rotary embedding requires a layer_type argument (e.g., "sliding_attention") + # to select the correct inv_freq buffer. Without it, forward() tries to access + # "None_inv_freq" which doesn't exist. + if self.original_component is not None and hasattr(self.original_component, "layer_types"): + layer_type = self.original_component.layer_types[0] # type: ignore[index] + args = (x, position_ids, layer_type) + return {"args": args} def forward(self, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass through the rotary embedding bridge. diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 90675b167..9628bcbb9 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -86,6 +86,12 @@ def map_default_transformer_lens_config(hf_config): tl_config.n_ctx = hf_config.max_position_embeddings elif hasattr(hf_config, "max_length"): tl_config.n_ctx = hf_config.max_length + elif hasattr(hf_config, "seq_length"): + tl_config.n_ctx = hf_config.seq_length + else: + # Models like Bloom use ALiBi (no positional embeddings) and have no + # context length field. Default to 2048 as a reasonable fallback. + tl_config.n_ctx = 2048 if hasattr(hf_config, "n_inner"): tl_config.d_mlp = hf_config.n_inner elif hasattr(hf_config, "intermediate_size"): @@ -237,6 +243,11 @@ def boot( device = get_device() adapter.cfg.device = str(device) model_class = get_hf_model_class_for_architecture(architecture) + # Ensure pad_token_id exists on HF config. Transformers v5 raises AttributeError + # for missing config attributes (instead of returning None), which crashes models + # like Phi-1 that access config.pad_token_id during __init__. + if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: + hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) model_kwargs = {"config": hf_config, "torch_dtype": dtype} if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation diff --git a/transformer_lens/model_bridge/supported_architectures/gemma3.py b/transformer_lens/model_bridge/supported_architectures/gemma3.py index 76ee59b3b..4e37ba7a6 100644 --- a/transformer_lens/model_bridge/supported_architectures/gemma3.py +++ b/transformer_lens/model_bridge/supported_architectures/gemma3.py @@ -127,7 +127,6 @@ def __init__(self, cfg: Any) -> None: self.component_mapping = { "embed": EmbeddingBridge(name="model.embed_tokens"), "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), - "rotary_emb_local": RotaryEmbeddingBridge(name="model.rotary_emb_local"), "blocks": BlockBridge( name="model.layers", submodules={ @@ -224,8 +223,8 @@ def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> No hf_model: The HuggingFace Gemma-3 model instance bridge_model: The TransformerBridge model (if available, set rotary_emb on actual instances) """ - # Get rotary embedding instances from the model - rotary_emb_local = hf_model.model.rotary_emb_local # Used by 22/26 layers + # Get the shared rotary embedding from the model (contains both global and local RoPE) + rotary_emb = hf_model.model.rotary_emb # Force HF model to use "eager" attention to match bridge implementation # Bridge uses "eager" to support output_attentions for hook compatibility @@ -244,7 +243,7 @@ def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> No # Set on each layer's actual attention bridge instance for block in bridge_model.blocks: if hasattr(block, "attn"): - block.attn.set_rotary_emb(rotary_emb_local) + block.attn.set_rotary_emb(rotary_emb) # Enable native autograd for q_norm/k_norm to match HF exactly if hasattr(block.attn, "original_component"): @@ -256,4 +255,4 @@ def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> No # Also set on the template for get_generalized_component() calls attn_bridge = self.get_generalized_component("blocks.0.attn") - attn_bridge.set_rotary_emb(rotary_emb_local) + attn_bridge.set_rotary_emb(rotary_emb) diff --git a/transformer_lens/model_bridge/supported_architectures/qwen2.py b/transformer_lens/model_bridge/supported_architectures/qwen2.py index fbe94fe77..8a905e7c0 100644 --- a/transformer_lens/model_bridge/supported_architectures/qwen2.py +++ b/transformer_lens/model_bridge/supported_architectures/qwen2.py @@ -62,13 +62,13 @@ def __init__(self, cfg: Any) -> None: "blocks.{i}.attn.k.weight": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( "(n h) m -> n m h", - n=getattr(self.cfg, "num_key_value_heads", self.cfg.n_heads), + n=getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads), ), ), "blocks.{i}.attn.v.weight": ParamProcessingConversion( tensor_conversion=RearrangeTensorConversion( "(n h) m -> n m h", - n=getattr(self.cfg, "num_key_value_heads", self.cfg.n_heads), + n=getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads), ), ), "blocks.{i}.attn.o.weight": ParamProcessingConversion( diff --git a/transformer_lens/supported_models.py b/transformer_lens/supported_models.py index 18f7bb377..ac103736f 100644 --- a/transformer_lens/supported_models.py +++ b/transformer_lens/supported_models.py @@ -15,6 +15,12 @@ "codellama/CodeLlama-7b-hf", "codellama/CodeLlama-7b-Instruct-hf", "codellama/CodeLlama-7b-Python-hf", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", + "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", "distilgpt2", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-1.3B", @@ -254,6 +260,30 @@ "codellama/CodeLlama-7b-Instruct-hf", ], "codellama/CodeLlama-7b-Python-hf": ["CodeLlama-7b-python", "codellama/CodeLlama-7b-Python-hf"], + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B": [ + "deepseek-r1-distill-llama-8b", + "deepseek-r1-distill-llama-8b-chat", + ], + "deepseek-ai/DeepSeek-R1-Distill-Llama-70B": [ + "deepseek-r1-distill-llama-70b", + "deepseek-r1-distill-llama-70b-chat", + ], + "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B": [ + "deepseek-r1-distill-qwen-1.5b", + "deepseek-r1-distill-qwen-1.5b-chat", + ], + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": [ + "deepseek-r1-distill-qwen-7b", + "deepseek-r1-distill-qwen-7b-chat", + ], + "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": [ + "deepseek-r1-distill-qwen-14b", + "deepseek-r1-distill-qwen-14b-chat", + ], + "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B": [ + "deepseek-r1-distill-qwen-32b", + "deepseek-r1-distill-qwen-32b-chat", + ], "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], From fe7067aa9d32a2528bcd9842b3e8578da4e98034 Mon Sep 17 00:00:00 2001 From: jlarson Date: Tue, 10 Feb 2026 19:58:21 -0600 Subject: [PATCH 2/6] Updating order to be alphabetical --- transformer_lens/supported_models.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_lens/supported_models.py b/transformer_lens/supported_models.py index ac103736f..3adf140f8 100644 --- a/transformer_lens/supported_models.py +++ b/transformer_lens/supported_models.py @@ -15,12 +15,12 @@ "codellama/CodeLlama-7b-hf", "codellama/CodeLlama-7b-Instruct-hf", "codellama/CodeLlama-7b-Python-hf", - "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", - "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", "distilgpt2", "EleutherAI/gpt-j-6B", "EleutherAI/gpt-neo-1.3B", @@ -260,22 +260,18 @@ "codellama/CodeLlama-7b-Instruct-hf", ], "codellama/CodeLlama-7b-Python-hf": ["CodeLlama-7b-python", "codellama/CodeLlama-7b-Python-hf"], - "deepseek-ai/DeepSeek-R1-Distill-Llama-8B": [ - "deepseek-r1-distill-llama-8b", - "deepseek-r1-distill-llama-8b-chat", - ], "deepseek-ai/DeepSeek-R1-Distill-Llama-70B": [ "deepseek-r1-distill-llama-70b", "deepseek-r1-distill-llama-70b-chat", ], + "deepseek-ai/DeepSeek-R1-Distill-Llama-8B": [ + "deepseek-r1-distill-llama-8b", + "deepseek-r1-distill-llama-8b-chat", + ], "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B": [ "deepseek-r1-distill-qwen-1.5b", "deepseek-r1-distill-qwen-1.5b-chat", ], - "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": [ - "deepseek-r1-distill-qwen-7b", - "deepseek-r1-distill-qwen-7b-chat", - ], "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B": [ "deepseek-r1-distill-qwen-14b", "deepseek-r1-distill-qwen-14b-chat", @@ -284,6 +280,10 @@ "deepseek-r1-distill-qwen-32b", "deepseek-r1-distill-qwen-32b-chat", ], + "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": [ + "deepseek-r1-distill-qwen-7b", + "deepseek-r1-distill-qwen-7b-chat", + ], "distilgpt2": ["distillgpt2", "distill-gpt2", "distil-gpt2", "gpt2-xs"], "EleutherAI/gpt-j-6B": ["gpt-j-6B", "gpt-j", "gptj"], "EleutherAI/gpt-neo-1.3B": ["gpt-neo-1.3B", "gpt-neo-medium", "neo-medium"], From f8de02ae5cd9ba13cd371f9dc475170a5a6a5657 Mon Sep 17 00:00:00 2001 From: jlarson Date: Wed, 11 Feb 2026 10:47:02 -0600 Subject: [PATCH 3/6] Setup StableLM architecture adapter --- tests/mocks/models.py | 36 ++++ transformer_lens/benchmarks/main_benchmark.py | 7 + .../factories/architecture_adapter_factory.py | 2 + .../model_bridge/sources/transformers.py | 1 + .../supported_architectures/__init__.py | 4 + .../supported_architectures/stablelm.py | 180 ++++++++++++++++++ 6 files changed, 230 insertions(+) create mode 100644 transformer_lens/model_bridge/supported_architectures/stablelm.py diff --git a/tests/mocks/models.py b/tests/mocks/models.py index ada5b26da..d1a8e0978 100644 --- a/tests/mocks/models.py +++ b/tests/mocks/models.py @@ -35,3 +35,39 @@ def __init__(self): self.model.norm = nn.LayerNorm(512) self.lm_head = nn.Linear(512, 1000) # Add missing lm_head self.embed_tokens = self.model.embed_tokens # For shared embedding/unembedding + + +class MockStableLmModel(nn.Module): + """A mock implementation of the StableLM model architecture for testing purposes. + + Replicates the key architectural components of StableLM: + - Embedding layer (embed_tokens) + - Rotary embedding (rotary_emb) + - Multiple transformer layers with: + - Input and post-attention layer norms (standard LayerNorm) + - Self-attention with Q, K, V, O projections (Q/K/V have bias) + - MLP with gate, up, and down projections (no bias) + - Final layer norm + - LM head (tied to embed_tokens) + """ + + def __init__(self): + super().__init__() + self.model = nn.Module() + self.model.embed_tokens = nn.Embedding(1000, 512) + self.model.rotary_emb = nn.Module() # Mock rotary embedding + self.model.layers = nn.ModuleList([nn.Module() for _ in range(2)]) + for layer in self.model.layers: + layer.input_layernorm = nn.LayerNorm(512) + layer.post_attention_layernorm = nn.LayerNorm(512) + layer.self_attn = nn.Module() + layer.self_attn.q_proj = nn.Linear(512, 512, bias=True) + layer.self_attn.k_proj = nn.Linear(512, 512, bias=True) + layer.self_attn.v_proj = nn.Linear(512, 512, bias=True) + layer.self_attn.o_proj = nn.Linear(512, 512, bias=False) + layer.mlp = nn.Module() + layer.mlp.gate_proj = nn.Linear(512, 2048, bias=False) + layer.mlp.up_proj = nn.Linear(512, 2048, bias=False) + layer.mlp.down_proj = nn.Linear(2048, 512, bias=False) + self.model.norm = nn.LayerNorm(512) + self.lm_head = nn.Linear(512, 1000, bias=False) diff --git a/transformer_lens/benchmarks/main_benchmark.py b/transformer_lens/benchmarks/main_benchmark.py index 24e981f60..132ce69bb 100644 --- a/transformer_lens/benchmarks/main_benchmark.py +++ b/transformer_lens/benchmarks/main_benchmark.py @@ -861,6 +861,13 @@ def cleanup_model(model, model_name_str: str): auto_model_class = get_auto_model_class(model_name) if verbose and auto_model_class != AutoModelForCausalLM: print(f"Using {auto_model_class.__name__} for encoder-decoder model") + # Ensure pad_token_id exists on HF config. Transformers v5 raises + # AttributeError for missing config attributes, which crashes models + # like StableLM that access config.pad_token_id during __init__. + hf_config = AutoConfig.from_pretrained(model_name) + if not hasattr(hf_config, "pad_token_id") or "pad_token_id" not in hf_config.__dict__: + hf_config.pad_token_id = getattr(hf_config, "eos_token_id", None) + hf_kwargs["config"] = hf_config hf_model = auto_model_class.from_pretrained(model_name, **hf_kwargs) # type: ignore[arg-type] hf_model = hf_model.to(device) hf_model.eval() diff --git a/transformer_lens/factories/architecture_adapter_factory.py b/transformer_lens/factories/architecture_adapter_factory.py index 7d6c7f4c1..aa83dd402 100644 --- a/transformer_lens/factories/architecture_adapter_factory.py +++ b/transformer_lens/factories/architecture_adapter_factory.py @@ -29,6 +29,7 @@ Qwen2ArchitectureAdapter, Qwen3ArchitectureAdapter, QwenArchitectureAdapter, + StableLmArchitectureAdapter, T5ArchitectureAdapter, ) @@ -56,6 +57,7 @@ "QwenForCausalLM": QwenArchitectureAdapter, "Qwen2ForCausalLM": Qwen2ArchitectureAdapter, "Qwen3ForCausalLM": Qwen3ArchitectureAdapter, + "StableLmForCausalLM": StableLmArchitectureAdapter, "T5ForConditionalGeneration": T5ArchitectureAdapter, "NanoGPTForCausalLM": NanogptArchitectureAdapter, "MinGPTForCausalLM": MingptArchitectureAdapter, diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 9628bcbb9..a4124fd35 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -152,6 +152,7 @@ def determine_architecture_from_hf_config(hf_config): "qwen": "QwenForCausalLM", "qwen2": "Qwen2ForCausalLM", "qwen3": "Qwen3ForCausalLM", + "stablelm": "StableLmForCausalLM", "t5": "T5ForConditionalGeneration", } if model_type in model_type_mappings: diff --git a/transformer_lens/model_bridge/supported_architectures/__init__.py b/transformer_lens/model_bridge/supported_architectures/__init__.py index 23bbabada..a07cb3c03 100644 --- a/transformer_lens/model_bridge/supported_architectures/__init__.py +++ b/transformer_lens/model_bridge/supported_architectures/__init__.py @@ -73,6 +73,9 @@ from transformer_lens.model_bridge.supported_architectures.qwen3 import ( Qwen3ArchitectureAdapter, ) +from transformer_lens.model_bridge.supported_architectures.stablelm import ( + StableLmArchitectureAdapter, +) from transformer_lens.model_bridge.supported_architectures.t5 import ( T5ArchitectureAdapter, ) @@ -101,5 +104,6 @@ "QwenArchitectureAdapter", "Qwen2ArchitectureAdapter", "Qwen3ArchitectureAdapter", + "StableLmArchitectureAdapter", "T5ArchitectureAdapter", ] diff --git a/transformer_lens/model_bridge/supported_architectures/stablelm.py b/transformer_lens/model_bridge/supported_architectures/stablelm.py new file mode 100644 index 000000000..56cd272e3 --- /dev/null +++ b/transformer_lens/model_bridge/supported_architectures/stablelm.py @@ -0,0 +1,180 @@ +"""StableLM architecture adapter.""" + +from typing import Any + +from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion +from transformer_lens.conversion_utils.param_processing_conversion import ( + ParamProcessingConversion, +) +from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter +from transformer_lens.model_bridge.generalized_components import ( + BlockBridge, + EmbeddingBridge, + GatedMLPBridge, + LinearBridge, + NormalizationBridge, + PositionEmbeddingsAttentionBridge, + RotaryEmbeddingBridge, + UnembeddingBridge, +) + + +class StableLmArchitectureAdapter(ArchitectureAdapter): + """Architecture adapter for StableLM models. + + StableLM uses a Llama-like architecture with separate Q/K/V projections and + gated MLP, but differs in using standard LayerNorm (not RMSNorm) and partial + rotary embeddings (25% of head dimensions by default). + + Supports optional features: + - Grouped Query Attention (num_key_value_heads != num_attention_heads) + - QKV bias (use_qkv_bias=True on some models like stable-code-3b) + - Parallel residual connections (use_parallel_residual=True) + - Per-head QK LayerNorm (qk_layernorm=True) + + Optional Parameters (may not exist in state_dict): + ------------------------------------------------- + - blocks.{i}.attn.b_Q - Only present when use_qkv_bias=True + - blocks.{i}.attn.b_K - Only present when use_qkv_bias=True + - blocks.{i}.attn.b_V - Only present when use_qkv_bias=True + - blocks.{i}.attn.b_O - No bias on output projection + - blocks.{i}.mlp.b_in - No bias on MLP up_proj + - blocks.{i}.mlp.b_gate - No bias on MLP gate_proj + - blocks.{i}.mlp.b_out - No bias on MLP down_proj + """ + + def __init__(self, cfg: Any) -> None: + """Initialize the StableLM architecture adapter.""" + super().__init__(cfg) + + # Set config variables for weight processing + self.cfg.normalization_type = "LN" + self.cfg.positional_embedding_type = "rotary" + self.cfg.final_rms = False + self.cfg.gated_mlp = True + self.cfg.attn_only = False + self.cfg.uses_rms_norm = False + # Force eager attention for numerical consistency with benchmark reference + # PositionEmbeddingsAttentionBridge delegates to native HF attention, so + # both bridge and reference must use the same implementation + self.cfg.attn_implementation = "eager" + + self.default_config = { + "d_model": cfg.d_model, + "d_head": cfg.d_model // cfg.n_heads, + "n_heads": cfg.n_heads, + "n_layers": cfg.n_layers, + "d_vocab": cfg.d_vocab, + } + + # GQA support + if hasattr(cfg, "n_key_value_heads") and cfg.n_key_value_heads is not None: + self.default_config["n_key_value_heads"] = cfg.n_key_value_heads + self.cfg.n_key_value_heads = cfg.n_key_value_heads + + n_kv_heads = getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads) + + self.weight_processing_conversions = { + "blocks.{i}.attn.q.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=self.cfg.n_heads), + ), + "blocks.{i}.attn.k.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.v.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) m -> n m h", n=n_kv_heads), + ), + "blocks.{i}.attn.o.weight": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("m (n h) -> n h m", n=self.cfg.n_heads), + ), + # Bias conversions for models with use_qkv_bias=True + "blocks.{i}.attn.q.bias": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=self.cfg.n_heads), + ), + "blocks.{i}.attn.k.bias": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads), + ), + "blocks.{i}.attn.v.bias": ParamProcessingConversion( + tensor_conversion=RearrangeTensorConversion("(n h) -> n h", n=n_kv_heads), + ), + } + + self.component_mapping = { + "embed": EmbeddingBridge(name="model.embed_tokens"), + "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), + "blocks": BlockBridge( + name="model.layers", + submodules={ + "ln1": NormalizationBridge( + name="input_layernorm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + "ln2": NormalizationBridge( + name="post_attention_layernorm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + "attn": PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ), + "mlp": GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ), + }, + ), + "ln_final": NormalizationBridge( + name="model.norm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), + } + + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: + """Set up rotary embedding references for StableLM component testing. + + StableLM uses RoPE (Rotary Position Embeddings) with partial rotation. + We set the rotary_emb reference on all attention bridge instances and + force eager attention for numerical consistency. + + Args: + hf_model: The HuggingFace StableLM model instance + bridge_model: The TransformerBridge model (if available) + """ + rotary_emb = hf_model.model.rotary_emb + + # Force HF model to use "eager" attention to match bridge implementation + # Bridge uses "eager" to support output_attentions for hook compatibility + # SDPA and eager are mathematically equivalent but have numerical differences + if hasattr(hf_model, "config") and hasattr(hf_model.config, "_attn_implementation"): + hf_model.config._attn_implementation = "eager" + + # Also set on all attention layers + if hasattr(hf_model, "model") and hasattr(hf_model.model, "layers"): + for layer in hf_model.model.layers: + if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "config"): + layer.self_attn.config._attn_implementation = "eager" + + if bridge_model is not None and hasattr(bridge_model, "blocks"): + for block in bridge_model.blocks: + if hasattr(block, "attn"): + block.attn.set_rotary_emb(rotary_emb) + + attn_bridge = self.get_generalized_component("blocks.0.attn") + attn_bridge.set_rotary_emb(rotary_emb) From 0c6bfe6a6815ecbb3eefc2d573d8573e0899b3ec Mon Sep 17 00:00:00 2001 From: jlarson Date: Wed, 11 Feb 2026 13:24:02 -0600 Subject: [PATCH 4/6] Resolved weight and qk issues with stablelm. Added more models --- .../model_bridge/sources/transformers.py | 2 + .../supported_architectures/stablelm.py | 149 ++++++++++++++---- transformer_lens/supported_models.py | 12 ++ transformer_lens/weight_processing.py | 11 +- 4 files changed, 138 insertions(+), 36 deletions(-) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index a4124fd35..b46a4d67c 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -110,6 +110,8 @@ def map_default_transformer_lens_config(hf_config): tl_config.experts_per_token = hf_config.num_experts_per_tok if hasattr(hf_config, "sliding_window") and hf_config.sliding_window is not None: tl_config.sliding_window = hf_config.sliding_window + if getattr(hf_config, "use_parallel_residual", False): + tl_config.parallel_attn_mlp = True tl_config.default_prepend_bos = True return tl_config diff --git a/transformer_lens/model_bridge/supported_architectures/stablelm.py b/transformer_lens/model_bridge/supported_architectures/stablelm.py index 56cd272e3..7a8d77c5f 100644 --- a/transformer_lens/model_bridge/supported_architectures/stablelm.py +++ b/transformer_lens/model_bridge/supported_architectures/stablelm.py @@ -2,10 +2,13 @@ from typing import Any +import torch + from transformer_lens.conversion_utils.conversion_steps import RearrangeTensorConversion from transformer_lens.conversion_utils.param_processing_conversion import ( ParamProcessingConversion, ) +from transformer_lens.hook_points import HookPoint from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter from transformer_lens.model_bridge.generalized_components import ( BlockBridge, @@ -72,7 +75,7 @@ def __init__(self, cfg: Any) -> None: self.default_config["n_key_value_heads"] = cfg.n_key_value_heads self.cfg.n_key_value_heads = cfg.n_key_value_heads - n_kv_heads = getattr(self.cfg, "n_key_value_heads", self.cfg.n_heads) + n_kv_heads = self.cfg.n_key_value_heads if self.cfg.n_key_value_heads is not None else self.cfg.n_heads self.weight_processing_conversions = { "blocks.{i}.attn.q.weight": ParamProcessingConversion( @@ -99,44 +102,56 @@ def __init__(self, cfg: Any) -> None: ), } + # When parallel_attn_mlp=True (HF: use_parallel_residual=True), both attn + # and MLP read from ln1 output: + # x = x + attn(ln1(x)) + mlp(ln1(x)) + # When False, they are sequential with separate norms: + # x = x + attn(ln1(x)); x = x + mlp(ln2(x)) + # HF sets post_attention_layernorm=None when use_parallel_residual=True, + # so we must not include ln2 in that case. + use_parallel_residual = getattr(cfg, "parallel_attn_mlp", False) + + block_submodules: dict[str, Any] = { + "ln1": NormalizationBridge( + name="input_layernorm", + config=self.cfg, + use_native_layernorm_autograd=True, + ), + } + if not use_parallel_residual: + block_submodules["ln2"] = NormalizationBridge( + name="post_attention_layernorm", + config=self.cfg, + use_native_layernorm_autograd=True, + ) + block_submodules["attn"] = PositionEmbeddingsAttentionBridge( + name="self_attn", + config=self.cfg, + submodules={ + "q": LinearBridge(name="q_proj"), + "k": LinearBridge(name="k_proj"), + "v": LinearBridge(name="v_proj"), + "o": LinearBridge(name="o_proj"), + }, + requires_attention_mask=True, + requires_position_embeddings=True, + ) + block_submodules["mlp"] = GatedMLPBridge( + name="mlp", + config=self.cfg, + submodules={ + "gate": LinearBridge(name="gate_proj"), + "in": LinearBridge(name="up_proj"), + "out": LinearBridge(name="down_proj"), + }, + ) + self.component_mapping = { "embed": EmbeddingBridge(name="model.embed_tokens"), "rotary_emb": RotaryEmbeddingBridge(name="model.rotary_emb"), "blocks": BlockBridge( name="model.layers", - submodules={ - "ln1": NormalizationBridge( - name="input_layernorm", - config=self.cfg, - use_native_layernorm_autograd=True, - ), - "ln2": NormalizationBridge( - name="post_attention_layernorm", - config=self.cfg, - use_native_layernorm_autograd=True, - ), - "attn": PositionEmbeddingsAttentionBridge( - name="self_attn", - config=self.cfg, - submodules={ - "q": LinearBridge(name="q_proj"), - "k": LinearBridge(name="k_proj"), - "v": LinearBridge(name="v_proj"), - "o": LinearBridge(name="o_proj"), - }, - requires_attention_mask=True, - requires_position_embeddings=True, - ), - "mlp": GatedMLPBridge( - name="mlp", - config=self.cfg, - submodules={ - "gate": LinearBridge(name="gate_proj"), - "in": LinearBridge(name="up_proj"), - "out": LinearBridge(name="down_proj"), - }, - ), - }, + submodules=block_submodules, ), "ln_final": NormalizationBridge( name="model.norm", @@ -146,6 +161,72 @@ def __init__(self, cfg: Any) -> None: "unembed": UnembeddingBridge(name="lm_head", config=self.cfg), } + def setup_hook_compatibility(self, bridge: Any) -> None: + """Inject hook points for QK LayerNorm on models with qk_layernorm=True. + + StableLM v2 models (e.g., stablelm-2-12b) apply per-head LayerNorm to Q and K + after projection but before rotary embedding. The native HF attention handles + this internally, but we inject hooks so researchers can observe/intervene on + the post-norm Q/K values. + + Adds to each attention bridge: + - hook_q_layernorm: fires after q_layernorm(query_states) + - hook_k_layernorm: fires after k_layernorm(key_states) + + This runs during bridge __init__ via _setup_hook_compatibility(), after + component setup but before hook registry finalization. The hook registry + scanner skips _original_component subtrees, so we register hooks directly + in bridge._hook_registry with canonical TL-style names. + + Args: + bridge: The TransformerBridge instance (fully initialized) + """ + if not hasattr(bridge, "blocks"): + return + + for i, block in enumerate(bridge.blocks): + if not hasattr(block, "attn"): + continue + attn_bridge = block.attn + hf_attn = getattr(attn_bridge, "original_component", None) + if hf_attn is None: + continue + if not getattr(hf_attn, "qk_layernorm", False): + continue + + # Add hook points to the attention bridge as proper submodules + attn_bridge.add_module("hook_q_layernorm", HookPoint()) + attn_bridge.add_module("hook_k_layernorm", HookPoint()) + + # Register directly in bridge's hook registry with canonical names + # (the scanner skips _original_component subtrees so won't find these) + q_name = f"blocks.{i}.attn.hook_q_layernorm" + k_name = f"blocks.{i}.attn.hook_k_layernorm" + attn_bridge.hook_q_layernorm.name = q_name + attn_bridge.hook_k_layernorm.name = k_name + bridge._hook_registry[q_name] = attn_bridge.hook_q_layernorm + bridge._hook_registry[k_name] = attn_bridge.hook_k_layernorm + + # Wrap the HF q_layernorm/k_layernorm forward methods to fire hooks + original_q_ln_forward = hf_attn.q_layernorm.forward + original_k_ln_forward = hf_attn.k_layernorm.forward + + # Use a closure factory to capture the correct references + def _make_hooked_forward( + original_forward: Any, hook: HookPoint + ) -> Any: + def hooked_forward(hidden_states: torch.Tensor) -> torch.Tensor: + result = original_forward(hidden_states) + return hook(result) + return hooked_forward + + hf_attn.q_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign] + original_q_ln_forward, attn_bridge.hook_q_layernorm + ) + hf_attn.k_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign] + original_k_ln_forward, attn_bridge.hook_k_layernorm + ) + def setup_component_testing(self, hf_model: Any, bridge_model: Any = None) -> None: """Set up rotary embedding references for StableLM component testing. diff --git a/transformer_lens/supported_models.py b/transformer_lens/supported_models.py index 3adf140f8..4eb851d28 100644 --- a/transformer_lens/supported_models.py +++ b/transformer_lens/supported_models.py @@ -222,10 +222,16 @@ "roneneldan/TinyStories-Instruct-3M", "roneneldan/TinyStories-Instruct-8M", "roneneldan/TinyStories-Instuct-1Layer-21M", + "stabilityai/stable-code-3b", + "stabilityai/stable-code-instruct-3b", + "stabilityai/stablelm-2-1_6b", + "stabilityai/stablelm-2-zephyr-1_6b", + "stabilityai/stablelm-3b-4e1t", "stabilityai/stablelm-base-alpha-3b", "stabilityai/stablelm-base-alpha-7b", "stabilityai/stablelm-tuned-alpha-3b", "stabilityai/stablelm-tuned-alpha-7b", + "stabilityai/stablelm-zephyr-3b", "stanford-crfm/alias-gpt2-small-x21", "stanford-crfm/arwen-gpt2-medium-x21", "stanford-crfm/battlestar-gpt2-small-x49", @@ -576,10 +582,16 @@ "roneneldan/TinyStories-Instruct-3M": ["tiny-stories-instruct-3M"], "roneneldan/TinyStories-Instruct-8M": ["tiny-stories-instruct-8M"], "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], + "stabilityai/stable-code-3b": ["stable-code-3b"], + "stabilityai/stable-code-instruct-3b": ["stable-code-instruct-3b"], + "stabilityai/stablelm-2-1_6b": ["stablelm-2-1.6b"], + "stabilityai/stablelm-2-zephyr-1_6b": ["stablelm-2-zephyr-1.6b"], + "stabilityai/stablelm-3b-4e1t": ["stablelm-3b-4e1t", "stablelm-3b"], "stabilityai/stablelm-base-alpha-3b": ["stablelm-base-alpha-3b", "stablelm-base-3b"], "stabilityai/stablelm-base-alpha-7b": ["stablelm-base-alpha-7b", "stablelm-base-7b"], "stabilityai/stablelm-tuned-alpha-3b": ["stablelm-tuned-alpha-3b", "stablelm-tuned-3b"], "stabilityai/stablelm-tuned-alpha-7b": ["stablelm-tuned-alpha-7b", "stablelm-tuned-7b"], + "stabilityai/stablelm-zephyr-3b": ["stablelm-zephyr-3b"], "stanford-crfm/alias-gpt2-small-x21": [ "stanford-gpt2-small-a", "alias-gpt2-small-x21", diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index 31318f7a7..8a2fa63cf 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -528,8 +528,12 @@ def _fold_mlp_layer_norm( mlp_b_in = ProcessWeights.convert_tensor_to_tl_format( mlp_b_in_key, state_dict, state_dict.get(mlp_b_in_key), cfg, adapter, layer ) - assert mlp_b_in is not None, f"MLP b_in not found at key {mlp_b_in_key}" - new_mlp_b_in = mlp_b_in + (mlp_W_in * ln2_b_broadcast).sum(sum_dim) + ln2_b_folded = (mlp_W_in * ln2_b_broadcast).sum(sum_dim) + if mlp_b_in is not None: + new_mlp_b_in = mlp_b_in + ln2_b_folded + else: + # MLP has no bias — create one from the folded LN bias + new_mlp_b_in = ln2_b_folded state_dict[mlp_b_in_key] = ProcessWeights.convert_tensor_to_hf_format( mlp_b_in_key, new_mlp_b_in, cfg, adapter, layer ) @@ -1554,6 +1558,9 @@ def convert_tensor_to_tl_format( # (string mappings are handled elsewhere in the architecture adapter) return tensor else: + # Skip conversion for optional parameters that don't exist (e.g. biases) + if tensor is None and param_name not in model_state_dict: + return None # Let ParamProcessingConversion handle the fetching and conversion return param_conversion.convert(model_state_dict, param_name) else: From a561675e8c0fed8cba28c7ae930aa8eb18470856 Mon Sep 17 00:00:00 2001 From: jlarson Date: Wed, 11 Feb 2026 14:29:14 -0600 Subject: [PATCH 5/6] Added more models --- transformer_lens/supported_models.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_lens/supported_models.py b/transformer_lens/supported_models.py index 4eb851d28..a3e8f86c3 100644 --- a/transformer_lens/supported_models.py +++ b/transformer_lens/supported_models.py @@ -224,7 +224,10 @@ "roneneldan/TinyStories-Instuct-1Layer-21M", "stabilityai/stable-code-3b", "stabilityai/stable-code-instruct-3b", + "stabilityai/stablelm-2-12b", + "stabilityai/stablelm-2-12b-chat", "stabilityai/stablelm-2-1_6b", + "stabilityai/stablelm-2-1_6b-chat", "stabilityai/stablelm-2-zephyr-1_6b", "stabilityai/stablelm-3b-4e1t", "stabilityai/stablelm-base-alpha-3b", @@ -584,7 +587,10 @@ "roneneldan/TinyStories-Instuct-1Layer-21M": ["tiny-stories-instruct-1L-21M"], "stabilityai/stable-code-3b": ["stable-code-3b"], "stabilityai/stable-code-instruct-3b": ["stable-code-instruct-3b"], + "stabilityai/stablelm-2-12b": ["stablelm-2-12b"], + "stabilityai/stablelm-2-12b-chat": ["stablelm-2-12b-chat"], "stabilityai/stablelm-2-1_6b": ["stablelm-2-1.6b"], + "stabilityai/stablelm-2-1_6b-chat": ["stablelm-2-1.6b-chat"], "stabilityai/stablelm-2-zephyr-1_6b": ["stablelm-2-zephyr-1.6b"], "stabilityai/stablelm-3b-4e1t": ["stablelm-3b-4e1t", "stablelm-3b"], "stabilityai/stablelm-base-alpha-3b": ["stablelm-base-alpha-3b", "stablelm-base-3b"], From 6238f5a2afae1231a7bf2a35a1eb33b460241839 Mon Sep 17 00:00:00 2001 From: jlarson Date: Wed, 11 Feb 2026 14:39:07 -0600 Subject: [PATCH 6/6] reformatted --- .../model_bridge/supported_architectures/stablelm.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/transformer_lens/model_bridge/supported_architectures/stablelm.py b/transformer_lens/model_bridge/supported_architectures/stablelm.py index 7a8d77c5f..4a16f458e 100644 --- a/transformer_lens/model_bridge/supported_architectures/stablelm.py +++ b/transformer_lens/model_bridge/supported_architectures/stablelm.py @@ -75,7 +75,11 @@ def __init__(self, cfg: Any) -> None: self.default_config["n_key_value_heads"] = cfg.n_key_value_heads self.cfg.n_key_value_heads = cfg.n_key_value_heads - n_kv_heads = self.cfg.n_key_value_heads if self.cfg.n_key_value_heads is not None else self.cfg.n_heads + n_kv_heads = ( + self.cfg.n_key_value_heads + if self.cfg.n_key_value_heads is not None + else self.cfg.n_heads + ) self.weight_processing_conversions = { "blocks.{i}.attn.q.weight": ParamProcessingConversion( @@ -212,12 +216,11 @@ def setup_hook_compatibility(self, bridge: Any) -> None: original_k_ln_forward = hf_attn.k_layernorm.forward # Use a closure factory to capture the correct references - def _make_hooked_forward( - original_forward: Any, hook: HookPoint - ) -> Any: + def _make_hooked_forward(original_forward: Any, hook: HookPoint) -> Any: def hooked_forward(hidden_states: torch.Tensor) -> torch.Tensor: result = original_forward(hidden_states) return hook(result) + return hooked_forward hf_attn.q_layernorm.forward = _make_hooked_forward( # type: ignore[method-assign]