diff --git a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 4f589370cc..29268dd6e7 100644 --- a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -20,9 +20,6 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V2-Lite' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" - if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. # this bucket will store all the files generated by MaxText during a run diff --git a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh index 8cc2bb87c4..705b8f4f53 100644 --- a/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh @@ -18,9 +18,6 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3' # Installing torch for checkpoint conversion and forward_pass_logit_checker.py python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu -# e.g., $HOME/maxtext/src/MaxText -export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}" - if [ -z "${BASE_OUTPUT_PATH}" ]; then # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. # this bucket will store all the files generated by MaxText during a run diff --git a/tests/end_to_end/tpu/llama4/1_test_llama4.sh b/tests/end_to_end/tpu/llama4/1_test_llama4.sh deleted file mode 100644 index f26d13bc71..0000000000 --- a/tests/end_to_end/tpu/llama4/1_test_llama4.sh +++ /dev/null @@ -1,48 +0,0 @@ -#!/bin/bash - -# This file, combined with step 2 in the same directory, runs on daily basis and demonstrates: -# 1. Converts the Llama4-Maverick HuggingFace checkpoint to MaxText (Orbax) format using a CPU VM. -# 2. Takes the MaxText (unscanned Orbax) checkpoint to run inference on a TPU VM. - -# The flow of this file is to convert the Llama4 (Scout/Maverick) HuggingFace checkpoint to MaxText (Orbax) format using a CPU VM. - -# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; export MODEL_VARIATION=llama4-17b-[16e/128e]; bash tests/end_to_end/tpu/llama4/1_test_llama4.sh -# Use the same BASE_OUTPUT_PATH and MODEL_VARIATION for both 1_test_llama4.sh & 1_test_llama4.sh. - -# In order to generate the Llama4 golden logits, please see this script: tests/assets/logits_generation/golden_llama4_17b_16e_128e_export.ipynb - -set -ex -idx=$(date +%Y-%m-%d) - - -if [ -z "${BASE_OUTPUT_PATH}" ]; then - # Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. - export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d)/ - echo "BASE_OUTPUT_PATH is not set" -fi -BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} -echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} - -# By default, we'll use "llama4-17b-16e" -if [ -z "${MODEL_VARIATION}" ]; then - export MODEL_VARIATION="llama4-17b-16e" - echo "MODEL_VARIATION is not set, using MODEL_VARIATION = ${MODEL_VARIATION}" -fi - -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - -# Step 1: -# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ -# Non-Googlers please remember to use separate GCS paths for uploading model weights from HuggingFace ($CHKPT_BUCKET) and MaxText compatible weights -# ($MODEL_BUCKET). Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. -# You can use the HuggingFace checkpoint at https://huggingface.co/meta-llama/Llama-4-Maverick-17B-128E for Scout and -# https://huggingface.co/meta-llama/Llama-4-Maverick-17B-128E for Maverick -export CHKPT_BUCKET=gs://maxtext-llama/${MODEL_VARIATION}/hf-checkpoint/ - -# In the following command, we are copying the HF checkpoint into a local directory `tmp` -- you are free to use a different local directory than /tmp/, -gcloud storage cp -r ${CHKPT_BUCKET} /tmp - -export LOCATION_OF_HF_CHKPT_ON_DISK=/tmp/hf-checkpoint - -JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama4_ckpt_unscanned --base-model-path ${LOCATION_OF_HF_CHKPT_ON_DISK} --maxtext-model-path ${BASE_OUTPUT_PATH}/unscanned --model-size ${MODEL_VARIATION} --huggingface-checkpoint True -echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/unscanned/0/items" diff --git a/tests/end_to_end/tpu/llama4/2_test_llama4.sh b/tests/end_to_end/tpu/llama4/2_test_llama4.sh deleted file mode 100644 index d23ef4cd50..0000000000 --- a/tests/end_to_end/tpu/llama4/2_test_llama4.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash - -# This file, combined with step 1 in the same directory, runs on daily basis and demonstrates: -# 1. Converts the Llama4-Maverick HuggingFace checkpoint to MaxText (Orbax) format using a CPU VM. -# 2. Takes the MaxText (unscanned Orbax) checkpoint to run inference on a TPU VM. - -# The flow of this file is to take the MaxText (unscanned Orbax) checkpoint and run inference on a TPU VM. - -# Example Usage: export BASE_OUTPUT_PATH=/path/to/GCS/bucket; export MODEL_VARIATION=llama4-17b-[16e/128e]; bash tests/end_to_end/tpu/llama4/2_test_llama4.sh -# Use the same BASE_OUTPUT_PATH and MODEL_VARIATION for both 1_test_llama4.sh & 1_test_llama4.sh. - -# In order to generate the Llama4 golden logits, please see this script: tests/assets/logits_generation/golden_llama4_17b_16e_128e_export.ipynb - -set -ex -idx=$(date +%Y-%m-%d) - -# By default, we'll use "llama4-17b-16e" -if [ -z "${MODEL_VARIATION}" ]; then - export MODEL_VARIATION="llama4-17b-16e" - echo "MODEL_VARIATION is not set, using MODEL_VARIATION = ${MODEL_VARIATION}" - export TOKENIZER_PATH=meta-llama/Llama-4-Scout-17B-16E -fi - -# Installing torch for deps in forward_pass_logit_checker.py -python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu - -if [ -z "${BASE_OUTPUT_PATH}" ]; then - # Non-Googlers please remember to point BASE_OUTPUT_PATH to GCS buckets that you own, this script uses internal buckets for testing. - export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d) - echo "BASE_OUTPUT_PATH is not set" -fi -BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} -echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} - - -export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items - -# Step 2: run logit checking -python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"//base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} run_name=forward_pass_test_${MODEL_VARIATION} attention=dot_product per_device_batch_size=1 model_name=${MODEL_VARIATION} max_prefill_predict_length=4 max_target_length=4 scan_layers=false --atol=0.01 --rtol=0.01 async_checkpointing=false sparse_matmul=false weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=float32 float32_logits=true float32_qk_product=true ici_expert_parallelism=16 diff --git a/tests/end_to_end/tpu/llama4/scout/1_test_llama4.sh b/tests/end_to_end/tpu/llama4/scout/1_test_llama4.sh new file mode 100644 index 0000000000..4d911fa9f3 --- /dev/null +++ b/tests/end_to_end/tpu/llama4/scout/1_test_llama4.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +# This file is documentation for how to get started with Llama4 Scout. + +# This file runs Step 1 on CPU (on a daily basis). +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pre-training, fine-tuning, and decoding. + +set -ex + +export MODEL_NAME='llama4-17b-16e' +export TOKENIZER_PATH='meta-llama/Llama-4-Scout-17B-16E' + +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# Step 1: Checkpoint conversion +# You can use the HuggingFace checkpoint at TODO +# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET +# Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own +# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory +if [ -z "${CKPT_DISK_LOCATION}" ]; then + export CKPT_BUCKET=gs://maxtext-llama/llama4-17b-16e/hf-checkpoint + gcloud storage cp -r ${CKPT_BUCKET} /tmp + export CKPT_DISK_LOCATION=/tmp/hf-checkpoint +fi + +# Convert checkpoint to `unscanned` format, more suitable for decoding +JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama4_ckpt_unscanned --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/unscanned --model-size ${MODEL_NAME} --huggingface-checkpoint True +echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/unscanned/0/items" \ No newline at end of file diff --git a/tests/end_to_end/tpu/llama4/scout/2_test_llama4.sh b/tests/end_to_end/tpu/llama4/scout/2_test_llama4.sh new file mode 100644 index 0000000000..fd588d4e7c --- /dev/null +++ b/tests/end_to_end/tpu/llama4/scout/2_test_llama4.sh @@ -0,0 +1,54 @@ +#!/bin/bash + +# This file is documentation for how to get started with Llama4 Scout. + +# This file runs Step 2 on v6e-256 (on a daily basis). +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pre-training, fine-tuning, and decoding. + +# The golden logit can be generated by: +# tests/assets/logits_generation/golden_llama4_17b_16e_128e_export.ipynb + +set -ex + +export MODEL_NAME='llama4-17b-16e' +export TOKENIZER_PATH='meta-llama/Llama-4-Scout-17B-16E' + +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# Step 2: +# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset + +# Test whether the forward pass logits match the golden logits +# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl +GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl" +if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then + GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl" + GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl + gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} +fi + +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=-1 ici_expert_parallelism=16 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=0.01 --rtol=0.01 + + +# Run pre-training - tokamax_gmm implementation +python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=-1 + +# Run decoding - tokamax_gmm implementation +# Note decode requires the access token for huggingface tokenizer even if the model is not gated +python3 -m maxtext.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=-1 ici_expert_parallelism=16 checkpoint_storage_concurrent_gb=1024 prompt="I love to" \ No newline at end of file diff --git a/tests/unit/train_compile_test.py b/tests/unit/train_compile_test.py index 609914eed9..c76e39b6d0 100644 --- a/tests/unit/train_compile_test.py +++ b/tests/unit/train_compile_test.py @@ -636,27 +636,6 @@ def test_pipeline_subset(self): ) ) - @pytest.mark.cpu_only - def test_moe_llama4_17b_16e(self): - compiled_trainstep_file = "/tmp/test_moe_llama4_17b_16e.pickle" - train_compile_main( - ( - "", - get_test_config_path(), - f"compiled_trainstep_file={compiled_trainstep_file}", - "compile_topology=v5p-128", - "compile_topology_num_slices=1", - "model_name=llama4-17b-16e", - "per_device_batch_size=1", - "max_target_length=1024", - "dtype=bfloat16", - "weight_dtype=bfloat16", - "scan_layers=True", - "ici_fsdp_parallelism=16", - "ici_tensor_parallelism=4", - ) - ) - @pytest.mark.cpu_only def test_moe_gpt_oss_20b_sparse_matmul(self): compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"