diff --git a/examples/models/qwen3_5_moe/export.py b/examples/models/qwen3_5_moe/export.py index 7437bc5f461..19a720a2e79 100644 --- a/examples/models/qwen3_5_moe/export.py +++ b/examples/models/qwen3_5_moe/export.py @@ -107,6 +107,8 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096): # Any missing weight key indicates a version mismatch between the # checkpoint and the model (e.g., unfused vs fused projections). runtime_prefixes = ( + ".mask", + ".inv_freq", ".kv_cache.", ".conv_state", ".recurrent_state", @@ -312,10 +314,11 @@ def _materialize_buffers(model, config): """Materialize meta-device buffers before torch.export. Replaces meta buffers with real tensors on CPU, recomputes RoPE - inv_freq and causal masks. + inv_freq and causal masks. State buffers (KV cache, conv/recurrent + state) are zero-initialized registered buffers that will be shared + across methods via share_mutable_buffers. """ - # State buffers (KV cache, conv/recurrent state) are bf16 to match - # compute dtype. Masks stay bool, inv_freq stays float32. + # Masks stay bool, inv_freq stays float32. for fqn, buf in list(model.named_buffers()): if buf.device.type == "meta": dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool @@ -378,7 +381,18 @@ def _apply_turboquant(model, config): def export_and_lower(model, config, args): - """Export model to .pte via torch.export + CUDA backend.""" + """Export model to .pte via torch.export + CUDA backend. + + Exports two methods: + - "decode": decode path (T=1), uses native PyTorch recurrent FLA + so AOTI can fuse with surrounding ops for maximum decode throughput. + - "prefill": prefill path (T>=2), uses chunked FLA triton_op with + dynamic sequence length. + + Both methods share mutable state buffers (KV cache, conv_state, + recurrent_state) via share_mutable_buffers=True. The model uses + registered buffers with in-place updates — no state in/out args. + """ import torch._inductor.config as inductor_config from executorch.backends.cuda.cuda_backend import CudaBackend @@ -398,25 +412,39 @@ def export_and_lower(model, config, args): # -O0 compiles ~8x faster than -O1 with no measurable runtime impact. inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" - # Dynamic shapes - example_tokens = torch.tensor([[0, 1]], dtype=torch.long) - example_input_pos = torch.tensor([0, 1], dtype=torch.long) - seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1) - dynamic_shapes = ({1: seq_dim}, {0: seq_dim}) - - print("Exporting with torch.export...") + # --- Decode method (T=1, static shape) --- + print("Exporting decode method...") + decode_tokens = torch.tensor([[0]], dtype=torch.long) + decode_pos = torch.tensor([0], dtype=torch.long) + with torch.no_grad(): + decode_ep = export( + model, + (decode_tokens, decode_pos), + strict=True, + ) + print("Decode export successful!") + + # --- Prefill method (T>=2, dynamic shape) --- + print("Exporting prefill method...") + prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long) + prefill_pos = torch.tensor([0, 1], dtype=torch.long) + seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1) + prefill_dynamic_shapes = ( + {1: seq_dim}, # tokens + {0: seq_dim}, # input_pos + ) with torch.no_grad(): - exported = export( + prefill_ep = export( model, - (example_tokens, example_input_pos), - dynamic_shapes=dynamic_shapes, + (prefill_tokens, prefill_pos), + dynamic_shapes=prefill_dynamic_shapes, strict=True, ) - print("Export successful!") + print("Prefill export successful!") - # Lower with CUDA backend + # Lower with CUDA backend (per-method partitioners to avoid so_blob collision) print("Lowering to ExecuTorch with CUDA...") - compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")] + metadata = { "get_max_seq_len": config.max_seq_len, "get_vocab_size": config.vocab_size, @@ -426,8 +454,19 @@ def export_and_lower(model, config, args): "enable_dynamic_shape": True, } et_prog = to_edge_transform_and_lower( - exported, - partitioner=[CudaPartitioner(compile_specs)], + {"decode": decode_ep, "prefill": prefill_ep}, + partitioner={ + "decode": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("decode")] + ) + ], + "prefill": [ + CudaPartitioner( + [CudaBackend.generate_method_name_compile_spec("prefill")] + ) + ], + }, compile_config=EdgeCompileConfig( _check_ir_validity=False, _skip_dim_order=True, @@ -438,7 +477,11 @@ def export_and_lower(model, config, args): config=ExecutorchBackendConfig( extract_delegate_segments=True, do_quant_fusion_and_const_prop=True, - memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + share_mutable_buffers=True, + ), + emit_mutable_buffer_names=True, ), ) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 266d0e65419..c327d9e91fd 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -8,11 +8,15 @@ #include -#include +#include +#include +#include #include +#include #include #include +#include #include #include @@ -24,6 +28,13 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy)."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); namespace llm = ::executorch::extension::llm; +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::extension::TensorPtr; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; + +using SizesType = executorch::aten::SizesType; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -37,11 +48,6 @@ int main(int argc, char** argv) { return 1; } - std::vector data_files; - if (!FLAGS_data_path.empty()) { - data_files.push_back(FLAGS_data_path); - } - // Load tokenizer auto tokenizer = std::make_unique(); auto tok_status = tokenizer->load(FLAGS_tokenizer_path); @@ -53,25 +59,169 @@ int main(int argc, char** argv) { return 1; } - // Create LLM runner - auto runner = llm::create_text_llm_runner( - FLAGS_model_path, std::move(tokenizer), data_files, FLAGS_temperature); + // Create Module with share_memory_arenas=true so prefill and forward + // share mutable buffers (KV cache, conv_state, recurrent_state). + std::vector data_files; + if (!FLAGS_data_path.empty()) { + data_files.push_back(FLAGS_data_path); + } + auto module = std::make_unique( + FLAGS_model_path, + data_files, + Module::LoadMode::File, + /*event_tracer=*/nullptr, + /*memory_allocator=*/nullptr, + /*temp_allocator=*/nullptr, + /*share_memory_arenas=*/true); + + // Get metadata + auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get()); + if (metadata_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to get metadata from model"); + return 1; + } + auto metadata = metadata_result.get(); + + printf("Loading methods...\n"); - if (runner == nullptr) { - ET_LOG(Error, "Failed to create runner"); + // Load both methods + auto err = module->load_method("prefill"); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to load prefill method"); + return 1; + } + err = module->load_method("decode"); + if (err != Error::Ok) { + ET_LOG(Error, "Failed to load decode method"); return 1; } - // Generate - llm::GenerationConfig config; - config.temperature = FLAGS_temperature; - config.max_new_tokens = FLAGS_max_new_tokens; + // Get EOS ids + auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); - auto error = runner->generate(FLAGS_prompt.c_str(), config); - if (error != executorch::runtime::Error::Ok) { - ET_LOG(Error, "Generation failed"); + // Encode prompt + auto encode_result = tokenizer->encode(FLAGS_prompt); + if (!encode_result.ok()) { + ET_LOG(Error, "Failed to encode prompt"); return 1; } + auto prompt_tokens = std::move(*encode_result); + int64_t num_prompt_tokens = prompt_tokens.size(); + printf("Prompt tokens: %ld\n", num_prompt_tokens); + + // --------------------------------------------------------------- + // Prefill — process all prompt tokens at once + // --------------------------------------------------------------- + std::vector pos_data(num_prompt_tokens); + for (int64_t i = 0; i < num_prompt_tokens; i++) { + pos_data[i] = i; + } + + auto S = [](int64_t v) -> SizesType { return static_cast(v); }; + std::vector token_data(prompt_tokens.begin(), prompt_tokens.end()); + auto tokens_tensor = from_blob( + token_data.data(), + {1, S(num_prompt_tokens)}, + executorch::aten::ScalarType::Long); + auto pos_tensor = from_blob( + pos_data.data(), + {S(num_prompt_tokens)}, + executorch::aten::ScalarType::Long); + + std::vector prefill_inputs; + prefill_inputs.push_back(tokens_tensor); + prefill_inputs.push_back(pos_tensor); + + auto prefill_start = std::chrono::steady_clock::now(); + auto prefill_result = module->execute("prefill", prefill_inputs); + if (prefill_result.error() != Error::Ok) { + ET_LOG(Error, "Prefill failed"); + return 1; + } + auto& prefill_outputs = prefill_result.get(); + auto prefill_end = std::chrono::steady_clock::now(); + + auto logits_tensor = prefill_outputs[0].toTensor(); + auto logits_ptr = + std::make_shared(std::move(logits_tensor)); + uint64_t cur_token = llm::logits_to_token(*logits_ptr, FLAGS_temperature); + + double prefill_ms = + std::chrono::duration(prefill_end - prefill_start) + .count(); + printf( + "Prefill: %ld tokens in %.1f ms (%.1f tok/s)\n", + num_prompt_tokens, + prefill_ms, + num_prompt_tokens * 1000.0 / prefill_ms); + + // --------------------------------------------------------------- + // Decode — generate tokens one at a time + // --------------------------------------------------------------- + llm::Stats stats; + int64_t pos = num_prompt_tokens; + uint64_t prev_token; + + std::vector decode_token_data = {static_cast(cur_token)}; + std::vector decode_pos_data = {pos}; + auto decode_tokens = from_blob( + decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long); + auto decode_pos = from_blob( + decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long); + + auto decode_start = std::chrono::steady_clock::now(); + + for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) { + decode_token_data[0] = static_cast(cur_token); + decode_pos_data[0] = pos; + + std::vector decode_inputs; + decode_inputs.push_back(EValue(decode_tokens)); + decode_inputs.push_back(EValue(decode_pos)); + + auto decode_result = module->execute("decode", decode_inputs); + if (decode_result.error() != Error::Ok) { + ET_LOG(Error, "Decode step %d failed", step); + return 1; + } + auto& decode_outputs = decode_result.get(); + + auto step_logits = decode_outputs[0].toTensor(); + auto step_logits_ptr = + std::make_shared(std::move(step_logits)); + + prev_token = cur_token; + stats.on_sampling_begin(); + cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature); + stats.on_sampling_end(); + + pos++; + + auto decode_str = tokenizer->decode(prev_token, cur_token); + if (decode_str.ok()) { + printf("%s", decode_str->c_str()); + fflush(stdout); + } + + if (eos_ids.find(cur_token) != eos_ids.end()) { + printf("\n"); + break; + } + } + + auto decode_end = std::chrono::steady_clock::now(); + + printf("\n"); + int64_t num_generated = pos - num_prompt_tokens; + double decode_ms = + std::chrono::duration(decode_end - decode_start) + .count(); + printf( + "Decode: %ld tokens in %.1f ms (%.1f tok/s)\n", + num_generated, + decode_ms, + num_generated * 1000.0 / decode_ms); + printf("Prompt tokens: %ld\n", num_prompt_tokens); return 0; } diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index d9f127d9ed1..751915fb123 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -350,6 +350,12 @@ def __init__(self, config): ) def forward(self, x, input_pos): + """GatedDeltaNet with trace-time dispatch. + + When traced with T=1: uses native PyTorch recurrent delta rule + (AOTI fuses with surrounding ops for maximum decode throughput). + When traced with T>1: uses chunked FLA via triton_op. + """ B, T, _ = x.size() # Reset state at position 0 @@ -406,13 +412,43 @@ def forward(self, x, input_pos): beta = b.sigmoid() g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) - # FLA Triton kernel (returns final_state separately, does not mutate initial_state) - output, state = torch.ops.triton.chunk_gated_delta_rule( - q, k, v, g, beta, self.recurrent_state[:B] - ) + if T == 1: + # Native recurrent delta rule — AOTI fuses with surrounding ops + scale = self.head_k_dim**-0.5 - with torch.no_grad(): - self.recurrent_state[:B].copy_(state) + q_s = q[:, 0].float() # [B, H, K] + k_s = k[:, 0].float() # [B, H, K] + v_s = v[:, 0].float() # [B, H, V] + g_s = g[:, 0] # [B, H] + beta_s = beta[:, 0] # [B, H] + + state = self.recurrent_state[:B].float() # [B, H, K, V] + + # Decay state by exp(g) + decay = torch.exp(g_s).unsqueeze(-1).unsqueeze(-1) # [B, H, 1, 1] + state = state * decay + + # Sk = state @ k (project state by key) + Sk = torch.einsum("bhkv,bhk->bhv", state, k_s) + + # Delta rule state update + delta = beta_s.unsqueeze(-1) * (v_s - Sk) # [B, H, V] + state = state + torch.einsum("bhk,bhv->bhkv", k_s, delta) + + # Output = state @ q * scale + output = torch.einsum("bhkv,bhk->bhv", state, q_s) * scale + output = output.unsqueeze(1).to(q.dtype) # [B, 1, H, V] + + with torch.no_grad(): + self.recurrent_state[:B].copy_(state.to(self.recurrent_state.dtype)) + else: + # Chunked FLA triton_op for prefill + output, new_state = torch.ops.triton.chunk_gated_delta_rule( + q, k, v, g, beta, self.recurrent_state[:B] + ) + + with torch.no_grad(): + self.recurrent_state[:B].copy_(new_state) # Output: RMSNorm(output) * silu(z) output = output.reshape(-1, self.head_v_dim)