Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
# Any missing weight key indicates a version mismatch between the
# checkpoint and the model (e.g., unfused vs fused projections).
runtime_prefixes = (
".mask",
".inv_freq",
".kv_cache.",
".conv_state",
".recurrent_state",
Expand Down Expand Up @@ -312,10 +314,11 @@ def _materialize_buffers(model, config):
"""Materialize meta-device buffers before torch.export.

Replaces meta buffers with real tensors on CPU, recomputes RoPE
inv_freq and causal masks.
inv_freq and causal masks. State buffers (KV cache, conv/recurrent
state) are zero-initialized registered buffers that will be shared
across methods via share_mutable_buffers.
"""
# State buffers (KV cache, conv/recurrent state) are bf16 to match
# compute dtype. Masks stay bool, inv_freq stays float32.
# Masks stay bool, inv_freq stays float32.
for fqn, buf in list(model.named_buffers()):
if buf.device.type == "meta":
dtype = torch.bfloat16 if buf.dtype != torch.bool else torch.bool
Expand Down Expand Up @@ -378,7 +381,18 @@ def _apply_turboquant(model, config):


def export_and_lower(model, config, args):
"""Export model to .pte via torch.export + CUDA backend."""
"""Export model to .pte via torch.export + CUDA backend.

Exports two methods:
- "decode": decode path (T=1), uses native PyTorch recurrent FLA
so AOTI can fuse with surrounding ops for maximum decode throughput.
- "prefill": prefill path (T>=2), uses chunked FLA triton_op with
dynamic sequence length.

Both methods share mutable state buffers (KV cache, conv_state,
recurrent_state) via share_mutable_buffers=True. The model uses
registered buffers with in-place updates — no state in/out args.
"""
import torch._inductor.config as inductor_config

from executorch.backends.cuda.cuda_backend import CudaBackend
Expand All @@ -398,25 +412,39 @@ def export_and_lower(model, config, args):
# -O0 compiles ~8x faster than -O1 with no measurable runtime impact.
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"

# Dynamic shapes
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
dynamic_shapes = ({1: seq_dim}, {0: seq_dim})

print("Exporting with torch.export...")
# --- Decode method (T=1, static shape) ---
print("Exporting decode method...")
decode_tokens = torch.tensor([[0]], dtype=torch.long)
decode_pos = torch.tensor([0], dtype=torch.long)
with torch.no_grad():
decode_ep = export(
model,
(decode_tokens, decode_pos),
strict=True,
)
print("Decode export successful!")

# --- Prefill method (T>=2, dynamic shape) ---
print("Exporting prefill method...")
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
prefill_dynamic_shapes = (
{1: seq_dim}, # tokens
{0: seq_dim}, # input_pos
)
with torch.no_grad():
exported = export(
prefill_ep = export(
model,
(example_tokens, example_input_pos),
dynamic_shapes=dynamic_shapes,
(prefill_tokens, prefill_pos),
dynamic_shapes=prefill_dynamic_shapes,
strict=True,
)
print("Export successful!")
print("Prefill export successful!")

# Lower with CUDA backend
# Lower with CUDA backend (per-method partitioners to avoid so_blob collision)
print("Lowering to ExecuTorch with CUDA...")
compile_specs = [CudaBackend.generate_method_name_compile_spec("forward")]

metadata = {
"get_max_seq_len": config.max_seq_len,
"get_vocab_size": config.vocab_size,
Expand All @@ -426,8 +454,19 @@ def export_and_lower(model, config, args):
"enable_dynamic_shape": True,
}
et_prog = to_edge_transform_and_lower(
exported,
partitioner=[CudaPartitioner(compile_specs)],
{"decode": decode_ep, "prefill": prefill_ep},
partitioner={
"decode": [
CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("decode")]
)
],
"prefill": [
CudaPartitioner(
[CudaBackend.generate_method_name_compile_spec("prefill")]
)
],
},
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True,
Expand All @@ -438,7 +477,11 @@ def export_and_lower(model, config, args):
config=ExecutorchBackendConfig(
extract_delegate_segments=True,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
memory_planning_pass=MemoryPlanningPass(
alloc_graph_input=False,
share_mutable_buffers=True,
),
emit_mutable_buffer_names=True,
),
)

Expand Down
186 changes: 168 additions & 18 deletions examples/models/qwen3_5_moe/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

#include <gflags/gflags.h>

#include <executorch/extension/llm/runner/text_llm_runner.h>
#include <executorch/extension/llm/runner/llm_runner_helper.h>
#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
#include <executorch/runtime/platform/log.h>
#include <pytorch/tokenizers/hf_tokenizer.h>

#include <chrono>
#include <string>
#include <vector>

Expand All @@ -24,6 +28,13 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");

namespace llm = ::executorch::extension::llm;
using ::executorch::extension::from_blob;
using ::executorch::extension::Module;
using ::executorch::extension::TensorPtr;
using ::executorch::runtime::Error;
using ::executorch::runtime::EValue;

using SizesType = executorch::aten::SizesType;

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
Expand All @@ -37,11 +48,6 @@ int main(int argc, char** argv) {
return 1;
}

std::vector<std::string> data_files;
if (!FLAGS_data_path.empty()) {
data_files.push_back(FLAGS_data_path);
}

// Load tokenizer
auto tokenizer = std::make_unique<tokenizers::HFTokenizer>();
auto tok_status = tokenizer->load(FLAGS_tokenizer_path);
Expand All @@ -53,25 +59,169 @@ int main(int argc, char** argv) {
return 1;
}

// Create LLM runner
auto runner = llm::create_text_llm_runner(
FLAGS_model_path, std::move(tokenizer), data_files, FLAGS_temperature);
// Create Module with share_memory_arenas=true so prefill and forward
// share mutable buffers (KV cache, conv_state, recurrent_state).
std::vector<std::string> data_files;
if (!FLAGS_data_path.empty()) {
data_files.push_back(FLAGS_data_path);
}
auto module = std::make_unique<Module>(
FLAGS_model_path,
data_files,
Module::LoadMode::File,
/*event_tracer=*/nullptr,
/*memory_allocator=*/nullptr,
/*temp_allocator=*/nullptr,
/*share_memory_arenas=*/true);

// Get metadata
auto metadata_result = llm::get_llm_metadata(tokenizer.get(), module.get());
if (metadata_result.error() != Error::Ok) {
ET_LOG(Error, "Failed to get metadata from model");
return 1;
}
auto metadata = metadata_result.get();

printf("Loading methods...\n");

if (runner == nullptr) {
ET_LOG(Error, "Failed to create runner");
// Load both methods
auto err = module->load_method("prefill");
if (err != Error::Ok) {
ET_LOG(Error, "Failed to load prefill method");
return 1;
}
err = module->load_method("decode");
if (err != Error::Ok) {
ET_LOG(Error, "Failed to load decode method");
return 1;
}

// Generate
llm::GenerationConfig config;
config.temperature = FLAGS_temperature;
config.max_new_tokens = FLAGS_max_new_tokens;
// Get EOS ids
auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get());

auto error = runner->generate(FLAGS_prompt.c_str(), config);
if (error != executorch::runtime::Error::Ok) {
ET_LOG(Error, "Generation failed");
// Encode prompt
auto encode_result = tokenizer->encode(FLAGS_prompt);
if (!encode_result.ok()) {
ET_LOG(Error, "Failed to encode prompt");
return 1;
}
auto prompt_tokens = std::move(*encode_result);
int64_t num_prompt_tokens = prompt_tokens.size();
printf("Prompt tokens: %ld\n", num_prompt_tokens);

// ---------------------------------------------------------------
// Prefill — process all prompt tokens at once
// ---------------------------------------------------------------
std::vector<int64_t> pos_data(num_prompt_tokens);
for (int64_t i = 0; i < num_prompt_tokens; i++) {
pos_data[i] = i;
}

auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };
std::vector<int64_t> token_data(prompt_tokens.begin(), prompt_tokens.end());
auto tokens_tensor = from_blob(
token_data.data(),
{1, S(num_prompt_tokens)},
executorch::aten::ScalarType::Long);
auto pos_tensor = from_blob(
pos_data.data(),
{S(num_prompt_tokens)},
executorch::aten::ScalarType::Long);

std::vector<EValue> prefill_inputs;
prefill_inputs.push_back(tokens_tensor);
prefill_inputs.push_back(pos_tensor);

auto prefill_start = std::chrono::steady_clock::now();
auto prefill_result = module->execute("prefill", prefill_inputs);
if (prefill_result.error() != Error::Ok) {
ET_LOG(Error, "Prefill failed");
return 1;
}
auto& prefill_outputs = prefill_result.get();
auto prefill_end = std::chrono::steady_clock::now();

auto logits_tensor = prefill_outputs[0].toTensor();
auto logits_ptr =
std::make_shared<executorch::aten::Tensor>(std::move(logits_tensor));
uint64_t cur_token = llm::logits_to_token(*logits_ptr, FLAGS_temperature);

double prefill_ms =
std::chrono::duration<double, std::milli>(prefill_end - prefill_start)
.count();
printf(
"Prefill: %ld tokens in %.1f ms (%.1f tok/s)\n",
num_prompt_tokens,
prefill_ms,
num_prompt_tokens * 1000.0 / prefill_ms);

// ---------------------------------------------------------------
// Decode — generate tokens one at a time
// ---------------------------------------------------------------
llm::Stats stats;
int64_t pos = num_prompt_tokens;
uint64_t prev_token;

std::vector<int64_t> decode_token_data = {static_cast<int64_t>(cur_token)};
std::vector<int64_t> decode_pos_data = {pos};
auto decode_tokens = from_blob(
decode_token_data.data(), {1, 1}, executorch::aten::ScalarType::Long);
auto decode_pos = from_blob(
decode_pos_data.data(), {1}, executorch::aten::ScalarType::Long);

auto decode_start = std::chrono::steady_clock::now();

for (int32_t step = 0; step < FLAGS_max_new_tokens; step++) {
decode_token_data[0] = static_cast<int64_t>(cur_token);
decode_pos_data[0] = pos;

std::vector<EValue> decode_inputs;
decode_inputs.push_back(EValue(decode_tokens));
decode_inputs.push_back(EValue(decode_pos));

auto decode_result = module->execute("decode", decode_inputs);
if (decode_result.error() != Error::Ok) {
ET_LOG(Error, "Decode step %d failed", step);
return 1;
}
auto& decode_outputs = decode_result.get();

auto step_logits = decode_outputs[0].toTensor();
auto step_logits_ptr =
std::make_shared<executorch::aten::Tensor>(std::move(step_logits));

prev_token = cur_token;
stats.on_sampling_begin();
cur_token = llm::logits_to_token(*step_logits_ptr, FLAGS_temperature);
stats.on_sampling_end();

pos++;

auto decode_str = tokenizer->decode(prev_token, cur_token);
if (decode_str.ok()) {
printf("%s", decode_str->c_str());
fflush(stdout);
}

if (eos_ids.find(cur_token) != eos_ids.end()) {
printf("\n");
break;
}
}

auto decode_end = std::chrono::steady_clock::now();

printf("\n");
int64_t num_generated = pos - num_prompt_tokens;
double decode_ms =
std::chrono::duration<double, std::milli>(decode_end - decode_start)
.count();
printf(
"Decode: %ld tokens in %.1f ms (%.1f tok/s)\n",
num_generated,
decode_ms,
num_generated * 1000.0 / decode_ms);
printf("Prompt tokens: %ld\n", num_prompt_tokens);

return 0;
}
Loading
Loading