Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
1fe4d04
added and tested: OLMo-1B,OLMo-7B
jonasrohw Dec 12, 2024
0f3e3b3
fixed: numpy do not do a major upgrade!
jonasrohw Dec 13, 2024
3a101f4
fixed: dimensions of 7b to be correct
jonasrohw Dec 13, 2024
1b34ccd
tested: Loading checkpoints & model variations
jonasrohw Dec 13, 2024
f0a0a68
Reimplement OLMoE changes.
joelburget Dec 14, 2024
8c094e5
Implement TODO (norm_topk_prob)
joelburget Dec 14, 2024
7565c06
Disable bos token for OLMoE.
joelburget Dec 14, 2024
04cd309
Add q and k norm.
joelburget Dec 15, 2024
68d6961
Correct normalization type for OLMoE.
joelburget Dec 15, 2024
9afd032
Merge pull request #1 from joelburget/olmoe
jonasrohw Dec 15, 2024
96c1fbb
Merge branch 'dev' into OLMo
jonasrohw Dec 15, 2024
72fb903
ran formatting
jonasrohw Dec 15, 2024
9d3a85e
Merge branch 'dev' into OLMo
bryce13950 Feb 4, 2025
d4519b2
Merge branch 'dev' into OLMo
bryce13950 Feb 5, 2025
064310f
tmp update for olmo2
Ja1Zhou Feb 1, 2025
b1fd04b
Fix: Olmo2 uses normalization after the attention/mlp
jonasrohw Feb 15, 2025
871ba03
Merge branch 'dev' into OLMo
bryce13950 Jun 16, 2025
7939e8d
ran format
bryce13950 Jun 16, 2025
97fd1e7
fixed some type issues
bryce13950 Jun 19, 2025
9032fe7
Merge branch 'dev' into OLMo
bryce13950 Jun 24, 2025
39703c4
OLMo 2 RMS
jleechung Jul 22, 2025
1c283c1
OLMo 2 RMS
jleechung Jul 22, 2025
688a421
Tested Instruct models
jleechung Jul 22, 2025
9febc5c
Merge pull request #3 from jleechung/OLMo
jonasrohw Jul 23, 2025
86b1fce
fix: Olmo2DecoderLayer type issues
taziksh Oct 11, 2025
fa5c885
fix type assertions for attention
taziksh Oct 11, 2025
148df46
chore: bump min Python to 3.10 for jaxtyping mypy plugin compatibility
taziksh Oct 12, 2025
1c60345
Merge dev and regenerate lock file
taziksh Oct 12, 2025
7aa3a91
fix: sort imports in olmo2.py
taziksh Oct 12, 2025
c8d443b
docs: update Colab notebook for OLMo models
taziksh Oct 12, 2025
9fcf0db
Merge branch 'dev' into OLMo
jlarson4 Jan 15, 2026
856443a
added and tested: OLMo-1B,OLMo-7B
jonasrohw Dec 12, 2024
6d53a0c
fixed: dimensions of 7b to be correct
jonasrohw Dec 13, 2024
b7bf828
tested: Loading checkpoints & model variations
jonasrohw Dec 13, 2024
89dc6df
Reimplement OLMoE changes.
joelburget Dec 14, 2024
676960c
Implement TODO (norm_topk_prob)
joelburget Dec 14, 2024
e78d68d
Disable bos token for OLMoE.
joelburget Dec 14, 2024
c281e71
Add q and k norm.
joelburget Dec 15, 2024
5f0c91d
Correct normalization type for OLMoE.
joelburget Dec 15, 2024
19f5eec
ran formatting
jonasrohw Dec 15, 2024
93da62e
tmp update for olmo2
Ja1Zhou Feb 1, 2025
5c65b92
Fix: Olmo2 uses normalization after the attention/mlp
jonasrohw Feb 15, 2025
2171e2f
ran format
bryce13950 Jun 16, 2025
c4e543f
fixed some type issues
bryce13950 Jun 19, 2025
3532376
OLMo 2 RMS
jleechung Jul 22, 2025
ffb3d3b
OLMo 2 RMS
jleechung Jul 22, 2025
808bb57
Tested Instruct models
jleechung Jul 22, 2025
797872f
fix: Olmo2DecoderLayer type issues
taziksh Oct 11, 2025
a39fccd
fix type assertions for attention
taziksh Oct 11, 2025
884aeb6
chore: bump min Python to 3.10 for jaxtyping mypy plugin compatibility
taziksh Oct 12, 2025
d51ab7d
fix: sort imports in olmo2.py
taziksh Oct 12, 2025
aa6d3b8
docs: update Colab notebook for OLMo models
taziksh Oct 12, 2025
9458988
Adjust error message to improve testing
jlarson4 Jan 16, 2026
41280fe
rebasing to dev
jlarson4 Jan 16, 2026
80b8835
conflict resolution
jlarson4 Jan 16, 2026
b2ac313
Updating lock
jlarson4 Jan 16, 2026
8c92fc8
Fixed formatting, update error messages to properly test
jlarson4 Jan 17, 2026
fc3da3e
more formatting
jlarson4 Jan 17, 2026
d7e5523
fixing type error
jlarson4 Jan 17, 2026
c20859c
fix format error
jlarson4 Jan 17, 2026
5ccbf68
Fix type issues
jlarson4 Jan 17, 2026
6d3c870
Fix type issues
jlarson4 Jan 17, 2026
7151270
Fix format issues
jlarson4 Jan 17, 2026
1d0aebb
Fix format issues again
jlarson4 Jan 17, 2026
4316e8b
Fix format issues for black
jlarson4 Jan 17, 2026
040b19b
another attempt at black formatting
jlarson4 Jan 17, 2026
fb259ce
Fix format issues for black again
jlarson4 Jan 17, 2026
72521c7
Retyping the blocks in HookedTransformer and HookedEncoder
jlarson4 Jan 17, 2026
f0ddc0e
undo modulelist typing
jlarson4 Jan 17, 2026
bdbd649
Improve type checking in test_detect_head_with_invalid_head_name
jlarson4 Jan 17, 2026
0ec06b9
removing unused import
jlarson4 Jan 17, 2026
09a9bdd
Fixing Patchscopes_Generation_Demo.ipynb
jlarson4 Jan 17, 2026
7933afc
Fixing the rest of the notebooks
jlarson4 Jan 17, 2026
61347e0
Fixing the more notebooks
jlarson4 Jan 17, 2026
d4e986a
run_line_magic
jlarson4 Jan 17, 2026
4fed25a
BERT ipynb fix
jlarson4 Jan 17, 2026
7b22ce4
Trying to fix the BERT set_grad cell
jlarson4 Jan 17, 2026
0899f3c
more set_grad cell fixes
jlarson4 Jan 17, 2026
3154dbb
Merge remote-tracking branch 'origin/dev-3.x' into OLMo
jlarson4 Feb 6, 2026
f6bfd71
Updated after rebase to fix missing 3.x changes
jlarson4 Feb 6, 2026
b819f82
Merge remote-tracking branch 'origin/dev-3.x' into OLMo
jlarson4 Feb 12, 2026
bdcc977
Updating OLMo PR to work with v3.x
jlarson4 Feb 12, 2026
1e2db33
Format fix
jlarson4 Feb 12, 2026
93f2ee4
fix model ordering
jlarson4 Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions debugging/hf-tl-logit-comparator.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Logit Comparator for HuggingFace and TransformerLens Outputs\n",
"This notebook is a quick and dirty tool to compare the logit outputs of a HuggingFace model and a TransformerLens model via several different metrics. It is intended to help debug issues with the TransformerLens model, such as bugs in the model's implementation. If you identify any issues, please open an issue on the [GitHub repository](https://github.com/TransformerLensOrg/TransformerLens)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from transformer_lens import HookedTransformer\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"if torch.backends.mps.is_available():\n",
" device = \"mps\"\n",
"else:\n",
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"\n",
"# NBVAL_IGNORE_OUTPUT\n",
"_ = torch.set_grad_enabled(False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Comparator Setup"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"model_name = \"EleutherAI/pythia-2.8b\" # You can change this to any model name\n",
"sentence = \"The quick brown fox\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import login\n",
"login(token=\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get Transformers Logits"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"\n",
"def load_model(model_name=\"gpt2\"):\n",
" tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
" model = AutoModelForCausalLM.from_pretrained(model_name)\n",
" return model, tokenizer\n",
"\n",
"def get_logits(model, tokenizer, sentence, device):\n",
" # Tokenize the input sentence\n",
" inputs = tokenizer(sentence, return_tensors=\"pt\")\n",
" \n",
" # Move inputs to the device\n",
" inputs = {k: v.to(device) for k, v in inputs.items()}\n",
" \n",
" # Generate the logits\n",
" with torch.no_grad():\n",
" outputs = model(**inputs)\n",
" \n",
" # Get the logits for all tokens\n",
" logits = outputs.logits\n",
" \n",
" return logits\n",
"\n",
"model, tokenizer = load_model(model_name)\n",
"model = model.to(device)\n",
"\n",
"hf_logits = get_logits(model, tokenizer, sentence, device)[:, -1, :]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get TransformerLens Logits"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = HookedTransformer.from_pretrained_no_processing(model_name, device=device)\n",
"tokens = model.to_tokens(sentence, prepend_bos=False)\n",
"tl_logits = model(tokens)[:, -1, :]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compare Logit Distributions\n",
"Various metrics are used to compare the logit distributions of the two models. We don't yet have standard values for what constitutes a \"good\" logit comparison, so we are working on establishing benchmarks."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"HF Logits Shape: {hf_logits.shape}\")\n",
"print(f\"TL Logits Shape: {tl_logits.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tensor Comparison"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"are_close = torch.allclose(tl_logits, hf_logits, rtol=1e-5, atol=1e-3)\n",
"print(f\"Are the logits close? {are_close}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mean Squared Error"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compare the logits with MSE\n",
"mse = torch.nn.functional.mse_loss(hf_logits, tl_logits)\n",
"print(f\"MSE: {mse}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Maximum Absolute Difference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"max_diff = torch.max(torch.abs(tl_logits - hf_logits))\n",
"print(f\"Max Diff: {max_diff}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Cosine Similarity"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cosine_sim = F.cosine_similarity(tl_logits, hf_logits, dim=-1).mean()\n",
"print(f\"Cosine Sim: {cosine_sim}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### KL Divergence"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def kl_div(logits1: torch.Tensor, logits2: torch.Tensor) -> torch.Tensor:\n",
" probs1 = F.softmax(logits1, dim=-1)\n",
" probs2 = F.softmax(logits2, dim=-1)\n",
" return F.kl_div(probs1.log(), probs2, reduction='batchmean')\n",
"\n",
"kl_tl_hf = kl_div(tl_logits, hf_logits)\n",
"kl_hf_tl = kl_div(hf_logits, tl_logits)\n",
"print(f\"KL(TL||HF): {kl_tl_hf}\")\n",
"print(f\"KL(HF||TL): {kl_hf_tl}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "sae-l",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions demos/ARENA_Content.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
"\n",
" ipython = get_ipython()\n",
" # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n",
" # ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
" # ipython.run_line_magic(\"autoreload\", \"2\")\n",
" ipython.run_line_magic(\"load_ext\", \"autoreload\")\n",
" ipython.run_line_magic(\"autoreload\", \"2\")\n",
"\n",
"if IN_GITHUB or IN_COLAB:\n",
" %pip install torch\n",
Expand Down
3 changes: 2 additions & 1 deletion demos/Activation_Patching_in_TL_Demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@
}
],
"source": [
"torch.set_grad_enabled(False)"
"# NBVAL_IGNORE_OUTPUT\n",
"_ = torch.set_grad_enabled(False)"
]
},
{
Expand Down
31 changes: 6 additions & 25 deletions demos/BERT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -36,17 +35,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using renderer: colab\n"
]
}
],
"outputs": [],
"source": [
"# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n",
"import plotly.io as pio\n",
Expand Down Expand Up @@ -108,22 +99,12 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<torch.autograd.grad_mode.set_grad_enabled at 0x2a285a790>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"torch.set_grad_enabled(False)"
"# NBVAL_IGNORE_OUTPUT\n",
"_ = torch.set_grad_enabled(False)"
]
},
{
Expand Down
Loading