diff --git a/debugging/hf-tl-logit-comparator.ipynb b/debugging/hf-tl-logit-comparator.ipynb new file mode 100644 index 000000000..99b2b3962 --- /dev/null +++ b/debugging/hf-tl-logit-comparator.ipynb @@ -0,0 +1,266 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Logit Comparator for HuggingFace and TransformerLens Outputs\n", + "This notebook is a quick and dirty tool to compare the logit outputs of a HuggingFace model and a TransformerLens model via several different metrics. It is intended to help debug issues with the TransformerLens model, such as bugs in the model's implementation. If you identify any issues, please open an issue on the [GitHub repository](https://github.com/TransformerLensOrg/TransformerLens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from transformer_lens import HookedTransformer\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "if torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparator Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"EleutherAI/pythia-2.8b\" # You can change this to any model name\n", + "sentence = \"The quick brown fox\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get Transformers Logits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "def load_model(model_name=\"gpt2\"):\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + " model = AutoModelForCausalLM.from_pretrained(model_name)\n", + " return model, tokenizer\n", + "\n", + "def get_logits(model, tokenizer, sentence, device):\n", + " # Tokenize the input sentence\n", + " inputs = tokenizer(sentence, return_tensors=\"pt\")\n", + " \n", + " # Move inputs to the device\n", + " inputs = {k: v.to(device) for k, v in inputs.items()}\n", + " \n", + " # Generate the logits\n", + " with torch.no_grad():\n", + " outputs = model(**inputs)\n", + " \n", + " # Get the logits for all tokens\n", + " logits = outputs.logits\n", + " \n", + " return logits\n", + "\n", + "model, tokenizer = load_model(model_name)\n", + "model = model.to(device)\n", + "\n", + "hf_logits = get_logits(model, tokenizer, sentence, device)[:, -1, :]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get TransformerLens Logits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = HookedTransformer.from_pretrained_no_processing(model_name, device=device)\n", + "tokens = model.to_tokens(sentence, prepend_bos=False)\n", + "tl_logits = model(tokens)[:, -1, :]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare Logit Distributions\n", + "Various metrics are used to compare the logit distributions of the two models. We don't yet have standard values for what constitutes a \"good\" logit comparison, so we are working on establishing benchmarks." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"HF Logits Shape: {hf_logits.shape}\")\n", + "print(f\"TL Logits Shape: {tl_logits.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tensor Comparison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "are_close = torch.allclose(tl_logits, hf_logits, rtol=1e-5, atol=1e-3)\n", + "print(f\"Are the logits close? {are_close}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Mean Squared Error" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compare the logits with MSE\n", + "mse = torch.nn.functional.mse_loss(hf_logits, tl_logits)\n", + "print(f\"MSE: {mse}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Maximum Absolute Difference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_diff = torch.max(torch.abs(tl_logits - hf_logits))\n", + "print(f\"Max Diff: {max_diff}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cosine Similarity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cosine_sim = F.cosine_similarity(tl_logits, hf_logits, dim=-1).mean()\n", + "print(f\"Cosine Sim: {cosine_sim}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### KL Divergence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def kl_div(logits1: torch.Tensor, logits2: torch.Tensor) -> torch.Tensor:\n", + " probs1 = F.softmax(logits1, dim=-1)\n", + " probs2 = F.softmax(logits2, dim=-1)\n", + " return F.kl_div(probs1.log(), probs2, reduction='batchmean')\n", + "\n", + "kl_tl_hf = kl_div(tl_logits, hf_logits)\n", + "kl_hf_tl = kl_div(hf_logits, tl_logits)\n", + "print(f\"KL(TL||HF): {kl_tl_hf}\")\n", + "print(f\"KL(HF||TL): {kl_hf_tl}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sae-l", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/ARENA_Content.ipynb b/demos/ARENA_Content.ipynb index b48219f85..88b2ae025 100644 --- a/demos/ARENA_Content.ipynb +++ b/demos/ARENA_Content.ipynb @@ -40,8 +40,8 @@ "\n", " ipython = get_ipython()\n", " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", - " # ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", - " # ipython.run_line_magic(\"autoreload\", \"2\")\n", + " ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + " ipython.run_line_magic(\"autoreload\", \"2\")\n", "\n", "if IN_GITHUB or IN_COLAB:\n", " %pip install torch\n", diff --git a/demos/Activation_Patching_in_TL_Demo.ipynb b/demos/Activation_Patching_in_TL_Demo.ipynb index ab0f7c9d1..abc033ad7 100644 --- a/demos/Activation_Patching_in_TL_Demo.ipynb +++ b/demos/Activation_Patching_in_TL_Demo.ipynb @@ -158,7 +158,8 @@ } ], "source": [ - "torch.set_grad_enabled(False)" + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)" ] }, { diff --git a/demos/BERT.ipynb b/demos/BERT.ipynb index 1851f09d2..a46b49976 100644 --- a/demos/BERT.ipynb +++ b/demos/BERT.ipynb @@ -10,7 +10,6 @@ ] }, { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -36,17 +35,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using renderer: colab\n" - ] - } - ], + "outputs": [], "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", @@ -108,22 +99,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "torch.set_grad_enabled(False)" + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)" ] }, { diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 6a3b3343e..427baa213 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -1,28 +1,10 @@ { "cells": [ { - "cell_type": "code", - "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_47164/3507779555.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", - " ipython.magic(\"load_ext autoreload\")\n", - "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_47164/3507779555.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", - " ipython.magic(\"autoreload 2\")\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "# NBVAL_IGNORE_OUTPUT\n", "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", @@ -54,21 +36,14 @@ " # %pip install transformer_lens\n", " %pip install transformers_stream_generator\n", " # !huggingface-cli login --token NEEL'S TOKEN" - ] + ], + "id": "b69b1bc89b92b891" }, { - "cell_type": "code", - "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TransformerLens currently supports 216 models out of the box.\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "import torch\n", "from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, BertNextSentencePrediction, loading\n", @@ -84,13 +59,14 @@ "GENERATE = True\n", "# Fill this in if you have llama weights uploaded, and you with to test those models\n", "LLAMA_MODEL_PATH = \"\"" - ] + ], + "id": "d71c1b4387bb530f" }, { - "cell_type": "code", - "execution_count": 3, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "def mark_models_as_tested(model_set: List[str]) -> None:\n", " for model in model_set:\n", @@ -194,13 +170,14 @@ " gc.collect()\n", " if IN_COLAB:\n", " %rm -rf /root/.cache/huggingface/hub/models*" - ] + ], + "id": "4712c125cf82eecf" }, { - "cell_type": "code", - "execution_count": 4, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "# The following models can run in the T4 free environment\n", "free_compatible = [\n", @@ -320,13 +297,14 @@ " run_set(free_compatible)\n", "\n", "mark_models_as_tested(free_compatible)" - ] + ], + "id": "c58e8d102ccd64ef" }, { - "cell_type": "code", - "execution_count": 5, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "paid_gpu_models = [\n", " \"01-ai/Yi-6B\",\n", @@ -391,13 +369,14 @@ " run_set(paid_gpu_models)\n", "\n", "mark_models_as_tested(paid_gpu_models)" - ] + ], + "id": "c680d3d00c605fdb" }, { - "cell_type": "code", - "execution_count": 6, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "paid_cpu_models = [\n", " \"EleutherAI/gpt-j-6B\",\n", @@ -424,13 +403,14 @@ " run_set(paid_cpu_models, \"cpu\")\n", "\n", "mark_models_as_tested(paid_cpu_models)" - ] + ], + "id": "fda1709e915479f1" }, { - "cell_type": "code", - "execution_count": 7, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "incompatible_models = [\n", " \"01-ai/Yi-34B\",\n", @@ -456,13 +436,14 @@ "]\n", "\n", "mark_models_as_tested(incompatible_models)" - ] + ], + "id": "13a3398da75aff7d" }, { - "cell_type": "code", - "execution_count": 8, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "# The following models take a few extra steps to function. Check the official demo for more\n", "# information on how to use. 7b and 13b will work in the paid environment. 30b and 65b will not work\n", @@ -478,13 +459,14 @@ " run_llama_set(not_hosted_models, LLAMA_MODEL_PATH)\n", "\n", "mark_models_as_tested(not_hosted_models)" - ] + ], + "id": "9467479c832f46e1" }, { - "cell_type": "code", - "execution_count": 9, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "# These all work on the free version of Colab\n", "encoder_decoders = [\n", @@ -496,13 +478,14 @@ " run_encoder_decoder_set(encoder_decoders)\n", "\n", "mark_models_as_tested(encoder_decoders)" - ] + ], + "id": "631c68f258277efe" }, { - "cell_type": "code", - "execution_count": 10, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "# This model works on the free version of Colab\n", "encoder_only_models = [\n", @@ -516,13 +499,14 @@ " run_encoder_only_set(encoder_only_models)\n", "\n", "mark_models_as_tested(encoder_only_models)" - ] + ], + "id": "e5a2bcb0db509b7e" }, { - "cell_type": "code", - "execution_count": 11, "metadata": {}, + "cell_type": "code", "outputs": [], + "execution_count": null, "source": [ "broken_models = [\n", " \"Baidicoot/Othello-GPT-Transformer-Lens\",\n", @@ -530,9 +514,9 @@ ] }, { + "metadata": {}, "cell_type": "code", "execution_count": 12, - "metadata": {}, "outputs": [ { "name": "stdout", @@ -547,28 +531,11 @@ "# PR fails due to this notebook, most likely you need to check any new model changes to ensure that\n", "# this notebook is up to date.\n", "print(*untested_models, sep=\"\\n\")" - ] + ], + "id": "e561257b477af425" } ], - "metadata": { - "kernelspec": { - "display_name": ".venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.7" - } - }, + "metadata": {}, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 5 } diff --git a/demos/Grokking_Demo.ipynb b/demos/Grokking_Demo.ipynb index 6df698ebc..7b7fe5243 100644 --- a/demos/Grokking_Demo.ipynb +++ b/demos/Grokking_Demo.ipynb @@ -39,18 +39,10 @@ ] }, { - "cell_type": "code", - "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running as a Jupyter notebook - intended for development only!\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", "import os\n", @@ -82,18 +74,10 @@ ] }, { - "cell_type": "code", - "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using renderer: notebook_connected\n" - ] - } - ], + "cell_type": "code", + "outputs": [], + "execution_count": null, "source": [ "# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh\n", "import plotly.io as pio\n", @@ -2937,10 +2921,10 @@ "evalue": "name 'train_losses' is not defined", "output_type": "error", "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_1229617/2975677256.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mneel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnpx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnpx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtrain_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_losses\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_losses\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Epoch\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myaxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Loss\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtitle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"Training Curve for Modular Addition\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mline_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'test'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoggle_x\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtoggle_y\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_fig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0madd_lines\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'train_losses' is not defined" + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m/tmp/ipykernel_1229617/2975677256.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mneel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mplot\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mnpx\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mfig\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnpx\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mline\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mtrain_losses\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtest_losses\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;36m100\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0marange\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtrain_losses\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m100\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mxaxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Epoch\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0myaxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Loss\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlog_y\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtitle\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m\"Training Curve for Modular Addition\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mline_labels\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'train'\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m'test'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtoggle_x\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtoggle_y\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mreturn_fig\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0madd_lines\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mfig\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mNameError\u001B[0m: name 'train_losses' is not defined" ] } ], @@ -3516,11 +3500,11 @@ "evalue": "Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1", "output_type": "error", "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_1215793/3004607503.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall_logits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/tmp/ipykernel_1215793/4096650173.py\u001b[0m in \u001b[0;36mloss_fn\u001b[0;34m(logits, labels)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mlogits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mlog_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogits\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog_softmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mcorrect_log_probs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlog_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlabels\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mcorrect_log_probs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mtrain_logits\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1" + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mRuntimeError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m/tmp/ipykernel_1215793/3004607503.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mloss_fn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mall_logits\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlabels\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m", + "\u001B[0;32m/tmp/ipykernel_1215793/4096650173.py\u001B[0m in \u001B[0;36mloss_fn\u001B[0;34m(logits, labels)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlogits\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlogits\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mto\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfloat64\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mlog_probs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlogits\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog_softmax\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdim\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 6\u001B[0;31m \u001B[0mcorrect_log_probs\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mlog_probs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mgather\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdim\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mindex\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mlabels\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 7\u001B[0m \u001B[0;32mreturn\u001B[0m \u001B[0;34m-\u001B[0m\u001B[0mcorrect_log_probs\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mmean\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 8\u001B[0m \u001B[0mtrain_logits\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mtrain_data\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mRuntimeError\u001B[0m: Size does not match at dimension 0 expected index [12769, 1] to be smaller than self [113, 113] apart from dimension 1" ] } ], diff --git a/demos/LLaMA.ipynb b/demos/LLaMA.ipynb index 1c0f4f67c..ea406307d 100644 --- a/demos/LLaMA.ipynb +++ b/demos/LLaMA.ipynb @@ -112,7 +112,8 @@ ") # Hooking utilities\n", "from transformer_lens import HookedTransformer\n", "\n", - "torch.set_grad_enabled(False)\n", + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)\n", "\n", "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", diff --git a/demos/LLaMA2_GPU_Quantized.ipynb b/demos/LLaMA2_GPU_Quantized.ipynb index a10259d4f..4370ce1e0 100644 --- a/demos/LLaMA2_GPU_Quantized.ipynb +++ b/demos/LLaMA2_GPU_Quantized.ipynb @@ -41,9 +41,9 @@ " Switched to a new branch 'llama_4bit_v2'\n", " Branch 'llama_4bit_v2' set up to track remote branch 'llama_4bit_v2' from 'origin'.\n", " Resolved https://github.com/coolvision/TransformerLens.git to commit b2b80cb92f4aa6d63a456196f0c3472b3d34c6eb\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Installing build dependencies ... \u001B[?25l\u001B[?25hdone\n", + " Getting requirements to build wheel ... \u001B[?25l\u001B[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001B[?25l\u001B[?25hdone\n", "Requirement already satisfied: accelerate>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.26.1)\n", "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (0.14.1)\n", "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.10/dist-packages (from transformer-lens==0.0.0) (2.16.1)\n", @@ -234,7 +234,8 @@ ") # Hooking utilities\n", "from transformer_lens import HookedTransformer\n", "\n", - "torch.set_grad_enabled(False)\n", + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)\n", "\n", "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", diff --git a/demos/Main_Demo.ipynb b/demos/Main_Demo.ipynb index 1d398fe82..543d3dcac 100644 --- a/demos/Main_Demo.ipynb +++ b/demos/Main_Demo.ipynb @@ -222,7 +222,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ diff --git a/demos/Othello_GPT.ipynb b/demos/Othello_GPT.ipynb index e9a2cf7a6..d69f1a166 100644 --- a/demos/Othello_GPT.ipynb +++ b/demos/Othello_GPT.ipynb @@ -217,7 +217,8 @@ } ], "source": [ - "torch.set_grad_enabled(False)" + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)" ] }, { diff --git a/demos/Patchscopes_Generation_Demo.ipynb b/demos/Patchscopes_Generation_Demo.ipynb index 3b69ddb63..2a9109154 100644 --- a/demos/Patchscopes_Generation_Demo.ipynb +++ b/demos/Patchscopes_Generation_Demo.ipynb @@ -3494,6 +3494,13 @@ " print(f\"Generation by patching layer {target_layer_id}:\\n{gen}\\n{'='*30}\\n\")" ] }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "" + }, { "cell_type": "markdown", "metadata": {}, diff --git a/demos/Qwen.ipynb b/demos/Qwen.ipynb index 69053fb1b..96732abe8 100644 --- a/demos/Qwen.ipynb +++ b/demos/Qwen.ipynb @@ -71,8 +71,8 @@ "Requirement already satisfied: MarkupSafe>=2.0 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from jinja2->torch>=1.10->circuitsvis) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in /root/TransformerLens/.venv/lib/python3.10/site-packages (from sympy->torch>=1.10->circuitsvis) (1.3.0)\n", "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip is available: \u001B[0m\u001B[31;49m23.3.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m24.0\u001B[0m\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\n", "Note: you may need to restart the kernel to use updated packages.\n" ] } @@ -168,7 +168,8 @@ "source": [ "%cd ~/TransformerLens\n", "import torch\n", - "torch.set_grad_enabled(False)\n", + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)\n", "\n", "from transformers import AutoTokenizer\n", "from transformer_lens import HookedTransformer\n", diff --git a/demos/Santa_Coder.ipynb b/demos/Santa_Coder.ipynb index af98752df..6d0c8fe19 100644 --- a/demos/Santa_Coder.ipynb +++ b/demos/Santa_Coder.ipynb @@ -103,7 +103,8 @@ ") # Hooking utilities\n", "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n", "\n", - "torch.set_grad_enabled(False)\n", + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)\n", "\n", "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", diff --git a/demos/T5.ipynb b/demos/T5.ipynb index fb0c4897c..6e108cc6f 100644 --- a/demos/T5.ipynb +++ b/demos/T5.ipynb @@ -134,7 +134,8 @@ } ], "source": [ - "torch.set_grad_enabled(False)" + "# NBVAL_IGNORE_OUTPUT\n", + "_ = torch.set_grad_enabled(False)" ] }, { diff --git a/tests/integration/test_head_detector.py b/tests/integration/test_head_detector.py index e465f8805..1cee263de 100644 --- a/tests/integration/test_head_detector.py +++ b/tests/integration/test_head_detector.py @@ -2,7 +2,6 @@ import pytest import torch -from beartype.roar import BeartypeCallHintParamViolation from transformer_lens import HookedTransformer from transformer_lens.head_detector import ( @@ -350,8 +349,19 @@ def test_detect_head_with_cache(error_measure: ErrorMeasure, expected: torch.Ten def test_detect_head_with_invalid_head_name(): - with pytest.raises(BeartypeCallHintParamViolation) as e: + with pytest.raises(Exception) as e: detect_head(model, test_regular_sequence, "test") + # Type checking can be done by jaxtyping (TypeCheckError) or beartype (BeartypeCallHintParamViolation) + exc_name = type(e.value).__name__ + exc_msg = str(e.value).lower() + assert "TypeCheckError" in exc_name or "Beartype" in exc_name + # Check for meaningful error message (different formats for different type checkers) + assert ( + "type-check" in exc_msg + or "vector_type" in exc_msg + or "detection_pattern" in exc_msg + or "violates type" in exc_msg + ) def test_detect_head_with_zero_sequence_length(): diff --git a/tests/unit/test_svd_interpreter.py b/tests/unit/test_svd_interpreter.py index 6fea5b001..0f656692e 100644 --- a/tests/unit/test_svd_interpreter.py +++ b/tests/unit/test_svd_interpreter.py @@ -1,9 +1,15 @@ +import jaxtyping import pytest import torch -from beartype.roar import BeartypeCallHintParamViolation from transformer_lens import HookedTransformer, SVDInterpreter +# Get TypeCheckError from jaxtyping module (it may be re-exported from typeguard) +TypeCheckError = getattr(jaxtyping, "TypeCheckError", None) +if TypeCheckError is None: + # Fallback to typeguard + from typeguard import TypeCheckError + MODEL = "solu-2l" VECTOR_TYPES = ["OV", "w_in", "w_out"] ATOL = 2e-4 # Absolute tolerance - how far does a float have to be before we consider it no longer equal? @@ -125,8 +131,14 @@ def test_svd_interpreter_returns_different_answers_for_different_models(second_m def test_svd_interpreter_fails_on_invalid_vector_type(model): svd_interpreter = SVDInterpreter(model) - with pytest.raises(BeartypeCallHintParamViolation) as e: + # Type checking can be done by jaxtyping (TypeCheckError) or beartype (BeartypeCallHintParamViolation) + # Catch by checking the exception type name since jaxtyping may wrap typeguard's exception + with pytest.raises(Exception) as exc_info: svd_interpreter.get_singular_vectors("test", layer_index=0, num_vectors=4, head_index=0) + # Verify it's a type checking error (from jaxtyping, typeguard, or beartype) + exc_name = type(exc_info.value).__name__ + assert "TypeCheckError" in exc_name or "Beartype" in exc_name + assert "type-check" in str(exc_info.value).lower() or "vector_type" in str(exc_info.value) def test_svd_interpreter_fails_on_not_passing_required_head_index(model): diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index a24ce0383..2f69220df 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -34,12 +34,22 @@ def __init__( self.rdim = self.B.size(-1) self.mdim = self.B.size(-2) self.has_leading_dims = (self.A.ndim > 2) or (self.B.ndim > 2) - self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + ( - self.ldim, - self.rdim, - ) - self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) - self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) + try: + self.shape = torch.broadcast_shapes(self.A.shape[:-2], self.B.shape[:-2]) + ( + self.ldim, + self.rdim, + ) + except RuntimeError as e: + raise RuntimeError( + f"Shape mismatch: Cannot broadcast leading dimensions. A has shape {self.A.shape}, B has shape {self.B.shape}. {str(e)}" + ) from e + try: + self.A = self.A.broadcast_to(self.shape[:-2] + (self.ldim, self.mdim)) + self.B = self.B.broadcast_to(self.shape[:-2] + (self.mdim, self.rdim)) + except RuntimeError as e: + raise RuntimeError( + f"Shape mismatch: Cannot broadcast tensors. A has shape {self.A.shape}, B has shape {self.B.shape}, expected broadcast shape {self.shape}. {str(e)}" + ) from e @overload def __matmul__( diff --git a/transformer_lens/HookedEncoder.py b/transformer_lens/HookedEncoder.py index ff76ba75f..378f7bbd5 100644 --- a/transformer_lens/HookedEncoder.py +++ b/transformer_lens/HookedEncoder.py @@ -8,7 +8,7 @@ import logging import os -from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, overload +from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union, cast, overload import torch import torch.nn as nn @@ -21,7 +21,6 @@ from transformer_lens.ActivationCache import ActivationCache from transformer_lens.components import ( MLP, - Attention, BertBlock, BertEmbed, BertMLMHead, @@ -29,6 +28,7 @@ BertPooler, Unembed, ) +from transformer_lens.components.mlps.gated_mlp import GatedMLP from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -48,6 +48,12 @@ class HookedEncoder(HookedRootModule): - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model """ + blocks: nn.ModuleList[BertBlock] # type: ignore[type-arg] + + def _get_blocks(self) -> list[BertBlock]: + """Helper to get blocks with proper typing.""" + return [cast(BertBlock, block) for block in self.blocks] + def __init__( self, cfg: Union[HookedTransformerConfig, Dict], @@ -463,86 +469,70 @@ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: @property def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the key weights across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.W_K for block in self.blocks], dim=0) + return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) @property def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the query weights across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) + return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) @property def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stacks the value weights across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.W_V for block in self.blocks], dim=0) + return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) @property def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: """Stacks the attn output weights across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.W_O for block in self.blocks], dim=0) + return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) @property def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: """Stacks the MLP input weights across all layers""" - for block in self.blocks: - assert isinstance(block.mlp, MLP) - return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 + ) @property def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: """Stacks the MLP output weights across all layers""" - for block in self.blocks: - assert isinstance(block.mlp, MLP) - return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 + ) @property def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stacks the key biases across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.b_K for block in self.blocks], dim=0) + return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) @property def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stacks the query biases across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) + return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) @property def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stacks the value biases across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.b_V for block in self.blocks], dim=0) + return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) @property def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: """Stacks the attn output biases across all layers""" - for block in self.blocks: - assert isinstance(block.attn, Attention) - return torch.stack([block.attn.b_O for block in self.blocks], dim=0) + return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) @property def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: """Stacks the MLP input biases across all layers""" - for block in self.blocks: - assert isinstance(block.mlp, MLP) - return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 + ) @property def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: """Stacks the MLP output biases across all layers""" - for block in self.blocks: - assert isinstance(block.mlp, MLP) - return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 + ) @property def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 518ef0afe..9627731d0 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -57,6 +57,8 @@ TransformerBlock, Unembed, ) +from transformer_lens.components.mlps.gated_mlp import GatedMLP +from transformer_lens.components.mlps.mlp import MLP from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig from transformer_lens.FactoredMatrix import FactoredMatrix from transformer_lens.hook_points import HookedRootModule, HookPoint @@ -117,6 +119,7 @@ class HookedTransformer(HookedRootModule): ln_final: nn.Module tokenizer: Optional[PreTrainedTokenizerBase] + blocks: nn.ModuleList[TransformerBlock] # type: ignore[type-arg] def __init__( self, @@ -148,7 +151,6 @@ def __init__( ) self.cfg = HookedTransformerConfig.unwrap(cfg) - if tokenizer is not None: self.set_tokenizer(tokenizer, default_padding_side=default_padding_side) elif self.cfg.tokenizer_name is not None: @@ -166,10 +168,15 @@ def __init__( if "phi" in self.cfg.tokenizer_name.lower(): use_fast = False huggingface_token = os.environ.get("HF_TOKEN", "") + add_bos_token = self.cfg.original_architecture not in [ + "OlmoForCausalLM", + "OlmoeForCausalLM", + "Olmo2ForCausalLM", + ] self.set_tokenizer( AutoTokenizer.from_pretrained( self.cfg.tokenizer_name, - add_bos_token=True, + add_bos_token=add_bos_token, trust_remote_code=self.cfg.trust_remote_code, use_fast=use_fast, token=huggingface_token if len(huggingface_token) > 0 else None, @@ -740,7 +747,14 @@ def set_tokenizer( # tokenizers like LlamaTokenizer are different when bos token is automatically/manually # prepended, and add_bos_token cannot be dynamically controlled after initialization # (https://github.com/huggingface/transformers/issues/25886). - tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) + tokenizer_with_bos = tokenizer + if self.cfg.original_architecture not in [ + "OlmoForCausalLM", + "OlmoeForCausalLM", + "Olmo2ForCausalLM", + ]: + tokenizer_with_bos = utils.get_tokenizer_with_bos(tokenizer) + self.tokenizer = tokenizer_with_bos assert self.tokenizer is not None # keep mypy happy @@ -1389,6 +1403,34 @@ def from_pretrained( "Setting center_writing_weights=False instead." ) center_writing_weights = False + # OLMo 2 uses post-norm (norm after attention/MLP, not before), which is + # incompatible with weight processing that assumes pre-norm structure. + if cfg.original_architecture == "Olmo2ForCausalLM": + if fold_ln: + logging.warning( + "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " + "Setting fold_ln=False." + ) + fold_ln = False + if center_writing_weights: + logging.warning( + "center_writing_weights=True is incompatible with OLMo 2's post-norm " + "architecture. Setting center_writing_weights=False." + ) + center_writing_weights = False + if center_unembed: + logging.warning( + "center_unembed=True is incompatible with OLMo 2's post-norm " + "architecture (uses RMSNorm which does not center). " + "Setting center_unembed=False." + ) + center_unembed = False + if fold_value_biases: + logging.warning( + "fold_value_biases=True is incompatible with OLMo 2's post-norm " + "architecture. Setting fold_value_biases=False." + ) + fold_value_biases = False if center_unembed and cfg.output_logits_soft_cap > 0.0: logging.warning( "You tried to specify center_unembed=True for a model using logit softcap, but this can't be done! Softcapping is not invariant upon adding a constant " @@ -1805,7 +1847,7 @@ def process_weights_( # but it's the easiest way to do it. self.cfg.normalization_type = "LNPre" self.ln_final = LayerNormPre(self.cfg) - for layer in self.blocks: + for layer in self._get_blocks(): layer.ln1 = LayerNormPre(self.cfg) layer.ln2 = LayerNormPre(self.cfg) if self.cfg.is_layer_norm_activation(): @@ -1814,7 +1856,7 @@ def process_weights_( # We do the same for RMSNorm if used self.cfg.normalization_type = "RMSPre" self.ln_final = RMSNormPre(self.cfg) - for layer in self.blocks: + for layer in self._get_blocks(): layer.ln1 = RMSNormPre(self.cfg) layer.ln2 = RMSNormPre(self.cfg) if self.cfg.is_layer_norm_activation(): @@ -2205,30 +2247,36 @@ def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: # we want to do analysis on weights across all layers. If GPU memory is a bottleneck, don't use # these properties! + def _get_blocks(self) -> list[TransformerBlock]: + """Helper to get blocks with proper typing.""" + return [cast(TransformerBlock, block) for block in self.blocks] + @property def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stack the key weights across all layers.""" - return torch.stack([block.attn.W_K for block in self.blocks], dim=0) + return torch.stack([block.attn.W_K for block in self._get_blocks()], dim=0) @property def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stack the query weights across all layers.""" - return torch.stack([block.attn.W_Q for block in self.blocks], dim=0) + return torch.stack([block.attn.W_Q for block in self._get_blocks()], dim=0) @property def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: """Stack the value weights across all layers.""" - return torch.stack([block.attn.W_V for block in self.blocks], dim=0) + return torch.stack([block.attn.W_V for block in self._get_blocks()], dim=0) @property def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: """Stack the attn output weights across all layers.""" - return torch.stack([block.attn.W_O for block in self.blocks], dim=0) + return torch.stack([block.attn.W_O for block in self._get_blocks()], dim=0) @property def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: """Stack the MLP input weights across all layers.""" - return torch.stack([block.mlp.W_in for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).W_in for block in self._get_blocks()], dim=0 + ) @property def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: @@ -2237,44 +2285,52 @@ def W_gate(self) -> Union[Float[torch.Tensor, "n_layers d_model d_mlp"], None]: Only works for models with gated MLPs. """ if self.cfg.gated_mlp: - return torch.stack([block.mlp.W_gate for block in self.blocks], dim=0) + return torch.stack( + [cast(GatedMLP, block.mlp).W_gate for block in self._get_blocks()], dim=0 + ) else: return None @property def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: """Stack the MLP output weights across all layers.""" - return torch.stack([block.mlp.W_out for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).W_out for block in self._get_blocks()], dim=0 + ) @property def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stack the key biases across all layers.""" - return torch.stack([block.attn.b_K for block in self.blocks], dim=0) + return torch.stack([block.attn.b_K for block in self._get_blocks()], dim=0) @property def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stack the query biases across all layers.""" - return torch.stack([block.attn.b_Q for block in self.blocks], dim=0) + return torch.stack([block.attn.b_Q for block in self._get_blocks()], dim=0) @property def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: """Stack the value biases across all layers.""" - return torch.stack([block.attn.b_V for block in self.blocks], dim=0) + return torch.stack([block.attn.b_V for block in self._get_blocks()], dim=0) @property def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: """Stack the attn output biases across all layers.""" - return torch.stack([block.attn.b_O for block in self.blocks], dim=0) + return torch.stack([block.attn.b_O for block in self._get_blocks()], dim=0) @property def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: """Stack the MLP input biases across all layers.""" - return torch.stack([block.mlp.b_in for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).b_in for block in self._get_blocks()], dim=0 + ) @property def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: """Stack the MLP output biases across all layers.""" - return torch.stack([block.mlp.b_out for block in self.blocks], dim=0) + return torch.stack( + [cast(Union[MLP, GatedMLP], block.mlp).b_out for block in self._get_blocks()], dim=0 + ) @property def QK(self): diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 1c5d5b585..dc6926525 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -89,6 +89,18 @@ def __init__( if self.cfg.use_qk_norm: self.q_norm = RMSNorm(self.cfg, length=self.cfg.d_head) self.k_norm = RMSNorm(self.cfg, length=self.cfg.d_head) + + elif ( + self.cfg.original_architecture == "OlmoeForCausalLM" + or self.cfg.original_architecture == "Olmo2ForCausalLM" + ): + self.q_norm: Optional[RMSNorm] = RMSNorm(self.cfg, self.cfg.d_model) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + k_norm_dim = self.cfg.d_model + else: + assert self.cfg.n_key_value_heads is not None + k_norm_dim = self.cfg.d_head * self.cfg.n_key_value_heads + self.k_norm: Optional[RMSNorm] = RMSNorm(self.cfg, k_norm_dim) else: self.q_norm = None self.k_norm = None @@ -217,6 +229,34 @@ def forward( q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) + # OLMoE uses QK-norm. + if ( + self.cfg.original_architecture == "OlmoeForCausalLM" + or self.cfg.original_architecture == "Olmo2ForCausalLM" + ): + assert self.q_norm is not None + assert self.k_norm is not None + q = einops.rearrange( + self.q_norm( + einops.rearrange( + q, + "batch pos head_index d_head -> batch pos (head_index d_head)", + ) + ), + "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", + head_index=q.shape[2], + ) + k = einops.rearrange( + self.k_norm( + einops.rearrange( + k, + "batch pos head_index d_head -> batch pos (head_index d_head)", + ) + ), + "batch kv_pos (head_index d_head) -> batch kv_pos head_index d_head", + head_index=k.shape[2], + ) + if past_kv_cache_entry is not None: # Appends the new keys and values to the cached values, and automatically updates the cache kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) @@ -271,7 +311,8 @@ def forward( device=attn_scores.device, ) - attn_scores += position_bias + if position_bias is not None: # Add None check + attn_scores += position_bias if self.cfg.attention_dir == "causal": # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. attn_scores = self.apply_causal_mask( @@ -732,7 +773,7 @@ def create_alibi_slope( @staticmethod def create_alibi_multipliers( n_heads: int, device: Optional[Union[str, torch.device]] = None - ) -> Float[torch.Tensor, "head_idx"]: + ) -> Float[torch.Tensor, "n_heads"]: """Create the ALiBi Scalar Multipliers for each Head. For n heads, the set of multipliers (m) is the geometric sequence that starts at 2^(-8/n), and diff --git a/transformer_lens/components/mlps/moe.py b/transformer_lens/components/mlps/moe.py index 1d537a187..305cdcf2b 100644 --- a/transformer_lens/components/mlps/moe.py +++ b/transformer_lens/components/mlps/moe.py @@ -88,7 +88,8 @@ def forward( # both are [batch, pos, experts_per_token] weights = self.hook_expert_weights(F.softmax(gate_logits, dim=1, dtype=torch.float)) weights, expert_indices = torch.topk(weights, self.experts_per_token, dim=-1) - weights /= weights.sum(dim=-1, keepdim=True) + if self.cfg.norm_topk_prob: + weights /= weights.sum(dim=-1, keepdim=True) expert_indices = self.hook_expert_indices(expert_indices) weights = weights.to(x.dtype) diff --git a/transformer_lens/components/transformer_block.py b/transformer_lens/components/transformer_block.py index 8fcbec0ae..a74b29092 100644 --- a/transformer_lens/components/transformer_block.py +++ b/transformer_lens/components/transformer_block.py @@ -155,20 +155,29 @@ def forward( key_input = attn_in value_input = attn_in - attn_out = ( - # hook the residual stream states that are used to calculate the - # queries, keys and values, independently. - # Then take the layer norm of these inputs, and pass these to the attention module. - self.attn( - query_input=self.ln1(query_input) - + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), - key_input=self.ln1(key_input) - + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), - value_input=self.ln1(value_input), + if self.cfg.original_architecture == "Olmo2ForCausalLM": + attn_out = self.attn( + query_input=query_input, + key_input=key_input, + value_input=value_input, past_kv_cache_entry=past_kv_cache_entry, attention_mask=attention_mask, ) - ) # [batch, pos, d_model] + else: + attn_out = ( + # hook the residual stream states that are used to calculate the + # queries, keys and values, independently. + # Then take the layer norm of these inputs, and pass these to the attention module. + self.attn( + query_input=self.ln1(query_input) + + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), + key_input=self.ln1(key_input) + + (0.0 if shortformer_pos_embed is None else shortformer_pos_embed), + value_input=self.ln1(value_input), + past_kv_cache_entry=past_kv_cache_entry, + attention_mask=attention_mask, + ) + ) # [batch, pos, d_model] if self.cfg.use_normalization_before_and_after: # If we use LayerNorm both before and after, then apply the second LN after the layer # and before the hook. We do it before the hook so hook_attn_out captures "that which @@ -176,6 +185,9 @@ def forward( attn_out = self.ln1_post(attn_out) attn_out = self.hook_attn_out(attn_out) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + attn_out = self.ln1(attn_out) + if resid_pre.device != attn_out.device: resid_pre = resid_pre.to(attn_out.device) @@ -184,8 +196,12 @@ def forward( mlp_in = ( resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) ) - normalized_resid_mid = self.ln2(mlp_in) - mlp_out = self.apply_mlp(normalized_resid_mid) + if self.cfg.original_architecture == "Olmo2ForCausalLM": + mlp_out = self.apply_mlp(mlp_in) + mlp_out = self.ln2(mlp_out) + else: + normalized_resid_mid = self.ln2(mlp_in) + mlp_out = self.apply_mlp(normalized_resid_mid) resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model] elif self.cfg.parallel_attn_mlp: # Dumb thing done by GPT-J, both MLP and Attn read from resid_pre and write to resid_post, no resid_mid used. diff --git a/transformer_lens/config/HookedTransformerConfig.py b/transformer_lens/config/HookedTransformerConfig.py index cd4f94cc8..b18023098 100644 --- a/transformer_lens/config/HookedTransformerConfig.py +++ b/transformer_lens/config/HookedTransformerConfig.py @@ -199,6 +199,7 @@ class HookedTransformerConfig(TransformerLensConfig): attention layers. Used by models with hybrid local/global attention (e.g., Gemma 3) which use different RoPE bases for local (10k) and global (1M) attention. Defaults to None, which means the standard rotary_base is used for all layers. + norm_topk_prob (bool): Whether to normalize the top-k probabilities in the MoE layer. """ model_name: str = "custom" @@ -259,6 +260,7 @@ class HookedTransformerConfig(TransformerLensConfig): NTK_by_parts_high_freq_factor: float = 4.0 NTK_by_parts_factor: float = 8.0 NTK_original_ctx_len: int = 8192 + norm_topk_prob: bool = False def __post_init__(self): # Call parent's post_init first diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 7cbfc86a1..3b6fd614b 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -38,6 +38,9 @@ convert_neel_solu_old_weights, convert_neo_weights, convert_neox_weights, + convert_olmo2_weights, + convert_olmo_weights, + convert_olmoe_weights, convert_opt_weights, convert_phi3_weights, convert_phi_weights, @@ -1231,6 +1234,231 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]: "local", ], } + elif official_model_name.startswith("google/gemma-2b"): + # Architecture for Gemma 2b and Gemma 2b Instruct models + cfg_dict = { + "d_model": 2048, + "d_head": 256, + "n_heads": 8, + "d_mlp": 16384, + "n_layers": 18, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu_new", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000, + "rotary_dim": 256, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "n_key_value_heads": 1, + "gated_mlp": True, + "final_rms": True, + } + elif official_model_name.startswith("google/gemma-7b"): + # Architecture for Gemma 7b and Gemma 7b Instruct models + cfg_dict = { + "d_model": 3072, + "d_head": 256, + "n_heads": 16, + "d_mlp": 24576, + "n_layers": 28, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu_new", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000.0, + "rotary_dim": 256, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "n_key_value_heads": 16, + "gated_mlp": True, + "final_rms": True, + } + elif official_model_name.startswith("google/gemma-2-2b"): + # Architecture for Gemma-2 2b and Gemma-2 2b Instruct models + cfg_dict = { + "d_model": 2304, + "d_head": 256, + "n_heads": 8, + "d_mlp": 9216, + "n_layers": 26, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu_pytorch_tanh", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000.0, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "n_key_value_heads": 4, + "window_size": 4096, + "use_local_attn": True, + "attn_types": ["global", "local"] * 21, # Alternate global and local attn + "attn_scores_soft_cap": 50.0, + "output_logits_soft_cap": 30.0, + "gated_mlp": True, + "final_rms": True, + "use_normalization_before_and_after": True, + } + elif official_model_name.startswith("google/gemma-2-9b"): + # Architecture for Gemma-2 9b and Gemma-2 9b Instruct models + cfg_dict = { + "d_model": 3584, + "d_head": 256, + "n_heads": 16, + "d_mlp": 14336, + "n_layers": 42, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu_pytorch_tanh", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000.0, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "n_key_value_heads": 8, + "window_size": 4096, + "use_local_attn": True, + "attn_types": ["global", "local"] * 21, # Alternate global and local attn + "attn_scores_soft_cap": 50.0, + "output_logits_soft_cap": 30.0, + "gated_mlp": True, + "final_rms": True, + "use_normalization_before_and_after": True, + } + elif official_model_name.startswith("google/gemma-2-27b"): + # Architecture for Gemma-2 27b and Gemma-2 27b Instruct models + cfg_dict = { + "d_model": 4608, + "d_head": 128, + "n_heads": 32, + "d_mlp": 36864, + "n_layers": 46, + "n_ctx": 8192, + "eps": 1e-06, + "d_vocab": 256000, + "act_fn": "gelu_pytorch_tanh", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 10000.0, + "positional_embedding_type": "rotary", + "use_attn_scale": True, + "attn_scale": 12.0, + "n_key_value_heads": 16, + "window_size": 4096, + "use_local_attn": True, + "attn_types": ["global", "local"] * 23, # Alternate global and local attn + "attn_scores_soft_cap": 50.0, + "output_logits_soft_cap": 30.0, + "gated_mlp": True, + "final_rms": True, + "use_normalization_before_and_after": True, + } + elif official_model_name.startswith("allenai/OLMo-1B") and official_model_name.endswith("hf"): + cfg_dict = { + "d_model": 2048, + "d_head": 128, + "n_heads": 16, + "d_mlp": 8192, + "n_layers": 16, + "n_ctx": 2048, + "eps": 1e-05, + "d_vocab": 50304, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "LN", + "rotary_base": 10000.0, + "attn_types": ["global"] * 16, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif official_model_name.startswith("allenai/OLMo-7B") and official_model_name.endswith("hf"): + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 11008, + "n_layers": 32, + "n_ctx": 2048, + "eps": 1e-05, + "d_vocab": 50304, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "LN", + "rotary_base": 10000.0, + "attn_types": ["global"] * 32, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif official_model_name.startswith("allenai/OLMo-2-0425-1B"): + cfg_dict = { + "d_model": 2048, + "d_head": 128, + "n_heads": 16, + "d_mlp": 8192, + "n_layers": 16, + "n_ctx": 4096, + "eps": 1e-06, + "d_vocab": 100352, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 500000.0, + "attn_types": ["global"] * 16, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif official_model_name.startswith("allenai/OLMo-2-1124-7B"): + cfg_dict = { + "d_model": 4096, + "d_head": 128, + "n_heads": 32, + "d_mlp": 11008, + "n_layers": 32, + "n_ctx": 4096, + "eps": 1e-06, + "d_vocab": 100352, + "act_fn": "silu", + "initializer_range": 0.02, + "normalization_type": "RMS", + "rotary_base": 500000.0, + "attn_types": ["global"] * 32, + "positional_embedding_type": "rotary", + "gated_mlp": True, + } + elif architecture == "OlmoeForCausalLM": + cfg_dict = { + "d_model": hf_config.hidden_size, + "d_head": hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, + "n_ctx": hf_config.max_position_embeddings, + "eps": hf_config.rms_norm_eps, + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "num_experts": hf_config.num_experts, + "experts_per_token": hf_config.num_experts_per_tok, + "norm_topk_prob": hf_config.norm_topk_prob, + "n_key_value_heads": hf_config.num_key_value_heads, + "rotary_base": getattr( + hf_config, + "rope_theta", + hf_config.rope_parameters.get("rope_theta", 10000.0), + ), + "tie_word_embeddings": hf_config.tie_word_embeddings, + "initializer_range": hf_config.initializer_range, + "positional_embedding_type": "rotary", + "rotary_dim": hf_config.hidden_size // hf_config.num_attention_heads, + "gated_mlp": True, + "normalization_type": "RMS", + } elif architecture == "T5ForConditionalGeneration": cfg_dict = { "d_model": hf_config.d_model, @@ -1390,6 +1618,15 @@ def get_pretrained_model_config( ) fold_ln = False + # OLMo 2 uses post-norm (norm after attention/MLP, not before), so folding + # the norm weights into adjacent linear layers is not mathematically valid. + if cfg_dict.get("original_architecture") == "Olmo2ForCausalLM" and fold_ln: + logging.warning( + "fold_ln=True is incompatible with OLMo 2's post-norm architecture. " + "Setting fold_ln=False." + ) + fold_ln = False + if device is not None: cfg_dict["device"] = device @@ -1668,6 +1905,12 @@ def get_pretrained_state_dict( state_dict = convert_gemma_weights(hf_model, cfg) elif cfg.original_architecture == "Gemma3ForConditionalGeneration": state_dict = convert_gemma_weights(hf_model, cfg) + elif cfg.original_architecture == "OlmoForCausalLM": + state_dict = convert_olmo_weights(hf_model, cfg) + elif cfg.original_architecture == "Olmo2ForCausalLM": + state_dict = convert_olmo2_weights(hf_model, cfg) + elif cfg.original_architecture == "OlmoeForCausalLM": + state_dict = convert_olmoe_weights(hf_model, cfg) else: raise ValueError( f"Loading weights from the architecture is not currently supported: {cfg.original_architecture}, generated from model name {cfg.model_name}. Feel free to open an issue on GitHub to request this feature." diff --git a/transformer_lens/pretrained/weight_conversions/__init__.py b/transformer_lens/pretrained/weight_conversions/__init__.py index c5ea9581b..bba841a29 100644 --- a/transformer_lens/pretrained/weight_conversions/__init__.py +++ b/transformer_lens/pretrained/weight_conversions/__init__.py @@ -19,3 +19,6 @@ from .nanogpt import convert_nanogpt_weights from .t5 import convert_t5_weights from .neel_solu_old import convert_neel_solu_old_weights +from .olmo import convert_olmo_weights +from .olmoe import convert_olmoe_weights +from .olmo2 import convert_olmo2_weights diff --git a/transformer_lens/pretrained/weight_conversions/olmo.py b/transformer_lens/pretrained/weight_conversions/olmo.py new file mode 100644 index 000000000..afbbb6263 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmo.py @@ -0,0 +1,50 @@ +import einops +import torch + +from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmo_weights(olmo, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + + state_dict["embed.W_E"] = olmo.model.embed_tokens.weight + for l in range(cfg.n_layers): + olmo_layer = olmo.model.layers[l] + + W_Q = olmo_layer.self_attn.q_proj.weight + W_K = olmo_layer.self_attn.k_proj.weight + W_V = olmo_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) + W_K = einops.rearrange(W_K, "(i h) m->i m h", i=cfg.n_heads) + W_V = einops.rearrange(W_V, "(i h) m->i m h", i=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + + W_O = olmo_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_in"] = olmo_layer.mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = olmo_layer.mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = olmo_layer.mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln1.w"] = torch.ones(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln1.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln2.w"] = torch.ones(cfg.d_model, dtype=cfg.dtype) + state_dict[f"blocks.{l}.ln2.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["ln_final.w"] = torch.ones(cfg.d_model, dtype=cfg.dtype) + state_dict["ln_final.b"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict["unembed.W_U"] = olmo.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict diff --git a/transformer_lens/pretrained/weight_conversions/olmo2.py b/transformer_lens/pretrained/weight_conversions/olmo2.py new file mode 100644 index 000000000..4b441dbe6 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmo2.py @@ -0,0 +1,57 @@ +import einops +import torch +from transformers.models.olmo2.modeling_olmo2 import Olmo2DecoderLayer + +from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmo2_weights(olmo2, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.d_mlp is not None + + state_dict["embed.W_E"] = olmo2.model.embed_tokens.weight + + for l in range(cfg.n_layers): + olmo2_layer = olmo2.model.layers[l] + assert isinstance(olmo2_layer, Olmo2DecoderLayer) + + W_Q = olmo2_layer.self_attn.q_proj.weight + W_K = olmo2_layer.self_attn.k_proj.weight + W_V = olmo2_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn.W_K"] = W_K + state_dict[f"blocks.{l}.attn.W_V"] = W_V + state_dict[f"blocks.{l}.attn.q_norm.w"] = olmo2_layer.self_attn.q_norm.weight + state_dict[f"blocks.{l}.attn.k_norm.w"] = olmo2_layer.self_attn.k_norm.weight + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_K"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn.b_V"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + + W_O = olmo2_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln1.w"] = olmo2_layer.post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_in"] = olmo2_layer.mlp.up_proj.weight.T + state_dict[f"blocks.{l}.mlp.W_gate"] = olmo2_layer.mlp.gate_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.mlp.W_out"] = olmo2_layer.mlp.down_proj.weight.T + state_dict[f"blocks.{l}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = olmo2_layer.post_feedforward_layernorm.weight + + state_dict["ln_final.w"] = olmo2.model.norm.weight + + state_dict["unembed.W_U"] = olmo2.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict diff --git a/transformer_lens/pretrained/weight_conversions/olmoe.py b/transformer_lens/pretrained/weight_conversions/olmoe.py new file mode 100644 index 000000000..16efe3dc3 --- /dev/null +++ b/transformer_lens/pretrained/weight_conversions/olmoe.py @@ -0,0 +1,69 @@ +import einops +import torch + +from transformer_lens.config.HookedTransformerConfig import HookedTransformerConfig + + +def convert_olmoe_weights(olmoe, cfg: HookedTransformerConfig): + state_dict = {} + + assert cfg.n_key_value_heads is not None + assert cfg.d_mlp is not None + assert cfg.num_experts is not None + + state_dict["embed.W_E"] = olmoe.model.embed_tokens.weight + + for l in range(cfg.n_layers): + olmoe_layer = olmoe.model.layers[l] + state_dict[f"blocks.{l}.ln1.w"] = olmoe_layer.input_layernorm.weight + + W_Q = olmoe_layer.self_attn.q_proj.weight + W_K = olmoe_layer.self_attn.k_proj.weight + W_V = olmoe_layer.self_attn.v_proj.weight + W_Q = einops.rearrange(W_Q, "(n h) m->n m h", n=cfg.n_heads) + W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads) + W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads) + state_dict[f"blocks.{l}.attn.W_Q"] = W_Q + state_dict[f"blocks.{l}.attn._W_K"] = W_K + state_dict[f"blocks.{l}.attn._W_V"] = W_V + state_dict[f"blocks.{l}.attn.q_norm.w"] = olmoe_layer.self_attn.q_norm.weight + state_dict[f"blocks.{l}.attn.k_norm.w"] = olmoe_layer.self_attn.k_norm.weight + + state_dict[f"blocks.{l}.attn.b_Q"] = torch.zeros(cfg.n_heads, cfg.d_head, dtype=cfg.dtype) + state_dict[f"blocks.{l}.attn._b_K"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + state_dict[f"blocks.{l}.attn._b_V"] = torch.zeros( + cfg.n_key_value_heads, cfg.d_head, dtype=cfg.dtype + ) + + W_O = olmoe_layer.self_attn.o_proj.weight + W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) + state_dict[f"blocks.{l}.attn.W_O"] = W_O + + state_dict[f"blocks.{l}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + + state_dict[f"blocks.{l}.ln2.w"] = olmoe_layer.post_attention_layernorm.weight + + state_dict[f"blocks.{l}.mlp.W_gate.weight"] = olmoe_layer.mlp.gate.weight + + # HF OLMoE uses batched expert weights: + # gate_up_proj: [num_experts, 2 * intermediate_size, hidden_size] + # down_proj: [num_experts, hidden_size, intermediate_size] + # The gate_up_proj fuses gate and up projections along dim 1. + experts = olmoe_layer.mlp.experts + gate_up = experts.gate_up_proj # [num_experts, 2*d_mlp, d_model] + down = experts.down_proj # [num_experts, d_model, d_mlp] + + for e in range(cfg.num_experts): + # Split fused gate_up into gate and up projections + state_dict[f"blocks.{l}.mlp.experts.{e}.W_gate.weight"] = gate_up[e, : cfg.d_mlp, :] + state_dict[f"blocks.{l}.mlp.experts.{e}.W_in.weight"] = gate_up[e, cfg.d_mlp :, :] + state_dict[f"blocks.{l}.mlp.experts.{e}.W_out.weight"] = down[e] + + state_dict["ln_final.w"] = olmoe.model.norm.weight + + state_dict["unembed.W_U"] = olmoe.lm_head.weight.T + state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) + + return state_dict diff --git a/transformer_lens/supported_models.py b/transformer_lens/supported_models.py index a3e8f86c3..dbca9b381 100644 --- a/transformer_lens/supported_models.py +++ b/transformer_lens/supported_models.py @@ -4,6 +4,11 @@ "01-ai/Yi-6B", "01-ai/Yi-6B-Chat", "ai-forever/mGPT", + "allenai/OLMo-1B-hf", + "allenai/OLMo-2-0425-1B", + "allenai/OLMo-2-1124-7B", + "allenai/OLMo-7B-hf", + "allenai/OLMoE-1B-7B-0924", "ArthurConmy/redwood_attn_2l", "Baidicoot/Othello-GPT-Transformer-Lens", "bigcode/santacoder", @@ -255,6 +260,11 @@ "01-ai/Yi-6B": ["yi-6b", "Yi-6B"], "01-ai/Yi-6B-Chat": ["yi-6b-chat", "Yi-6B-Chat"], "ai-forever/mGPT": ["mGPT"], + "allenai/OLMo-1B-hf": ["olmo-1b"], + "allenai/OLMo-2-0425-1B": ["olmo-2-1b"], + "allenai/OLMo-2-1124-7B": ["olmo-2-7b"], + "allenai/OLMo-7B-hf": ["olmo-7b"], + "allenai/OLMoE-1B-7B-0924": ["olmoe"], "ArthurConmy/redwood_attn_2l": ["redwood_attn_2l"], "Baidicoot/Othello-GPT-Transformer-Lens": ["othello-gpt"], "bigcode/santacoder": ["santacoder"], diff --git a/transformer_lens/utilities/__init__.py b/transformer_lens/utilities/__init__.py index 460e85db1..5ecce3896 100644 --- a/transformer_lens/utilities/__init__.py +++ b/transformer_lens/utilities/__init__.py @@ -24,6 +24,7 @@ select_compatible_kwargs, ) from .initialization_utils import ( + NonlinearityType, calc_fan_in_and_fan_out, init_kaiming_normal_, init_kaiming_uniform_, diff --git a/transformer_lens/utilities/initialization_utils.py b/transformer_lens/utilities/initialization_utils.py index 0309e7c00..064b9d842 100644 --- a/transformer_lens/utilities/initialization_utils.py +++ b/transformer_lens/utilities/initialization_utils.py @@ -6,7 +6,25 @@ from __future__ import annotations import numpy as np +import torch import torch.nn as nn +from typing_extensions import Literal + +# Type alias for valid nonlinearity values accepted by nn.init.calculate_gain +NonlinearityType = Literal[ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + "sigmoid", + "tanh", + "relu", + "leaky_relu", + "selu", +] def calc_fan_in_and_fan_out(tensor): @@ -52,7 +70,13 @@ def init_xavier_normal_(param, gain=1.0): return nn.init.normal_(param, mean=0.0, std=std) -def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): +def init_kaiming_uniform_( + param: torch.Tensor, + a: float = 0, + nonlinearity: NonlinearityType = "relu", + gain: float = 1.0, + mode: str = "fan_in", +) -> torch.Tensor: """ Initializes the input tensor using the Kaiming initialization method. @@ -68,7 +92,13 @@ def init_kaiming_uniform_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_i return nn.init.uniform_(param, -max, max) -def init_kaiming_normal_(param, a=0, nonlinearity="relu", gain=1.0, mode="fan_in"): +def init_kaiming_normal_( + param: torch.Tensor, + a: float = 0, + nonlinearity: NonlinearityType = "relu", + gain: float = 1.0, + mode: str = "fan_in", +) -> torch.Tensor: """ Initializes the input tensor using the Kaiming initialization method. diff --git a/transformer_lens/weight_processing.py b/transformer_lens/weight_processing.py index 8a2fa63cf..c06351e84 100644 --- a/transformer_lens/weight_processing.py +++ b/transformer_lens/weight_processing.py @@ -503,6 +503,30 @@ def _fold_mlp_layer_norm( ln2_w_key = ProcessWeights._get_param_key(f"blocks.{layer}.ln2.w", adapter) # CRITICAL FIX: For RMS norm (Gemma), ln2_b doesn't exist. Only require ln2_w! if ln2_w_key in state_dict: + # MoE layers: fold ln2 into router gate and each expert's W_in/W_gate + if getattr(cfg, "num_experts", None) is not None and cfg.num_experts > 0: + ln2_w = state_dict[ln2_w_key] + # Fold into router gate + router_key = f"blocks.{layer}.mlp.W_gate.weight" + if router_key in state_dict: + state_dict[router_key] = state_dict[router_key] * ln2_w[None, :] + # Fold into each expert's W_in and W_gate (SwiGLU gate) + for e in range(cfg.num_experts): + for suffix in ("W_in.weight", "W_gate.weight"): + key = f"blocks.{layer}.mlp.experts.{e}.{suffix}" + if key in state_dict: + state_dict[key] = state_dict[key] * ln2_w[None, :] + # Set ln2.w to identity + state_dict[ln2_w_key] = torch.ones_like(ln2_w) + alternate_ln2_w_key = ( + ln2_w_key.replace("ln_2", "ln2") + if "ln_2" in ln2_w_key + else ln2_w_key.replace("ln2", "ln_2") + ) + if alternate_ln2_w_key != ln2_w_key and alternate_ln2_w_key in state_dict: + state_dict[alternate_ln2_w_key] = torch.ones_like(ln2_w) + return state_dict + mlp_W_in = ProcessWeights.convert_tensor_to_tl_format( mlp_W_in_key, state_dict, state_dict.get(mlp_W_in_key), cfg, adapter, layer ) @@ -884,52 +908,56 @@ def center_writing_weights( Returns: Dict[str, torch.Tensor]: Modified state dict with centered writing weights. """ - # Make a deep copy to avoid modifying the original - embed_W_E_key = ProcessWeights._get_param_key("embed.W_E", adapter) - try: - pos_embed_W_pos_key = ( - ProcessWeights._get_param_key("pos_embed.W_pos", adapter) - if getattr(cfg, "positional_embedding_type", "standard") != "rotary" - else None - ) - except ValueError: - pos_embed_W_pos_key = None - if embed_W_E_key not in state_dict: - raise KeyError( - f"Expected embedding key '{embed_W_E_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." - ) - embed_W_E = ProcessWeights.convert_tensor_to_tl_format( - embed_W_E_key, state_dict, state_dict.get(embed_W_E_key), cfg, adapter, None - ) - assert embed_W_E is not None, f"Embedding not found at key {embed_W_E_key}" - embed_W_E = embed_W_E - embed_W_E.mean(-1, keepdim=True) - state_dict[embed_W_E_key] = ProcessWeights.convert_tensor_to_hf_format( - embed_W_E_key, embed_W_E, cfg, adapter, None - ) - - if ( - getattr(cfg, "positional_embedding_type", "standard") != "rotary" - and pos_embed_W_pos_key is not None - ): - if pos_embed_W_pos_key not in state_dict: + # Skip centering for Olmo2 models - input of attn of 1st layer is not normed + if getattr(cfg, "original_architecture", None) == "Olmo2ForCausalLM": + print("Not centering embedding weights for Olmo2ForCausalLM") + else: + # Make a deep copy to avoid modifying the original + embed_W_E_key = ProcessWeights._get_param_key("embed.W_E", adapter) + try: + pos_embed_W_pos_key = ( + ProcessWeights._get_param_key("pos_embed.W_pos", adapter) + if getattr(cfg, "positional_embedding_type", "standard") != "rotary" + else None + ) + except ValueError: + pos_embed_W_pos_key = None + if embed_W_E_key not in state_dict: raise KeyError( - f"Expected positional embedding key '{pos_embed_W_pos_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." + f"Expected embedding key '{embed_W_E_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." ) - pos_embed_W_pos = ProcessWeights.convert_tensor_to_tl_format( - pos_embed_W_pos_key, - state_dict, - state_dict.get(pos_embed_W_pos_key), - cfg, - adapter, - None, + embed_W_E = ProcessWeights.convert_tensor_to_tl_format( + embed_W_E_key, state_dict, state_dict.get(embed_W_E_key), cfg, adapter, None ) - assert ( - pos_embed_W_pos is not None - ), f"Positional embedding not found at key {pos_embed_W_pos_key}" - pos_embed_W_pos = pos_embed_W_pos - pos_embed_W_pos.mean(-1, keepdim=True) - state_dict[pos_embed_W_pos_key] = ProcessWeights.convert_tensor_to_hf_format( - pos_embed_W_pos_key, pos_embed_W_pos, cfg, adapter, None + assert embed_W_E is not None, f"Embedding not found at key {embed_W_E_key}" + embed_W_E = embed_W_E - embed_W_E.mean(-1, keepdim=True) + state_dict[embed_W_E_key] = ProcessWeights.convert_tensor_to_hf_format( + embed_W_E_key, embed_W_E, cfg, adapter, None ) + + if ( + getattr(cfg, "positional_embedding_type", "standard") != "rotary" + and pos_embed_W_pos_key is not None + ): + if pos_embed_W_pos_key not in state_dict: + raise KeyError( + f"Expected positional embedding key '{pos_embed_W_pos_key}' not found in state_dict. Available keys: {list(state_dict.keys())[:10]}..." + ) + pos_embed_W_pos = ProcessWeights.convert_tensor_to_tl_format( + pos_embed_W_pos_key, + state_dict, + state_dict.get(pos_embed_W_pos_key), + cfg, + adapter, + None, + ) + assert ( + pos_embed_W_pos is not None + ), f"Positional embedding not found at key {pos_embed_W_pos_key}" + pos_embed_W_pos = pos_embed_W_pos - pos_embed_W_pos.mean(-1, keepdim=True) + state_dict[pos_embed_W_pos_key] = ProcessWeights.convert_tensor_to_hf_format( + pos_embed_W_pos_key, pos_embed_W_pos, cfg, adapter, None + ) for l in range(cfg.n_layers): attn_W_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.W_O", adapter) attn_b_O_key = ProcessWeights._get_param_key(f"blocks.{l}.attn.b_O", adapter) @@ -1282,9 +1310,7 @@ def process_weights( Dict[str, torch.Tensor]: Fully processed state dict. """ if fold_ln: - if getattr(cfg, "num_experts", None) and cfg.num_experts > 1: - pass - elif getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"]: + if getattr(cfg, "normalization_type", "LN") in ["LN", "LNPre"]: state_dict = ProcessWeights.fold_layer_norm( state_dict, cfg, fold_biases=True, center_weights=True, adapter=adapter )