Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies

## 🔥 Latest news 🔥

* \[April 10, 2026\] [Kimi-K2](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/kimi-k2-1t.yml) (32B activated, 1T total) is now supported.
* \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md).
* \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config.
* \[March 5, 2026\] New `tpu-post-train` [target in PyPI](https://pypi.org/project/maxtext). Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info.
Expand Down
157 changes: 157 additions & 0 deletions tests/end_to_end/tpu/kimi/Run_Kimi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
<!--
# Copyright 2023-2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->

# Kimi

Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1.04T)**.

* **Kimi K2** features a massive 1.04 trillion total parameters with 32 billion activated parameters. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks.
* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient Muon optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training.
* **Agentic Excellence**: K2 is specifically post-trained using a large-scale agentic data synthesis pipeline and Reinforcement Learning (RL), achieving state-of-the-art performance on benchmarks like Tau2-Bench and SWE-Bench.

## Pre-training
You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on v5p-512 (adjust parallelism for the 1T parameter scale).

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
run_name=kimi_k2_pre_training \
per_device_batch_size=1 \
enable_checkpointing=false \
model_name=kimi-k2-1t \
ici_fsdp_parallelism=64 \
ici_expert_parallelism=8 \
steps=5 \
max_target_length=4096 \
async_checkpointing=false \
tokenizer_type=huggingface \
tokenizer_path=moonshotai/Kimi-K2-Instruct \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
megablox=False \
sparse_matmul=False \
dataset_type=synthetic
```

## Checkpoint conversion
To get started, download the model from [HuggingFace](https://huggingface.co/moonshotai/Kimi-K2-Instruct). Kimi K2 uses a trillion-parameter architecture that requires efficient sharding.
* Run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning.
* Note that Kimi K2 utilizes **YaRN** for context window extension to 128k; ensure your configuration reflects these positional embedding settings during conversion for decoding.

## Fine-tuning
After you have a MaxText compatible checkpoint, you can fine-tune Kimi K2. The Kimi team recommends using the **Muon optimizer** during fine-tuning to maintain the token efficiency established during pre-training.

Example command for General Fine-tuning on v5p-512:

```sh
python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
dataset_path=${DATASET_PATH?} \
load_parameters_path=${CONVERTED_CHECKPOINT?} \
run_name=kimi_k2_fine_tuning \
per_device_batch_size=2 \
model_name=kimi-k2-1t \
steps=50 \
max_target_length=4096 \
tokenizer_type=huggingface \
tokenizer_path=moonshotai/Kimi-K2-Instruct \
attention=flash \
dtype=bfloat16 \
weight_dtype=bfloat16 \
ici_expert_parallelism=64 \
ici_fsdp_parallelism=8
```

## Decoding
Example command to run decoding with Kimi K2. Given its 1T size, high tensor parallelism is recommended.

```sh
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
base_output_directory=${BASE_OUTPUT_DIRECTORY?} \
load_parameters_path=${CONVERTED_CHECKPOINT?} \
run_name=kimi_decode \
per_device_batch_size=1 \
model_name=kimi-k2-1t \
max_target_length=2048 \
tokenizer_type=huggingface \
tokenizer_path=moonshotai/Kimi-K2-Instruct \
attention=dot_product \
ici_tensor_parallelism=128 \
ici_fsdp_parallelism=1 \
prompt="The primary goal of agentic intelligence is to " \
scan_layers=False
```

Based on the [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), here is the rewritten **Correctness** section. This adaptation focuses on the 1.04T architecture and the specific benchmarks used by the Moonshot AI team (like ACEBench and Tau2-Bench) to verify the model's agentic reasoning.

-----

## Correctness

To verify the correctness of the Kimi K2 implementation, we perform two primary validation steps:

* **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts.
* **MMLU Score Validation**: We validate the MMLU score against established benchmarks.

### Logit Comparison

Use the following example to generate "golden" logits from the HuggingFace reference model for Kimi K2.

```sh
python3 -m tests.assets.logits_generation.generate_hf_golden_logits \
--model-id=moonshotai/Kimi-K2-Instruct \
--output-path=golden_Kimi-K2.jsonl \
--prompts='I love to;Today is a;What is the'
```

You should see logs confirming the file location:

```
...
File is stored locally at golden_Kimi-K2.jsonl.
```

Run the command below to compare the logits between the HuggingFace reference and your MaxText implementation. This ensures the **MuonClip** optimizer and **MLA** attention heads (64 heads for K2) are correctly mapped.

```sh
python3 -m tests.utils.forward_pass_logit_checker \
src/maxtext/configs/base.yml \
tokenizer_type=huggingface \
tokenizer_path=moonshotai/Kimi-K2-Instruct \
load_parameters_path=${CONVERTED_CHECKPOINT?} \
run_name=forward_pass_test_kimi_k2 \
per_device_batch_size=1 \
model_name=kimi-k2-1t \
max_prefill_predict_length=4 \
max_target_length=4 \
scan_layers=false \
sparse_matmul=False \
dtype=float32 \
activations_in_float32=true \
matmul_precision=high \
--max_kl_div=2e-4 \
--golden_logits_path=${PWD}/golden_Kimi-K2.jsonl
```

To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here](../../../benchmarks/api_server/README.md).

## Supported MoE strategy
* Dropless
* [MegaBlocks](https://arxiv.org/abs/2211.15841) implementation with flag `sparse_matmul=True megablox=True`.
* [JAX ragged_dot](https://github.com/jax-ml/jax/blob/a8fb0e01f8d083fff337d3c26375bb1b77344a99/jax/_src/lax/lax.py#L2415) implementation with flag `sparse_matmul=True megablox=False`.
* General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`.
* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25.
Loading