Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions tests/integration/model_bridge/test_analysis_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Tests for TransformerBridge mechanistic interpretability analysis methods.

Tests tokens_to_residual_directions, accumulated_bias, all_composition_scores,
all_head_labels, and top-level W_E/W_U/b_U properties. Validates against
HookedTransformer for correctness, not just shape/type.

Uses distilgpt2 (CI-cached).
"""

import pytest
import torch

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge.bridge import TransformerBridge


@pytest.fixture(scope="module")
def bridge_compat():
b = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
b.enable_compatibility_mode()
return b


@pytest.fixture(scope="module")
def reference_ht():
return HookedTransformer.from_pretrained("distilgpt2", device="cpu")


class TestTopLevelWeightProperties:
"""Test W_E, W_U, b_U delegate to the correct component tensors."""

def test_W_E_is_same_object_as_embed(self, bridge_compat):
"""bridge.W_E should be the exact same tensor as bridge.embed.W_E."""
assert bridge_compat.W_E is bridge_compat.embed.W_E

def test_W_U_equals_unembed(self, bridge_compat):
"""bridge.W_U should equal bridge.unembed.W_U (may be a view/transpose)."""
assert torch.equal(bridge_compat.W_U, bridge_compat.unembed.W_U)

def test_b_U_equals_unembed(self, bridge_compat):
"""bridge.b_U should equal bridge.unembed.b_U."""
assert torch.equal(bridge_compat.b_U, bridge_compat.unembed.b_U)

def test_W_E_matches_hooked_transformer(self, bridge_compat, reference_ht):
"""bridge.W_E values should match HookedTransformer.W_E."""
assert bridge_compat.W_E.shape == reference_ht.W_E.shape
# After weight processing, embeddings may differ due to centering.
# But shapes must match and both must be non-zero.
assert bridge_compat.W_E.std() > 0
assert reference_ht.W_E.std() > 0

def test_W_U_matches_hooked_transformer(self, bridge_compat, reference_ht):
"""bridge.W_U values should match HookedTransformer.W_U."""
assert bridge_compat.W_U.shape == reference_ht.W_U.shape
max_diff = (bridge_compat.W_U - reference_ht.W_U).abs().max().item()
assert max_diff < 1e-4, f"W_U differs by {max_diff}"


class TestTokensToResidualDirections:
"""Test tokens_to_residual_directions produces correct unembedding vectors."""

def test_single_token_string(self, bridge_compat):
"""String token should return a 1-D vector of size d_model."""
rd = bridge_compat.tokens_to_residual_directions("hello")
assert rd.shape == (bridge_compat.cfg.d_model,)

def test_single_token_int(self, bridge_compat):
"""Integer token should return a 1-D vector of size d_model."""
rd = bridge_compat.tokens_to_residual_directions(100)
assert rd.shape == (bridge_compat.cfg.d_model,)

def test_equals_W_U_column(self, bridge_compat):
"""Result should be exactly the corresponding column of W_U."""
token_id = 42
rd = bridge_compat.tokens_to_residual_directions(token_id)
expected = bridge_compat.W_U[:, token_id]
assert torch.equal(rd, expected)

def test_batch_tokens(self, bridge_compat):
"""1-D tensor of tokens should return (n_tokens, d_model)."""
tokens = torch.tensor([100, 200, 300])
rd = bridge_compat.tokens_to_residual_directions(tokens)
assert rd.shape == (3, bridge_compat.cfg.d_model)
# Each row should match the corresponding W_U column
for i, tok in enumerate(tokens):
assert torch.equal(rd[i], bridge_compat.W_U[:, tok])

def test_matches_hooked_transformer(self, bridge_compat, reference_ht):
"""Output should match HookedTransformer for the same tokens."""
tokens = torch.tensor([10, 20, 30])
bridge_rd = bridge_compat.tokens_to_residual_directions(tokens)
ht_rd = reference_ht.tokens_to_residual_directions(tokens)
max_diff = (bridge_rd - ht_rd).abs().max().item()
assert max_diff < 1e-4, f"Residual directions differ by {max_diff}"


class TestAccumulatedBias:
"""Test accumulated_bias sums biases correctly."""

def test_layer_zero_is_zeros(self, bridge_compat):
"""accumulated_bias(0) should be all zeros (no layers processed)."""
ab = bridge_compat.accumulated_bias(0)
assert ab.shape == (bridge_compat.cfg.d_model,)
assert torch.allclose(ab, torch.zeros_like(ab))

def test_layer_one_includes_first_block(self, bridge_compat):
"""accumulated_bias(1) should include block 0's biases and be non-zero."""
ab = bridge_compat.accumulated_bias(1)
assert ab.shape == (bridge_compat.cfg.d_model,)
# distilgpt2 has biases, so this should be non-zero
assert ab.norm() > 0

