From ca7ab69313beee71968f0078a47adb855713b400 Mon Sep 17 00:00:00 2001 From: Rohan Bierneni Date: Thu, 19 Feb 2026 19:35:30 +0000 Subject: [PATCH] Update readme with model support Update test scripts & model ReadMe Update flags for script --- README.md | 1 + .../1_test_qwen3_next_80b_a3b.sh | 77 ++++++------ .../2_test_qwen3_next_80b_a3b.sh | 65 ++++++++++ .../tpu/qwen/next/run_qwen3_next.md | 119 ++++++++++++++---- 4 files changed, 203 insertions(+), 59 deletions(-) create mode 100644 tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh diff --git a/README.md b/README.md index 2b751070f0..026771621b 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies ## 🔥 Latest news 🔥 +* \[February 19, 2026\] [Qwen3-Next](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md) is now supported. * \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported. * \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model. * \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/examples) are available. diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh index 77476eaf6b..a8753241fa 100644 --- a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh @@ -1,8 +1,11 @@ #!/bin/bash -# This script validates a pre-converted MaxText checkpoint against its original -# HuggingFace counterpart to ensure numerical correctness. +# This file is documentation for how to get started with Qwen3 Next. +# This file runs Step 1 on CPU. +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pre-training, fine-tuning, and decoding. # --- # Example Usage: # @@ -17,43 +20,41 @@ set -ex -# --- Configuration & Input Validation --- +export MODEL_NAME='qwen3-next-80b-a3b' +export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct' -if [ -z "${MAXTEXT_CHECKPOINT_PATH}" ]; then - echo "ERROR: The MAXTEXT_CHECKPOINT_PATH environment variable is not set." - echo "Please set it to the full GCS path of the pre-converted MaxText checkpoint weights." - exit 1 -fi +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# Set a default for the HF model path if it's not provided by the user -if [ -z "${HF_MODEL_PATH}" ]; then - export HF_MODEL_PATH="Qwen/Qwen3-Next-80B-A3B-Instruct" - echo "HF_MODEL_PATH is not set, using default: ${HF_MODEL_PATH}" +# Ensure HF_TOKEN is set +if [ -z "${HF_TOKEN}" ]; then + echo "Error: HF_TOKEN environment variable is not set. Please export your Hugging Face token." + echo "Example: export HF_TOKEN=hf_..." + exit 1 fi -# Install dependencies required for the logit checker. -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - -# --- Run the Forward Pass Logit Checker --- - -echo "Validating MaxText checkpoint at ${MAXTEXT_CHECKPOINT_PATH}" -echo "Against original HF model: ${HF_MODEL_PATH}" - -# This command runs the core validation logic. -JAX_PLATFORMS=cpu python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ - tokenizer_type=huggingface \ - tokenizer_path="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/assets/tokenizers}}"/qwen3-tokenizer \ - megablox=False \ - sparse_matmul=False \ - load_parameters_path=${MAXTEXT_CHECKPOINT_PATH} \ - model_name=qwen3-next-80b-a3b \ - checkpoint_storage_concurrent_gb=1024 \ - skip_jax_distributed_system=True \ - dtype=float32 \ - weight_dtype=float32 \ - matmul_precision=highest \ - --hf_model_path=${HF_MODEL_PATH} \ - --max_kl_div=0.03 \ - --run_hf_model=True - -echo "Validation complete." +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# 1.1 Convert checkpoint to `scanned` format, more suitable for training +JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \ + model_name=qwen3-next-80b-a3b \ + base_output_directory=${BASE_OUTPUT_PATH}/scanned \ + hf_access_token=${HF_TOKEN} \ + scan_layers=true \ + use_multimodal=false + +# 1.2 Convert checkpoint to `unscanned` format, more suitable for decoding +JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \ + model_name=qwen3-next-80b-a3b \ + base_output_directory=${BASE_OUTPUT_PATH}/unscanned \ + hf_access_token=${HF_TOKEN} \ + scan_layers=false \ + use_multimodal=false + \ No newline at end of file diff --git a/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh new file mode 100644 index 0000000000..8ef6a3969a --- /dev/null +++ b/tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/2_test_qwen3_next_80b_a3b.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +# This file is documentation for how to get started with Qwen3 Next. + +# This file runs Step 2 on v5p-128 on a daily basis. +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pretraining, finetuning, and decoding. + +# The golden logit can be generated by: +# python3 -m tests.assets.logits_generation.generate_hf_golden_logits --model-id=Qwen/Qwen3-Next-80B-A3B-Instruct --output-path=golden_data_qwen3-next-80b-a3b.jsonl --prompts='I love to' --hf-model-path=$local_bf16_path --trust-remote-code=False --hf-load-dtype=bfloat16 + +set -ex + +export PYTHONPATH=$PYTHONPATH:$(pwd)/src + +export MODEL_NAME='qwen3-next-80b-a3b' +export TOKENIZER_PATH='Qwen/Qwen3-Next-80B-A3B-Instruct' + +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# e.g., $HOME/maxtext/src/MaxText +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# Step 2: +# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands +# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items +# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items +# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test. +SCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/scanned/0/items +UNSCANNED_CKPT_PATH=gs://maxtext-model-checkpoints/qwen3-next-80b-a3b/unscanned/0/items +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset + +# Test whether the forward pass logits match the golden logits +# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl +GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl" +if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then + GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl" + GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl + gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} +fi + +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=True ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=1.5 --rtol=1.5 --max_kl_div=0.1 + +# Run pre-training - tokamax_gmm implementation +python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024 + +# Run fine-tuning - tokamax_gmm implementation +python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_fine_tuning model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_path=${DATASET_PATH} enable_checkpointing=true async_checkpointing=false load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=True attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 steps=5 max_target_length=1024 checkpoint_storage_concurrent_gb=1024 + + +# Run decoding - tokamax_gmm implementation +# Note decode requires the access token for huggingface tokenizer even if the model is not gated +python3 -m maxtext.decode ${MAXTEXT_PKG_DIR}/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=64 max_target_length=512 ici_fsdp_parallelism=1 ici_tensor_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " \ No newline at end of file diff --git a/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md b/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md index 25898cc3b1..86bb95bd97 100644 --- a/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md +++ b/tests/end_to_end/tpu/qwen/next/run_qwen3_next.md @@ -7,6 +7,31 @@ For more details on the architecture, see the [Qwen3 Technical Blog](https://qwe * * * * * +Pre-Training +--------------------- +You can train from scratch to generate a new checkpoint. One example command to run pretraining with Qwen3-Next on v5p-64. + +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + run_name=q3_next_pre_training \ + per_device_batch_size=1 \ + enable_checkpointing=false \ + model_name=qwen3-next-80b-a3b \ + ici_fsdp_parallelism=-1 \ + steps=5 \ + max_target_length=1024 \ + async_checkpointing=false \ + tokenizer_type=huggingface \ + tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=False \ + sparse_matmul=False \ + dataset_type=synthetic +``` + Checkpoint Conversion --------------------- @@ -22,18 +47,20 @@ To get started, you first need a MaxText-compatible checkpoint. 2. **Convert the Checkpoint**: Run the `convert_qwen3_next_scanned.py` script to convert the downloaded Hugging Face weights into the Orbax format required by MaxText. ``` - python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_qwen3_next_scanned \ - --base_model_path /path/to/qwen3_next_hf_checkpoint \ - --maxtext_model_path gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \ - --model_size qwen3-next-80b-a3b + JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \ + model_name=qwen3-next-80b-a3b \ + base_output_directory=gs://your-gcs-bucket/qwen3_next_maxtext_ckpt \ + hf_access_token=${HF_TOKEN} \ + scan_layers=true \ # Set to false for unscanned checkpoint + use_multimodal=false ``` * * * * * -Pre-training and Fine-tuning +Fine-tuning ---------------------------- -After converting the checkpoint, you can use it for fine-tuning or start a pre-training run from scratch. The command below is an example for fine-tuning on a v5p-512 slice. To pre-train, simply remove the `load_parameters_path` argument. +After converting the checkpoint, you can use it for fine-tuning. The command below is an example for fine-tuning on a v5p-64 slice. ``` python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ @@ -43,12 +70,38 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ run_name=qwen3_next_finetuning \ per_device_batch_size=1 \ model_name=qwen3-next-80b-a3b \ - steps=500 \ - max_target_length=8192 \ - ici_fsdp_parallelism=256 \ + steps=30 \ + max_target_length=4096 \ + ici_fsdp_parallelism=-1 \ tokenizer_type=huggingface \ tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer +``` + +## Decoding +One example command to run decoding with Qwen3-Next on v5p-64 with unscanned checkpoint for fast decoding. +```sh +python3 -m maxtext.decode src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY} \ + load_parameters_path=${CONVERTED_CHECKPOINT} \ + run_name=q3-next-decode \ + per_device_batch_size=1 \ + enable_checkpointing=false \ + model_name=qwen3-next-80b-a3b \ + max_prefill_predict_length=64 \ + max_target_length=1024 \ + tokenizer_type=huggingface \ + tokenizer_path=src/maxtext/assets/tokenizers/qwen3-tokenizer \ + attention=dot_product \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=False \ + sparse_matmul=False \ + ici_tensor_parallelism=1 \ + ici_fsdp_parallelism=1 \ + ici_expert_parallelism=-1 \ + prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is " \ + scan_layers=False ``` * * * * * @@ -56,27 +109,51 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ Correctness Validation ---------------------- -To verify that the MaxText implementation is numerically equivalent to the original Hugging Face model, you can run the end-to-end test scripts. These scripts automate the logit comparison test for each model. +we perform two primary checks: -Before running, you must set the `MAXTEXT_CHECKPOINT_PATH` environment variable. You can also optionally set `HF_MODEL_PATH` to point to a local copy of the Hugging Face model. +* **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts. +* **MMLU Score Validation**: We validate the MMLU score against established benchmarks. -### Qwen3-Next-80B-A3B - -Bash +One example command to generate golden logits from HuggingFace for Qwen3-Next: +```sh +python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ + --model-id=Qwen/Qwen3-Next-80B-A3B-Instruct \ + --output-path=golden_Qwen3_Next.jsonl \ + --prompts='I love to;Today is a;What is the' ``` -# Set the required path to your converted MaxText checkpoint -export MAXTEXT_CHECKPOINT_PATH=gs://your-gcs-bucket/qwen3-next-80b-a3b_maxtext_ckpt/0/items/ -# (Optional) Set the path to your local Hugging Face checkpoint -# export HF_MODEL_PATH=/path/to/local/qwen3-next-80b-a3b_hf_checkpoint +You should be able to see logs like below: + +``` +... +File is stored locally at golden_Qwen3_Next.jsonl. +``` -# Execute the validation script -bash tests/end_to_end/tpu/qwen/next/qwen3-next-80b-a3b/1_test_qwen3_next_80b_a3b.sh +Run command below to compare logits between HuggingFace and MaxText. +```sh +python3 -m tests.utils.forward_pass_logit_checker \ + src/maxtext/configs/base.yml \ + tokenizer_type=huggingface \ + tokenizer_path=Qwen/Qwen3-Next-80B-A3B-Instruct \ + load_parameters_path=${CONVERTED_CHECKPOINT} \ + run_name=forward_pass_test_qwen3_next \ + per_device_batch_size=1 \ + model_name=qwen3-next-80b-a3b \ + max_prefill_predict_length=4 \ + max_target_length=4 \ + scan_layers=false \ + sparse_matmul=False \ + dtype=float32 \ + activations_in_float32=true \ + matmul_precision=high \ + --max_kl_div=2e-4 \ + --golden_logits_path=${PWD}/golden_Qwen3_Next.jsonl ``` +To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here](../../../benchmarks/api_server/README.md). + ## Supported MoE Strategies This model implementation supports both **Token Dropping** and **Dropless** strategies for Mixture of Experts routing. Take a look at the MaxText [documentation](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/moe_configuration.md) on MoE configs and flags to set based on desired strategy. -