From 8fa1faa82c634bb26d3d0581c53cd282e352e75f Mon Sep 17 00:00:00 2001 From: Phineas1500 <41450967+Phineas1500@users.noreply.github.com> Date: Fri, 3 Apr 2026 18:50:52 -0400 Subject: [PATCH] Add Gemma 4 E2B/E4B support --- examples/models/BUCK | 1 + examples/models/gemma4/BUCK | 24 + examples/models/gemma4/README.md | 52 +++ examples/models/gemma4/__init__.py | 19 + examples/models/gemma4/config/e2b_config.json | 40 ++ examples/models/gemma4/config/e4b_config.json | 40 ++ examples/models/gemma4/convert_weights.py | 157 +++++++ examples/models/llama/attention.py | 337 +++++++++++++- examples/models/llama/export_llama_lib.py | 14 +- examples/models/llama/feed_forward.py | 10 +- examples/models/llama/llama_transformer.py | 207 ++++++++- examples/models/llama/model_args.py | 20 +- examples/models/llama/norm.py | 16 +- examples/models/llama/rope.py | 205 ++++++-- .../llama/source_transformation/sdpa.py | 18 +- examples/models/llama/tests/BUCK | 14 + .../models/llama/tests/test_gemma4_support.py | 438 ++++++++++++++++++ examples/models/model_factory.py | 7 +- examples/models/test/BUCK | 15 + examples/models/test/test_model_factory.py | 64 +++ exir/_serialize/_flatbuffer.py | 16 +- exir/_serialize/test/test_flatbuffer.py | 17 + extension/llm/export/config/llm_config.py | 2 + 23 files changed, 1658 insertions(+), 75 deletions(-) create mode 100644 examples/models/gemma4/BUCK create mode 100644 examples/models/gemma4/README.md create mode 100644 examples/models/gemma4/__init__.py create mode 100644 examples/models/gemma4/config/e2b_config.json create mode 100644 examples/models/gemma4/config/e4b_config.json create mode 100644 examples/models/gemma4/convert_weights.py create mode 100644 examples/models/llama/tests/test_gemma4_support.py create mode 100644 examples/models/test/BUCK create mode 100644 examples/models/test/test_model_factory.py diff --git a/examples/models/BUCK b/examples/models/BUCK index a2b6789a95e..f8cd2509428 100644 --- a/examples/models/BUCK +++ b/examples/models/BUCK @@ -26,6 +26,7 @@ fbcode_target(_kind = python_library, "//executorch/examples/models/toy_model:toy_model", # @manual "//executorch/examples/models/wav2letter:w2l_model", # @manual "//executorch/examples/models/llama3_2_vision:multimodal_lib", # @manual + "//executorch/examples/models/gemma4:gemma4", # @manual "//executorch/examples/models/gemma3:gemma3", # @manual "//executorch/examples/models/qwen2_5:qwen2_5", # @manual "//executorch/examples/models/qwen3:qwen3", # @manual diff --git a/examples/models/gemma4/BUCK b/examples/models/gemma4/BUCK new file mode 100644 index 00000000000..0056893c45c --- /dev/null +++ b/examples/models/gemma4/BUCK @@ -0,0 +1,24 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target") +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +oncall("executorch") + +fbcode_target(_kind = runtime.python_library, + name = "gemma4", + srcs = [ + "__init__.py", + "convert_weights.py", + ], + base_module = "executorch.examples.models.gemma4", + resources = { + "config/e2b_config.json": "config/e2b_config.json", + "config/e4b_config.json": "config/e4b_config.json", + }, + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:llama2_model", + "//executorch/examples/models:checkpoint", + "fbsource//third-party/pypi/safetensors:safetensors", + ], + visibility = ["PUBLIC"], +) diff --git a/examples/models/gemma4/README.md b/examples/models/gemma4/README.md new file mode 100644 index 00000000000..65e63f5b699 --- /dev/null +++ b/examples/models/gemma4/README.md @@ -0,0 +1,52 @@ +# Summary + +This example adds native ExecuTorch text-only export support for Google's Gemma 4 `E2B` and `E4B` models through the existing Llama-style export path. + +The current scope is the decoder-only text model. It does not include the multimodal image or audio towers from the full Gemma 4 release. + +# Supported models + +- `google/gemma-4-E2B` +- `google/gemma-4-E4B` + +# Exporting the model + +The exporter can download and convert the Hugging Face checkpoint automatically, or you can point it at a pre-converted local checkpoint. + +## Export Gemma 4 E2B + +```bash +PYTHONPATH=.:.. python examples/models/llama/export_llama.py \ + --model gemma4_e2b \ + --params examples/models/gemma4/config/e2b_config.json \ + --dtype-override bf16 \ + --output-dir ./gemma4_e2b_out +``` + +## Export Gemma 4 E4B + +```bash +PYTHONPATH=.:.. python examples/models/llama/export_llama.py \ + --model gemma4_e4b \ + --params examples/models/gemma4/config/e4b_config.json \ + --dtype-override bf16 \ + --output-dir ./gemma4_e4b_out +``` + +## Export with KV cache and custom SDPA + +```bash +PYTHONPATH=.:.. python examples/models/llama/export_llama.py \ + --model gemma4_e4b \ + --params examples/models/gemma4/config/e4b_config.json \ + --dtype-override bf16 \ + --use_kv_cache \ + --use_sdpa_with_kv_cache \ + --disable_dynamic_shape \ + --output-dir ./gemma4_e4b_kv_out +``` + +# Notes + +- The Gemma 4 exporter uses the native ExecuTorch text runtime and the local `convert_weights.py` checkpoint conversion flow. +- In local source-tree workflows, `PYTHONPATH=.:..` makes both `examples.*` and `executorch.*` imports work consistently. diff --git a/examples/models/gemma4/__init__.py b/examples/models/gemma4/__init__.py new file mode 100644 index 00000000000..805646ee29d --- /dev/null +++ b/examples/models/gemma4/__init__.py @@ -0,0 +1,19 @@ +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.examples.models.gemma4.convert_weights import convert_weights + +__all__ = ["Gemma4Model", "convert_weights"] + + +def __getattr__(name): + if name == "Gemma4Model": + from executorch.examples.models.llama.model import Llama2Model + + class Gemma4Model(Llama2Model): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + globals()["Gemma4Model"] = Gemma4Model + return Gemma4Model + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/examples/models/gemma4/config/e2b_config.json b/examples/models/gemma4/config/e2b_config.json new file mode 100644 index 00000000000..dd16e38fd32 --- /dev/null +++ b/examples/models/gemma4/config/e2b_config.json @@ -0,0 +1,40 @@ +{ + "dim": 1536, + "hidden_dim": 6144, + "n_layers": 35, + "n_heads": 8, + "n_kv_heads": 1, + "head_dim": 256, + "global_head_dim": 512, + "vocab_size": 262144, + "vocab_size_per_layer_input": 262144, + "hidden_size_per_layer_input": 256, + "num_kv_shared_layers": 20, + "use_double_wide_mlp": true, + "act_fn": "gelu_pytorch_tanh", + "norm_eps": 1e-06, + "post_attention_norm": true, + "post_ffn_norm": true, + "apply_embedding": true, + "embedding_scale_factor": 39.191835884530846, + "use_hf_rope": true, + "attention_qkv_bias": false, + "attention_type": "gemma4_mha", + "attention_multiplier": 1.0, + "final_logit_softcapping": 30.0, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "sliding_window": 512, + "layer_types": ["sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention"], + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional" + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default" + } + } +} diff --git a/examples/models/gemma4/config/e4b_config.json b/examples/models/gemma4/config/e4b_config.json new file mode 100644 index 00000000000..74942146c33 --- /dev/null +++ b/examples/models/gemma4/config/e4b_config.json @@ -0,0 +1,40 @@ +{ + "dim": 2560, + "hidden_dim": 10240, + "n_layers": 42, + "n_heads": 8, + "n_kv_heads": 2, + "head_dim": 256, + "global_head_dim": 512, + "vocab_size": 262144, + "vocab_size_per_layer_input": 262144, + "hidden_size_per_layer_input": 256, + "num_kv_shared_layers": 18, + "use_double_wide_mlp": false, + "act_fn": "gelu_pytorch_tanh", + "norm_eps": 1e-06, + "post_attention_norm": true, + "post_ffn_norm": true, + "apply_embedding": true, + "embedding_scale_factor": 50.59644256269407, + "use_hf_rope": true, + "attention_qkv_bias": false, + "attention_type": "gemma4_mha", + "attention_multiplier": 1.0, + "final_logit_softcapping": 30.0, + "use_qk_norm": true, + "qk_norm_before_rope": true, + "sliding_window": 512, + "layer_types": ["sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention", "full_attention"], + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional" + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default" + } + } +} diff --git a/examples/models/gemma4/convert_weights.py b/examples/models/gemma4/convert_weights.py new file mode 100644 index 00000000000..d84a664b7b3 --- /dev/null +++ b/examples/models/gemma4/convert_weights.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +from typing import Dict + +import torch +from executorch.examples.models.checkpoint import ( + get_mapped_key, + load_checkpoint_from_pytorch_model, +) + + +_GEMMA4_TO_EXECUTORCH = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.embed_tokens_per_layer.weight": "embed_tokens_per_layer.weight", + "model.per_layer_model_projection.weight": "per_layer_model_projection.weight", + "model.per_layer_projection_norm.weight": "per_layer_projection_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.q_norm.weight": "layers.{}.attention.q_norm_fn.weight", + "model.layers.{}.self_attn.k_norm.weight": "layers.{}.attention.k_norm_fn.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.post_attention_norm.weight", + "model.layers.{}.pre_feedforward_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.layers.{}.post_feedforward_layernorm.weight": "layers.{}.post_ffn_norm.weight", + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.layer_scalar": "layers.{}.layer_scalar", + "model.layers.{}.per_layer_input_gate.weight": "layers.{}.per_layer_input_gate.weight", + "model.layers.{}.per_layer_projection.weight": "layers.{}.per_layer_projection.weight", + "model.layers.{}.post_per_layer_input_norm.weight": "layers.{}.post_per_layer_input_norm.weight", +} + + +_IGNORED_UNMAPPED_SUFFIXES = ( + "rotary_emb.inv_freq", + "self_attn.v_norm.weight", +) + + +def _load_checkpoint_from_safetensors(input_dir: str) -> Dict: + from safetensors.torch import load_file + + index_path = os.path.join(input_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + index = json.load(f) + weight_map = index["weight_map"] + checkpoint_shards = sorted(set(weight_map.values())) + + merged_state_dict = {} + shard_to_weight_names = {} + for weight_name, shard in weight_map.items(): + shard_to_weight_names.setdefault(shard, []).append(weight_name) + + for shard in checkpoint_shards: + shard_weights = load_file(os.path.join(input_dir, shard)) + for weight_name in shard_to_weight_names[shard]: + merged_state_dict[weight_name] = shard_weights[weight_name] + return merged_state_dict + + model_path = os.path.join(input_dir, "model.safetensors") + if os.path.exists(model_path): + return load_file(model_path) + + raise FileNotFoundError(f"Could not find safetensors checkpoint in {input_dir}") + + +def load_checkpoint(input_dir: str) -> Dict: + try: + print("Loading checkpoint from pytorch_model directory") + return load_checkpoint_from_pytorch_model(input_dir) + except FileNotFoundError: + print( + "Could not find pytorch_model checkpoints in directory, trying safetensors" + ) + + try: + print("Loading checkpoint from safetensors directory") + return _load_checkpoint_from_safetensors(input_dir) + except FileNotFoundError: + pass + + raise FileNotFoundError(f"Could not find checkpoint in {input_dir}") + + +def gemma4_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + converted_state_dict = {} + + for key, value in state_dict.items(): + normalized_key = key + if normalized_key.startswith("model.language_model."): + normalized_key = normalized_key.replace("model.language_model.", "model.", 1) + + if not normalized_key.startswith( + ( + "model.layers.", + "model.embed_tokens.", + "model.embed_tokens_per_layer.", + "model.per_layer_model_projection.", + "model.per_layer_projection_norm.", + "model.norm.", + "lm_head.", + ) + ): + continue + + try: + new_key = get_mapped_key(normalized_key, _GEMMA4_TO_EXECUTORCH) + except Exception as err: + if normalized_key.endswith(_IGNORED_UNMAPPED_SUFFIXES): + continue + raise ValueError( + f"Unexpected checkpoint key not mapped for Gemma4 export: {key}" + ) from err + converted_state_dict[new_key] = value + + if "output.weight" not in converted_state_dict: + converted_state_dict["output.weight"] = converted_state_dict[ + "tok_embeddings.weight" + ] + + return converted_state_dict + + +def convert_weights(input_dir: str, output_file: str) -> None: + print("Loading checkpoint...") + state_dict = load_checkpoint(input_dir) + print("Converting checkpoint...") + state_dict = gemma4_to_meta(state_dict) + print("Saving checkpoint...") + torch.save(state_dict, output_file) + print("Done.") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert Gemma4 weights to ExecuTorch meta format." + ) + parser.add_argument( + "input_dir", + type=str, + help="Path to directory containing safetensor or PyTorch checkpoint files.", + ) + parser.add_argument("output", type=str, help="Path to the output checkpoint") + + args = parser.parse_args() + convert_weights(args.input_dir, args.output) + + +if __name__ == "__main__": + main() diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 1cb7ba866b7..c38ad8a88fa 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -24,7 +24,7 @@ class ForwardOptions(TypedDict, total=False): # YOCO (You Only Cache Once): shared K/V from a donor layer. # When provided, the attention layer skips its own K/V projection # and reuses the donor's K/V instead. - shared_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] + shared_kv: Optional[Any] class Attention(nn.Module, ABC): @@ -119,12 +119,14 @@ def __init__( head_dim: int, n_rep: int, max_context_len: int, + scale: Optional[float] = None, ): super().__init__() self.dim = dim self.head_dim = head_dim self.n_rep = n_rep self.max_context_len = max_context_len + self.scale = scale def forward( self, @@ -141,7 +143,14 @@ def forward( # can natively support GQA now. But needs enable_gqa=True k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + scale=self.scale, + ) return y.transpose(1, 2).reshape(bsz, seqlen, self.dim) @@ -156,6 +165,20 @@ def _create_causal_mask_for_ring_buffer( return attn_mask +def _create_sliding_window_mask( + seq_len: int, + key_len: int, + window_size: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + q_pos = torch.arange(seq_len, device=device, dtype=torch.long).view(-1, 1) + k_pos = torch.arange(key_len, device=device, dtype=torch.long).view(1, -1) + delta = q_pos - k_pos + attn_mask = (delta >= 0) & (delta < window_size) + return torch.where(attn_mask, 0.0, float("-inf")).to(dtype=dtype) + + class CacheUpdateStrategy(Enum): RING_BUFFER = "RingBuffer" INVALID = "Invalid" @@ -375,6 +398,7 @@ def __init__( self.qk_norm_before_rope = args.qk_norm_before_rope self.use_q_gate = args.use_q_gate self.enable_dynamic_shape = args.enable_dynamic_shape + self.attention_scale = args.attention_multiplier q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1) # YOCO: Determine if this is a KV shared layer (receives shared KV from donor). @@ -412,6 +436,7 @@ def __init__( head_dim=self.head_dim, n_rep=self.n_rep, max_context_len=self.max_context_len, + scale=self.attention_scale, ) def _init_norms(self, args: ModelArgs) -> None: @@ -592,7 +617,14 @@ def forward( mask = self.mask[:seqlen, :seqlen] - output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + output = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + scale=self.attention_scale, + ) output = output.transpose(1, 2).reshape(bsz, seqlen, -1) if gate is not None: @@ -603,6 +635,305 @@ def forward( return output, None +@register_attention("gemma4_mha") +class AttentionGemma4MHA(Attention): + def __init__( + self, + args: ModelArgs, + layer_id: int, + rope: Rope, + **_kwargs: Any, + ): + super().__init__() + self.use_kv_cache = args.use_kv_cache + self.n_heads = args.n_heads + self.layer_id = layer_id + self.layer_type = ( + args.layer_types[layer_id] if args.layer_types is not None else None + ) + self.is_sliding = self.layer_type == "sliding_attention" + self.sliding_window = args.sliding_window if self.is_sliding else None + self.use_alternative_attention = args.attention_k_eq_v and not self.is_sliding + self.n_kv_heads = ( + args.num_global_key_value_heads + if self.use_alternative_attention + else (self.n_heads if args.n_kv_heads is None else args.n_kv_heads) + ) + assert self.n_heads % self.n_kv_heads == 0 + model_parallel_size = 1 + self.n_local_heads = self.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = ( + args.global_head_dim + if not self.is_sliding and args.global_head_dim is not None + else args.head_dim + ) + self.max_batch_size = args.max_batch_size + self.max_context_len = args.max_context_len + self.dim = args.dim + self.attention_qkv_bias = args.attention_qkv_bias + self.enable_dynamic_shape = args.enable_dynamic_shape + self.attention_scale = ( + 1.0 if args.attention_multiplier is None else args.attention_multiplier + ) + + num_kv_shared = args.num_kv_shared_layers + first_shared = args.n_layers - num_kv_shared + self.is_kv_shared_layer = layer_id >= first_shared > 0 + self.has_kv_weights = not self.is_kv_shared_layer + self.store_full_length_kv = False + if num_kv_shared > 0 and first_shared > 0 and not self.is_kv_shared_layer: + if args.layer_types is None: + self.store_full_length_kv = layer_id == first_shared - 1 + else: + target_type = args.layer_types[layer_id] + for donor_idx in range(first_shared - 1, -1, -1): + if args.layer_types[donor_idx] == target_type: + self.store_full_length_kv = donor_idx == layer_id + break + + q_out_dim = self.n_heads * self.head_dim + self.wq = _create_projection( + args, args.dim, q_out_dim, ("q_proj",), bias=self.attention_qkv_bias + ) + if self.has_kv_weights: + kv_dim = self.n_kv_heads * self.head_dim + self.wk = _create_projection( + args, args.dim, kv_dim, ("k_proj",), bias=self.attention_qkv_bias + ) + self.wv = ( + _create_projection( + args, + args.dim, + kv_dim, + ("v_proj",), + bias=self.attention_qkv_bias, + ) + if not self.use_alternative_attention + else None + ) + else: + self.wk = None + self.wv = None + self.wo = _create_projection( + args, + self.n_heads * self.head_dim, + args.dim, + ("output_proj", "o_proj"), + bias=False, + ) + + self.q_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.k_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + self.v_norm_fn = RMSNorm( + self.head_dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + with_scale=False, + ) + + self.rope = rope + causal_mask = torch.tril( + torch.ones( + self.max_context_len, + self.max_context_len, + dtype=torch.bool, + device="cpu", + ) + ) + self.register_buffer("mask", causal_mask, persistent=False) + + if self.use_kv_cache: + if self.has_kv_weights: + if self.is_sliding and self.sliding_window is not None: + self.kv_cache = RingKVCache( + args.max_batch_size, + self.sliding_window, + self.n_kv_heads, + self.head_dim, + args.enable_dynamic_shape, + ) + else: + self.kv_cache = KVCache( + args.max_batch_size, + args.max_context_len, + self.n_kv_heads, + self.head_dim, + args.enable_dynamic_shape, + ) + else: + self.kv_cache = None + self.SDPA = SDPA( + dim=self.n_local_heads * self.head_dim, + head_dim=self.head_dim, + n_rep=self.n_rep, + max_context_len=self.max_context_len, + scale=self.attention_scale, + ) + + def _unpack_shared_kv( + self, shared_kv: Any + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[int]]: + if isinstance(shared_kv, dict): + return ( + shared_kv["k"], + shared_kv["v"], + shared_kv.get("cache_positions"), + shared_kv.get("window_size"), + ) + k, v = shared_kv + return k, v, None, self.sliding_window + + def _build_kv_payload( + self, k: torch.Tensor, v: torch.Tensor + ) -> Any: + if ( + self.use_kv_cache + and self.is_sliding + and self.kv_cache is not None + and getattr(self.kv_cache, "is_ring_buffer", False) + ): + return { + "k": k, + "v": v, + "cache_positions": self.kv_cache.cache_positions_manager.cache_positions.clone(), + "window_size": self.sliding_window, + } + return (k, v) + + def _build_sliding_mask( + self, q: torch.Tensor, k: torch.Tensor, seqlen: int + ) -> torch.Tensor: + assert self.sliding_window is not None + return _create_sliding_window_mask( + seqlen, + k.shape[2], + self.sliding_window, + q.device, + q.dtype, + ) + + def forward( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + **kwargs: ForwardOptions, + ) -> Tuple[torch.Tensor, Optional[Any]]: + del freqs_cos, freqs_sin + input_pos = kwargs.get("input_pos") + shared_kv = kwargs.get("shared_kv") + bsz, seqlen, _ = x.shape + freqs_cos, freqs_sin = self.rope.get_freqs_for_layer_type( + self.layer_type, input_pos, seqlen + ) + + q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim) + q = self.q_norm_fn(q) + q = self.rope.forward_to_tensor(q, freqs_cos, freqs_sin) + q = q.transpose(1, 2) + + shared_cache_positions = None + shared_window_size = self.sliding_window + if shared_kv is not None: + k, v, shared_cache_positions, shared_window_size = self._unpack_shared_kv( + shared_kv + ) + else: + assert self.wk is not None, ( + "wk projection is required when shared_kv is not provided. " + "This Gemma4 layer expects shared KV from an earlier donor layer." + ) + k = self.wk(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + value_states = ( + self.wv(x).view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) + if self.wv is not None + else k + ) + k = self.k_norm_fn(k) + k = self.rope.forward_to_tensor(k, freqs_cos, freqs_sin) + k = k.transpose(1, 2) + v = self.v_norm_fn(value_states).transpose(1, 2) + + if self.use_kv_cache: + assert input_pos is not None + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + attn_mask = self.mask.narrow(0, start_pos, seq_length) + else: + attn_mask = self.mask[input_pos] + + if shared_kv is None: + assert self.kv_cache is not None, ( + "kv_cache is required when shared_kv is not provided. " + "This Gemma4 layer expects shared KV from an earlier donor layer." + ) + k, v = self.kv_cache.update(input_pos, k, v) + if getattr(self.kv_cache, "is_ring_buffer", False): + attn_mask = self.kv_cache.create_causal_mask_for_ring_buffer( + input_pos[0].item(), seqlen + ) + elif ( + self.is_sliding + and shared_cache_positions is not None + and shared_window_size is not None + ): + attn_mask = _create_causal_mask_for_ring_buffer( + shared_cache_positions.to(device=q.device), + shared_window_size, + input_pos[0].item(), + seqlen, + ).to(dtype=q.dtype) + + output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask) + update = None + if shared_kv is None and self.store_full_length_kv: + update = {"kv_to_share": {self.layer_id: self._build_kv_payload(k, v)}} + return self.wo(output), update + + k_to_share = k + v_to_share = v + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + if self.is_sliding and self.sliding_window is not None: + mask = self._build_sliding_mask(q, k, seqlen) + else: + mask = self.mask[:seqlen, : k.shape[2]].to(device=q.device, dtype=q.dtype) + + output = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + scale=self.attention_scale, + ) + output = output.transpose(1, 2).reshape(bsz, seqlen, -1) + output = self.wo(output) + + update = None + if shared_kv is None and self.store_full_length_kv: + update = { + "kv_to_share": { + self.layer_id: self._build_kv_payload(k_to_share, v_to_share) + } + } + return output, update + + def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) return x * inv_norm diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 7d6371add44..6e4ef81e411 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -110,6 +110,8 @@ "qwen3_5_0_8b", "qwen3_5_2b", "qwen3_5_4b", + "gemma4_e2b", + "gemma4_e4b", "phi_4_mini", "smollm2", "lfm2_350m", # hybrid @@ -130,6 +132,8 @@ "qwen3_5_0_8b": "Qwen/Qwen3.5-0.8B", "qwen3_5_2b": "Qwen/Qwen3.5-2B", "qwen3_5_4b": "Qwen/Qwen3.5-4B", + "gemma4_e2b": "google/gemma-4-E2B", + "gemma4_e4b": "google/gemma-4-E4B", "lfm2_350m": "LiquidAI/LFM2-350M", "lfm2_700m": "LiquidAI/LFM2-700M", "lfm2_1_2b": "LiquidAI/LFM2-1.2B", @@ -655,6 +659,8 @@ def export_llama( # noqa: C901 from executorch.examples.models.qwen3_5 import convert_weights elif model_name.startswith("qwen3"): from executorch.examples.models.qwen3 import convert_weights + elif model_name.startswith("gemma4"): + from executorch.examples.models.gemma4 import convert_weights elif model_name == "phi_4_mini": from executorch.examples.models.phi_4_mini import convert_weights elif model_name == "smollm2": @@ -1520,8 +1526,12 @@ def _load_llama_model(llm_config: LlmConfig) -> "LLMEdgeManager": modelname = llm_config.base.model_class.value if modelname in EXECUTORCH_DEFINED_MODELS: - module_name = "llama" - model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. + if modelname.startswith("gemma4"): + module_name = "gemma4" + model_class_name = "Gemma4Model" + else: + module_name = "llama" + model_class_name = "Llama2Model" # TODO: Change to "LlamaModel" in examples/models/llama/model.py. elif modelname in TORCHTUNE_DEFINED_MODELS: if modelname == "llama3_2_vision": module_name = "llama3_2_vision" diff --git a/examples/models/llama/feed_forward.py b/examples/models/llama/feed_forward.py index 786567273c0..3ce3d1409f5 100644 --- a/examples/models/llama/feed_forward.py +++ b/examples/models/llama/feed_forward.py @@ -1,3 +1,5 @@ +from typing import Callable + import torch.nn.functional as F from executorch.examples.models.llama.lora import LoRALinear @@ -6,14 +8,15 @@ class FeedForward(nn.Module): - def __init__(self, dim: int, hidden_dim: int): + def __init__(self, dim: int, hidden_dim: int, act_fn: Callable = F.silu): super().__init__() self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) + self.act_fn = act_fn def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) class LoRAFeedForward(nn.Module): @@ -63,6 +66,7 @@ def __init__(self, dim: int, hidden_dim: int, args: ModelArgs): if "up_proj" in args.target_modules else nn.Linear(dim, hidden_dim, bias=False) ) + self.act_fn = args.act_fn.get_function() def forward(self, x): - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(self.act_fn(self.w1(x)) * self.w3(x)) diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index cb87995aaf6..a577d3195ea 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -51,6 +51,28 @@ def _is_kv_shared_layer( return layer_idx >= first_shared and first_shared > 0 +def _get_kv_donor_layer_idx( + layer_idx: int, + n_layers: int, + num_kv_shared_layers: int, + layer_types: Optional[list] = None, +) -> Optional[int]: + if not _is_kv_shared_layer(layer_idx, n_layers, num_kv_shared_layers): + return None + + first_shared = n_layers - num_kv_shared_layers + if first_shared <= 0: + return None + if layer_types is None: + return first_shared - 1 + + target_type = layer_types[layer_idx] + for donor_idx in range(first_shared - 1, -1, -1): + if layer_types[donor_idx] == target_type: + return donor_idx + return None + + class ConditionalFeedForward(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -98,7 +120,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs, attention: Attention): + def __init__(self, args: ModelArgs, attention: Attention, layer_id: int): """ Transformer block with support for pre-norm and post-norm. Args: @@ -113,10 +135,16 @@ def __init__(self, args: ModelArgs, attention: Attention): self.dim = args.dim self.head_dim = args.head_dim self.attention = attention + self.layer_id = layer_id + self.hidden_size_per_layer_input = args.hidden_size_per_layer_input assert ( args.hidden_dim is not None ), "`hidden_dim` must be set in ModelArgs to construct a TransformerBlock." + ffn_hidden_dim = args.hidden_dim + if _is_kv_shared_layer(layer_id, args.n_layers, args.num_kv_shared_layers): + if args.use_double_wide_mlp: + ffn_hidden_dim *= 2 if args.moe: self.block_sparse_moe = MOEFeedForward(args) elif args.target_modules is not None and ( @@ -124,9 +152,13 @@ def __init__(self, args: ModelArgs, attention: Attention): or "up_proj" in args.target_modules or "gate_proj" in args.target_modules ): - self.feed_forward = LoRAFeedForward(args.dim, args.hidden_dim, args) + self.feed_forward = LoRAFeedForward(args.dim, ffn_hidden_dim, args) else: - self.feed_forward = FeedForward(dim=args.dim, hidden_dim=args.hidden_dim) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=ffn_hidden_dim, + act_fn=args.act_fn.get_function(), + ) if isinstance(self.attention, AttentionSkip): self.attention_norm = nn.Identity() @@ -141,6 +173,39 @@ def __init__(self, args: ModelArgs, attention: Attention): eps=args.norm_eps, add_unit_offset=args.rms_norm_add_unit_offset, ) + self.post_attention_norm = ( + RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + if args.post_attention_norm and not isinstance(self.attention, AttentionSkip) + else None + ) + self.post_ffn_norm = ( + RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + if args.post_ffn_norm + else None + ) + if self.hidden_size_per_layer_input > 0: + self.per_layer_act = args.act_fn.get_function() + self.per_layer_input_gate = nn.Linear( + args.dim, self.hidden_size_per_layer_input, bias=False + ) + self.per_layer_projection = nn.Linear( + self.hidden_size_per_layer_input, args.dim, bias=False + ) + self.post_per_layer_input_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + add_unit_offset=args.rms_norm_add_unit_offset, + ) + if args.attention_type == "gemma4_mha": + self.register_buffer("layer_scalar", torch.ones(1)) @classmethod def from_type(cls, layer_id, args, rope) -> "TransformerBlock": @@ -158,19 +223,42 @@ def from_type(cls, layer_id, args, rope) -> "TransformerBlock": ) cls = ATTENTION_REGISTRY[args.attention_type] attention = cls(args, layer_id, rope, **args.attention_kwargs) - return TransformerBlock(args, attention) + return TransformerBlock(args, attention, layer_id) - def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x: 1xN + def forward( + self, + x, + freqs_cos, + freqs_sin, + attn_options: ForwardOptions, + per_layer_input: Optional[torch.Tensor] = None, + ): # x: 1xN h, attn_options_update = self.attention( self.attention_norm(x), freqs_cos, freqs_sin, **attn_options ) if not isinstance(self.attention, AttentionSkip): + if self.post_attention_norm is not None: + h = self.post_attention_norm(h) h = x + h if hasattr(self, "block_sparse_moe"): - out = h + self.block_sparse_moe(self.ffn_norm(h)) + ffn_out = self.block_sparse_moe(self.ffn_norm(h)) else: - out = h + self.feed_forward(self.ffn_norm(h)) + ffn_out = self.feed_forward(self.ffn_norm(h)) + if self.post_ffn_norm is not None: + ffn_out = self.post_ffn_norm(ffn_out) + out = h + ffn_out + + if per_layer_input is not None and self.hidden_size_per_layer_input > 0: + residual = out + out = self.per_layer_input_gate(out) + out = self.per_layer_act(out) + out = out * per_layer_input + out = self.per_layer_projection(out) + out = self.post_per_layer_input_norm(out) + out = residual + out + if hasattr(self, "layer_scalar"): + out = out * self.layer_scalar return out, attn_options_update @@ -190,12 +278,38 @@ def __init__(self, params: ModelArgs, layers: nn.ModuleList, rope: Rope): self.n_layers = params.n_layers self.apply_embedding = params.apply_embedding self.apply_output = params.apply_output + self.embedding_scale_factor = params.embedding_scale_factor + self.final_logit_softcapping = params.final_logit_softcapping + self.hidden_size_per_layer_input = params.hidden_size_per_layer_input self.tok_embeddings = ( nn.Embedding(params.vocab_size, params.dim) if self.apply_embedding else None ) + if self.hidden_size_per_layer_input > 0: + assert params.vocab_size_per_layer_input is not None + self.embed_tokens_per_layer = nn.Embedding( + params.vocab_size_per_layer_input, + params.n_layers * self.hidden_size_per_layer_input, + ) + self.per_layer_embedding_scale_factor = ( + params.hidden_size_per_layer_input**0.5 + ) + self.per_layer_input_scale = 2.0**-0.5 + self.per_layer_model_projection = nn.Linear( + params.dim, + params.n_layers * self.hidden_size_per_layer_input, + bias=False, + ) + self.per_layer_model_projection_scale = params.dim**-0.5 + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, + eps=params.norm_eps, + add_unit_offset=params.rms_norm_add_unit_offset, + ) + else: + self.embed_tokens_per_layer = None self.layers = layers self.rope = rope self.norm = RMSNorm( @@ -224,32 +338,53 @@ def _forward_layers( freqs_sin: torch.Tensor, attn_options_: Dict, seqlen: int, + per_layer_inputs: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[Any]]: """Run transformer layers with YOCO KV sharing support.""" attn_options_update = None - shared_kv: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {} + shared_kv: Dict[int, Any] = {} is_prefill = seqlen > 1 for layer_idx, layer in enumerate(self.layers): is_shared = _is_kv_shared_layer( layer_idx, self.n_layers, self.num_kv_shared_layers ) - if is_shared and is_prefill: + if ( + is_shared + and is_prefill + and self.params.attention_type != "gemma4_mha" + ): continue if is_shared: - donor_idx = self.n_layers - self.num_kv_shared_layers - 1 + donor_idx = _get_kv_donor_layer_idx( + layer_idx, + self.n_layers, + self.num_kv_shared_layers, + self.params.layer_types, + ) if donor_idx in shared_kv: attn_options_["shared_kv"] = shared_kv[donor_idx] - h, attn_options_update = layer(h, freqs_cos, freqs_sin, attn_options_) + layer_per_input = ( + per_layer_inputs[:, :, layer_idx, :] + if per_layer_inputs is not None + else None + ) + h, attn_options_update = layer( + h, + freqs_cos, + freqs_sin, + attn_options_, + per_layer_input=layer_per_input, + ) - if _is_kv_donor_layer(layer_idx, self.n_layers, self.num_kv_shared_layers): - assert ( - attn_options_update is not None - and "kv_to_share" in attn_options_update - ), f"Donor layer {layer_idx} must produce kv_to_share" - shared_kv[layer_idx] = attn_options_update["kv_to_share"] + if attn_options_update is not None and "kv_to_share" in attn_options_update: + kv_to_share = attn_options_update["kv_to_share"] + if isinstance(kv_to_share, dict): + shared_kv.update(kv_to_share) + else: + shared_kv[layer_idx] = kv_to_share if attn_options_update is not None: attn_options_.update(**attn_options_update) @@ -276,6 +411,8 @@ def forward( ) if self.apply_embedding and tokens is not None and h is None: h = self.tok_embeddings(tokens) + if self.embedding_scale_factor != 1.0: + h = h * self.embedding_scale_factor if attn_options is None: attn_options = {} @@ -283,11 +420,35 @@ def forward( freqs_cos, freqs_sin = self.rope.get_freqs( attn_options.get("input_pos"), seqlen ) + per_layer_inputs = None + if self.hidden_size_per_layer_input > 0: + per_layer_projection = ( + self.per_layer_model_projection(h) * self.per_layer_model_projection_scale + ) + per_layer_projection = per_layer_projection.reshape( + *h.shape[:-1], + self.n_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if tokens is not None and self.embed_tokens_per_layer is not None: + per_layer_inputs = ( + self.embed_tokens_per_layer(tokens) + * self.per_layer_embedding_scale_factor + ).reshape( + *tokens.shape, self.n_layers, self.hidden_size_per_layer_input + ) + per_layer_inputs = ( + per_layer_projection + per_layer_inputs + ) * self.per_layer_input_scale + else: + per_layer_inputs = per_layer_projection attn_options_ = attn_options.copy() if attn_options is not None else {} h, attn_options_update = self._forward_layers( - h, freqs_cos, freqs_sin, attn_options_, seqlen + h, freqs_cos, freqs_sin, attn_options_, seqlen, per_layer_inputs ) if not self.generate_full_logits: @@ -298,6 +459,10 @@ def forward( if self.apply_output: logits = self.output(h) + if self.final_logit_softcapping is not None: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping if self.output_prune_map is not None: if self.generate_full_logits: @@ -358,7 +523,7 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: and model_args.layer_types[layer_id] == "skip_attention" ): attention = AttentionSkip() - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id) layers.append(transformer_block) elif ( model_args.layer_types @@ -373,13 +538,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: attention = linear_cls( model_args, layer_id, rope, **model_args.attention_kwargs ) - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id) layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs ) # pyre-ignore[45] - transformer_block = TransformerBlock(model_args, attention) + transformer_block = TransformerBlock(model_args, attention, layer_id) layers.append(transformer_block) return Transformer(model_args, layers, rope) diff --git a/examples/models/llama/model_args.py b/examples/models/llama/model_args.py index 402c6c39750..d62d21a1d58 100644 --- a/examples/models/llama/model_args.py +++ b/examples/models/llama/model_args.py @@ -11,10 +11,13 @@ class ActFn(Enum): SILU = "silu" GELU = "gelu" GELU_APPROX = "gelu_approx" + GELU_PYTORCH_TANH = "gelu_pytorch_tanh" @classmethod def from_string(cls, value: str) -> "ActFn": """Convert string to ActFn enum.""" + if value == "gelu_pytorch_tanh": + return cls.GELU_PYTORCH_TANH try: return cls(value) except ValueError: @@ -29,7 +32,7 @@ def get_function(self): return F.silu elif self == ActFn.GELU: return F.gelu - elif self == ActFn.GELU_APPROX: + elif self == ActFn.GELU_APPROX or self == ActFn.GELU_PYTORCH_TANH: return partial(F.gelu, approximate="tanh") else: raise ValueError(f"Unsupported activation function: {self}") @@ -110,6 +113,7 @@ class ModelArgs: None # Interval at which to skip RoPE. From Rope to Nope and Back Again: A New Hybrid Attention Strategy (https://huggingface.co/papers/2501.18795). ) partial_rotary_factor: float = 1.0 + rope_parameters: Optional[Dict[str, Dict[str, Any]]] = None rope_theta: Optional[float] = ( None # The official name to override self.rope_freq_base. ) @@ -145,6 +149,12 @@ class ModelArgs: attention_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) # Hybrid models can have layer types different from attention layer_types: Optional[list] = None + vocab_size_per_layer_input: Optional[int] = None + hidden_size_per_layer_input: int = 0 + num_global_key_value_heads: Optional[int] = None + global_head_dim: Optional[int] = None + attention_k_eq_v: bool = False + use_double_wide_mlp: bool = False model_architecture: Optional[str] = ( None # Architecture of model. For HF models, please refer to the HF model.config.architectures. This is used in QNN backend only for now. ) @@ -162,6 +172,10 @@ class ModelArgs: def __post_init__(self): # noqa: C901 if self.n_kv_heads is None: self.n_kv_heads = self.n_heads + if self.num_global_key_value_heads is None: + self.num_global_key_value_heads = self.n_kv_heads + if self.global_head_dim is None: + self.global_head_dim = self.head_dim # rope_theta overrides rope_freq_base since it's the official name. if self.rope_theta is not None: @@ -188,6 +202,10 @@ def find_multiple(n: int, k: int) -> int: if self.head_dim is None: self.head_dim = self.dim // self.n_heads + if self.global_head_dim is None: + self.global_head_dim = self.head_dim + if self.vocab_size_per_layer_input is None: + self.vocab_size_per_layer_input = self.vocab_size if self.linear_key_head_dim is None: self.linear_key_head_dim = self.head_dim diff --git a/examples/models/llama/norm.py b/examples/models/llama/norm.py index 0189c88b13b..b5d518274fa 100644 --- a/examples/models/llama/norm.py +++ b/examples/models/llama/norm.py @@ -10,7 +10,13 @@ class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): + def __init__( + self, + dim: int, + eps: float = 1e-6, + add_unit_offset: bool = False, + with_scale: bool = True, + ): """ Initialize the RMSNorm normalization layer. @@ -29,7 +35,11 @@ def __init__(self, dim: int, eps: float = 1e-6, add_unit_offset: bool = False): self.dim = dim self.eps = eps self.add_unit_offset = add_unit_offset - self.weight = nn.Parameter(torch.ones(dim)) + self.with_scale = with_scale + if with_scale: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.register_parameter("weight", None) def _norm(self, x): """ @@ -56,6 +66,8 @@ def forward(self, x): """ output = self._norm(x.float()).type_as(x) + if not self.with_scale: + return output if self.add_unit_offset: return output * (1.0 + self.weight.float()).type_as(x) return output * self.weight.type_as(x) diff --git a/examples/models/llama/rope.py b/examples/models/llama/rope.py index ea4e6b37243..54d01b8eb4c 100644 --- a/examples/models/llama/rope.py +++ b/examples/models/llama/rope.py @@ -165,6 +165,41 @@ def hf_precompute_freqs_cis( return freqs_cos, freqs_sin +def hf_precompute_proportional_freqs_cis( + dim: int, + end: int, + theta: float, + partial_rotary_factor: float = 1.0, + device: Union[str, torch.device] = "cpu", +): + rope_angles = int(partial_rotary_factor * dim // 2) + inv_freq_rotated = 1.0 / ( + theta + ** ( + torch.arange(0, 2 * rope_angles, 2, device=device, dtype=torch.int64).float() + / dim + ) + ) + nope_angles = dim // 2 - rope_angles + if nope_angles > 0: + inv_freq = torch.cat( + ( + inv_freq_rotated, + torch.zeros(nope_angles, dtype=torch.float32, device=device), + ), + dim=0, + ) + else: + inv_freq = inv_freq_rotated + + t = torch.arange(end, device=inv_freq.device, dtype=torch.int64).type_as(inv_freq) + freqs = torch.outer(t, inv_freq).float() + emb = torch.cat((freqs, freqs), dim=-1) + freqs_cos = torch.cos(emb) + freqs_sin = torch.sin(emb) + return freqs_cos, freqs_sin + + # Based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L135 def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -193,15 +228,25 @@ def hf_apply_rotary_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos_q = cos.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) + sin_q = sin.unsqueeze(unsqueeze_dim).to(dtype=q.dtype) + if k.dtype == q.dtype: + cos_k = cos_q + sin_k = sin_q + else: + cos_k = cos.unsqueeze(unsqueeze_dim).to(dtype=k.dtype) + sin_k = sin.unsqueeze(unsqueeze_dim).to(dtype=k.dtype) rotary_dim = cos.shape[-1] q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] - q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) - k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + q_embed = torch.cat( + [(q_rot * cos_q) + (rotate_half(q_rot) * sin_q), q_pass], dim=-1 + ) + k_embed = torch.cat( + [(k_rot * cos_k) + (rotate_half(k_rot) * sin_k), k_pass], dim=-1 + ) return q_embed, k_embed @@ -224,8 +269,8 @@ def hf_apply_rotary_emb_to_k(k, cos, sin, position_ids=None, unsqueeze_dim=1): Returns: `torch.Tensor` the key tensor rotated using the Rotary Position Embedding. """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim).to(dtype=k.dtype) + sin = sin.unsqueeze(unsqueeze_dim).to(dtype=k.dtype) k_embed = (k * cos) + (rotate_half(k) * sin) return k_embed @@ -234,6 +279,7 @@ class Rope(torch.nn.Module): def __init__(self, params: ModelArgs): super().__init__() self.params = params + self._layer_freq_buffer_names = {} # Choose the appropriate RoPE implementation if self.params.use_hf_rope: @@ -243,6 +289,7 @@ def __init__(self, params: ModelArgs): device=getattr(self.params, "device", "cpu"), ) self.apply_rotary_emb = hf_apply_rotary_emb + self.apply_rotary_emb_to_tensor = hf_apply_rotary_emb_to_k else: self.precompute_freqs_cis = partial( precompute_freqs_cis, @@ -252,6 +299,7 @@ def __init__(self, params: ModelArgs): device=getattr(self.params, "device", "cpu"), ) self.apply_rotary_emb = RotaryEmbedding() + self.apply_rotary_emb_to_tensor = apply_rotary_emb_to_k # Precompute frequencies freqs_cos, freqs_sin = self.precompute_freqs_cis( @@ -265,6 +313,97 @@ def __init__(self, params: ModelArgs): ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) + self._register_layer_type_freqs() + + def _register_layer_type_freqs(self) -> None: + if not self.params.rope_parameters: + return + + max_context_len = ( + self.params.max_context_len + if self.params.ffn_dim_multiplier is None + else self.params.max_context_len * 2 + ) + device = getattr(self.params, "device", "cpu") + for layer_type, rope_params in self.params.rope_parameters.items(): + if rope_params is None: + continue + rope_type = rope_params.get("rope_type", "default") + head_dim = ( + self.params.global_head_dim + if layer_type == "full_attention" + else self.params.head_dim + ) + rope_theta = rope_params.get( + "rope_theta", + self.params.local_rope_theta + if layer_type == "sliding_attention" + else self.params.rope_freq_base, + ) + partial_rotary_factor = rope_params.get( + "partial_rotary_factor", self.params.partial_rotary_factor + ) + + if self.params.use_hf_rope: + if rope_type == "proportional": + freqs_cos, freqs_sin = hf_precompute_proportional_freqs_cis( + head_dim, + max_context_len, + rope_theta, + partial_rotary_factor=partial_rotary_factor, + device=device, + ) + else: + freqs_cos, freqs_sin = hf_precompute_freqs_cis( + head_dim, + max_context_len, + rope_theta, + partial_rotary_factor=partial_rotary_factor, + device=device, + ) + else: + freqs_cos, freqs_sin = precompute_freqs_cis( + head_dim, + max_context_len, + rope_theta, + use_scaled=self.params.use_scaled_rope, + scale_factor=self.params.rope_scale_factor, + high_freq_factor=self.params.high_freq_factor, + device=device, + ) + + cos_name = f"{layer_type}_freqs_cos" + sin_name = f"{layer_type}_freqs_sin" + self.register_buffer(cos_name, freqs_cos, persistent=False) + self.register_buffer(sin_name, freqs_sin, persistent=False) + self._layer_freq_buffer_names[layer_type] = (cos_name, sin_name) + + def _slice_freqs( + self, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + input_pos: Optional[torch.Tensor], + seq_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if self.params.use_kv_cache: + assert ( + input_pos is not None + ), "input_pos must be provided when use_kv_cache is True" + + if self.params.enable_dynamic_shape: + input_pos_item = input_pos[-1].item() + torch._check_is_size(input_pos_item) + torch._check(input_pos_item < self.params.max_context_len) + freqs_cos = freqs_cos.narrow(0, input_pos_item, seq_len) + freqs_sin = freqs_sin.narrow(0, input_pos_item, seq_len) + else: + freqs_cos = freqs_cos[input_pos] + freqs_sin = freqs_sin[input_pos] + else: + assert input_pos is None, "input_pos is unused when use_kv_cache is False" + freqs_cos = freqs_cos[:seq_len] + freqs_sin = freqs_sin[:seq_len] + return freqs_cos, freqs_sin def forward( self, @@ -275,6 +414,18 @@ def forward( ): return self.apply_rotary_emb(q, k, freqs_cos, freqs_sin) + def forward_to_tensor( + self, + x: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, + ) -> torch.Tensor: + if self.params.use_hf_rope: + return self.apply_rotary_emb_to_tensor( + x, freqs_cos, freqs_sin, unsqueeze_dim=1 + ) + return self.apply_rotary_emb_to_tensor(x, freqs_cos, freqs_sin) + def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): """ Get the precomputed frequencies for the given input position and sequence length. @@ -286,33 +437,23 @@ def get_freqs(self, input_pos: Optional[torch.Tensor], seq_len: int): Returns: Tuple[torch.Tensor, torch.Tensor]: The precomputed frequencies for the given input position and sequence length. """ - if self.params.use_kv_cache: - assert ( - input_pos is not None - ), "input_pos must be provided when use_kv_cache is True" + return self._slice_freqs(self.freqs_cos, self.freqs_sin, input_pos, seq_len) - if self.params.enable_dynamic_shape: - # when KV cache is used, seqlen is most likely 1. We want to slice from the start_pos. - input_pos_item = input_pos[-1].item() - torch._check_is_size(input_pos_item) - torch._check(input_pos_item < self.params.max_context_len) - # pyre-ignore: Incompatible parameter type [6]: torch.narrow does expect int or Tensor - freqs_cos = self.freqs_cos.narrow(0, input_pos_item, seq_len) - # pyre-ignore: Incompatible parameter type [6] - freqs_sin = self.freqs_sin.narrow(0, input_pos_item, seq_len) - else: - # When not using dynamic shape, use of the .item results in - # symints, due to querying the data from tensor. - # this path avoids that for mps backend, although probably mps backend - # can support dynamic shape? - freqs_cos = self.freqs_cos[input_pos] - freqs_sin = self.freqs_sin[input_pos] - - else: - assert input_pos is None, "input_pos is unused when use_kv_cache is False" - freqs_cos = self.freqs_cos[:seq_len] - freqs_sin = self.freqs_sin[:seq_len] - return freqs_cos, freqs_sin + def get_freqs_for_layer_type( + self, + layer_type: Optional[str], + input_pos: Optional[torch.Tensor], + seq_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if ( + layer_type is None + or layer_type not in self._layer_freq_buffer_names + ): + return self.get_freqs(input_pos, seq_len) + cos_name, sin_name = self._layer_freq_buffer_names[layer_type] + return self._slice_freqs( + getattr(self, cos_name), getattr(self, sin_name), input_pos, seq_len + ) def get_freqs_using_indices(self, indices: torch.Tensor): """ diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index c54e689ba8d..d9abc1a2c42 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -23,10 +23,12 @@ def __init__( self, dim: int, use_attention_mask: bool = False, + scale: float | None = None, ): super().__init__() self.dim = dim self.use_attention_mask = use_attention_mask + self.scale = scale def forward( self, @@ -58,6 +60,7 @@ def forward( mask, # Attention mask 0, # dropout probability. Ignored by the code False, # is_causal + self.scale, ) else: output = torch.ops.llama.custom_sdpa( @@ -68,6 +71,7 @@ def forward( None, # Attention mask 0, # dropout probability. Ignored by the code True, # is_causal + self.scale, ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) @@ -83,6 +87,7 @@ def _replace_sdpa_with_custom_op( SDPACustom( child.dim, use_attention_mask=use_attention_mask, + scale=child.scale, ), ) else: @@ -118,7 +123,11 @@ class QuantizedSDPA(torch.nn.Module): """ def __init__( - self, dim: int, kv_cache: QuantizedKVCache, use_attention_mask: bool = False + self, + dim: int, + kv_cache: QuantizedKVCache, + use_attention_mask: bool = False, + scale: float | None = None, ): super().__init__() self.dim = dim @@ -126,6 +135,7 @@ def __init__( self.float_dtype = torch.float32 self.kv_cache = kv_cache self.use_attention_mask = use_attention_mask + self.scale = scale def forward( self, @@ -172,7 +182,7 @@ def forward( mask, 0, False, - None, + self.scale, q_zero_point_int8, q_scale_fp32, k_zero_point_int8, @@ -189,7 +199,7 @@ def forward( None, 0, True, - None, + self.scale, q_zero_point_int8, q_scale_fp32, k_zero_point_int8, @@ -208,7 +218,7 @@ def _update_attention_module_with_quantized_sdpa( assert sdpa is not None # TODO: add support for SDPA with attention mask # pyre-ignore - setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache)) # noqa: B010 + setattr(module, "SDPA", QuantizedSDPA(sdpa.dim, kv_cache, scale=sdpa.scale)) # noqa: B010 def _replace_sdpa_with_quantized_sdpa(module: torch.nn.Module): diff --git a/examples/models/llama/tests/BUCK b/examples/models/llama/tests/BUCK index c01fa9f2151..eae725b034a 100644 --- a/examples/models/llama/tests/BUCK +++ b/examples/models/llama/tests/BUCK @@ -127,3 +127,17 @@ fbcode_target(_kind = python_unittest, "//executorch/examples/models/llama:llama_transformer", ], ) + +fbcode_target(_kind = python_unittest, + name = "test_gemma4_support", + srcs = [ + "test_gemma4_support.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/gemma4:gemma4", + "//executorch/examples/models/llama:llama_transformer", + "//executorch/extension/export_util:export_util", + "//executorch/extension/pybindings:portable_lib", + ], +) diff --git a/examples/models/llama/tests/test_gemma4_support.py b/examples/models/llama/tests/test_gemma4_support.py new file mode 100644 index 00000000000..6cb897dff16 --- /dev/null +++ b/examples/models/llama/tests/test_gemma4_support.py @@ -0,0 +1,438 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.examples.models.gemma4.convert_weights import gemma4_to_meta +from executorch.examples.models.llama.attention import ( + AttentionGemma4MHA, + KVCache, + RingKVCache, +) +from executorch.examples.models.llama.llama_transformer import ( + _get_kv_donor_layer_idx, + construct_transformer, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.rope import Rope +from executorch.extension.export_util.utils import export_to_edge +from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer + + +class _Gemma4NextTokenModule(torch.nn.Module): + def __init__(self, model: torch.nn.Module): + super().__init__() + self.model = model + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + return self.model(tokens=tokens) + + +class Gemma4SupportTest(unittest.TestCase): + def test_hf_rope_preserves_input_dtype(self): + args = ModelArgs( + dim=32, + hidden_dim=64, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=8, + max_context_len=16, + use_hf_rope=True, + ) + rope = Rope(args) + + x = torch.randn(1, 3, 4, 8, dtype=torch.bfloat16) + freqs_cos, freqs_sin = rope.get_freqs(None, x.shape[1]) + + rotated = rope.forward_to_tensor(x, freqs_cos, freqs_sin) + + self.assertEqual(rotated.dtype, x.dtype) + + def test_dual_rope_tables_use_layer_specific_head_dims(self): + args = ModelArgs( + dim=64, + hidden_dim=128, + n_layers=4, + n_heads=4, + n_kv_heads=2, + head_dim=16, + global_head_dim=32, + max_context_len=32, + use_hf_rope=True, + rope_parameters={ + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + }, + ) + + rope = Rope(args) + sliding_cos, _ = rope.get_freqs_for_layer_type("sliding_attention", None, 3) + full_cos, _ = rope.get_freqs_for_layer_type("full_attention", None, 3) + + self.assertEqual(sliding_cos.shape, (3, 16)) + self.assertEqual(full_cos.shape, (3, 32)) + + def test_shared_layers_pick_same_type_donors(self): + layer_types = [ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ] + + self.assertEqual( + _get_kv_donor_layer_idx( + 4, n_layers=6, num_kv_shared_layers=2, layer_types=layer_types + ), + 2, + ) + self.assertEqual( + _get_kv_donor_layer_idx( + 5, n_layers=6, num_kv_shared_layers=2, layer_types=layer_types + ), + 3, + ) + + def test_gemma4_attention_uses_ring_cache_for_sliding_layers(self): + args = ModelArgs( + dim=32, + hidden_dim=64, + n_layers=4, + n_heads=4, + n_kv_heads=2, + head_dim=8, + global_head_dim=16, + max_batch_size=1, + max_context_len=16, + use_kv_cache=True, + enable_dynamic_shape=False, + attention_type="gemma4_mha", + layer_types=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + sliding_window=4, + ) + rope = Rope(args) + + sliding_attn = AttentionGemma4MHA(args, layer_id=0, rope=rope) + full_attn = AttentionGemma4MHA(args, layer_id=1, rope=rope) + + self.assertIsInstance(sliding_attn.kv_cache, RingKVCache) + self.assertIsInstance(full_attn.kv_cache, KVCache) + + def test_gemma4_attention_uses_unit_attention_scale(self): + args = ModelArgs( + dim=32, + hidden_dim=64, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=8, + global_head_dim=16, + attention_type="gemma4_mha", + attention_multiplier=1.0, + layer_types=["sliding_attention", "full_attention"], + sliding_window=4, + use_kv_cache=True, + max_batch_size=1, + max_context_len=16, + ) + rope = Rope(args) + + sliding_attn = AttentionGemma4MHA(args, layer_id=0, rope=rope) + full_attn = AttentionGemma4MHA(args, layer_id=1, rope=rope) + + self.assertEqual(sliding_attn.attention_scale, 1.0) + self.assertEqual(full_attn.attention_scale, 1.0) + self.assertEqual(sliding_attn.SDPA.scale, 1.0) + self.assertEqual(full_attn.SDPA.scale, 1.0) + + def test_transformer_executes_shared_layers_and_softcaps_logits(self): + args = ModelArgs( + dim=32, + hidden_dim=64, + n_layers=4, + n_heads=4, + n_kv_heads=2, + head_dim=8, + global_head_dim=16, + vocab_size=64, + vocab_size_per_layer_input=64, + hidden_size_per_layer_input=4, + num_kv_shared_layers=2, + use_double_wide_mlp=True, + act_fn="gelu_pytorch_tanh", + norm_eps=1e-6, + post_attention_norm=True, + post_ffn_norm=True, + apply_embedding=True, + embedding_scale_factor=1.5, + use_hf_rope=True, + attention_type="gemma4_mha", + final_logit_softcapping=2.0, + layer_types=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + sliding_window=4, + max_batch_size=1, + max_seq_len=16, + max_context_len=16, + use_kv_cache=True, + generate_full_logits=True, + enable_dynamic_shape=False, + rope_parameters={ + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + }, + ) + + torch.manual_seed(0) + model = construct_transformer(args) + self.assertEqual(model.layers[0].feed_forward.w1.weight.shape[0], 64) + self.assertEqual(model.layers[2].feed_forward.w1.weight.shape[0], 128) + + shared_layer_calls = [] + + def _make_hook(layer_idx): + def _hook(_module, _inputs, _outputs): + shared_layer_calls.append(layer_idx) + + return _hook + + hooks = [ + model.layers[2].register_forward_hook(_make_hook(2)), + model.layers[3].register_forward_hook(_make_hook(3)), + ] + try: + tokens = torch.tensor([[1, 2, 3]], dtype=torch.long) + logits = model( + tokens=tokens, + attn_options={"input_pos": torch.tensor([0, 1, 2], dtype=torch.long)}, + ) + finally: + for hook in hooks: + hook.remove() + + self.assertEqual(shared_layer_calls, [2, 3]) + self.assertEqual(logits.shape, (1, 3, 64)) + self.assertLessEqual(float(logits.detach().abs().max()), 2.0001) + + def test_per_layer_token_embeddings_are_scaled_before_mix(self): + args = ModelArgs( + dim=16, + hidden_dim=32, + n_layers=2, + n_heads=4, + n_kv_heads=2, + head_dim=4, + vocab_size=16, + vocab_size_per_layer_input=16, + hidden_size_per_layer_input=4, + attention_type="gemma4_mha", + max_seq_len=8, + max_context_len=8, + ) + + model = construct_transformer(args) + model.per_layer_model_projection.weight.data.zero_() + model.per_layer_projection_norm.weight.data.fill_(1.0) + model.embed_tokens_per_layer.weight.data.zero_() + model.embed_tokens_per_layer.weight.data[1].fill_(1.0) + + captured = {} + + def _capture_forward_layers( + h, freqs_cos, freqs_sin, attn_options_, seqlen, per_layer_inputs=None + ): + captured["per_layer_inputs"] = per_layer_inputs.detach().clone() + return h, None + + model._forward_layers = _capture_forward_layers + _ = model(tokens=torch.tensor([[1]], dtype=torch.long)) + + expected_scale = (args.hidden_size_per_layer_input**0.5) * (2.0**-0.5) + expected = torch.full( + (1, 1, args.n_layers, args.hidden_size_per_layer_input), + expected_scale, + ) + self.assertTrue(torch.allclose(captured["per_layer_inputs"], expected)) + + def test_gemma4_layer_scalar_scales_block_output(self): + args = ModelArgs( + dim=16, + hidden_dim=32, + n_layers=1, + n_heads=4, + n_kv_heads=2, + head_dim=4, + vocab_size=32, + attention_type="gemma4_mha", + max_seq_len=8, + max_context_len=8, + ) + + torch.manual_seed(0) + model = construct_transformer(args) + layer = model.layers[0] + self.assertTrue(hasattr(layer, "layer_scalar")) + + tokens = torch.tensor([[1, 2, 3]], dtype=torch.long) + base_logits = model(tokens=tokens) + + layer.layer_scalar.zero_() + scaled_logits = model(tokens=tokens) + + self.assertGreater(float(base_logits.detach().abs().max()), 0.0) + self.assertTrue(torch.allclose(scaled_logits, torch.zeros_like(scaled_logits))) + + def test_tiny_gemma4_export_runtime_matches_eager(self): + args = ModelArgs( + dim=32, + hidden_dim=64, + n_layers=4, + n_heads=4, + n_kv_heads=2, + head_dim=8, + global_head_dim=16, + vocab_size=32, + vocab_size_per_layer_input=32, + hidden_size_per_layer_input=4, + num_kv_shared_layers=2, + act_fn="gelu_pytorch_tanh", + norm_eps=1e-6, + post_attention_norm=True, + post_ffn_norm=True, + apply_embedding=True, + embedding_scale_factor=1.5, + use_hf_rope=True, + attention_type="gemma4_mha", + attention_multiplier=1.0, + final_logit_softcapping=2.0, + layer_types=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + sliding_window=4, + max_batch_size=1, + max_seq_len=16, + max_context_len=16, + generate_full_logits=False, + rope_parameters={ + "sliding_attention": { + "rope_type": "default", + "rope_theta": 10000.0, + }, + "full_attention": { + "rope_type": "proportional", + "rope_theta": 1000000.0, + "partial_rotary_factor": 0.25, + }, + }, + ) + + torch.manual_seed(0) + model = _Gemma4NextTokenModule(construct_transformer(args)).eval() + tokens = torch.tensor([[1, 2, 3, 4]], dtype=torch.long) + + edge_program = export_to_edge(model, (tokens,), strict=True, verbose=False) + executorch_program = edge_program.to_executorch() + executorch_model = _load_for_executorch_from_buffer(executorch_program.buffer) + + with torch.no_grad(): + eager_logits = model(tokens) + executorch_logits = executorch_model.run_method("forward", (tokens,))[0] + + self.assertEqual(eager_logits.shape, executorch_logits.shape) + self.assertTrue( + torch.allclose(eager_logits, executorch_logits, rtol=1e-4, atol=1e-4) + ) + self.assertEqual( + int(torch.argmax(eager_logits, dim=-1).item()), + int(torch.argmax(executorch_logits, dim=-1).item()), + ) + + +class Gemma4ConvertWeightsTest(unittest.TestCase): + def test_maps_text_and_per_layer_weights_from_multimodal_checkpoint(self): + state_dict = { + "model.language_model.embed_tokens.weight": torch.randn(16, 8), + "model.language_model.embed_tokens_per_layer.weight": torch.randn(16, 12), + "model.language_model.per_layer_model_projection.weight": torch.randn(12, 8), + "model.language_model.per_layer_projection_norm.weight": torch.randn(4), + "model.language_model.norm.weight": torch.randn(8), + "model.language_model.layers.0.input_layernorm.weight": torch.randn(8), + "model.language_model.layers.0.self_attn.q_proj.weight": torch.randn(16, 8), + "model.language_model.layers.0.self_attn.k_proj.weight": torch.randn(8, 8), + "model.language_model.layers.0.self_attn.v_proj.weight": torch.randn(8, 8), + "model.language_model.layers.0.self_attn.o_proj.weight": torch.randn(8, 8), + "model.language_model.layers.0.self_attn.q_norm.weight": torch.randn(4), + "model.language_model.layers.0.self_attn.k_norm.weight": torch.randn(4), + "model.language_model.layers.0.self_attn.v_norm.weight": torch.randn(4), + "model.language_model.layers.0.post_attention_layernorm.weight": torch.randn(8), + "model.language_model.layers.0.pre_feedforward_layernorm.weight": torch.randn(8), + "model.language_model.layers.0.post_feedforward_layernorm.weight": torch.randn(8), + "model.language_model.layers.0.mlp.gate_proj.weight": torch.randn(12, 8), + "model.language_model.layers.0.mlp.down_proj.weight": torch.randn(8, 12), + "model.language_model.layers.0.mlp.up_proj.weight": torch.randn(12, 8), + "model.language_model.layers.0.layer_scalar": torch.ones(1), + "model.language_model.layers.0.per_layer_input_gate.weight": torch.randn(4, 8), + "model.language_model.layers.0.per_layer_projection.weight": torch.randn(8, 4), + "model.language_model.layers.0.post_per_layer_input_norm.weight": torch.randn(8), + "model.vision_tower.weight": torch.randn(8, 8), + } + + converted = gemma4_to_meta(state_dict) + + self.assertIn("tok_embeddings.weight", converted) + self.assertIn("embed_tokens_per_layer.weight", converted) + self.assertIn("per_layer_model_projection.weight", converted) + self.assertIn("layers.0.attention.wq.weight", converted) + self.assertIn("layers.0.post_attention_norm.weight", converted) + self.assertIn("layers.0.layer_scalar", converted) + self.assertIn("layers.0.per_layer_projection.weight", converted) + self.assertIn("output.weight", converted) + self.assertNotIn("layers.0.attention.v_norm_fn.weight", converted) + + def test_raises_on_unexpected_text_key(self): + state_dict = { + "model.language_model.embed_tokens.weight": torch.randn(16, 8), + "model.language_model.layers.0.unknown.weight": torch.randn(8, 8), + } + + with self.assertRaisesRegex( + ValueError, "Unexpected checkpoint key not mapped for Gemma4 export" + ): + gemma4_to_meta(state_dict) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/model_factory.py b/examples/models/model_factory.py index 5b66aef8de7..41bb4203d6a 100644 --- a/examples/models/model_factory.py +++ b/examples/models/model_factory.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import importlib -import os from typing import Any, Dict, Tuple import torch @@ -34,10 +33,8 @@ def create_model( Raises: ValueError: If the provided model class is not found in the module. """ - package_prefix = "executorch." if not os.getcwd().endswith("executorch") else "" - module = importlib.import_module( - f"{package_prefix}examples.models.{module_name}" - ) + package_root = __package__ or "executorch.examples.models" + module = importlib.import_module(f"{package_root}.{module_name}") if hasattr(module, model_class_name): model_class = getattr(module, model_class_name) diff --git a/examples/models/test/BUCK b/examples/models/test/BUCK new file mode 100644 index 00000000000..ceb1d49fb8e --- /dev/null +++ b/examples/models/test/BUCK @@ -0,0 +1,15 @@ +load("@fbcode_macros//build_defs:build_file_migration.bzl", "fbcode_target", "non_fbcode_target") +load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest") + +oncall("executorch") + +fbcode_target(_kind = python_unittest, + name = "test_model_factory", + srcs = [ + "test_model_factory.py", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models:models", + ], +) diff --git a/examples/models/test/test_model_factory.py b/examples/models/test/test_model_factory.py new file mode 100644 index 00000000000..3ab9452b0c3 --- /dev/null +++ b/examples/models/test/test_model_factory.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import types +import unittest +from unittest.mock import patch + +import torch + +from executorch.examples.models.model_factory import EagerModelFactory + + +class _FakeModel(torch.nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def get_eager_model(self): + return self + + def get_example_inputs(self): + return (torch.ones(1),) + + def get_example_kwarg_inputs(self): + return {"input_pos": torch.tensor([0])} + + def get_dynamic_shapes(self): + return {"tokens": None} + + +class ModelFactoryTest(unittest.TestCase): + def test_create_model_imports_from_package_root(self) -> None: + fake_module = types.SimpleNamespace(AddModule=_FakeModel) + + with patch( + "executorch.examples.models.model_factory.importlib.import_module", + return_value=fake_module, + ) as mock_import: + model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( + EagerModelFactory.create_model("toy_model", "AddModule", foo="bar") + ) + + mock_import.assert_called_once_with("executorch.examples.models.toy_model") + self.assertEqual(model.kwargs, {"foo": "bar"}) + self.assertEqual(len(example_inputs), 1) + self.assertTrue(torch.equal(example_inputs[0], torch.ones(1))) + self.assertEqual(set(example_kwarg_inputs.keys()), {"input_pos"}) + self.assertTrue( + torch.equal(example_kwarg_inputs["input_pos"], torch.tensor([0])) + ) + self.assertEqual(dynamic_shapes, {"tokens": None}) + + def test_create_model_loads_real_toy_model(self) -> None: + model, example_inputs, example_kwarg_inputs, dynamic_shapes = ( + EagerModelFactory.create_model("toy_model", "AddModule") + ) + + self.assertEqual(type(model).__name__, "AddModule") + self.assertEqual(len(example_inputs), 2) + self.assertIsNone(example_kwarg_inputs) + self.assertIsNone(dynamic_shapes) diff --git a/exir/_serialize/_flatbuffer.py b/exir/_serialize/_flatbuffer.py index 77d0d073907..3dc58cb4271 100644 --- a/exir/_serialize/_flatbuffer.py +++ b/exir/_serialize/_flatbuffer.py @@ -12,10 +12,10 @@ import re import shutil import subprocess - import tempfile from dataclasses import dataclass +from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence # If this environment variable is set to true, save the flatc input files when @@ -125,7 +125,19 @@ def __init__(self, resource_names: Sequence[str]) -> None: # Map each name to its contents. self._files: Dict[str, bytes] = {} for name in resource_names: - self._files[name] = importlib.resources.read_binary(__package__, name) + self._files[name] = self._read_binary_resource(name) + + @staticmethod + def _read_binary_resource(name: str) -> bytes: + try: + return importlib.resources.read_binary(__package__, name) + except FileNotFoundError: + # Editable/source-tree usage may not have copied the schemas into + # exir/_serialize yet. Fall back to the source schema directory. + repo_schema = Path(__file__).resolve().parents[2] / "schema" / name + if repo_schema.exists(): + return repo_schema.read_bytes() + raise def patch_files(self, patch_fn: Callable[[bytes], bytes]) -> None: """Uses the provided patching function to update the contents of all diff --git a/exir/_serialize/test/test_flatbuffer.py b/exir/_serialize/test/test_flatbuffer.py index 801ddca112d..3ee0b49bf5f 100644 --- a/exir/_serialize/test/test_flatbuffer.py +++ b/exir/_serialize/test/test_flatbuffer.py @@ -11,6 +11,7 @@ import shutil import tempfile import unittest +from pathlib import Path from typing import Dict, Optional, Sequence from unittest.mock import patch @@ -72,6 +73,22 @@ def test_load_patch_and_write(self) -> None: read_file(out_dir, "resource-2"), b"resource-2 data PATCHED" ) + def test_falls_back_to_source_schema_when_package_resource_missing(self) -> None: + expected = ( + Path(_flatbuffer.__file__).resolve().parents[2] / "schema" / "program.fbs" + ).read_bytes() + + with patch.object( + _flatbuffer.importlib.resources, + "read_binary", + side_effect=FileNotFoundError, + ): + rf = _ResourceFiles(("program.fbs",)) + + with tempfile.TemporaryDirectory() as out_dir: + rf.write_to(out_dir) + self.assertEqual(read_file(out_dir, "program.fbs"), expected) + # Fake resource files to use when testing alignment-patching. SCHEMA_FILES: Dict[str, bytes] = { diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index e126ef54456..606df6794de 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -47,6 +47,8 @@ class ModelType(str, Enum): qwen3_5_0_8b = "qwen3_5_0_8b" qwen3_5_2b = "qwen3_5_2b" qwen3_5_4b = "qwen3_5_4b" + gemma4_e2b = "gemma4_e2b" + gemma4_e4b = "gemma4_e4b" phi_4_mini = "phi_4_mini" smollm2 = "smollm2" lfm2_350m = "lfm2_350m"