def test_monotonically_increasing_norm(self, bridge_compat):
"""Accumulated bias norm should generally increase with more layers."""
# Not strictly monotonic, but bias(n_layers) should have larger norm than bias(0)
ab_0 = bridge_compat.accumulated_bias(0)
ab_all = bridge_compat.accumulated_bias(bridge_compat.cfg.n_layers)
assert ab_all.norm() > ab_0.norm()

def test_matches_hooked_transformer(self, bridge_compat, reference_ht):
"""Output should match HookedTransformer."""
for layer in [0, 1, 3, bridge_compat.cfg.n_layers]:
bridge_ab = bridge_compat.accumulated_bias(layer)
ht_ab = reference_ht.accumulated_bias(layer)
max_diff = (bridge_ab - ht_ab).abs().max().item()
assert max_diff < 1e-4, f"accumulated_bias({layer}) differs by {max_diff}"

def test_mlp_input_flag(self, bridge_compat, reference_ht):
"""mlp_input=True should include the current layer's attn bias."""
bridge_ab = bridge_compat.accumulated_bias(1, mlp_input=True)
ht_ab = reference_ht.accumulated_bias(1, mlp_input=True)
max_diff = (bridge_ab - ht_ab).abs().max().item()
assert max_diff < 1e-4, f"accumulated_bias(1, mlp_input=True) differs by {max_diff}"


class TestAllCompositionScores:
"""Test all_composition_scores produces correct composition score matrices."""

def test_shape(self, bridge_compat):
"""Shape should be (n_layers, n_heads, n_layers, n_heads)."""
cfg = bridge_compat.cfg
scores = bridge_compat.all_composition_scores("Q")
assert scores.shape == (cfg.n_layers, cfg.n_heads, cfg.n_layers, cfg.n_heads)

def test_upper_triangular_masking(self, bridge_compat):
"""Scores should be zero where left_layer >= right_layer."""
scores = bridge_compat.all_composition_scores("Q")
n_layers = bridge_compat.cfg.n_layers
for l1 in range(n_layers):
for l2 in range(l1 + 1): # l2 <= l1
assert (
scores[l1, :, l2, :] == 0
).all(), f"Scores at L{l1}->L{l2} should be zero (upper triangular)"

def test_nonzero_above_diagonal(self, bridge_compat):
"""At least some scores above the diagonal should be non-zero."""
scores = bridge_compat.all_composition_scores("Q")
# Check L0 -> L1 (first above-diagonal block)
assert scores[0, :, 1, :].abs().sum() > 0

def test_all_modes_work(self, bridge_compat):
"""Q, K, V modes should all produce valid tensors."""
for mode in ["Q", "K", "V"]:
scores = bridge_compat.all_composition_scores(mode)
assert not torch.isnan(scores).any(), f"NaN in {mode} composition scores"

def test_invalid_mode_raises(self, bridge_compat):
"""Invalid mode should raise ValueError."""
with pytest.raises(ValueError, match="mode must be one of"):
bridge_compat.all_composition_scores("X")


class TestAllHeadLabels:
"""Test all_head_labels produces correct labels."""

def test_count(self, bridge_compat):
"""Should have n_layers * n_heads labels."""
labels = bridge_compat.all_head_labels
expected = bridge_compat.cfg.n_layers * bridge_compat.cfg.n_heads
assert len(labels) == expected

def test_format(self, bridge_compat):
"""Labels should follow L{layer}H{head} format."""
labels = bridge_compat.all_head_labels
assert labels[0] == "L0H0"
assert labels[1] == "L0H1"
assert labels[bridge_compat.cfg.n_heads] == "L1H0"

def test_matches_hooked_transformer(self, bridge_compat, reference_ht):
"""Should match HookedTransformer's labels exactly."""
assert bridge_compat.all_head_labels == reference_ht.all_head_labels()
128 changes: 128 additions & 0 deletions transformer_lens/model_bridge/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,21 @@ def b_out(self) -> torch.Tensor:
"""Stack the MLP output biases across all layers."""
return self._stack_block_params("mlp.b_out")

@property
def W_U(self) -> torch.Tensor:
"""Unembedding matrix (d_model, d_vocab). Maps residual stream to logits."""
return self.unembed.W_U

