Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V2-Lite'
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# e.g., $HOME/maxtext/src/MaxText
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
Expand Down
3 changes: 0 additions & 3 deletions tests/end_to_end/tpu/deepseek/v3-671b/2_test_deepseek.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3'
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# e.g., $HOME/maxtext/src/MaxText
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}"

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
Expand Down
48 changes: 0 additions & 48 deletions tests/end_to_end/tpu/llama4/1_test_llama4.sh

This file was deleted.

39 changes: 0 additions & 39 deletions tests/end_to_end/tpu/llama4/2_test_llama4.sh

This file was deleted.

40 changes: 40 additions & 0 deletions tests/end_to_end/tpu/llama4/scout/1_test_llama4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/bin/bash

# This file is documentation for how to get started with Llama4 Scout.

# This file runs Step 1 on CPU (on a daily basis).
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.

set -ex

export MODEL_NAME='llama4-17b-16e'
export TOKENIZER_PATH='meta-llama/Llama-4-Scout-17B-16E'

# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# Step 1: Checkpoint conversion
# You can use the HuggingFace checkpoint at TODO
# Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET
# Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own
# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory
if [ -z "${CKPT_DISK_LOCATION}" ]; then
export CKPT_BUCKET=gs://maxtext-llama/llama4-17b-16e/hf-checkpoint
gcloud storage cp -r ${CKPT_BUCKET} /tmp
export CKPT_DISK_LOCATION=/tmp/hf-checkpoint
fi

# Convert checkpoint to `unscanned` format, more suitable for decoding
JAX_PLATFORMS=cpu python3 -m maxtext.checkpoint_conversion.standalone_scripts.llama4_ckpt_unscanned --base-model-path ${CKPT_DISK_LOCATION} --maxtext-model-path ${BASE_OUTPUT_PATH}/unscanned --model-size ${MODEL_NAME} --huggingface-checkpoint True
echo "Wrote MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_PATH}/unscanned/0/items"
54 changes: 54 additions & 0 deletions tests/end_to_end/tpu/llama4/scout/2_test_llama4.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/bin/bash

# This file is documentation for how to get started with Llama4 Scout.

# This file runs Step 2 on v6e-256 (on a daily basis).
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.

# The golden logit can be generated by:
# tests/assets/logits_generation/golden_llama4_17b_16e_128e_export.ipynb

set -ex

export MODEL_NAME='llama4-17b-16e'
export TOKENIZER_PATH='meta-llama/Llama-4-Scout-17B-16E'

# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu


if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# Step 2:
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# Test whether the forward pass logits match the golden logits
# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl"
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
fi

python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=false attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=-1 ici_expert_parallelism=16 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --atol=0.01 --rtol=0.01


# Run pre-training - tokamax_gmm implementation
python3 -m maxtext.trainers.pre_train.train ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=tokamax_gmm_pre_training model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} dataset_type=synthetic enable_checkpointing=false attention=flash sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=4 steps=5 max_target_length=1024 ici_fsdp_parallelism=-1

# Run decoding - tokamax_gmm implementation
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
python3 -m maxtext.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=True dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=512 max_target_length=1024 ici_fsdp_parallelism=1 ici_tensor_parallelism=-1 ici_expert_parallelism=16 checkpoint_storage_concurrent_gb=1024 prompt="I love to"
21 changes: 0 additions & 21 deletions tests/unit/train_compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,27 +636,6 @@ def test_pipeline_subset(self):
)
)

@pytest.mark.cpu_only
def test_moe_llama4_17b_16e(self):
compiled_trainstep_file = "/tmp/test_moe_llama4_17b_16e.pickle"
train_compile_main(
(
"",
get_test_config_path(),
f"compiled_trainstep_file={compiled_trainstep_file}",
"compile_topology=v5p-128",
"compile_topology_num_slices=1",
"model_name=llama4-17b-16e",
"per_device_batch_size=1",
"max_target_length=1024",
"dtype=bfloat16",
"weight_dtype=bfloat16",
"scan_layers=True",
"ici_fsdp_parallelism=16",
"ici_tensor_parallelism=4",
)
)

@pytest.mark.cpu_only
def test_moe_gpt_oss_20b_sparse_matmul(self):
compiled_trainstep_file = "/tmp/test_moe_gpt_oss_20b_sparse_matmul.pickle"
Expand Down
Loading