Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The following models are supported:
| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ |
| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ |
| **DeepSeek3** | 671B | - | - | √ | - |
| **Qwen3 Next** | 80B | √ | √ | √ | √ |

## Prerequisites

Expand Down
44 changes: 44 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,49 @@
},
)

qwen3_next_80b_a3b_dict = {
"architectures": ["Qwen3NextForCausalLM"],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5120,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 32,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_next",
"moe_intermediate_size": 512,
"norm_topk_prob": True,
"num_attention_heads": 16,
"num_experts": 512,
"num_experts_per_tok": 10,
"num_hidden_layers": 48,
"num_key_value_heads": 2,
"output_router_logits": False,
"partial_rotary_factor": 0.25,
"rms_norm_eps": 1e-06,
"rope_scaling": None,
"rope_theta": 10000000,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 512,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.57.0.dev0",
"use_cache": True,
"use_sliding_window": False,
"vocab_size": 151936,
}
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)


# from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json
mixtral_8x7b_dict = {
Expand Down Expand Up @@ -789,6 +832,7 @@
"gpt-oss-20b": gpt_oss_20b_config,
"gpt-oss-120b": gpt_oss_120b_config,
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
"mixtral-8x7b": mixtral_8x7b_config,
"mixtral-8x22b": mixtral_8x22b_config,
}
96 changes: 96 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,102 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
return mapping


def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen3-Next weights path and their shape."""
# --- Extract Core Config Values ---
hidden_size = config["hidden_size"]
num_hidden_layers = config["num_hidden_layers"]
vocab_size = config["vocab_size"]
num_attention_heads = config["num_attention_heads"]
num_key_value_heads = config["num_key_value_heads"]
num_experts = config["num_experts"]
head_dim = config["head_dim"]
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
linear_key_head_dim = config["linear_key_head_dim"]
linear_num_key_heads = config["linear_num_key_heads"]
linear_num_value_heads = config["linear_num_value_heads"]
moe_intermediate_size = config["moe_intermediate_size"]
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
cycle_interval = config["full_attention_interval"]

# --- Calculated Values ---
q_dim = num_attention_heads * head_dim
kv_dim = num_key_value_heads * head_dim

linear_k_dim = linear_num_key_heads * linear_key_head_dim
linear_v_dim = linear_num_value_heads * head_dim
conv_dim = 2 * linear_k_dim + linear_v_dim
qkvz_dim = 2 * linear_k_dim + 2 * linear_v_dim
ba_dim = 2 * linear_num_value_heads

# --- Initialize Mapping ---
mapping = {
"model.embed_tokens.weight": [vocab_size, hidden_size],
"model.norm.weight": [hidden_size],
"lm_head.weight": [vocab_size, hidden_size],
}

for layer_idx in range(num_hidden_layers):
layer_prefix = f"model.layers.{layer_idx}"

# Standard Layer Norms
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]

is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0

if is_full_attention_layer:
# Full Attention Block
mapping.update(
{
f"{layer_prefix}.self_attn.q_proj.weight": [2 * q_dim, hidden_size],
f"{layer_prefix}.self_attn.k_proj.weight": [kv_dim, hidden_size],
f"{layer_prefix}.self_attn.v_proj.weight": [kv_dim, hidden_size],
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, q_dim],
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
}
)
else:
# Linear Attention (GDN) Block
mapping.update(
{
f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [qkvz_dim, hidden_size],
f"{layer_prefix}.linear_attn.in_proj_ba.weight": [ba_dim, hidden_size],
f"{layer_prefix}.linear_attn.conv1d.weight": [conv_dim, 1, linear_conv_kernel_dim],
f"{layer_prefix}.linear_attn.A_log": [linear_num_value_heads],
f"{layer_prefix}.linear_attn.dt_bias": [linear_num_value_heads],
f"{layer_prefix}.linear_attn.norm.weight": [head_dim],
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, linear_v_dim],
}
)

# --- MLP Logic (MoE + Shared) ---
mapping.update(
{
# Router
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],
# Shared Experts (SwiGLU - Separate Weights)
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],
# Shared Expert Gate (learned scaling factor)
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
}
)

# Routed Experts Loop
# Note: HF typically stores experts as a ModuleList
for e in range(num_experts):
mapping.update(
{
f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size],
}
)


def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
# --- Extract Core Config Values ---
Expand Down
Loading
Loading