diff --git a/README.md b/README.md index 9024e288d8..9ba9e45b08 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies ## 🔥 Latest news 🔥 +* \[April 10, 2026\] [DeepSeek-V3.2](https://arxiv.org/pdf/2512.02556) is now supported, featuring DeepSeek Sparse Attention for long context. Try it out with the [deepseek3.2-671b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3.2-671b.yml) config. See the [user guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) for more details. * \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md). * \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config. * \[March 5, 2026\] New `tpu-post-train` [target in PyPI](https://pypi.org/project/maxtext). Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info. diff --git a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md index 10651f7186..fe89cc7481 100644 --- a/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md +++ b/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md @@ -16,19 +16,59 @@ # DeepSeek -DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V3.1 (671B), DeepSeek V3 (671B), DeepSeek R1 (671B), and DeepSeek V2-Lite (16B). +DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V2-Lite (16B), DeepSeek V3 (671B), DeepSeek R1 (671B), DeepSeek V3.1 (671B), and DeepSeek V3.2 (671B). * DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance. -* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency. - * DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning. +* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency. + +* DeepSeek-V3.2 introduces [DeepSeek Sparse Attention](https://arxiv.org/pdf/2512.02556) (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios. + **Please note:** * To leverage MLA with Flash Attention, ensure you have the latest JAX version. * The provided TPU configurations are examples and not mandatory. * For V3.1 & R1, use existing V3 671B model configurations, as it shares the same architecture. +## Checkpoint conversion +To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16: +* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. +* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. + +### Checkpoint conversion for V3.2 +#### 1. Download Model Weights +Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment. Weights are provided in FP8. +```bash +hf download deepseek-ai/DeepSeek-V3.2 --local-dir +``` + +#### 2. Dequantize Weights +Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU: + +```bash +python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path= --output-bf16-hf-path= +``` + +Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU. + +#### 3. Convert to MaxText-compatible Orbax format +Execute the following command to finalize the conversion. Ensure your environment variables (`$BASE_OUTPUT_PATH`, `$HF_TOKEN`, and `$DEQUANTIZED_LOCAL_WEIGHTS`) are exported before running. +Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning. Setting `scan_layers=false` unscanned format in Orbax for decoding. +```bash +python3 -m maxtext.checkpoint_conversion.to_maxtext \ + src/maxtext/configs/base.yml \ + model_name=deepseek3.2-671b \ + scan_layers=true \ + attention=dot_product \ + base_output_directory=$BASE_OUTPUT_PATH \ + hf_access_token=$HF_TOKEN \ + hardware=cpu \ + skip_jax_distributed_system=True \ + --hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \ + --eager_load_method=safetensors \ + --save_dtype=bfloat16 +``` ## Pre-training You can train from scratch to generate a new checkpoint. One example command to run pretraining with V3 on v5p-256. @@ -54,13 +94,6 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ dataset_type=synthetic ``` - -## Checkpoint conversion -To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16: -* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly. -* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding. - - ## Fine-tuning After you have a MaxText compatible checkpoint, you could fine-tune it with different datasets. @@ -136,6 +169,59 @@ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/conf dataset_type=hf ``` +## Continued pre-training for V3.2 Sparse Attention +**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**. + +1. **Dense Warmup Stage** +The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen. +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + model_name=deepseek3.2-671b \ + run_name=matmul_pre_training \ + per_device_batch_size=4 \ + enable_checkpointing=false \ + model_name=deepseek3-671b \ + ici_fsdp_parallelism=128 \ + steps=5 \ + tokenizer_path=deepseek-ai/DeepSeek-V3.2 \ + async_checkpointing=false \ + tokenizer_type=huggingface \ + tokenizer_path=deepseek-ai/DeepSeek-V3.2 \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=True \ + sparse_matmul=True \ + dataset_type=synthetic \ + indexer_sparse_training=False \ + indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0. + trainable_parameters_mask=['.*indexer.*'] +``` +2. **Sparse Training Stage** +The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss. +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + model_name=deepseek3.2-671b \ + per_device_batch_size=4 \ + enable_checkpointing=false \ + model_name=deepseek3-671b \ + ici_fsdp_parallelism=128 \ + steps=5 \ + max_target_length=1024 \ + async_checkpointing=false \ + tokenizer_type=huggingface \ + tokenizer_path=deepseek-ai/DeepSeek-V3.2 \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=True \ + sparse_matmul=True \git + dataset_type=synthetic \ + indexer_sparse_training=True \ + indexer_loss_scaling_factor=0.01 # Must be non-zero to activate indexer training. Default in base.yaml is 0. +``` + ## Decoding One example command to run decoding with V3 on v5p-256 with unscanned checkpoint for fast decoding. @@ -215,4 +301,4 @@ To run MMLU benchmarks and validate the model's performance, follow the instruct * General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`. * Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25. -See more examples in scripts for [V3](v3-671b/test_deepseek.sh) and [V2-Lite](v2-16b/test_deepseek.sh). +See more examples in scripts for [V2-Lite](v2-16b), [V3](v3-671b), and [V3.2](v3.2-671b). diff --git a/tests/end_to_end/tpu/deepseek/v3.2-671b/1_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3.2-671b/1_test_deepseek.sh new file mode 100644 index 0000000000..5adeb4955c --- /dev/null +++ b/tests/end_to_end/tpu/deepseek/v3.2-671b/1_test_deepseek.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# This file is documentation for how to get started with DeepSeek v3.2. + +# This file runs Step 1 on CPU. +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pre-training, fine-tuning, and decoding. + +set -ex + +export MODEL_NAME='deepseek3.2-671b' +export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2' + +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# Step 1: Checkpoint conversion +# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V3.2 (fp8), and dequantize it to bf16 +# Non-Googlers please remember to point `BF16_HF_PATH` to GCS buckets that you own +# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory +BF16_HF_PATH=gs://maxtext-deepseek/deepseek3.2/hf-bf16 +if [ -z "${CKPT_DISK_LOCATION}" ]; then + export BF16_HF_BUCKET=gs://maxtext-deepseek/deepseek3.2/hf-bf16 + gcloud storage cp -r ${CKPT_BUCKET} /tmp + export BF16_LOCAL_PATH=/tmp/hf-bf16 +fi + +# on cpu +# scanned +echo $BASE_OUTPUT_PATH/0/items; \ +python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \ +model_name=deepseek3.2-671b scan_layers=true attention=dot_product \ +base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \ +hardware=cpu skip_jax_distributed_system=True \ +--hf_model_path=$BF16_LOCAL_PATH \ +--eager_load_method=safetensors \ +--save_dtype=bfloat16 + +# unscanned +echo $BASE_OUTPUT_PATH/0/items; \ +python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \ +model_name=deepseek3.2-671b scan_layers=false attention=dot_product \ +base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \ +hardware=cpu skip_jax_distributed_system=True \ +--hf_model_path=$BF16_LOCAL_PATH \ +--eager_load_method=safetensors \ +--save_dtype=bfloat16 diff --git a/tests/end_to_end/tpu/deepseek/v3.2-671b/2_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3.2-671b/2_test_deepseek.sh new file mode 100644 index 0000000000..17e5180977 --- /dev/null +++ b/tests/end_to_end/tpu/deepseek/v3.2-671b/2_test_deepseek.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# This file is documentation for how to get started with DeepSeek v3.2. + +# This file runs Step 2 on v5p-128 on a daily basis. +# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16): +# Scanned format is better for training; unscanned format is better for decoding. +# 2. Run logit check, pre-training, fine-tuning, and decoding. + +set -ex + +export MODEL_NAME='deepseek3.2-671b' +export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2' + +# Installing torch for checkpoint conversion and forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# e.g., $HOME/maxtext/src/maxtext +export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}" + +if [ -z "${BASE_OUTPUT_PATH}" ]; then + # Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing. + # this bucket will store all the files generated by MaxText during a run + export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M) + echo "BASE_OUTPUT_PATH is not set" +fi +BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} +echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} + +# Step 2: +# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands +# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items +# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items +# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test. +SCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/scanned/0/items +UNSCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/unscanned/0/items +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset + +# Test whether the forward pass logits match the golden logits +# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl +GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl" +if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then + GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl" + GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl + gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION} +fi + +# override deepseek3.2-671b.yml with indexer_topk=2 +OVERRIDE="override_model_config=True indexer_topk=2" +python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true ${OVERRIDE} --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --max_kl_div=0.3 + +# Run decoding - megablox implementation +# Note decode requires the access token for huggingface tokenizer even if the model is not gated +python3 -m maxtext.inference.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=3072 max_target_length=4096 ici_fsdp_parallelism=1 ici_tensor_parallelism=-1 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "