From 3d69924e43f86d73b4e609ab3386ab3cade0375d Mon Sep 17 00:00:00 2001 From: Snehal Verma Date: Sat, 11 Apr 2026 00:21:15 +0000 Subject: [PATCH] Kimi-K2 launch announcement and user guide --- README.md | 1 + tests/end_to_end/tpu/kimi/Run_Kimi.md | 157 ++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 tests/end_to_end/tpu/kimi/Run_Kimi.md diff --git a/README.md b/README.md index 9024e288d8..b0722d982e 100644 --- a/README.md +++ b/README.md @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies ## 🔥 Latest news 🔥 +* \[April 10, 2026\] [Kimi-K2](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/kimi-k2-1t.yml) (32B activated, 1T total) is now supported. * \[April 2, 2026\] Gemma 4 multi-modal models (26B MoE, 31B dense) are now supported! Try them out with our [gemma4-26b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-26b.yml) and [gemma4-31b](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/gemma4-31b.yml) configs. For more details, see [Run_Gemma4.md](https://github.com/AI-Hypercomputer/maxtext/blob/main/tests/end_to_end/tpu/gemma4/Run_Gemma4.md). * \[March 6, 2026\] New features from DeepSeek-AI are now supported: Conditional Memory via Scalable Lookup ([Engram](https://arxiv.org/abs/2601.07372)) and Manifold-Constrained Hyper-Connections ([mHC](https://arxiv.org/abs/2512.24880)). Try them out with our [deepseek-custom](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/models/deepseek-custom.yml) starter config. * \[March 5, 2026\] New `tpu-post-train` [target in PyPI](https://pypi.org/project/maxtext). Please also use this installation option for running vllm_decode. See the [MaxText installation instructions](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) for more info. diff --git a/tests/end_to_end/tpu/kimi/Run_Kimi.md b/tests/end_to_end/tpu/kimi/Run_Kimi.md new file mode 100644 index 0000000000..7519d9e9ae --- /dev/null +++ b/tests/end_to_end/tpu/kimi/Run_Kimi.md @@ -0,0 +1,157 @@ + + +# Kimi + +Kimi is a family of high-performance, open-weights sparse MoE models by Moonshot AI designed for agentic intelligence. The currently supported models are **Kimi K2 (1.04T)**. + +* **Kimi K2** features a massive 1.04 trillion total parameters with 32 billion activated parameters. It utilizes **Multi-Head Latent Attention (MLA)** and an ultra-sparse MoE with **384 experts**, optimized for long-context and agentic tasks. +* **MuonClip Optimizer**: Kimi K2 was trained using the token-efficient Muon optimizer combined with a novel **QK-clip** technique to ensure training stability and eliminate loss spikes during large-scale pre-training. +* **Agentic Excellence**: K2 is specifically post-trained using a large-scale agentic data synthesis pipeline and Reinforcement Learning (RL), achieving state-of-the-art performance on benchmarks like Tau2-Bench and SWE-Bench. + +## Pre-training +You can train from scratch to generate a new checkpoint. One example command to run pre-training with Kimi K2 on v5p-512 (adjust parallelism for the 1T parameter scale). + +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + run_name=kimi_k2_pre_training \ + per_device_batch_size=1 \ + enable_checkpointing=false \ + model_name=kimi-k2-1t \ + ici_fsdp_parallelism=64 \ + ici_expert_parallelism=8 \ + steps=5 \ + max_target_length=4096 \ + async_checkpointing=false \ + tokenizer_type=huggingface \ + tokenizer_path=moonshotai/Kimi-K2-Instruct \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + megablox=False \ + sparse_matmul=False \ + dataset_type=synthetic +``` + +## Checkpoint conversion +To get started, download the model from [HuggingFace](https://huggingface.co/moonshotai/Kimi-K2-Instruct). Kimi K2 uses a trillion-parameter architecture that requires efficient sharding. +* Run [convert_deepseek_family_ckpt.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_ckpt.py) to convert the checkpoint for MaxText compatibility in [Orbax](https://orbax.readthedocs.io/en/latest/guides/checkpoint/orbax_checkpoint_101.html) for training and fine-tuning. +* Note that Kimi K2 utilizes **YaRN** for context window extension to 128k; ensure your configuration reflects these positional embedding settings during conversion for decoding. + +## Fine-tuning +After you have a MaxText compatible checkpoint, you can fine-tune Kimi K2. The Kimi team recommends using the **Muon optimizer** during fine-tuning to maintain the token efficiency established during pre-training. + +Example command for General Fine-tuning on v5p-512: + +```sh +python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + dataset_path=${DATASET_PATH?} \ + load_parameters_path=${CONVERTED_CHECKPOINT?} \ + run_name=kimi_k2_fine_tuning \ + per_device_batch_size=2 \ + model_name=kimi-k2-1t \ + steps=50 \ + max_target_length=4096 \ + tokenizer_type=huggingface \ + tokenizer_path=moonshotai/Kimi-K2-Instruct \ + attention=flash \ + dtype=bfloat16 \ + weight_dtype=bfloat16 \ + ici_expert_parallelism=64 \ + ici_fsdp_parallelism=8 +``` + +## Decoding +Example command to run decoding with Kimi K2. Given its 1T size, high tensor parallelism is recommended. + +```sh +python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \ + base_output_directory=${BASE_OUTPUT_DIRECTORY?} \ + load_parameters_path=${CONVERTED_CHECKPOINT?} \ + run_name=kimi_decode \ + per_device_batch_size=1 \ + model_name=kimi-k2-1t \ + max_target_length=2048 \ + tokenizer_type=huggingface \ + tokenizer_path=moonshotai/Kimi-K2-Instruct \ + attention=dot_product \ + ici_tensor_parallelism=128 \ + ici_fsdp_parallelism=1 \ + prompt="The primary goal of agentic intelligence is to " \ + scan_layers=False +``` + +Based on the [Kimi K2 Technical Report](https://arxiv.org/pdf/2507.20534), here is the rewritten **Correctness** section. This adaptation focuses on the 1.04T architecture and the specific benchmarks used by the Moonshot AI team (like ACEBench and Tau2-Bench) to verify the model's agentic reasoning. + +----- + +## Correctness + +To verify the correctness of the Kimi K2 implementation, we perform two primary validation steps: + + * **Logit Comparison**: We compare the logits generated by our implementation against those from a HuggingFace implementation for a set of given prompts. + * **MMLU Score Validation**: We validate the MMLU score against established benchmarks. + +### Logit Comparison + +Use the following example to generate "golden" logits from the HuggingFace reference model for Kimi K2. + +```sh +python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ + --model-id=moonshotai/Kimi-K2-Instruct \ + --output-path=golden_Kimi-K2.jsonl \ + --prompts='I love to;Today is a;What is the' +``` + +You should see logs confirming the file location: + +``` +... +File is stored locally at golden_Kimi-K2.jsonl. +``` + +Run the command below to compare the logits between the HuggingFace reference and your MaxText implementation. This ensures the **MuonClip** optimizer and **MLA** attention heads (64 heads for K2) are correctly mapped. + +```sh +python3 -m tests.utils.forward_pass_logit_checker \ + src/maxtext/configs/base.yml \ + tokenizer_type=huggingface \ + tokenizer_path=moonshotai/Kimi-K2-Instruct \ + load_parameters_path=${CONVERTED_CHECKPOINT?} \ + run_name=forward_pass_test_kimi_k2 \ + per_device_batch_size=1 \ + model_name=kimi-k2-1t \ + max_prefill_predict_length=4 \ + max_target_length=4 \ + scan_layers=false \ + sparse_matmul=False \ + dtype=float32 \ + activations_in_float32=true \ + matmul_precision=high \ + --max_kl_div=2e-4 \ + --golden_logits_path=${PWD}/golden_Kimi-K2.jsonl +``` + +To run MMLU benchmarks and validate the model's performance, follow the instructions provided [here](../../../benchmarks/api_server/README.md). + +## Supported MoE strategy +* Dropless + * [MegaBlocks](https://arxiv.org/abs/2211.15841) implementation with flag `sparse_matmul=True megablox=True`. + * [JAX ragged_dot](https://github.com/jax-ml/jax/blob/a8fb0e01f8d083fff337d3c26375bb1b77344a99/jax/_src/lax/lax.py#L2415) implementation with flag `sparse_matmul=True megablox=False`. + * General dense matmul implementation with flag `sparse_matmul=False capacity_factor=-1`. +* Dropping implementation with flag `sparse_matmul=False` and reasonable `capacity_factor`, commonly used from 1 to 1.25. \ No newline at end of file