Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6c4a213
Testing R1 Distills to confirm functional in TransformerLens
jlarson4 Feb 11, 2026
fe7067a
Updating order to be alphabetical
jlarson4 Feb 11, 2026
f8de02a
Setup StableLM architecture adapter
jlarson4 Feb 11, 2026
0c6bfe6
Resolved weight and qk issues with stablelm. Added more models
jlarson4 Feb 11, 2026
a561675
Added more models
jlarson4 Feb 11, 2026
7d07205
Merge remote-tracking branch 'origin/dev-3.x' into feature/StableLM-a…
jlarson4 Feb 11, 2026
6238f5a
reformatted
jlarson4 Feb 11, 2026
ae378aa
Created a ArchitectureAdapter for OpenElm, handled trusting remote code
jlarson4 Feb 12, 2026
b4dfd2a
Fix formatting
jlarson4 Feb 12, 2026
fc4a19f
Removed test file, update benchmark
jlarson4 Feb 12, 2026
16d2361
Add mock model test
jlarson4 Feb 12, 2026
688986b
Merge branch 'dev-3.x' into feature/OpenELM-architecture-adapter
jlarson4 Feb 12, 2026
21d18d2
More benchmark adjustments
jlarson4 Feb 12, 2026
4630b8b
removed improperly listed supported models
jlarson4 Feb 16, 2026
0f1dc31
Merge remote-tracking branch 'origin/dev-3.x' into feature/OpenELM-ar…
jlarson4 Feb 17, 2026
f760e74
Updating to resolve existing weight diff issues
jlarson4 Feb 17, 2026
2179be5
began working through issues with exsting architecture benchmarks
jlarson4 Feb 17, 2026
b59ecf1
Resolve any existing weight folding issues we can possibly resolve
jlarson4 Feb 17, 2026
3f26fe4
Fixing test failures
jlarson4 Feb 17, 2026
6dc104b
Clean up format and other changes
jlarson4 Feb 17, 2026
123425f
Added text quality benchmark, updated to pass CI
jlarson4 Feb 18, 2026
c84f5cd
Merge remote-tracking branch 'origin/dev-3.x' into debugging/architec…
jlarson4 Feb 18, 2026
cb9e18f
Cleaned up comment, tightened tolerances further for bfloat16 models
jlarson4 Feb 18, 2026
dfd089d
Merge remote-tracking branch 'origin/dev-3.x' into debugging/architec…
jlarson4 Feb 18, 2026
6cc9d6e
Removed unnecessary testing file
jlarson4 Feb 18, 2026
5bd5798
Cleanup of redundant code
jlarson4 Feb 18, 2026
e855e57
Resolve type issues and format issues
jlarson4 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tests/unit/test_weight_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,11 @@ def test_fold_layer_no_adapter_transformer_lens_format(self, basic_config):
"""
cfg = basic_config
cfg.n_layers = 1 # Test with single layer for simplicity
# Match config to the tensor dimensions used below
cfg.d_model = 4
cfg.n_heads = 2
cfg.d_head = 2
cfg.d_mlp = 8

# Create a state dict with known values for deterministic testing
state_dict = {}
Expand Down
9 changes: 8 additions & 1 deletion transformer_lens/benchmarks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
validate_hook_shape_compatibility,
)
from transformer_lens.benchmarks.main_benchmark import run_benchmark_suite
from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity, PhaseReferenceData
from transformer_lens.benchmarks.text_quality import benchmark_text_quality
from transformer_lens.benchmarks.utils import (
BenchmarkResult,
BenchmarkSeverity,
PhaseReferenceData,
)
from transformer_lens.benchmarks.weight_processing import (
benchmark_weight_modification,
benchmark_weight_processing,
Expand Down Expand Up @@ -72,6 +77,8 @@
"benchmark_generation",
"benchmark_generation_with_kv_cache",
"benchmark_multiple_generation_calls",
# Text quality benchmarks
"benchmark_text_quality",
# Weight processing benchmarks
"benchmark_weight_processing",
"benchmark_weight_sharing",
Expand Down
20 changes: 14 additions & 6 deletions transformer_lens/benchmarks/activation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from transformer_lens import HookedTransformer
from transformer_lens.ActivationCache import ActivationCache
from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity
from transformer_lens.benchmarks.utils import (
BenchmarkResult,
BenchmarkSeverity,
safe_allclose,
)
from transformer_lens.model_bridge import TransformerBridge


Expand Down Expand Up @@ -65,9 +69,10 @@ def benchmark_run_with_cache(
if missing_patterns:
return BenchmarkResult(
name="run_with_cache",
severity=BenchmarkSeverity.WARNING,
severity=BenchmarkSeverity.DANGER,
message=f"Cache missing expected patterns: {missing_patterns}",
details={"missing": missing_patterns, "cache_keys_count": len(cache_keys)},
passed=False,
)

# Verify cached tensors are actually tensors
Expand All @@ -79,9 +84,10 @@ def benchmark_run_with_cache(
if non_tensor_keys:
return BenchmarkResult(
name="run_with_cache",
severity=BenchmarkSeverity.WARNING,
severity=BenchmarkSeverity.DANGER,
message=f"Cache contains {len(non_tensor_keys)} non-tensor values",
details={"non_tensor_keys": non_tensor_keys[:5]},
passed=False,
)

if reference_model is not None:
Expand Down Expand Up @@ -175,9 +181,11 @@ def benchmark_activation_cache(
continue

# Check values
if not torch.allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0):
max_diff = torch.max(torch.abs(bridge_tensor - reference_tensor)).item()
mean_diff = torch.mean(torch.abs(bridge_tensor - reference_tensor)).item()
if not safe_allclose(bridge_tensor, reference_tensor, atol=tolerance, rtol=0.0):
b = bridge_tensor.float()
r = reference_tensor.float()
max_diff = torch.max(torch.abs(b - r)).item()
mean_diff = torch.mean(torch.abs(b - r)).item()
mismatches.append(
f"{key}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}"
)
Expand Down
26 changes: 17 additions & 9 deletions transformer_lens/benchmarks/backward_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

from transformer_lens import HookedTransformer
from transformer_lens.benchmarks.hook_structure import validate_hook_shape_compatibility
from transformer_lens.benchmarks.utils import BenchmarkResult, BenchmarkSeverity
from transformer_lens.benchmarks.utils import (
BenchmarkResult,
BenchmarkSeverity,
safe_allclose,
)
from transformer_lens.model_bridge import TransformerBridge


Expand Down Expand Up @@ -167,14 +171,14 @@ def hook_fn(tensor, hook):

if bridge_finite.numel() > 0 and reference_finite.numel() > 0:
# Compare finite values
if not torch.allclose(
if not safe_allclose(
bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance
):
max_diff = torch.max(torch.abs(bridge_finite - reference_finite)).item()
mean_diff = torch.mean(torch.abs(bridge_finite - reference_finite)).item()
rel_diff = torch.abs(bridge_finite - reference_finite) / (
torch.abs(bridge_finite) + 1e-8
)
bf = bridge_finite.float()
rf = reference_finite.float()
max_diff = torch.max(torch.abs(bf - rf)).item()
mean_diff = torch.mean(torch.abs(bf - rf)).item()
rel_diff = torch.abs(bf - rf) / (torch.abs(bf) + 1e-8)
mean_rel = rel_diff.mean().item()
mismatches.append(
f"{hook_name}: Value mismatch - max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, mean_rel={mean_rel:.6f}"
Expand All @@ -195,11 +199,13 @@ def hook_fn(tensor, hook):
"hook_k",
"ln1.hook_",
"ln2.hook_",
"ln_final.hook_",
"hook_resid_mid",
"hook_resid_pre",
"hook_resid_post",
"hook_embed",
"hook_pos_embed",
"unembed.hook_",
"mlp.hook_post",
"mlp.hook_pre",
"hook_mlp_out",
Expand Down Expand Up @@ -431,10 +437,12 @@ def hook_fn(tensor, hook):
reference_finite = reference_grad[torch.isfinite(reference_grad)]

if bridge_finite.numel() > 0 and reference_finite.numel() > 0:
if not torch.allclose(
if not safe_allclose(
bridge_finite, reference_finite, atol=abs_tolerance, rtol=rel_tolerance
):
max_diff = torch.max(torch.abs(bridge_finite - reference_finite)).item()
max_diff = torch.max(
torch.abs(bridge_finite.float() - reference_finite.float())
).item()
mismatches.append(f"{hook_name}: max_diff={max_diff:.6f}")

if mismatches:
Expand Down
Loading