Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies

## 🔥 Latest news 🔥

* \[April 10, 2026\] [DeepSeek-V3.2](https://arxiv.org/pdf/2512.02556) is now supported, featuring DeepSeek Sparse Attention for long context. Try it out with the [deepseek3.2-671b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3.2-671b.yml) config. See the [user guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) for more details.
* \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md).
* \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config.
* \[March 5, 2026\] New `tpu-post-train` [target in PyPI](https://pypi.org/project/maxtext). Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info.
Expand Down
108 changes: 97 additions & 11 deletions tests/end_to_end/tpu/deepseek/Run_DeepSeek.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,59 @@

# DeepSeek

DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V3.1 (671B), DeepSeek V3 (671B), DeepSeek R1 (671B), and DeepSeek V2-Lite (16B).
DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V2-Lite (16B), DeepSeek V3 (671B), DeepSeek R1 (671B), DeepSeek V3.1 (671B), and DeepSeek V3.2 (671B).

* DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance.

* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning.

* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.

* DeepSeek-V3.2 introduces [DeepSeek Sparse Attention](https://arxiv.org/pdf/2512.02556) (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios.

**Please note:**
* To leverage MLA with Flash Attention, ensure you have the latest JAX version.
* The provided TPU configurations are examples and not mandatory.
* For V3.1 & R1, use existing V3 671B model configurations, as it shares the same architecture.

## Checkpoint conversion
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.

### Checkpoint conversion for V3.2
#### 1. Download Model Weights
Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment. Weights are provided in FP8.
```bash
hf download deepseek-ai/DeepSeek-V3.2 --local-dir <local_fp8_path>
```

#### 2. Dequantize Weights
Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU:

```bash
python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=<local_fp8_path> --output-bf16-hf-path=<local_bf16_path>
```

Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU.

#### 3. Convert to MaxText-compatible Orbax format
Execute the following command to finalize the conversion. Ensure your environment variables (`$BASE_OUTPUT_PATH`, `$HF_TOKEN`, and `$DEQUANTIZED_LOCAL_WEIGHTS`) are exported before running.
Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning. Setting `scan_layers=false` unscanned format in Orbax for decoding.
```bash
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b \
scan_layers=true \
attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH \
hf_access_token=$HF_TOKEN \
hardware=cpu \
skip_jax_distributed_system=True \
--hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \
--eager_load_method=safetensors \
--save_dtype=bfloat16
```

## Pre-training
You can train from scratch to generate a new checkpoint. One example command to run pretraining with V3 on v5p-256.
Expand All @@ -54,13 +94,6 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
dataset_type=synthetic
```


## Checkpoint conversion
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.


## Fine-tuning

After you have a MaxText compatible checkpoint, you could fine-tune it with different datasets.
Expand Down Expand Up @@ -136,6 +169,59 @@ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/conf
dataset_type=hf
```

## Continued pre-training for V3.2 Sparse Attention
**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.

1. **Dense Warmup Stage**
The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.
```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b \
run_name=matmul_pre_training \
per_device_batch_size=4 \
enable_checkpointing=false \
model_name=deepseek3-671b \
ici_fsdp_parallelism=128 \
steps=5 \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=True \
sparse_matmul=True \
dataset_type=synthetic \
indexer_sparse_training=False \
indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0.
trainable_parameters_mask=['.*indexer.*']
```
2. **Sparse Training Stage**
The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss.
```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
model_name=deepseek3.2-671b \
per_device_batch_size=4 \
enable_checkpointing=false \
model_name=deepseek3-671b \
ici_fsdp_parallelism=128 \
steps=5 \
max_target_length=1024 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=True \
sparse_matmul=True \git
dataset_type=synthetic \
indexer_sparse_training=True \
indexer_loss_scaling_factor=0.01 # Must be non-zero to activate indexer training. Default in base.yaml is 0.
```

## Decoding
One example command to run decoding with V3 on v5p-256 with unscanned checkpoint for fast decoding.

Expand Down Expand Up @@ -215,4 +301,4 @@ To run MMLU benchmarks and validate the model's performance, follow the instruct
* General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`.
* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25.

See more examples in scripts for [V3](v3-671b/test_deepseek.sh) and [V2-Lite](v2-16b/test_deepseek.sh).
See more examples in scripts for [V2-Lite](v2-16b), [V3](v3-671b), and [V3.2](v3.2-671b).
57 changes: 57 additions & 0 deletions tests/end_to_end/tpu/deepseek/v3.2-671b/1_test_deepseek.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash

# This file is documentation for how to get started with DeepSeek v3.2.

# This file runs Step 1 on CPU.
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.

set -ex

export MODEL_NAME='deepseek3.2-671b'
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2'

# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# Step 1: Checkpoint conversion
# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V3.2 (fp8), and dequantize it to bf16
# Non-Googlers please remember to point `BF16_HF_PATH` to GCS buckets that you own
# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory
BF16_HF_PATH=gs://maxtext-deepseek/deepseek3.2/hf-bf16
if [ -z "${CKPT_DISK_LOCATION}" ]; then
export BF16_HF_BUCKET=gs://maxtext-deepseek/deepseek3.2/hf-bf16
gcloud storage cp -r ${CKPT_BUCKET} /tmp
export BF16_LOCAL_PATH=/tmp/hf-bf16
fi

# on cpu
# scanned
echo $BASE_OUTPUT_PATH/0/items; \
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--hf_model_path=$BF16_LOCAL_PATH \
--eager_load_method=safetensors \
--save_dtype=bfloat16

# unscanned
echo $BASE_OUTPUT_PATH/0/items; \
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
model_name=deepseek3.2-671b scan_layers=false attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--hf_model_path=$BF16_LOCAL_PATH \
--eager_load_method=safetensors \
--save_dtype=bfloat16
55 changes: 55 additions & 0 deletions tests/end_to_end/tpu/deepseek/v3.2-671b/2_test_deepseek.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash

# This file is documentation for how to get started with DeepSeek v3.2.

# This file runs Step 2 on v5p-128 on a daily basis.
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
# Scanned format is better for training; unscanned format is better for decoding.
# 2. Run logit check, pre-training, fine-tuning, and decoding.

set -ex

export MODEL_NAME='deepseek3.2-671b'
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2'

# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu

# e.g., $HOME/maxtext/src/maxtext
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"

if [ -z "${BASE_OUTPUT_PATH}" ]; then
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
# this bucket will store all the files generated by MaxText during a run
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
echo "BASE_OUTPUT_PATH is not set"
fi
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}

# Step 2:
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test.
SCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/scanned/0/items
UNSCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/unscanned/0/items
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
export DATASET_PATH=gs://maxtext-dataset

# Test whether the forward pass logits match the golden logits
# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl"
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
fi

# override deepseek3.2-671b.yml with indexer_topk=2
OVERRIDE="override_model_config=True indexer_topk=2"
python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true ${OVERRIDE} --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --max_kl_div=0.3

# Run decoding - megablox implementation
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
python3 -m maxtext.inference.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=3072 max_target_length=4096 ici_fsdp_parallelism=1 ici_tensor_parallelism=-1 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "
Loading