diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 97a15616bf..c896a22407 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -16,6 +16,7 @@ The following models are supported: | **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | | **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | | **DeepSeek3** | 671B | - | - | √ | - | +| **Qwen3 Next** | 80B | √ | √ | √ | √ | ## Prerequisites diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index d91b7987ca..f21eb0db32 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -701,6 +701,49 @@ }, ) +qwen3_next_80b_a3b_dict = { + "architectures": ["Qwen3NextForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5120, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "linear_value_head_dim": 128, + "max_position_embeddings": 262144, + "mlp_only_layers": [], + "model_type": "qwen3_next", + "moe_intermediate_size": 512, + "norm_topk_prob": True, + "num_attention_heads": 16, + "num_experts": 512, + "num_experts_per_tok": 10, + "num_hidden_layers": 48, + "num_key_value_heads": 2, + "output_router_logits": False, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 10000000, + "router_aux_loss_coef": 0.001, + "shared_expert_intermediate_size": 512, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.0.dev0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936, +} +qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict) + # from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json mixtral_8x7b_dict = { @@ -789,6 +832,7 @@ "gpt-oss-20b": gpt_oss_20b_config, "gpt-oss-120b": gpt_oss_120b_config, "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config, + "qwen3-next-80b-a3b": qwen3_next_80b_a3b_config, "mixtral-8x7b": mixtral_8x7b_config, "mixtral-8x22b": mixtral_8x22b_config, } diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index 081017dd96..cc49293f8d 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -349,6 +349,102 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): return mapping +def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config): + """Returns mapping between HuggingFace Qwen3-Next weights path and their shape.""" + # --- Extract Core Config Values --- + hidden_size = config["hidden_size"] + num_hidden_layers = config["num_hidden_layers"] + vocab_size = config["vocab_size"] + num_attention_heads = config["num_attention_heads"] + num_key_value_heads = config["num_key_value_heads"] + num_experts = config["num_experts"] + head_dim = config["head_dim"] + linear_conv_kernel_dim = config["linear_conv_kernel_dim"] + linear_key_head_dim = config["linear_key_head_dim"] + linear_num_key_heads = config["linear_num_key_heads"] + linear_num_value_heads = config["linear_num_value_heads"] + moe_intermediate_size = config["moe_intermediate_size"] + shared_expert_intermediate_size = config["shared_expert_intermediate_size"] + cycle_interval = config["full_attention_interval"] + + # --- Calculated Values --- + q_dim = num_attention_heads * head_dim + kv_dim = num_key_value_heads * head_dim + + linear_k_dim = linear_num_key_heads * linear_key_head_dim + linear_v_dim = linear_num_value_heads * head_dim + conv_dim = 2 * linear_k_dim + linear_v_dim + qkvz_dim = 2 * linear_k_dim + 2 * linear_v_dim + ba_dim = 2 * linear_num_value_heads + + # --- Initialize Mapping --- + mapping = { + "model.embed_tokens.weight": [vocab_size, hidden_size], + "model.norm.weight": [hidden_size], + "lm_head.weight": [vocab_size, hidden_size], + } + + for layer_idx in range(num_hidden_layers): + layer_prefix = f"model.layers.{layer_idx}" + + # Standard Layer Norms + mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size] + mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size] + + is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0 + + if is_full_attention_layer: + # Full Attention Block + mapping.update( + { + f"{layer_prefix}.self_attn.q_proj.weight": [2 * q_dim, hidden_size], + f"{layer_prefix}.self_attn.k_proj.weight": [kv_dim, hidden_size], + f"{layer_prefix}.self_attn.v_proj.weight": [kv_dim, hidden_size], + f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, q_dim], + f"{layer_prefix}.self_attn.q_norm.weight": [head_dim], + f"{layer_prefix}.self_attn.k_norm.weight": [head_dim], + } + ) + else: + # Linear Attention (GDN) Block + mapping.update( + { + f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [qkvz_dim, hidden_size], + f"{layer_prefix}.linear_attn.in_proj_ba.weight": [ba_dim, hidden_size], + f"{layer_prefix}.linear_attn.conv1d.weight": [conv_dim, 1, linear_conv_kernel_dim], + f"{layer_prefix}.linear_attn.A_log": [linear_num_value_heads], + f"{layer_prefix}.linear_attn.dt_bias": [linear_num_value_heads], + f"{layer_prefix}.linear_attn.norm.weight": [head_dim], + f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, linear_v_dim], + } + ) + + # --- MLP Logic (MoE + Shared) --- + mapping.update( + { + # Router + f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size], + # Shared Experts (SwiGLU - Separate Weights) + f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size], + # Shared Expert Gate (learned scaling factor) + f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size], + } + ) + + # Routed Experts Loop + # Note: HF typically stores experts as a ModuleList + for e in range(num_experts): + mapping.update( + { + f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size], + } + ) + + def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace GptOss weights path and their shape.""" # --- Extract Core Config Values --- diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index d4f7317969..0bb0adebd7 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -792,6 +792,242 @@ def reshape_kernel(input_tensor, target_shape): return mapping +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """ + Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths. + All MaxText keys start with 'params-' and use '-' separators for scanned layers. + """ + num_main_layers = config["num_hidden_layers"] + num_experts = config["num_experts"] + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + + # 1. Non-layer specific weight mappings + mapping = { + "params-token_embedder-embedding": "model.embed_tokens.weight", + "params-decoder-decoder_norm-scale": "model.norm.weight", + "params-decoder-logits_dense-kernel": "lm_head.weight", + } + + if scan_layers: + # 2. Scan over block cycles + for block_idx in range(layer_cycle_interval): + hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval)) + prefix = f"params-decoder-layers-layer_{block_idx}" + + # Layer norms + mapping[f"{prefix}-input_layernorm-scale"] = [f"model.layers.{i}.input_layernorm.weight" for i in hf_indices] + mapping[f"{prefix}-post_attention_layernorm-scale"] = [ + f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices + ] + + # Handle Interleaved Attention (Linear vs Full) + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update( + { + f"{prefix}-attention-attention-query-kernel": [ + f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-key-kernel": [ + f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-value-kernel": [ + f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-out-kernel": [ + f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-query_norm-scale": [ + f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-key_norm-scale": [ + f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices + ], + } + ) + else: + # Linear/Hybrid Attention Block + mapping.update( + { + f"{prefix}-attention-in_proj_qkvz-kernel": [ + f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices + ], + f"{prefix}-attention-in_proj_ba-kernel": [ + f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices + ], + f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices], + f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices], + f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices], + f"{prefix}-attention-norm-rms_norm-scale": [ + f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices + ], + f"{prefix}-attention-out_proj-kernel": [ + f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices + ], + } + ) + + # 3. Handle MLP: Gates and Shared Experts + mapping.update( + { + f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wi_0-kernel": [ + f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices + ], + f"{prefix}-mlp-shared_expert-wi_1-kernel": [ + f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices + ], + f"{prefix}-mlp-shared_expert-wo-kernel": [ + f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices + ], + f"{prefix}-mlp-shared_expert_gate-kernel": [ + f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices + ], + } + ) + + # 4. Handle MoE Routed Experts + mapping.update( + { + f"{prefix}-mlp-routed_experts-wi_0": [ + [f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wi_1": [ + [f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wo": [ + [f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts) + ], + } + ) + else: + # Unscanned layer mapping + for i in range(num_main_layers): + prefix = f"params-decoder-layers_{i}" + + # Layer Norms + mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight" + mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight" + + # Determine layer type based on cycle interval + # Assuming block logic: layer i corresponds to block_idx = i % interval + block_idx = i % layer_cycle_interval + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update( + { + f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", + f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", + f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", + f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", + f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", + } + ) + else: + # Linear/Hybrid Attention Block + mapping.update( + { + f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight", + f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight", + f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight", + f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log", + f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias", + f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight", + f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight", + } + ) + + # MLP: Gates and Shared Experts + mapping.update( + { + f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight", + f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight", + f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight", + f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight", + f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight", + } + ) + + # MoE Routed Experts (List of expert weights for this specific layer) + mapping.update( + { + f"{prefix}-mlp-routed_experts-wi_0": [ + f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wi_1": [ + f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wo": [ + f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts) + ], + } + ) + return mapping + + +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """ + Transformation hooks for parameters using hyphenated 'params-' MaxText keys. + """ + + def transpose(input_tensor, target_shape=None): + return input_tensor.T + + def reshape_kernel(input_tensor, target_shape): + if saving_to_hf: + flipped_target_shape = np.flip(np.array(target_shape)) + return input_tensor.reshape(flipped_target_shape).T + else: + return input_tensor.T.reshape(target_shape) + + def permute_conv(input_tensor, target_shape=None): + # MT: [K, 1, C] <-> HF: [C, 1, K] + return input_tensor.transpose(2, 1, 0) + + # Initialize Hooks + hooks = { + "params-decoder-logits_dense-kernel": transpose, + } + + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + num_main_layers = config["num_hidden_layers"] + loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers) + + for i in loop_indices: + if scan_layers: + prefix = f"params-decoder-layers-layer_{i}" + block_idx = i + else: + prefix = f"params-decoder-layers_{i}" + block_idx = i % layer_cycle_interval + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + for key in ["query", "key", "value", "out"]: + hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_kernel + else: + hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose + hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose + hooks[f"{prefix}-attention-out_proj-kernel"] = transpose + hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv + + mlp_prefix = f"{prefix}-mlp" + hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose + + hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose + + return hooks + + def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths. @@ -2098,6 +2334,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2129,6 +2366,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 5a0ecfe940..31607883f8 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -82,6 +82,7 @@ "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", "olmo3-7b": "allenai/Olmo-3-7B-Instruct",