From f75c7e2bc80aefc797b07e4471b2fdb6d6823ec4 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Thu, 29 Jan 2026 18:00:32 +0000 Subject: [PATCH] Add Qwen3-Next to checkpoint util (Squashed) Lower the decoding length to 128 oinstead of 512 update test script Add manual calcs for each hf shape instead of hardcoded values Run pylint Fix config values and remove unused vars Fix /configs path in script Fix decode path in script Update model ReadMe Update scripts to use new train path Reset qwen3 test files to match main Undo the temp fix to get training working Update reshape function to what other models use remove whitespaces --- .../convert_checkpoint.md | 1 + .../utils/hf_model_configs.py | 44 ++++ .../checkpoint_conversion/utils/hf_shape.py | 96 +++++++ .../utils/param_mapping.py | 238 ++++++++++++++++++ .../checkpoint_conversion/utils/utils.py | 1 + 5 files changed, 380 insertions(+) diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 97a15616bf..c896a22407 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -16,6 +16,7 @@ The following models are supported: | **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | | **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | | **DeepSeek3** | 671B | - | - | √ | - | +| **Qwen3 Next** | 80B | √ | √ | √ | √ | ## Prerequisites diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index d91b7987ca..f21eb0db32 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -701,6 +701,49 @@ }, ) +qwen3_next_80b_a3b_dict = { + "architectures": ["Qwen3NextForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5120, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "linear_value_head_dim": 128, + "max_position_embeddings": 262144, + "mlp_only_layers": [], + "model_type": "qwen3_next", + "moe_intermediate_size": 512, + "norm_topk_prob": True, + "num_attention_heads": 16, + "num_experts": 512, + "num_experts_per_tok": 10, + "num_hidden_layers": 48, + "num_key_value_heads": 2, + "output_router_logits": False, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-06, + "rope_scaling": None, + "rope_theta": 10000000, + "router_aux_loss_coef": 0.001, + "shared_expert_intermediate_size": 512, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.0.dev0", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 151936, +} +qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict) + # from https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/blob/main/config.json mixtral_8x7b_dict = { @@ -789,6 +832,7 @@ "gpt-oss-20b": gpt_oss_20b_config, "gpt-oss-120b": gpt_oss_120b_config, "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config, + "qwen3-next-80b-a3b": qwen3_next_80b_a3b_config, "mixtral-8x7b": mixtral_8x7b_config, "mixtral-8x22b": mixtral_8x22b_config, } diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index 081017dd96..cc49293f8d 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -349,6 +349,102 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): return mapping +def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config): + """Returns mapping between HuggingFace Qwen3-Next weights path and their shape.""" + # --- Extract Core Config Values --- + hidden_size = config["hidden_size"] + num_hidden_layers = config["num_hidden_layers"] + vocab_size = config["vocab_size"] + num_attention_heads = config["num_attention_heads"] + num_key_value_heads = config["num_key_value_heads"] + num_experts = config["num_experts"] + head_dim = config["head_dim"] + linear_conv_kernel_dim = config["linear_conv_kernel_dim"] + linear_key_head_dim = config["linear_key_head_dim"] + linear_num_key_heads = config["linear_num_key_heads"] + linear_num_value_heads = config["linear_num_value_heads"] + moe_intermediate_size = config["moe_intermediate_size"] + shared_expert_intermediate_size = config["shared_expert_intermediate_size"] + cycle_interval = config["full_attention_interval"] + + # --- Calculated Values --- + q_dim = num_attention_heads * head_dim + kv_dim = num_key_value_heads * head_dim + + linear_k_dim = linear_num_key_heads * linear_key_head_dim + linear_v_dim = linear_num_value_heads * head_dim + conv_dim = 2 * linear_k_dim + linear_v_dim + qkvz_dim = 2 * linear_k_dim + 2 * linear_v_dim + ba_dim = 2 * linear_num_value_heads + + # --- Initialize Mapping --- + mapping = { + "model.embed_tokens.weight": [vocab_size, hidden_size], + "model.norm.weight": [hidden_size], + "lm_head.weight": [vocab_size, hidden_size], + } + + for layer_idx in range(num_hidden_layers): + layer_prefix = f"model.layers.{layer_idx}" + + # Standard Layer Norms + mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size] + mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size] + + is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0 + + if is_full_attention_layer: + # Full Attention Block + mapping.update( + { + f"{layer_prefix}.self_attn.q_proj.weight": [2 * q_dim, hidden_size], + f"{layer_prefix}.self_attn.k_proj.weight": [kv_dim, hidden_size], + f"{layer_prefix}.self_attn.v_proj.weight": [kv_dim, hidden_size], + f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, q_dim], + f"{layer_prefix}.self_attn.q_norm.weight": [head_dim], + f"{layer_prefix}.self_attn.k_norm.weight": [head_dim], + } + ) + else: + # Linear Attention (GDN) Block + mapping.update( + { + f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [qkvz_dim, hidden_size], + f"{layer_prefix}.linear_attn.in_proj_ba.weight": [ba_dim, hidden_size], + f"{layer_prefix}.linear_attn.conv1d.weight": [conv_dim, 1, linear_conv_kernel_dim], + f"{layer_prefix}.linear_attn.A_log": [linear_num_value_heads], + f"{layer_prefix}.linear_attn.dt_bias": [linear_num_value_heads], + f"{layer_prefix}.linear_attn.norm.weight": [head_dim], + f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, linear_v_dim], + } + ) + + # --- MLP Logic (MoE + Shared) --- + mapping.update( + { + # Router + f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size], + # Shared Experts (SwiGLU - Separate Weights) + f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size], + # Shared Expert Gate (learned scaling factor) + f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size], + } + ) + + # Routed Experts Loop + # Note: HF typically stores experts as a ModuleList + for e in range(num_experts): + mapping.update( + { + f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size], + } + ) + + def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace GptOss weights path and their shape.""" # --- Extract Core Config Values --- diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index d4f7317969..0bb0adebd7 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -792,6 +792,242 @@ def reshape_kernel(input_tensor, target_shape): return mapping +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """ + Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths. + All MaxText keys start with 'params-' and use '-' separators for scanned layers. + """ + num_main_layers = config["num_hidden_layers"] + num_experts = config["num_experts"] + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + + # 1. Non-layer specific weight mappings + mapping = { + "params-token_embedder-embedding": "model.embed_tokens.weight", + "params-decoder-decoder_norm-scale": "model.norm.weight", + "params-decoder-logits_dense-kernel": "lm_head.weight", + } + + if scan_layers: + # 2. Scan over block cycles + for block_idx in range(layer_cycle_interval): + hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval)) + prefix = f"params-decoder-layers-layer_{block_idx}" + + # Layer norms + mapping[f"{prefix}-input_layernorm-scale"] = [f"model.layers.{i}.input_layernorm.weight" for i in hf_indices] + mapping[f"{prefix}-post_attention_layernorm-scale"] = [ + f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices + ] + + # Handle Interleaved Attention (Linear vs Full) + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update( + { + f"{prefix}-attention-attention-query-kernel": [ + f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-key-kernel": [ + f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-value-kernel": [ + f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-out-kernel": [ + f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-query_norm-scale": [ + f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices + ], + f"{prefix}-attention-attention-key_norm-scale": [ + f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices + ], + } + ) + else: + # Linear/Hybrid Attention Block + mapping.update( + { + f"{prefix}-attention-in_proj_qkvz-kernel": [ + f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices + ], + f"{prefix}-attention-in_proj_ba-kernel": [ + f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices + ], + f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices], + f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices], + f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices], + f"{prefix}-attention-norm-rms_norm-scale": [ + f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices + ], + f"{prefix}-attention-out_proj-kernel": [ + f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices + ], + } + ) + + # 3. Handle MLP: Gates and Shared Experts + mapping.update( + { + f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wi_0-kernel": [ + f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices + ], + f"{prefix}-mlp-shared_expert-wi_1-kernel": [ + f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices + ], + f"{prefix}-mlp-shared_expert-wo-kernel": [ + f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices + ], + f"{prefix}-mlp-shared_expert_gate-kernel": [ + f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices + ], + } + ) + + # 4. Handle MoE Routed Experts + mapping.update( + { + f"{prefix}-mlp-routed_experts-wi_0": [ + [f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wi_1": [ + [f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wo": [ + [f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts) + ], + } + ) + else: + # Unscanned layer mapping + for i in range(num_main_layers): + prefix = f"params-decoder-layers_{i}" + + # Layer Norms + mapping[f"{prefix}-input_layernorm-scale"] = f"model.layers.{i}.input_layernorm.weight" + mapping[f"{prefix}-post_attention_layernorm-scale"] = f"model.layers.{i}.post_attention_layernorm.weight" + + # Determine layer type based on cycle interval + # Assuming block logic: layer i corresponds to block_idx = i % interval + block_idx = i % layer_cycle_interval + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update( + { + f"{prefix}-attention-attention-query-kernel": f"model.layers.{i}.self_attn.q_proj.weight", + f"{prefix}-attention-attention-key-kernel": f"model.layers.{i}.self_attn.k_proj.weight", + f"{prefix}-attention-attention-value-kernel": f"model.layers.{i}.self_attn.v_proj.weight", + f"{prefix}-attention-attention-out-kernel": f"model.layers.{i}.self_attn.o_proj.weight", + f"{prefix}-attention-attention-query_norm-scale": f"model.layers.{i}.self_attn.q_norm.weight", + f"{prefix}-attention-attention-key_norm-scale": f"model.layers.{i}.self_attn.k_norm.weight", + } + ) + else: + # Linear/Hybrid Attention Block + mapping.update( + { + f"{prefix}-attention-in_proj_qkvz-kernel": f"model.layers.{i}.linear_attn.in_proj_qkvz.weight", + f"{prefix}-attention-in_proj_ba-kernel": f"model.layers.{i}.linear_attn.in_proj_ba.weight", + f"{prefix}-attention-conv1d-kernel": f"model.layers.{i}.linear_attn.conv1d.weight", + f"{prefix}-attention-A_log": f"model.layers.{i}.linear_attn.A_log", + f"{prefix}-attention-dt_bias": f"model.layers.{i}.linear_attn.dt_bias", + f"{prefix}-attention-norm-rms_norm-scale": f"model.layers.{i}.linear_attn.norm.weight", + f"{prefix}-attention-out_proj-kernel": f"model.layers.{i}.linear_attn.out_proj.weight", + } + ) + + # MLP: Gates and Shared Experts + mapping.update( + { + f"{prefix}-mlp-routed_experts-gate-kernel": f"model.layers.{i}.mlp.gate.weight", + f"{prefix}-mlp-shared_expert-wi_0-kernel": f"model.layers.{i}.mlp.shared_expert.gate_proj.weight", + f"{prefix}-mlp-shared_expert-wi_1-kernel": f"model.layers.{i}.mlp.shared_expert.up_proj.weight", + f"{prefix}-mlp-shared_expert-wo-kernel": f"model.layers.{i}.mlp.shared_expert.down_proj.weight", + f"{prefix}-mlp-shared_expert_gate-kernel": f"model.layers.{i}.mlp.shared_expert_gate.weight", + } + ) + + # MoE Routed Experts (List of expert weights for this specific layer) + mapping.update( + { + f"{prefix}-mlp-routed_experts-wi_0": [ + f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wi_1": [ + f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for e in range(num_experts) + ], + f"{prefix}-mlp-routed_experts-wo": [ + f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for e in range(num_experts) + ], + } + ) + return mapping + + +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """ + Transformation hooks for parameters using hyphenated 'params-' MaxText keys. + """ + + def transpose(input_tensor, target_shape=None): + return input_tensor.T + + def reshape_kernel(input_tensor, target_shape): + if saving_to_hf: + flipped_target_shape = np.flip(np.array(target_shape)) + return input_tensor.reshape(flipped_target_shape).T + else: + return input_tensor.T.reshape(target_shape) + + def permute_conv(input_tensor, target_shape=None): + # MT: [K, 1, C] <-> HF: [C, 1, K] + return input_tensor.transpose(2, 1, 0) + + # Initialize Hooks + hooks = { + "params-decoder-logits_dense-kernel": transpose, + } + + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + num_main_layers = config["num_hidden_layers"] + loop_indices = range(layer_cycle_interval) if scan_layers else range(num_main_layers) + + for i in loop_indices: + if scan_layers: + prefix = f"params-decoder-layers-layer_{i}" + block_idx = i + else: + prefix = f"params-decoder-layers_{i}" + block_idx = i % layer_cycle_interval + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + for key in ["query", "key", "value", "out"]: + hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_kernel + else: + hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose + hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose + hooks[f"{prefix}-attention-out_proj-kernel"] = transpose + hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv + + mlp_prefix = f"{prefix}-mlp" + hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose + + hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose + + return hooks + + def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates a parameter mapping from MaxText to HuggingFace Deepseek weight paths. @@ -2098,6 +2334,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_MAPPING, "olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2129,6 +2366,7 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x7b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "mixtral-8x22b": MIXTRAL_MAXTEXT_TO_HF_PARAM_HOOK_FN, "olmo3-7b": OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 5a0ecfe940..31607883f8 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -82,6 +82,7 @@ "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", "mixtral-8x22b": "mistralai/Mixtral-8x22B-Instruct-v0.1", "olmo3-7b": "allenai/Olmo-3-7B-Instruct",