@property
def b_U(self) -> torch.Tensor:
"""Unembedding bias (d_vocab)."""
return self.unembed.b_U

@property
def W_E(self) -> torch.Tensor:
"""Token embedding matrix (d_vocab, d_model)."""
return self.embed.W_E

@property
def QK(self):
return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1))
Expand All @@ -1111,6 +1126,119 @@ def QK(self):
def OV(self):
return FactoredMatrix(self.W_V, self.W_O)

# ------------------------------------------------------------------
# Mechanistic interpretability analysis methods
# ------------------------------------------------------------------

def tokens_to_residual_directions(
self,
tokens: Union[str, int, torch.Tensor],
) -> torch.Tensor:
"""Map tokens to their unembedding vectors (residual stream directions).

Returns the columns of W_U corresponding to the given tokens — i.e. the
directions in the residual stream that the model dots with to produce the
logit for each token.

WARNING: If you use this without folding in LayerNorm (compatibility mode),
the results will be misleading because LN weights change the unembed map.

Args:
tokens: A single token (str, int, or scalar tensor), a 1-D tensor of
token IDs, or a 2-D batch of token IDs.

Returns:
Tensor of unembedding vectors with shape matching the input token shape
plus a trailing d_model dimension.
"""
if isinstance(tokens, torch.Tensor) and tokens.numel() > 1:
residual_directions = self.W_U[:, tokens]
residual_directions = einops.rearrange(
residual_directions, "d_model ... -> ... d_model"
)
return residual_directions
else:
if isinstance(tokens, str):
token = self.to_single_token(tokens)
elif isinstance(tokens, int):
token = tokens
elif isinstance(tokens, torch.Tensor) and tokens.numel() == 1:
token = int(tokens.item())
else:
raise ValueError(f"Invalid token type: {type(tokens)}")
residual_direction = self.W_U[:, token]
return residual_direction

def accumulated_bias(
self,
layer: int,
mlp_input: bool = False,
include_mlp_biases: bool = True,
) -> torch.Tensor:
"""Sum of attention and MLP output biases up to the input of a given layer.

Args:
layer: Layer number in [0, n_layers]. 0 means no layers, n_layers means all.
mlp_input: If True, include the attention output bias of the target layer
(i.e. bias up to the MLP input of that layer).
include_mlp_biases: Whether to include MLP biases. Useful to set False when
expanding attn_out into individual heads but keeping mlp_out as-is.

Returns:
Tensor of shape [d_model] with the accumulated bias.
"""
accumulated = torch.zeros(self.cfg.d_model, device=self.cfg.device)
for i in range(layer):
block = self.blocks[i]
b_O = getattr(block.attn, "b_O", None)
if b_O is not None:
accumulated = accumulated + b_O
if include_mlp_biases:
b_out = getattr(block.mlp, "b_out", None)
if b_out is not None:
accumulated = accumulated + b_out
if mlp_input:
assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer"
block = self.blocks[layer]
b_O = getattr(block.attn, "b_O", None)
if b_O is not None:
accumulated = accumulated + b_O
return accumulated

def all_composition_scores(self, mode: str) -> torch.Tensor:
"""Composition scores for all pairs of heads.

Returns an (n_layers, n_heads, n_layers, n_heads) tensor that is upper
triangular on the layer axes (a head can only compose with later heads).

See https://transformer-circuits.pub/2021/framework/index.html

Args:
mode: One of "Q", "K", "V" — which composition type to compute.
"""
left = self.OV
if mode == "Q":
right = self.QK
elif mode == "K":
right = self.QK.T
elif mode == "V":
right = self.OV
else:
raise ValueError(f"mode must be one of ['Q', 'K', 'V'] not {mode}")

scores = utils.composition_scores(left, right, broadcast_dims=True)
mask = (
torch.arange(self.cfg.n_layers, device=self.cfg.device)[:, None, None, None]
< torch.arange(self.cfg.n_layers, device=self.cfg.device)[None, None, :, None]
)
scores = torch.where(mask, scores, torch.zeros_like(scores))
return scores

@property
def all_head_labels(self) -> list[str]:
"""Human-readable labels for all attention heads, e.g. ['L0H0', 'L0H1', ...]."""
return [f"L{l}H{h}" for l in range(self.cfg.n_layers) for h in range(self.cfg.n_heads)]

def parameters(self, recurse: bool = True) -> Iterator[nn.Parameter]:
"""Returns parameters following standard PyTorch semantics.

Expand Down
Loading