From 5f5e616bec1c455d82a9b6597b2582b2df09aad4 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:36:36 +0100 Subject: [PATCH 1/8] feat: implemented Llama3-like initialization for GPT2 models --- .../models/gpt2/llama3_like_initialization.py | 106 ++++++++++++++++++ src/modalities/registry/components.py | 7 ++ 2 files changed, 113 insertions(+) create mode 100644 src/modalities/models/gpt2/llama3_like_initialization.py diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py new file mode 100644 index 000000000..71127ddfe --- /dev/null +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -0,0 +1,106 @@ +import math +import re +from functools import partial +from typing import Annotated + +import torch.nn as nn +from pydantic import BaseModel, Field + +from modalities.nn.model_initialization.initialization_if import ModelInitializationIF +from modalities.utils.logger_utils import get_logger + +logger = get_logger(name="llama3 initialization") + + +class Llama3InitializerConfig(BaseModel): + num_layers: Annotated[int, Field(strict=True, gt=0)] + n_embd: Annotated[int, Field(strict=True, gt=0)] + + +class Llama3Initializer(ModelInitializationIF): + """ + Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. + """ + + def __init__(self, num_layers: int, n_embd: int) -> None: + super().__init__() + + self.regex_to_init = { + # embedding weights + r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1), + r"transformer\.wpe\.weight": partial(nn.init.normal_, mean=0.0, std=1), + # lm head weights + r"transformer\.lm_head\.weight": partial( + nn.init.trunc_normal_, + mean=0.0, + std=1 / math.sqrt(n_embd), + a=-3 / math.sqrt(n_embd), + b=3 / math.sqrt(n_embd), + ), + # qkv projections + r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02, + a=-2, + b=2, + ), + r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02, + a=-2, + b=2, + ), + # final attention projection in attention block + r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02 / math.sqrt(2 * num_layers), + a=-2, + b=2, + ), + r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02 / math.sqrt(2 * num_layers), + a=-2, + b=2, + ), + # SwiGLU + r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02, + a=-2, + b=2, + ), + r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_, + r"transformer\.h\.\w+\.mlp\.(V|W_2)\.weight": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02 / math.sqrt(2 * num_layers), + a=-2, + b=2, + ), + r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_, + } + + def initialize_in_place(self, model: nn.Module): + self._init_by_fqn_regex(model, self.regex_to_init) + + @staticmethod + def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]): + for parameter_name, p in model.named_parameters(): + match_count = 0 + for weight_regex in regex_to_init.keys(): + if re.fullmatch(weight_regex, parameter_name): + init_fn = regex_to_init[weight_regex] + init_fn(p) + match_count += 1 + if match_count == 0: + logger.warning(f"Parameter {parameter_name} did not match any regex for initialization") + elif match_count > 1: + raise ValueError( + f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed" + ) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 551e521e1..4c46bf1e8 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -92,6 +92,7 @@ ) from modalities.models.gpt2.collator import GPT2LLMCollateFn from modalities.models.gpt2.gpt2_model import GPT2LLMConfig +from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory @@ -240,6 +241,12 @@ class ComponentEntity: ComposedInitializationRoutines.get_composed_model_initializer, ComposedModelInitializationConfig, ), + ComponentEntity( + "model_initialization", + "llama3_like", + Llama3Initializer, + Llama3InitializerConfig, + ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), # optimizers From 4dea4965acdb6319c84f1cc4380a527459a86b69 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:37:04 +0100 Subject: [PATCH 2/8] feat: implemented llama3 weight init tests --- ...arm_start_from_step_4_fsdp2_grad_accu.yaml | 4 +- tests/test_initialization_fsdpx.py | 95 ++++++++++++++++++- 2 files changed, 95 insertions(+), 4 deletions(-) diff --git a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml index 8815604fa..8af01a926 100644 --- a/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml +++ b/tests/end2end_tests/configs/gpt2_warm_start_from_step_4_fsdp2_grad_accu.yaml @@ -177,7 +177,7 @@ app_state_raw: component_key: app_state variant_key: raw config: - model: + model: instance_key: initialized_model pass_type: BY_REFERENCE optimizer: @@ -288,7 +288,7 @@ optimizer: eps: 1e-8 weight_decay: 1e-1 weight_decay_groups_excluded: [embedding, layernorm] - wrapped_model: + wrapped_model: instance_key: initialized_model pass_type: BY_REFERENCE diff --git a/tests/test_initialization_fsdpx.py b/tests/test_initialization_fsdpx.py index 78e4db158..4ab0021c7 100644 --- a/tests/test_initialization_fsdpx.py +++ b/tests/test_initialization_fsdpx.py @@ -18,8 +18,16 @@ from torch.distributed.fsdp import StateDictType from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType -from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType +from modalities.config.component_factory import ComponentFactory +from modalities.config.config import ProcessGroupBackendType, load_app_config_dict +from modalities.config.pydantic_if_types import ( + PydanticFSDP1ModuleType, + PydanticFSDP2ModuleType, + PydanticPytorchModuleType, +) +from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -493,3 +501,86 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]: model=model, optimizers=[], options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True) )[0] return model_state + + +class TestLlama3LikeInitialization: + @pytest.mark.parametrize("has_bias", [True, False]) + def test_llama3_like_initialization(self, has_bias: bool): + config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml" + n_layer = 4 + model = self._get_components(config_file_path=config_file_path, has_bias=has_bias) + self._test_wte(model=model) + self._test_lm_head(model=model) + + for block in model.transformer.h: + self._test_qkv_proj(gpt2_block=block, has_bias=has_bias) + self._test_c_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer) + self._test_swiglu_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer) + + def _get_components(self, config_file_path: Path, has_bias: bool) -> GPT2LLM: + config_dict = load_app_config_dict( + config_file_path=config_file_path, + ) + config_dict["model_raw"]["config"]["bias"] = has_bias + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + + class ComponentsInstantiationModel(BaseModel): + initialized_model: PydanticPytorchModuleType + + components: ComponentsInstantiationModel = component_factory.build_components( + config_dict=config_dict, components_model_type=ComponentsInstantiationModel + ) + return components.initialized_model + + def _test_wte(self, model: GPT2LLM): + assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-3) + assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) + + def _test_lm_head(self, model: GPT2LLM, n_emb: int): + assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_emb), abs=1e-3) + assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_emb) + assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_emb) + assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) + + def _test_qkv_proj(self, gpt2_block: GPT2Block, has_bias: bool): + layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn) + for layer in layers: + assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) + assert layer.weight.max().detach().cpu() <= 2 + assert layer.weight.min().detach().cpu() >= -2 + if has_bias: + assert layer.bias is not None + assert layer.bias.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) + assert layer.bias.max().detach().cpu() <= 2 + assert layer.bias.min().detach().cpu() >= -2 + + def _test_c_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int): + layer = gpt2_block.attn.c_proj + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) + assert layer.weight.max().detach().cpu() <= 2 + assert layer.weight.min().detach().cpu() >= -2 + + if has_bias: + assert layer.bias is not None + assert layer.bias.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) + assert layer.bias.max().detach().cpu() <= 2 + assert layer.bias.min().detach().cpu() >= -2 + + def _test_swiglu_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int): + layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2) + for layer in layers: + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) + assert layer.weight.max().detach().cpu() <= 2 + assert layer.weight.min().detach().cpu() >= -2 + + if has_bias: + # all zero bias + assert layer.bias is not None and torch.all(layer.bias == 0) + + layer = gpt2_block.mlp.W + assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) + assert layer.weight.max().detach().cpu() <= 2 + assert layer.weight.min().detach().cpu() >= -2 + if has_bias: + assert layer.bias is not None and torch.all(layer.bias == 0) From b7043313d801da9f96f34b1c778eea7e468a3afa Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 4 Mar 2026 18:42:55 +0100 Subject: [PATCH 3/8] feat: added Llama3-like initialization test config --- .../llama3_config_initalization.yaml | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/test_yaml_configs/llama3_config_initalization.yaml diff --git a/tests/test_yaml_configs/llama3_config_initalization.yaml b/tests/test_yaml_configs/llama3_config_initalization.yaml new file mode 100644 index 000000000..a658c64f0 --- /dev/null +++ b/tests/test_yaml_configs/llama3_config_initalization.yaml @@ -0,0 +1,59 @@ +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: llama3_like + config: + num_layers: ${model_raw.config.n_layer} + n_embd: ${model_raw.config.n_embd} + + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: "input_ids" + poe_type: NOPE + sequence_length: 128 + prediction_key: "logits" + vocab_size: 2048 # 2K vocab for testing + n_layer: 4 # 4 layers for testing + n_head_q: 32 + n_head_kv: 8 + ffn_hidden: 128 # 128 ffn hidden dim for testing + n_embd: 256 # 256 embedding dim for testing + dropout: 0.0 + bias: true + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} + seq_length_dim: -2 + base_freq: 500000 + attention_implementation: pytorch_flash + activation_type: swiglu + attention_norm_config: + norm_type: pytorch_rms_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1.0e-05 + ffn_norm_config: + norm_type: pytorch_rms_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1.0e-05 + lm_head_norm_config: + norm_type: pytorch_rms_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1.0e-05 + From 34a8621144655855851dd055e72f78334f9fa665 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:48:44 +0100 Subject: [PATCH 4/8] chore: improved test coverage for llama-like weight init --- tests/test_initialization_fsdpx.py | 38 +++++++++++++------ .../llama3_config_initalization.yaml | 1 + 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/test_initialization_fsdpx.py b/tests/test_initialization_fsdpx.py index 4ab0021c7..9fcbc15f6 100644 --- a/tests/test_initialization_fsdpx.py +++ b/tests/test_initialization_fsdpx.py @@ -508,11 +508,12 @@ class TestLlama3LikeInitialization: def test_llama3_like_initialization(self, has_bias: bool): config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml" n_layer = 4 + n_embd = 256 model = self._get_components(config_file_path=config_file_path, has_bias=has_bias) self._test_wte(model=model) - self._test_lm_head(model=model) + self._test_lm_head(model=model, n_embd=n_embd) - for block in model.transformer.h: + for _, block in model.transformer["h"].items(): self._test_qkv_proj(gpt2_block=block, has_bias=has_bias) self._test_c_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer) self._test_swiglu_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer) @@ -522,6 +523,7 @@ def _get_components(self, config_file_path: Path, has_bias: bool) -> GPT2LLM: config_file_path=config_file_path, ) config_dict["model_raw"]["config"]["bias"] = has_bias + config_dict["initialized_model"]["config"]["model_initializer"]["config"]["bias"] = has_bias registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) @@ -534,38 +536,45 @@ class ComponentsInstantiationModel(BaseModel): return components.initialized_model def _test_wte(self, model: GPT2LLM): - assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-3) - assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) + assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-2) + assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-2) - def _test_lm_head(self, model: GPT2LLM, n_emb: int): - assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_emb), abs=1e-3) - assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_emb) - assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_emb) + def _test_lm_head(self, model: GPT2LLM, n_embd: int): + assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_embd), abs=1e-3) + assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_embd) + assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_embd) assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) def _test_qkv_proj(self, gpt2_block: GPT2Block, has_bias: bool): layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn) for layer in layers: - assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) + assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-2) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) + if has_bias: assert layer.bias is not None - assert layer.bias.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) + assert layer.bias.std().detach().cpu() == pytest.approx(0.02, abs=1e-2) assert layer.bias.max().detach().cpu() <= 2 assert layer.bias.min().detach().cpu() >= -2 + else: + assert layer.bias is None def _test_c_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int): layer = gpt2_block.attn.c_proj - assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-2) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) if has_bias: assert layer.bias is not None assert layer.bias.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) assert layer.bias.max().detach().cpu() <= 2 assert layer.bias.min().detach().cpu() >= -2 + else: + assert layer.bias is None def _test_swiglu_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int): layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2) @@ -573,14 +582,21 @@ def _test_swiglu_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int) assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) if has_bias: # all zero bias assert layer.bias is not None and torch.all(layer.bias == 0) + else: + assert layer.bias is None layer = gpt2_block.mlp.W assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 + assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) + if has_bias: assert layer.bias is not None and torch.all(layer.bias == 0) + else: + assert layer.bias is None diff --git a/tests/test_yaml_configs/llama3_config_initalization.yaml b/tests/test_yaml_configs/llama3_config_initalization.yaml index a658c64f0..89ecae2cb 100644 --- a/tests/test_yaml_configs/llama3_config_initalization.yaml +++ b/tests/test_yaml_configs/llama3_config_initalization.yaml @@ -11,6 +11,7 @@ initialized_model: config: num_layers: ${model_raw.config.n_layer} n_embd: ${model_raw.config.n_embd} + bias: ${model_raw.config.bias} model_raw: From c7bcaaa42b24d6e7596d1b52d289417090b299cb Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Wed, 4 Mar 2026 19:50:03 +0100 Subject: [PATCH 5/8] refactor: we only init bias if set to true in config. added consistency check --- .../models/gpt2/llama3_like_initialization.py | 51 ++++++++++++------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index 71127ddfe..d0b7888e6 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -15,6 +15,7 @@ class Llama3InitializerConfig(BaseModel): num_layers: Annotated[int, Field(strict=True, gt=0)] n_embd: Annotated[int, Field(strict=True, gt=0)] + bias: bool class Llama3Initializer(ModelInitializationIF): @@ -22,13 +23,12 @@ class Llama3Initializer(ModelInitializationIF): Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. """ - def __init__(self, num_layers: int, n_embd: int) -> None: + def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None: super().__init__() self.regex_to_init = { # embedding weights r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1), - r"transformer\.wpe\.weight": partial(nn.init.normal_, mean=0.0, std=1), # lm head weights r"transformer\.lm_head\.weight": partial( nn.init.trunc_normal_, @@ -45,13 +45,6 @@ def __init__(self, num_layers: int, n_embd: int) -> None: a=-2, b=2, ), - r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial( - nn.init.trunc_normal_, - mean=0.0, - std=0.02, - a=-2, - b=2, - ), # final attention projection in attention block r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial( nn.init.trunc_normal_, @@ -60,13 +53,6 @@ def __init__(self, num_layers: int, n_embd: int) -> None: a=-2, b=2, ), - r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial( - nn.init.trunc_normal_, - mean=0.0, - std=0.02 / math.sqrt(2 * num_layers), - a=-2, - b=2, - ), # SwiGLU r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial( nn.init.trunc_normal_, @@ -75,7 +61,6 @@ def __init__(self, num_layers: int, n_embd: int) -> None: a=-2, b=2, ), - r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_, r"transformer\.h\.\w+\.mlp\.(V|W_2)\.weight": partial( nn.init.trunc_normal_, mean=0.0, @@ -83,14 +68,37 @@ def __init__(self, num_layers: int, n_embd: int) -> None: a=-2, b=2, ), - r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_, } + if bias: + self.regex_to_init = { + **self.regex_to_init, + **{ + r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02, + a=-2, + b=2, + ), + r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial( + nn.init.trunc_normal_, + mean=0.0, + std=0.02 / math.sqrt(2 * num_layers), + a=-2, + b=2, + ), + r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_, + r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_, + }, + } def initialize_in_place(self, model: nn.Module): self._init_by_fqn_regex(model, self.regex_to_init) @staticmethod def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]): + hits = {k: 0 for k in regex_to_init.keys()} + for parameter_name, p in model.named_parameters(): match_count = 0 for weight_regex in regex_to_init.keys(): @@ -98,9 +106,16 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]): init_fn = regex_to_init[weight_regex] init_fn(p) match_count += 1 + hits[weight_regex] += 1 if match_count == 0: logger.warning(f"Parameter {parameter_name} did not match any regex for initialization") elif match_count > 1: raise ValueError( f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed" ) + + for k, count in hits.items(): + if count == 0: + raise ValueError( + f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3." + ) From 2a171aaa257f8a6176667b6f659308e451d80817 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:41:29 +0100 Subject: [PATCH 6/8] refactor: changed the intialization to allow for depth_init --- .../models/gpt2/llama3_like_initialization.py | 131 ++++++++++-------- src/modalities/registry/components.py | 2 +- 2 files changed, 76 insertions(+), 57 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index d0b7888e6..b85d81766 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -1,7 +1,6 @@ import math import re -from functools import partial -from typing import Annotated +from typing import Annotated, Callable import torch.nn as nn from pydantic import BaseModel, Field @@ -15,7 +14,7 @@ class Llama3InitializerConfig(BaseModel): num_layers: Annotated[int, Field(strict=True, gt=0)] n_embd: Annotated[int, Field(strict=True, gt=0)] - bias: bool + depth_init: bool = True class Llama3Initializer(ModelInitializationIF): @@ -23,90 +22,110 @@ class Llama3Initializer(ModelInitializationIF): Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan. """ - def __init__(self, num_layers: int, n_embd: int, bias: bool) -> None: + def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None: super().__init__() + self.depth_init = depth_init self.regex_to_init = { # embedding weights - r"transformer\.wte\.weight": partial(nn.init.normal_, mean=0.0, std=1), + r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}), # lm head weights - r"transformer\.lm_head\.weight": partial( + r"transformer\.lm_head\.weight": ( nn.init.trunc_normal_, - mean=0.0, - std=1 / math.sqrt(n_embd), - a=-3 / math.sqrt(n_embd), - b=3 / math.sqrt(n_embd), + { + "mean": 0.0, + "std": 1 / math.sqrt(n_embd), + "a": -3 / math.sqrt(n_embd), + "b": 3 / math.sqrt(n_embd), + }, ), # qkv projections - r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": partial( + r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": ( nn.init.trunc_normal_, - mean=0.0, - std=0.02, - a=-2, - b=2, + { + "mean": 0.0, + "std": 0.02, + "a": -2, + "b": 2, + }, ), # final attention projection in attention block - r"transformer\.h\.\d+\.attn\.c_proj\.weight": partial( + r"transformer\.h\.\d+\.attn\.c_proj\.weight": ( nn.init.trunc_normal_, - mean=0.0, - std=0.02 / math.sqrt(2 * num_layers), - a=-2, - b=2, + { + "mean": 0.0, + "std": ( + (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) + if depth_init + else 0.02 / math.sqrt(2 * num_layers) + ), + "a": -2, + "b": 2, + }, ), # SwiGLU - r"transformer\.h\.\w+\.mlp\.(W)\.weight": partial( + r"transformer\.h\.\d+\.mlp\.(W)\.weight": ( nn.init.trunc_normal_, - mean=0.0, - std=0.02, - a=-2, - b=2, + { + "mean": 0.0, + "std": 0.02, + "a": -2, + "b": 2, + }, ), - r"transformer\.h\.\w+\.mlp\.(V|W_2)\.weight": partial( + r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight": ( nn.init.trunc_normal_, - mean=0.0, - std=0.02 / math.sqrt(2 * num_layers), - a=-2, - b=2, - ), - } - if bias: - self.regex_to_init = { - **self.regex_to_init, - **{ - r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.bias": partial( - nn.init.trunc_normal_, - mean=0.0, - std=0.02, - a=-2, - b=2, + { + "mean": 0.0, + "std": ( + (lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1))) + if depth_init + else 0.02 / math.sqrt(2 * num_layers) ), - r"transformer\.h\.\d+\.attn\.c_proj\.bias": partial( - nn.init.trunc_normal_, - mean=0.0, - std=0.02 / math.sqrt(2 * num_layers), - a=-2, - b=2, - ), - r"transformer\.h\.\w+\.mlp\.(W)\.bias": nn.init.zeros_, - r"transformer\.h\.\w+\.mlp\.(V|W_2)\.bias": nn.init.zeros_, + "a": -2, + "b": 2, }, - } + ), + } def initialize_in_place(self, model: nn.Module): - self._init_by_fqn_regex(model, self.regex_to_init) + self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init) @staticmethod - def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, partial]): + def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool): hits = {k: 0 for k in regex_to_init.keys()} for parameter_name, p in model.named_parameters(): + if parameter_name.endswith("bias"): + raise ValueError( + f"Bias initialization is not allowed for Llama3Initializer. Found bias parameter: {parameter_name}" + ) match_count = 0 for weight_regex in regex_to_init.keys(): if re.fullmatch(weight_regex, parameter_name): - init_fn = regex_to_init[weight_regex] - init_fn(p) + init_fn, arg_dict = regex_to_init[weight_regex] + if arg_dict["std"] is not None and callable(arg_dict["std"]): + if not depth_init: + raise ValueError( + "Dynamic std calculation is only allowed if depth_init " + f"is True, but got depth_init={depth_init}" + ) + + # If std is a function, call it with the layer_id + layer_id_match = re.search(r"transformer\.h\.(\d+)\.", parameter_name) + if layer_id_match is not None: + layer_id = int(layer_id_match.group(1)) + arg_dict = arg_dict.copy() # create a copy of the arg_dict to avoid mutating the original + arg_dict["std"] = arg_dict["std"](layer_id) + else: + raise ValueError( + f"Could not extract layer_id from parameter name {parameter_name} " + "for dynamic std calculation" + ) + init_fn(p, **arg_dict) match_count += 1 hits[weight_regex] += 1 + if match_count == 0: logger.warning(f"Parameter {parameter_name} did not match any regex for initialization") elif match_count > 1: diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 4c46bf1e8..67f100f0f 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -243,7 +243,7 @@ class ComponentEntity: ), ComponentEntity( "model_initialization", - "llama3_like", + "gpt2_llama3_like", Llama3Initializer, Llama3InitializerConfig, ), From 43a9d50f38a7e706a114c773f9f69f2e5bcc2d6e Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Thu, 5 Mar 2026 22:42:02 +0100 Subject: [PATCH 7/8] refactor: removed bias from llama3 init test and added depth_init tests --- tests/test_initialization_fsdpx.py | 67 +++++++------------ .../llama3_config_initalization.yaml | 6 +- 2 files changed, 27 insertions(+), 46 deletions(-) diff --git a/tests/test_initialization_fsdpx.py b/tests/test_initialization_fsdpx.py index 9fcbc15f6..f1eb4210a 100644 --- a/tests/test_initialization_fsdpx.py +++ b/tests/test_initialization_fsdpx.py @@ -504,26 +504,25 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]: class TestLlama3LikeInitialization: - @pytest.mark.parametrize("has_bias", [True, False]) - def test_llama3_like_initialization(self, has_bias: bool): + @pytest.mark.parametrize("depth_init", [True, False]) + def test_llama3_like_initialization(self, depth_init: bool): config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml" n_layer = 4 n_embd = 256 - model = self._get_components(config_file_path=config_file_path, has_bias=has_bias) + model = self._get_components(config_file_path=config_file_path, depth_init=depth_init) self._test_wte(model=model) self._test_lm_head(model=model, n_embd=n_embd) - for _, block in model.transformer["h"].items(): - self._test_qkv_proj(gpt2_block=block, has_bias=has_bias) - self._test_c_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer) - self._test_swiglu_proj(gpt2_block=block, has_bias=has_bias, n_layer=n_layer) + for layer_id, (_, block) in enumerate(model.transformer["h"].items()): + self._test_qkv_proj(gpt2_block=block) + self._test_c_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id) + self._test_swiglu_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id) - def _get_components(self, config_file_path: Path, has_bias: bool) -> GPT2LLM: + def _get_components(self, config_file_path: Path, depth_init: bool) -> GPT2LLM: config_dict = load_app_config_dict( config_file_path=config_file_path, ) - config_dict["model_raw"]["config"]["bias"] = has_bias - config_dict["initialized_model"]["config"]["model_initializer"]["config"]["bias"] = has_bias + config_dict["initialized_model"]["config"]["model_initializer"]["config"]["depth_init"] = depth_init registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) @@ -545,58 +544,40 @@ def _test_lm_head(self, model: GPT2LLM, n_embd: int): assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_embd) assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) - def _test_qkv_proj(self, gpt2_block: GPT2Block, has_bias: bool): + def _test_qkv_proj(self, gpt2_block: GPT2Block): layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn) for layer in layers: - assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-2) + assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) - if has_bias: - assert layer.bias is not None - assert layer.bias.std().detach().cpu() == pytest.approx(0.02, abs=1e-2) - assert layer.bias.max().detach().cpu() <= 2 - assert layer.bias.min().detach().cpu() >= -2 - else: - assert layer.bias is None - - def _test_c_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int): + def _test_c_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int): layer = gpt2_block.attn.c_proj - assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-2) + if depth_init: + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3) + else: + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) + assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) - if has_bias: - assert layer.bias is not None - assert layer.bias.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) - assert layer.bias.max().detach().cpu() <= 2 - assert layer.bias.min().detach().cpu() >= -2 - else: - assert layer.bias is None - - def _test_swiglu_proj(self, gpt2_block: GPT2Block, has_bias: bool, n_layer: int): + def _test_swiglu_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int): layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2) for layer in layers: - assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) + if depth_init: + assert layer.weight.std().detach().cpu() == pytest.approx( + 0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3 + ) + else: + assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) - if has_bias: - # all zero bias - assert layer.bias is not None and torch.all(layer.bias == 0) - else: - assert layer.bias is None - layer = gpt2_block.mlp.W assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3) assert layer.weight.max().detach().cpu() <= 2 assert layer.weight.min().detach().cpu() >= -2 assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3) - - if has_bias: - assert layer.bias is not None and torch.all(layer.bias == 0) - else: - assert layer.bias is None diff --git a/tests/test_yaml_configs/llama3_config_initalization.yaml b/tests/test_yaml_configs/llama3_config_initalization.yaml index 89ecae2cb..42172c6d0 100644 --- a/tests/test_yaml_configs/llama3_config_initalization.yaml +++ b/tests/test_yaml_configs/llama3_config_initalization.yaml @@ -7,11 +7,11 @@ initialized_model: pass_type: BY_REFERENCE model_initializer: component_key: model_initialization - variant_key: llama3_like + variant_key: gpt2_llama3_like config: num_layers: ${model_raw.config.n_layer} n_embd: ${model_raw.config.n_embd} - bias: ${model_raw.config.bias} + depth_init: False model_raw: @@ -31,7 +31,7 @@ model_raw: ffn_hidden: 128 # 128 ffn hidden dim for testing n_embd: 256 # 256 embedding dim for testing dropout: 0.0 - bias: true + bias: false attention_config: qkv_transforms: - type_hint: RotaryTransform From 2549e8b0aa8760dfa106d5e0ac136947948a5b1d Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 6 Mar 2026 10:41:00 +0100 Subject: [PATCH 8/8] chore: removed redundant consistency check --- src/modalities/models/gpt2/llama3_like_initialization.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/modalities/models/gpt2/llama3_like_initialization.py b/src/modalities/models/gpt2/llama3_like_initialization.py index b85d81766..6c0df29d2 100644 --- a/src/modalities/models/gpt2/llama3_like_initialization.py +++ b/src/modalities/models/gpt2/llama3_like_initialization.py @@ -105,12 +105,6 @@ def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable if re.fullmatch(weight_regex, parameter_name): init_fn, arg_dict = regex_to_init[weight_regex] if arg_dict["std"] is not None and callable(arg_dict["std"]): - if not depth_init: - raise ValueError( - "Dynamic std calculation is only allowed if depth_init " - f"is True, but got depth_init={depth_init}" - ) - # If std is a function, call it with the layer_id layer_id_match = re.search(r"transformer\.h\.(\d+)\.", parameter_name) if layer_id_match is not None: