Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions src/modalities/models/gpt2/llama3_like_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import math
import re
from typing import Annotated, Callable

import torch.nn as nn
from pydantic import BaseModel, Field

from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.utils.logger_utils import get_logger

logger = get_logger(name="llama3 initialization")


class Llama3InitializerConfig(BaseModel):
num_layers: Annotated[int, Field(strict=True, gt=0)]
n_embd: Annotated[int, Field(strict=True, gt=0)]
depth_init: bool = True


class Llama3Initializer(ModelInitializationIF):
"""
Follows weight initialization distributions and parameterization for Llama3 as described in TorchTitan.
"""

def __init__(self, num_layers: int, n_embd: int, depth_init: bool) -> None:
super().__init__()
self.depth_init = depth_init

self.regex_to_init = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need regex patterns for attention_norm, ffn_norm, and the final lm_head_normnai ?. Something like

r"transformer\.h\.\d+\.(attention_norm|ffn_norm)\.weight": nn.init.ones_,
r"transformer\.lm_head_norm\.weight": nn.init.ones_,

Copy link
Member Author

@le1nux le1nux Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module.reset_parameters()

we already call this here.

and due to recursion we also call it for the RMSNorm.
https://github.com/pytorch/pytorch/blob/65762ca85745d786ab6b20e9cb060242b51e872d/torch/nn/modules/normalization.py#L407

# embedding weights
r"transformer\.wte\.weight": (nn.init.normal_, {"mean": 0.0, "std": 1}),
# lm head weights
r"transformer\.lm_head\.weight": (
nn.init.trunc_normal_,
{
"mean": 0.0,
"std": 1 / math.sqrt(n_embd),
"a": -3 / math.sqrt(n_embd),
"b": 3 / math.sqrt(n_embd),
},
),
# qkv projections
r"transformer\.h\.\d+\.attn\.(q_attn|k_attn|v_attn)\.weight": (
nn.init.trunc_normal_,
{
"mean": 0.0,
"std": 0.02,
"a": -2,
"b": 2,
},
),
# final attention projection in attention block
r"transformer\.h\.\d+\.attn\.c_proj\.weight": (
nn.init.trunc_normal_,
{
"mean": 0.0,
"std": (
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
if depth_init
else 0.02 / math.sqrt(2 * num_layers)
),
"a": -2,
"b": 2,
},
),
# SwiGLU
r"transformer\.h\.\d+\.mlp\.(W)\.weight": (
nn.init.trunc_normal_,
{
"mean": 0.0,
"std": 0.02,
"a": -2,
"b": 2,
},
),
r"transformer\.h\.\d+\.mlp\.(V|W_2)\.weight": (
nn.init.trunc_normal_,
{
"mean": 0.0,
"std": (
(lambda layer_id: 0.02 / math.sqrt(2 * (layer_id + 1)))
if depth_init
else 0.02 / math.sqrt(2 * num_layers)
),
"a": -2,
"b": 2,
},
),
}

def initialize_in_place(self, model: nn.Module):
self._init_by_fqn_regex(model, self.regex_to_init, depth_init=self.depth_init)

@staticmethod
def _init_by_fqn_regex(model: nn.Module, regex_to_init: dict[str, tuple[Callable, dict]], depth_init: bool):
hits = {k: 0 for k in regex_to_init.keys()}

for parameter_name, p in model.named_parameters():
if parameter_name.endswith("bias"):
raise ValueError(
f"Bias initialization is not allowed for Llama3Initializer. Found bias parameter: {parameter_name}"
)
match_count = 0
for weight_regex in regex_to_init.keys():
if re.fullmatch(weight_regex, parameter_name):
init_fn, arg_dict = regex_to_init[weight_regex]
if arg_dict["std"] is not None and callable(arg_dict["std"]):
# If std is a function, call it with the layer_id
layer_id_match = re.search(r"transformer\.h\.(\d+)\.", parameter_name)
if layer_id_match is not None:
layer_id = int(layer_id_match.group(1))
arg_dict = arg_dict.copy() # create a copy of the arg_dict to avoid mutating the original
arg_dict["std"] = arg_dict["std"](layer_id)
else:
raise ValueError(
f"Could not extract layer_id from parameter name {parameter_name} "
"for dynamic std calculation"
)
init_fn(p, **arg_dict)
match_count += 1
hits[weight_regex] += 1

if match_count == 0:
logger.warning(f"Parameter {parameter_name} did not match any regex for initialization")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a flag which turns this into an error?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the norms are initialized within the model factory via reset_parametersthis would always throw an error.

elif match_count > 1:
raise ValueError(
f"Parameter {parameter_name} matched multiple regexes for initialization, which is not allowed"
)

for k, count in hits.items():
if count == 0:
raise ValueError(
f"Regex {k} did not match any FQNs. The model specification probably does not match LLama3."
)
7 changes: 7 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
)
from modalities.models.gpt2.collator import GPT2LLMCollateFn
from modalities.models.gpt2.gpt2_model import GPT2LLMConfig
from modalities.models.gpt2.llama3_like_initialization import Llama3Initializer, Llama3InitializerConfig
from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig
from modalities.models.model_factory import GPT2ModelFactory, ModelFactory
from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory
Expand Down Expand Up @@ -240,6 +241,12 @@ class ComponentEntity:
ComposedInitializationRoutines.get_composed_model_initializer,
ComposedModelInitializationConfig,
),
ComponentEntity(
"model_initialization",
"gpt2_llama3_like",
Llama3Initializer,
Llama3InitializerConfig,
),
# losses
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
# optimizers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ app_state_raw:
component_key: app_state
variant_key: raw
config:
model:
model:
instance_key: initialized_model
pass_type: BY_REFERENCE
optimizer:
Expand Down Expand Up @@ -288,7 +288,7 @@ optimizer:
eps: 1e-8
weight_decay: 1e-1
weight_decay_groups_excluded: [embedding, layernorm]
wrapped_model:
wrapped_model:
instance_key: initialized_model
pass_type: BY_REFERENCE

Expand Down
92 changes: 90 additions & 2 deletions tests/test_initialization_fsdpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@
from torch.distributed.fsdp import StateDictType

from modalities.__main__ import Main
from modalities.config.config import ProcessGroupBackendType
from modalities.config.pydantic_if_types import PydanticFSDP1ModuleType, PydanticFSDP2ModuleType
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.pydantic_if_types import (
PydanticFSDP1ModuleType,
PydanticFSDP2ModuleType,
PydanticPytorchModuleType,
)
from modalities.models.gpt2.gpt2_model import GPT2LLM, GPT2Block
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from tests.end2end_tests.custom_components import MultiProcessingCudaEnv


Expand Down Expand Up @@ -493,3 +501,83 @@ def _get_fdsp2_state_dict(model: FSDP2) -> dict[str, Any]:
model=model, optimizers=[], options=StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
)[0]
return model_state


class TestLlama3LikeInitialization:
@pytest.mark.parametrize("depth_init", [True, False])
def test_llama3_like_initialization(self, depth_init: bool):
config_file_path = Path(__file__).parent / "test_yaml_configs/llama3_config_initalization.yaml"
n_layer = 4
n_embd = 256
model = self._get_components(config_file_path=config_file_path, depth_init=depth_init)
self._test_wte(model=model)
self._test_lm_head(model=model, n_embd=n_embd)

for layer_id, (_, block) in enumerate(model.transformer["h"].items()):
self._test_qkv_proj(gpt2_block=block)
self._test_c_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id)
self._test_swiglu_proj(gpt2_block=block, depth_init=depth_init, n_layer=n_layer, layer_id=layer_id)

def _get_components(self, config_file_path: Path, depth_init: bool) -> GPT2LLM:
config_dict = load_app_config_dict(
config_file_path=config_file_path,
)
config_dict["initialized_model"]["config"]["model_initializer"]["config"]["depth_init"] = depth_init
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)

class ComponentsInstantiationModel(BaseModel):
initialized_model: PydanticPytorchModuleType

components: ComponentsInstantiationModel = component_factory.build_components(
config_dict=config_dict, components_model_type=ComponentsInstantiationModel
)
return components.initialized_model

def _test_wte(self, model: GPT2LLM):
assert model.transformer.wte.weight.std().detach().cpu() == pytest.approx(1, abs=1e-2)
assert model.transformer.wte.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-2)

def _test_lm_head(self, model: GPT2LLM, n_embd: int):
assert model.transformer.lm_head.weight.std().detach().cpu() == pytest.approx(1 / math.sqrt(n_embd), abs=1e-3)
assert model.transformer.lm_head.weight.max().detach().cpu() <= 3 / math.sqrt(n_embd)
assert model.transformer.lm_head.weight.min().detach().cpu() >= -3 / math.sqrt(n_embd)
assert model.transformer.lm_head.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

def _test_qkv_proj(self, gpt2_block: GPT2Block):
layers = (gpt2_block.attn.q_attn, gpt2_block.attn.k_attn, gpt2_block.attn.v_attn)
for layer in layers:
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

def _test_c_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int):
layer = gpt2_block.attn.c_proj
if depth_init:
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3)
else:
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)

assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

def _test_swiglu_proj(self, gpt2_block: GPT2Block, depth_init: bool, n_layer: int, layer_id: int):
layers = (gpt2_block.mlp.V, gpt2_block.mlp.W_2)
for layer in layers:
if depth_init:
assert layer.weight.std().detach().cpu() == pytest.approx(
0.02 / math.sqrt(2 * (layer_id + 1)), abs=1e-3
)
else:
assert layer.weight.std().detach().cpu() == pytest.approx(0.02 / math.sqrt(2 * n_layer), abs=1e-3)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)

layer = gpt2_block.mlp.W
assert layer.weight.std().detach().cpu() == pytest.approx(0.02, abs=1e-3)
assert layer.weight.max().detach().cpu() <= 2
assert layer.weight.min().detach().cpu() >= -2
assert layer.weight.mean().detach().cpu() == pytest.approx(0, abs=1e-3)
60 changes: 60 additions & 0 deletions tests/test_yaml_configs/llama3_config_initalization.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
initialized_model:
component_key: model
variant_key: model_initialized
config:
model:
instance_key: model_raw
pass_type: BY_REFERENCE
model_initializer:
component_key: model_initialization
variant_key: gpt2_llama3_like
config:
num_layers: ${model_raw.config.n_layer}
n_embd: ${model_raw.config.n_embd}
depth_init: False


model_raw:
component_key: model
variant_key: gpt2
config:
use_meta_device: true
use_weight_tying: false
sample_key: "input_ids"
poe_type: NOPE
sequence_length: 128
prediction_key: "logits"
vocab_size: 2048 # 2K vocab for testing
n_layer: 4 # 4 layers for testing
n_head_q: 32
n_head_kv: 8
ffn_hidden: 128 # 128 ffn hidden dim for testing
n_embd: 256 # 256 embedding dim for testing
dropout: 0.0
bias: false
attention_config:
qkv_transforms:
- type_hint: RotaryTransform
config:
n_embd: ${model_raw.config.n_embd}
n_head: ${model_raw.config.n_head_q}
seq_length_dim: -2
base_freq: 500000
attention_implementation: pytorch_flash
activation_type: swiglu
attention_norm_config:
norm_type: pytorch_rms_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1.0e-05
ffn_norm_config:
norm_type: pytorch_rms_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1.0e-05
lm_head_norm_config:
norm_type: pytorch_rms_norm
config:
normalized_shape: ${model_raw.config.n_embd}
eps: 1.0e-05