Skip to content

Commit 2fcaeb8

Browse files
Merge pull request #3565 from AI-Hypercomputer:ds3.2-xlml-tests
PiperOrigin-RevId: 897975965
2 parents ce8a7de + 20d93f6 commit 2fcaeb8

4 files changed

Lines changed: 210 additions & 11 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies
4141

4242
## 🔥 Latest news 🔥
4343

44+
* \[April 10, 2026\] [DeepSeek-V3.2](https://arxiv.org/pdf/2512.02556) is now supported, featuring DeepSeek Sparse Attention for long context. Try it out with the [deepseek3.2-671b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek3.2-671b.yml) config. See the [user guide](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/deepseek/Run_DeepSeek.md) for more details.
4445
* \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md).
4546
* \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config.
4647
* \[March 5, 2026\] New `tpu-post-train` [target in PyPI](https://pypi.org/project/maxtext). Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info.

tests/end_to_end/tpu/deepseek/Run_DeepSeek.md

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,59 @@
1616

1717
# DeepSeek
1818

19-
DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V3.1 (671B), DeepSeek V3 (671B), DeepSeek R1 (671B), and DeepSeek V2-Lite (16B).
19+
DeepSeek is a novel family of open-weights sparse MoE models by DeepSeek AI. The currently supported models are DeepSeek V2-Lite (16B), DeepSeek V3 (671B), DeepSeek R1 (671B), DeepSeek V3.1 (671B), and DeepSeek V3.2 (671B).
2020

2121
* DeepSeek-V3 features advanced techniques, including Multi-Head Latent Attention (MLA), finer-grained and shared experts, Multi-Token Prediction (MTP), and FP8 mixed precision designed for enhanced efficiency and performance.
2222

23-
* DeepSeek V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
24-
2523
* DeepSeek R1 also uses V3 architecture. It utilizes cold-start data and large-scale reinforcement learning to incentivize chain-of-thought reasoning without relying solely on supervised fine-tuning.
2624

25+
* DeepSeek-V3.1 shares the same architecture as V3, but features an improved checkpoint that supports hybrid thinking modes, improved performance in agentic tasks, and higher thinking efficiency.
26+
27+
* DeepSeek-V3.2 introduces [DeepSeek Sparse Attention](https://arxiv.org/pdf/2512.02556) (DSA), successfully reduces computational complexity while preserving model performance in long-context scenarios.
28+
2729
**Please note:**
2830
* To leverage MLA with Flash Attention, ensure you have the latest JAX version.
2931
* The provided TPU configurations are examples and not mandatory.
3032
* For V3.1 & R1, use existing V3 671B model configurations, as it shares the same architecture.
3133

34+
## Checkpoint conversion
35+
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
36+
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
37+
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.
38+
39+
### Checkpoint conversion for V3.2
40+
#### 1. Download Model Weights
41+
Download the Hugging Face weights from [deepseek-ai/DeepSeek-V3.2](https://huggingface.co/deepseek-ai/DeepSeek-V3.2) to your local environment. Weights are provided in FP8.
42+
```bash
43+
hf download deepseek-ai/DeepSeek-V3.2 --local-dir <local_fp8_path>
44+
```
45+
46+
#### 2. Dequantize Weights
47+
Convert the weights from FP8 to BF16 using script [deepseek_fp8_to_bf16.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/deepseek_fp8_to_bf16.py) on CPU:
48+
49+
```bash
50+
python3 -m maxtext.checkpoint_conversion.standalone_scripts.deepseek_fp8_to_bf16 --input-fp8-hf-path=<local_fp8_path> --output-bf16-hf-path=<local_bf16_path>
51+
```
52+
53+
Alternatively, we can use the official DeepSeek script [fp8_cast_bf16.py](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert on GPU.
54+
55+
#### 3. Convert to MaxText-compatible Orbax format
56+
Execute the following command to finalize the conversion. Ensure your environment variables (`$BASE_OUTPUT_PATH`, `$HF_TOKEN`, and `$DEQUANTIZED_LOCAL_WEIGHTS`) are exported before running.
57+
Setting `scan_layers=true` generates scanned Orbax format for training and fine-tuning. Setting `scan_layers=false` unscanned format in Orbax for decoding.
58+
```bash
59+
python3 -m maxtext.checkpoint_conversion.to_maxtext \
60+
src/maxtext/configs/base.yml \
61+
model_name=deepseek3.2-671b \
62+
scan_layers=true \
63+
attention=dot_product \
64+
base_output_directory=$BASE_OUTPUT_PATH \
65+
hf_access_token=$HF_TOKEN \
66+
hardware=cpu \
67+
skip_jax_distributed_system=True \
68+
--hf_model_path=$DEQUANTIZED_LOCAL_WEIGHTS \
69+
--eager_load_method=safetensors \
70+
--save_dtype=bfloat16
71+
```
3272

3373
## Pre-training
3474
You can train from scratch to generate a new checkpoint. One example command to run pretraining with V3 on v5p-256.
@@ -54,13 +94,6 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
5494
dataset_type=synthetic
5595
```
5696

57-
58-
## Checkpoint conversion
59-
To get started, follow the instructions at HuggingFace ([V3](https://huggingface.co/deepseek-ai/DeepSeek-V3), [V2-Lite](https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite)) to download the model. Currently for V3, V3.1, and R1, it uses mixed precision fp8 & bf16 weights. To convert all FP8 weights to BF16, use the script [here](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/ckpt_scripts/deepseek_fp8_to_bf16.py). Once downloaded and converted to BF16:
60-
* run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. When converting a checkpoint with MTP layers (like DeepSeek-V3), be sure to add the `--enable_mtp` flag to process them correctly.
61-
* run [convert_deepseek_family_unscanned_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py) to convert the checkpoint to unscanned version in Orbax for decoding.
62-
63-
6497
## Fine-tuning
6598

6699
After you have a MaxText compatible checkpoint, you could fine-tune it with different datasets.
@@ -136,6 +169,59 @@ python3 -m maxtext.trainers.post_train.sft.train_sft_deprecated src/maxtext/conf
136169
dataset_type=hf
137170
```
138171

172+
## Continued pre-training for V3.2 Sparse Attention
173+
**DeepSeek Sparse Attention (DSA)** enhances the Multi-Head Latent Attention (MLA) architecture by introducing a **Lightning Indexer**, which selects the top-k tokens for attention. DeepSeek-V3.2 is instantiated from DeepSeek-V3.1 and undergoes continued pre-training to adapt this indexer via a two-stage strategy: **Dense Warm-up** and **Sparse Training**.
174+
175+
1. **Dense Warmup Stage**
176+
The indexer is trained exclusively using dense indexer loss while all other model parameters remain frozen.
177+
```sh
178+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
179+
model_name=deepseek3.2-671b \
180+
run_name=matmul_pre_training \
181+
per_device_batch_size=4 \
182+
enable_checkpointing=false \
183+
model_name=deepseek3-671b \
184+
ici_fsdp_parallelism=128 \
185+
steps=5 \
186+
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
187+
async_checkpointing=false \
188+
tokenizer_type=huggingface \
189+
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
190+
attention=flash \
191+
dtype=bfloat16 \
192+
weight_dtype=bfloat16 \
193+
megablox=True \
194+
sparse_matmul=True \
195+
dataset_type=synthetic \
196+
indexer_sparse_training=False \
197+
indexer_loss_scaling_factor=0.01 \ # Must be non-zero to activate indexer training. Default in base.yaml is 0.
198+
trainable_parameters_mask=['.*indexer.*']
199+
```
200+
2. **Sparse Training Stage**
201+
The indexer is trained with sparse indexer loss, while the remaining model parameters are unfrozen and updated using standard language modeling loss.
202+
```sh
203+
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
204+
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
205+
model_name=deepseek3.2-671b \
206+
per_device_batch_size=4 \
207+
enable_checkpointing=false \
208+
model_name=deepseek3-671b \
209+
ici_fsdp_parallelism=128 \
210+
steps=5 \
211+
max_target_length=1024 \
212+
async_checkpointing=false \
213+
tokenizer_type=huggingface \
214+
tokenizer_path=deepseek-ai/DeepSeek-V3.2 \
215+
attention=flash \
216+
dtype=bfloat16 \
217+
weight_dtype=bfloat16 \
218+
megablox=True \
219+
sparse_matmul=True \git
220+
dataset_type=synthetic \
221+
indexer_sparse_training=True \
222+
indexer_loss_scaling_factor=0.01 # Must be non-zero to activate indexer training. Default in base.yaml is 0.
223+
```
224+
139225
## Decoding
140226
One example command to run decoding with V3 on v5p-256 with unscanned checkpoint for fast decoding.
141227

@@ -215,4 +301,4 @@ To run MMLU benchmarks and validate the model's performance, follow the instruct
215301
* General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`.
216302
* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25.
217303

218-
See more examples in scripts for [V3](v3-671b/test_deepseek.sh) and [V2-Lite](v2-16b/test_deepseek.sh).
304+
See more examples in scripts for [V2-Lite](v2-16b), [V3](v3-671b), and [V3.2](v3.2-671b).
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/bin/bash
2+
3+
# This file is documentation for how to get started with DeepSeek v3.2.
4+
5+
# This file runs Step 1 on CPU.
6+
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
7+
# Scanned format is better for training; unscanned format is better for decoding.
8+
# 2. Run logit check, pre-training, fine-tuning, and decoding.
9+
10+
set -ex
11+
12+
export MODEL_NAME='deepseek3.2-671b'
13+
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2'
14+
15+
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
16+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
17+
18+
if [ -z "${BASE_OUTPUT_PATH}" ]; then
19+
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
20+
# this bucket will store all the files generated by MaxText during a run
21+
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
22+
echo "BASE_OUTPUT_PATH is not set"
23+
fi
24+
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
25+
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}
26+
27+
# Step 1: Checkpoint conversion
28+
# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V3.2 (fp8), and dequantize it to bf16
29+
# Non-Googlers please remember to point `BF16_HF_PATH` to GCS buckets that you own
30+
# Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory
31+
BF16_HF_PATH=gs://maxtext-deepseek/deepseek3.2/hf-bf16
32+
if [ -z "${CKPT_DISK_LOCATION}" ]; then
33+
export BF16_HF_BUCKET=gs://maxtext-deepseek/deepseek3.2/hf-bf16
34+
gcloud storage cp -r ${CKPT_BUCKET} /tmp
35+
export BF16_LOCAL_PATH=/tmp/hf-bf16
36+
fi
37+
38+
# on cpu
39+
# scanned
40+
echo $BASE_OUTPUT_PATH/0/items; \
41+
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
42+
model_name=deepseek3.2-671b scan_layers=true attention=dot_product \
43+
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
44+
hardware=cpu skip_jax_distributed_system=True \
45+
--hf_model_path=$BF16_LOCAL_PATH \
46+
--eager_load_method=safetensors \
47+
--save_dtype=bfloat16
48+
49+
# unscanned
50+
echo $BASE_OUTPUT_PATH/0/items; \
51+
python3 -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml \
52+
model_name=deepseek3.2-671b scan_layers=false attention=dot_product \
53+
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
54+
hardware=cpu skip_jax_distributed_system=True \
55+
--hf_model_path=$BF16_LOCAL_PATH \
56+
--eager_load_method=safetensors \
57+
--save_dtype=bfloat16
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/bin/bash
2+
3+
# This file is documentation for how to get started with DeepSeek v3.2.
4+
5+
# This file runs Step 2 on v5p-128 on a daily basis.
6+
# 1. Convert the HuggingFace checkpoint (bf16) to MaxText-compatible checkpoint (bf16):
7+
# Scanned format is better for training; unscanned format is better for decoding.
8+
# 2. Run logit check, pre-training, fine-tuning, and decoding.
9+
10+
set -ex
11+
12+
export MODEL_NAME='deepseek3.2-671b'
13+
export TOKENIZER_PATH='deepseek-ai/DeepSeek-V3.2'
14+
15+
# Installing torch for checkpoint conversion and forward_pass_logit_checker.py
16+
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
17+
18+
# e.g., $HOME/maxtext/src/maxtext
19+
export MAXTEXT_PKG_DIR="${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext}"
20+
21+
if [ -z "${BASE_OUTPUT_PATH}" ]; then
22+
# Non-Googlers please remember to point `BASE_OUTPUT_PATH` to GCS buckets that you own, this script uses internal buckets for testing.
23+
# this bucket will store all the files generated by MaxText during a run
24+
export BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M)
25+
echo "BASE_OUTPUT_PATH is not set"
26+
fi
27+
BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/}
28+
echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH}
29+
30+
# Step 2:
31+
# We define the checkpoint paths. This way it is easier to use these paths in the `train.py` and `decode.py` commands
32+
# export SCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/scanned/0/items
33+
# export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_PATH}/unscanned/0/items
34+
# Use a hard-coded golden checkpoint, rather than checkpoints generated by Step 1 as it is not in daily test.
35+
SCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/scanned/0/items
36+
UNSCANNED_CKPT_PATH=gs://maxtext-deepseek/deepseek3.2/2026-02-20/unscanned/0/items
37+
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
38+
export DATASET_PATH=gs://maxtext-dataset
39+
40+
# Test whether the forward pass logits match the golden logits
41+
# default golden_logits_path=/deps/tests/assets/golden_logits/golden_data_{MODEL_NAME}.jsonl, copied from gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl
42+
GOLDEN_LOGITS_DISK_LOCATION="/deps/tests/assets/golden_logits/golden_data_${MODEL_NAME}.jsonl"
43+
if [ ! -f "${GOLDEN_LOGITS_DISK_LOCATION}" ]; then
44+
GOLDEN_LOGITS_PATH="gs://maxtext-test-assets/golden_data_${MODEL_NAME}.jsonl"
45+
GOLDEN_LOGITS_DISK_LOCATION=/tmp/golden_data.jsonl
46+
gcloud storage cp ${GOLDEN_LOGITS_PATH} ${GOLDEN_LOGITS_DISK_LOCATION}
47+
fi
48+
49+
# override deepseek3.2-671b.yml with indexer_topk=2
50+
OVERRIDE="override_model_config=True indexer_topk=2"
51+
python3 -m tests.utils.forward_pass_logit_checker ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=forward_logits_check load_parameters_path=${SCANNED_CKPT_PATH} scan_layers=true attention=dot_product per_device_batch_size=1 model_name=${MODEL_NAME} max_prefill_predict_length=4 max_target_length=4 async_checkpointing=false sparse_matmul=false ici_fsdp_parallelism=1 ici_expert_parallelism=-1 checkpoint_storage_concurrent_gb=1024 weight_dtype=float32 dtype=float32 activations_in_float32=true matmul_precision=highest float32_logits=true float32_qk_product=true ${OVERRIDE} --golden_logits_path=${GOLDEN_LOGITS_DISK_LOCATION} --max_kl_div=0.3
52+
53+
# Run decoding - megablox implementation
54+
# Note decode requires the access token for huggingface tokenizer even if the model is not gated
55+
python3 -m maxtext.inference.decode ${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml base_output_directory=${BASE_OUTPUT_PATH} run_name=decode model_name=${MODEL_NAME} tokenizer_type=huggingface tokenizer_path=${TOKENIZER_PATH} hf_access_token=${HF_TOKEN} load_parameters_path=${UNSCANNED_CKPT_PATH} scan_layers=False attention=dot_product sparse_matmul=True use_tokamax_gmm=False dtype=bfloat16 weight_dtype=bfloat16 per_device_batch_size=1 max_prefill_predict_length=3072 max_target_length=4096 ici_fsdp_parallelism=1 ici_tensor_parallelism=-1 ici_expert_parallelism=1 checkpoint_storage_concurrent_gb=1024 mla_naive_kvcache=false prompt="An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and outputs are all vectors. The output is "

0 commit comments

Comments
 (0)