From 476166a6bc698e7cf1ee184099918e3fda2b62b2 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 6 Feb 2026 08:22:48 +0000 Subject: [PATCH 01/11] more batch tune --- ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 54 +++++++++++++++++++ ...num=8,use_fp8_w8a8=false}_NVIDIA_H200.json | 54 +++++++++++++++++++ .../{topk_num=8}_NVIDIA_H200.json | 28 ++++++++++ ...orch.bfloat16,topk_num=8}_NVIDIA_H200.json | 36 +++++++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 36 +++++++++++++ 5 files changed, 208 insertions(+) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index c8100c6762..519fd497f1 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=192,N=2048,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -17,6 +17,15 @@ "num_stages": 3, "num_warps": 4 }, + "192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, "2048": { "BLOCK_SIZE_K": 32, "BLOCK_SIZE_M": 32, @@ -35,6 +44,15 @@ "num_stages": 2, "num_warps": 4 }, + "384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, "512": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 16, @@ -53,6 +71,24 @@ "num_stages": 2, "num_warps": 4 }, + "640": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, "8": { "BLOCK_SIZE_K": 32, "BLOCK_SIZE_M": 16, @@ -79,5 +115,23 @@ "NEED_TRANS": false, "num_stages": 2, "num_warps": 4 + }, + "896": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json index 4142ee9835..26a6d63c47 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=384,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -26,6 +26,24 @@ "num_stages": 3, "num_warps": 4 }, + "112": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "12": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, "128": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -44,6 +62,15 @@ "num_stages": 3, "num_warps": 4 }, + "24": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "256": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 32, @@ -62,6 +89,15 @@ "num_stages": 3, "num_warps": 4 }, + "48": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, "64": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -79,5 +115,23 @@ "NEED_TRANS": false, "num_stages": 3, "num_warps": 8 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "96": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json index 002b842cbb..ea17f7f5ae 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=8}_NVIDIA_H200.json @@ -11,6 +11,14 @@ "BLOCK_SIZE": 256, "num_warps": 4 }, + "112": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "12": { + "BLOCK_SIZE": 512, + "num_warps": 8 + }, "128": { "BLOCK_SIZE": 256, "num_warps": 8 @@ -19,6 +27,14 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "2": { + "BLOCK_SIZE": 256, + "num_warps": 4 + }, + "24": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "256": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -27,6 +43,10 @@ "BLOCK_SIZE": 128, "num_warps": 8 }, + "48": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "64": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -34,5 +54,13 @@ "8": { "BLOCK_SIZE": 128, "num_warps": 8 + }, + "80": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "96": { + "BLOCK_SIZE": 128, + "num_warps": 8 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json index bc904bb7f8..6f5752b8da 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H200.json @@ -17,6 +17,18 @@ "NUM_STAGE": 4, "num_warps": 1 }, + "112": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "12": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 4 + }, "128": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -29,6 +41,12 @@ "NUM_STAGE": 1, "num_warps": 2 }, + "24": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "256": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -41,6 +59,12 @@ "NUM_STAGE": 4, "num_warps": 4 }, + "48": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "64": { "BLOCK_DIM": 128, "BLOCK_M": 1, @@ -52,5 +76,17 @@ "BLOCK_M": 1, "NUM_STAGE": 1, "num_warps": 16 + }, + "80": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "96": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 50499a3e7e..37f18f4549 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=192,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -11,6 +11,12 @@ "NUM_STAGES": 1, "num_warps": 8 }, + "192": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 4 + }, "2048": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -23,6 +29,12 @@ "NUM_STAGES": 1, "num_warps": 8 }, + "384": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, "512": { "BLOCK_M": 1, "BLOCK_N": 128, @@ -35,6 +47,18 @@ "NUM_STAGES": 4, "num_warps": 1 }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, "8": { "BLOCK_M": 1, "BLOCK_N": 64, @@ -52,5 +76,17 @@ "BLOCK_N": 256, "NUM_STAGES": 4, "num_warps": 1 + }, + "896": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "96": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 4 } } \ No newline at end of file From f715bb9dfd9e013d77e0f0e9bd1cbbbbdb61d947 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 6 Feb 2026 08:27:50 +0000 Subject: [PATCH 02/11] delete the useless process --- .../layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py | 3 ++- .../layer_weights/meta_weights/fused_moe/impl/triton_impl.py | 3 ++- .../common/basemodel/triton_kernel/fused_moe/topk_select.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index f00d572d9d..bdd86eb51e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -46,7 +46,8 @@ def _select_experts( num_expert_group=num_expert_group, scoring_func=scoring_func, ) - topk_weights.mul_(self.routed_scaling_factor) + if self.routed_scaling_factor != 1.0: + topk_weights.mul_(self.routed_scaling_factor) if self.redundancy_expert_num > 0: redundancy_topk_ids_repair( topk_ids=topk_ids, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index 8bcdb4bf90..d6e923a115 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -57,7 +57,8 @@ def _select_experts( num_expert_group=num_expert_group, scoring_func=scoring_func, ) - topk_weights.mul_(self.routed_scaling_factor) + if self.routed_scaling_factor != 1.0: + topk_weights.mul_(self.routed_scaling_factor) if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py index 72c3a381ed..59d1f825a3 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/topk_select.py @@ -46,7 +46,7 @@ def fused_topk( sgl_ops.topk_softmax( topk_weights, topk_ids, - gating_output.float(), # TODO(woosuk): Optimize this. + gating_output, renormalize=renormalize, ) return topk_weights, topk_ids From fe942a1f5361da7c4c9fa8cc0d64ed721ac5a581 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 6 Feb 2026 08:44:21 +0000 Subject: [PATCH 03/11] add random_sample --- .../mode_backend/generic_post_process.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index ca3901ebd0..bae4505992 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -16,6 +16,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_length_penalty_param, b_mask_eos_reqs, is_all_greedy, + is_all_random, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -68,6 +69,11 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) + elif is_all_random: + batch_next_token_ids = _random_sample(probs) + batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) + return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) + elif get_env_start_args().sampling_backend == "triton": probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks) sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) @@ -104,6 +110,12 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx +def _random_sample(probs: torch.Tensor): + q = torch.empty_like(probs) + q.exponential_() + return probs.div_(q).argmax(dim=-1).view(-1) + + def _get_post_sample_tensors(reqs: List[InferReq]): req_idxes: List[int] = [] temperatures: List[float] = [] @@ -112,6 +124,7 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param: List[int] = [] mask_eos_reqs: List[bool] = [] is_all_greedy = True + is_all_random = True for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param @@ -127,6 +140,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks.append(top_k_val) if top_k_val > 1: is_all_greedy = False + if top_k_val != -1 or shm_param.top_p != 1.0: + is_all_random = False req_idxes.append(req_obj.req_idx) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) @@ -146,4 +161,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), is_all_greedy, + is_all_random, ) From 31cc903c3b905783212d1b3616e35ac7e5c91de5 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 6 Feb 2026 09:06:43 +0000 Subject: [PATCH 04/11] add qk_rmsnorm_fused --- .../basemodel/triton_kernel/norm/qk_norm.py | 136 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 7 +- .../triton_kernel/test_qk_rmsnorm_fused.py | 52 +++++++ 3 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py index 40322e5093..d3fc1b7dbe 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py @@ -64,3 +64,139 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps): num_warps=4, ) return x + + +@triton.jit +def _qk_rms_norm_fused_kernel( + # Q Pointers & Strides + Q_ptr, + WQ_ptr, + stride_q_row, + stride_q_col, + # K Pointers & Strides + K_ptr, + WK_ptr, + stride_k_row, + stride_k_col, + # Dimensions + num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界) + head_dim: tl.constexpr, + eps, + BLOCK_SIZE: tl.constexpr, +): + # PID 0: 处理第几个 Token (Row) + row_idx = tl.program_id(0) + # PID 1: 处理第几个 Head (Combo Index) + # 范围是 [0, num_heads_q + num_heads_k) + combo_head_idx = tl.program_id(1) + + # 公共的 offset (0 ~ head_dim) + offs = tl.arange(0, BLOCK_SIZE) + + # === 分支逻辑:判断是处理 Q 还是 K === + if combo_head_idx < num_heads_q: + # ------------------ 处理 Q ------------------ + # 指针计算 + # Q 的实际 head index 就是 combo_head_idx + Q_ptr += row_idx * stride_q_row + + # 定位 Q 数据: Base + Row偏移 + Head偏移 + 列偏移 + q_ptr_offset = (combo_head_idx * head_dim + offs) * stride_q_col + + # 加载 Q 数据 + x = tl.load(Q_ptr + q_ptr_offset).to(tl.float32) + # RMSNorm 计算 + var = tl.sum(x * x, axis=0) / head_dim + rstd = 1 / tl.sqrt(var + eps) + + # 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重) + w = tl.load(WQ_ptr + offs).to(tl.float32) + + y = x * rstd * w + + # 写回 Q + tl.store(Q_ptr + q_ptr_offset, y.to(Q_ptr.dtype.element_ty)) + + else: + # ------------------ 处理 K ------------------ + # 重新映射 K 的 head index (从 0 开始) + k_head_idx = combo_head_idx - num_heads_q + + # 指针计算 + K_ptr += row_idx * stride_k_row + k_ptr_offset = (k_head_idx * head_dim + offs) * stride_k_col + + # 加载 K 数据 + x = tl.load(K_ptr + k_ptr_offset).to(tl.float32) + # RMSNorm 计算 + var = tl.sum(x * x, axis=0) / head_dim + rstd = 1 / tl.sqrt(var + eps) + + # 加载 K 的权重 + w = tl.load(WK_ptr + offs).to(tl.float32) + + y = x * rstd * w + + # 写回 K + tl.store(K_ptr + k_ptr_offset, y.to(K_ptr.dtype.element_ty)) + + +def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6): + """ + In-place RMSNorm for both Q and K in a single kernel launch. + Supports GQA (different number of heads for Q and K). + + Args: + q: (Total_Tokens, Hidden_Q) or (B, S, H_q, D) -> flattend to 2D inside + k: (Total_Tokens, Hidden_K) + w_q: (head_dim,) Scale parameter for Q + w_k: (head_dim,) Scale parameter for K + """ + # 1. 维度与连续性检查 + # 将输入统一视为 (Total_Tokens, Hidden_Size) 的 2D 视图 + q_view = q.view(-1, q.shape[-1]) + k_view = k.view(-1, k.shape[-1]) + + assert w_q.is_contiguous() and w_k.is_contiguous() + + M = q_view.shape[0] # Total Tokens + assert k_view.shape[0] == M, "Q and K must have the same number of tokens" + + head_dim = w_q.shape[0] + assert w_k.shape[0] == head_dim, "Head dim of Q and K must match" + + # 计算 Head 数量 + N_q = q_view.shape[1] + N_k = k_view.shape[1] + + assert N_q % head_dim == 0 + assert N_k % head_dim == 0 + + num_heads_q = N_q // head_dim + num_heads_k = N_k // head_dim + + # 2. Block Size 设置 + BLOCK_SIZE = triton.next_power_of_2(head_dim) + assert BLOCK_SIZE == head_dim, "Currently only supports head_dim power of 2 (e.g., 64, 128)" + + # 3. 启动 Kernel + # Grid: (Token数量, Q头数 + K头数) + grid = (M, num_heads_q + num_heads_k) + + _qk_rms_norm_fused_kernel[grid]( + q_view, + w_q, + q_view.stride(0), + q_view.stride(1), + k_view, + w_k, + k_view.stride(0), + k_view.stride(1), + num_heads_q=num_heads_q, + head_dim=head_dim, + eps=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=4, + ) + + return q, k diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 9eccddffc1..3915816273 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -5,6 +5,7 @@ import numpy as np import triton from typing import Tuple +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.infer_struct import LlamaInferStateInfo @@ -64,9 +65,11 @@ def _get_qkv( q, cache_kv = qkv.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - layer_weight.q_norm_weight_(q, eps=self.eps_) - layer_weight.k_norm_weight_( + qk_rmsnorm_fused_forward( + q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + layer_weight.q_norm_weight_.weight, + layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) diff --git a/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py b/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py new file mode 100644 index 0000000000..2210972caa --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py @@ -0,0 +1,52 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import ( + qk_rmsnorm_fused_forward, + qk_rmsnorm_forward, +) + + +def test_qk_rmsnorm_fused_matches_reference(): + """Compare fused QK RMSNorm with separate reference RMSNorm kernels.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for qk_rmsnorm_fused test") + + torch.manual_seed(0) + + # 模拟配置: Batch=2, Seq=128, Head_Dim=128 + # Q: 16 Heads, K: 4 Heads (GQA 场景) + B, S, D = 2, 128, 128 + H_Q = 16 + H_K = 4 + + q = torch.randn((B * S, H_Q * D), device="cuda", dtype=torch.bfloat16) + k = torch.randn((B * S, H_K * D), device="cuda", dtype=torch.bfloat16) + + w_q = torch.ones((D,), device="cuda", dtype=torch.bfloat16) + w_k = torch.ones((D,), device="cuda", dtype=torch.bfloat16) + + # 复制一份做对比(reference 会在新 tensor 上计算) + q_ref = q.clone() + k_ref = k.clone() + + # fused kernel in-place 计算 + q_out, k_out = qk_rmsnorm_fused_forward(q, k, w_q, w_k, eps=1e-6) + + # reference: 分别对 Q / K 做 RMSNorm + q_ref_out = qk_rmsnorm_forward(q_ref, w_q, eps=1e-6) + k_ref_out = qk_rmsnorm_forward(k_ref, w_k, eps=1e-6) + + # fused 是 in-place 的,返回的 q_out/k_out 应该与 q/k 引用一致 + assert q_out.data_ptr() == q.data_ptr() + assert k_out.data_ptr() == k.data_ptr() + + # 误差容忍度: 由于 bfloat16 计算,设定一个合理的 atol + q_max_diff = (q_out - q_ref_out).abs().max().item() + k_max_diff = (k_out - k_ref_out).abs().max().item() + + print(f"Q max diff: {q_max_diff}") + print(f"K max diff: {k_max_diff}") + + assert q_max_diff < 1e-5, f"Q max diff too large: {q_max_diff}" + assert k_max_diff < 1e-5, f"K max diff too large: {k_max_diff}" From f25a36ada84f51cb864cc948648c8d9bcdf2964c Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 9 Feb 2026 07:00:33 +0000 Subject: [PATCH 05/11] fix --- .../router/model_infer/mode_backend/generic_post_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index bae4505992..2afa06d38e 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -113,7 +113,7 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor def _random_sample(probs: torch.Tensor): q = torch.empty_like(probs) q.exponential_() - return probs.div_(q).argmax(dim=-1).view(-1) + return probs.div(q).argmax(dim=-1).view(-1) def _get_post_sample_tensors(reqs: List[InferReq]): From c603038227063af1329a9f1e2fa930429412bb1d Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 27 Feb 2026 02:50:06 +0000 Subject: [PATCH 06/11] refactor QKRMSNORMWeight --- .../layer_weights/meta_weights/norm_weight.py | 87 +++++++++++++------ .../layer_infer/transformer_layer_infer.py | 5 +- .../layer_weights/transformer_layer_weight.py | 10 +-- .../layer_infer/transformer_layer_infer.py | 8 +- .../layer_infer/transformer_layer_infer.py | 5 +- .../layer_infer/transformer_layer_infer.py | 5 +- 6 files changed, 71 insertions(+), 49 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index c922bffc45..20416606f9 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -4,7 +4,7 @@ from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward -from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_forward +from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward from .platform_op import PlatformAwareOp @@ -195,47 +195,84 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]): self.weight += 1 -class QKRMSNORMWeight(RMSNormWeight): - def __init__(self, dim: int, weight_name: str, data_type: torch.dtype): - super().__init__(dim=dim, weight_name=weight_name, data_type=data_type) +class QKRMSNORMWeight(BaseWeightTpl, PlatformAwareOp): + def __init__(self, dim: int, q_weight_name: str, k_weight_name: str, data_type: torch.dtype): + super().__init__(tp_rank=0, tp_world_size=1) + self.dim = dim + self.q_weight_name = q_weight_name + self.k_weight_name = k_weight_name + self.data_type_ = data_type + self._create_weight() + + def _create_weight(self): + self.q_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.k_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_) + self.q_weight.load_ok = False + self.k_weight.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.q_weight_name in weights: + self.q_weight.copy_(weights[self.q_weight_name]) + self.q_weight.load_ok = True + if self.k_weight_name in weights: + self.k_weight.copy_(weights[self.k_weight_name]) + self.k_weight.load_ok = True + + def verify_load(self): + return self.q_weight.load_ok and self.k_weight.load_ok def _native_forward( self, - input: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, eps: float, ) -> None: - assert input.ndim == 2 and self.weight.ndim == 1 - assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}" - head_dim = self.weight.shape[0] - x = input.to(torch.float32) - x = x.view(-1, head_dim) - x_var = x - variance = x_var.pow(2).mean(dim=-1, keepdim=True) - x = x * torch.rsqrt(variance + eps) - x = (x * self.weight).to(self.data_type_) - x = x.view(-1, input.shape[-1]) - input.copy_(x) + assert q.ndim == 2 and self.q_weight.ndim == 1 + assert k.ndim == 2 and self.k_weight.ndim == 1 + assert ( + q.shape[-1] % self.dim == 0 + ), f"Expected hidden_size to be multiple of {self.dim}, but found: {q.shape[-1]}" + assert ( + k.shape[-1] % self.dim == 0 + ), f"Expected hidden_size to be multiple of {self.dim}, but found: {k.shape[-1]}" + + head_dim = self.q_weight.shape[0] + + def _norm_inplace(t: torch.Tensor, weight: torch.Tensor): + t_fp32 = t.to(torch.float32) + t_fp32 = t_fp32.view(-1, head_dim) + variance = t_fp32.pow(2).mean(dim=-1, keepdim=True) + t_fp32 = t_fp32 * torch.rsqrt(variance + eps) + t_fp32 = (t_fp32 * weight).to(self.data_type_) + t_fp32 = t_fp32.view(-1, t.shape[-1]) + t.copy_(t_fp32) + + _norm_inplace(q, self.q_weight) + _norm_inplace(k, self.k_weight) return - def _triton_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: - assert input.ndim == 2 and self.weight.ndim == 1 - return qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps) + def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple: + assert q.ndim == 2 and self.q_weight.ndim == 1 + assert k.ndim == 2 and self.k_weight.ndim == 1 + return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps) def _cuda_forward( self, - input: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, eps: float, ) -> None: - self._triton_forward(input=input, eps=eps) + self._triton_forward(q=q, k=k, eps=eps) return - def _musa_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor: + def _musa_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple: # musa implementation is supported by musa triton on musa platform - return self._triton_forward(input=input, eps=eps) + return self._triton_forward(q=q, k=k, eps=eps) def __call__( self, - input: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, eps: float, ) -> None: - return self._forward(input=input, eps=eps) + return self._forward(q=q, k=k, eps=eps) diff --git a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py index 82331f8fb8..725b0cc02e 100644 --- a/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3/layer_infer/transformer_layer_infer.py @@ -24,11 +24,8 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - layer_weight.q_norm_weight_( + layer_weight.qk_norm_weight_( q, - eps=self.eps_, - ) - layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, ) diff --git a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py index 7d2163f283..e1bec490cb 100644 --- a/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3/layer_weights/transformer_layer_weight.py @@ -19,13 +19,9 @@ def _init_weight_names(self): def _init_norm(self): super()._init_norm() - self.q_norm_weight_ = QKRMSNORMWeight( + self.qk_norm_weight_ = QKRMSNORMWeight( dim=self.head_dim, - weight_name=self._q_norm_name, - data_type=self.data_type_, - ) - self.k_norm_weight_ = QKRMSNORMWeight( - dim=self.head_dim, - weight_name=self._k_norm_name, + q_weight_name=self._q_norm_name, + k_weight_name=self._k_norm_name, data_type=self.data_type_, ) diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 3915816273..b85216f22c 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -65,11 +65,9 @@ def _get_qkv( q, cache_kv = qkv.split( [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 ) - qk_rmsnorm_fused_forward( + layer_weight.qk_norm_weight_( q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], - layer_weight.q_norm_weight_.weight, - layer_weight.k_norm_weight_.weight, eps=self.eps_, ) cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) @@ -98,8 +96,8 @@ def _tpsp_get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - layer_weight.q_norm_weight_(q, eps=self.eps_) - layer_weight.k_norm_weight_( + layer_weight.qk_norm_weight_( + q, cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, ) diff --git a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py index d34babaabe..8951a04381 100644 --- a/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl/layer_infer/transformer_layer_infer.py @@ -29,11 +29,8 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - layer_weight.q_norm_weight_( + layer_weight.qk_norm_weight_( q, - eps=self.eps_, - ) - layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, ) diff --git a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py index 4ccc6da372..e90d5552e5 100644 --- a/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py @@ -25,11 +25,8 @@ def _get_qkv( input = input.view(-1, self.embed_dim_) q = layer_weight.q_proj.mm(input) cache_kv = layer_weight.kv_proj.mm(input) - layer_weight.q_norm_weight_( + layer_weight.qk_norm_weight_( q, - eps=self.eps_, - ) - layer_weight.k_norm_weight_( cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], eps=self.eps_, ) From 6069e90fbbba4eb49677bda690eea663a463b1cc Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 27 Feb 2026 06:33:55 +0000 Subject: [PATCH 07/11] qknorm x.to(w.dtype) --- .../common/basemodel/triton_kernel/norm/qk_norm.py | 14 ++++++++------ .../triton_kernel/test_qk_rmsnorm_fused.py | 12 ++++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py index d3fc1b7dbe..fab0141158 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py @@ -110,12 +110,13 @@ def _qk_rms_norm_fused_kernel( rstd = 1 / tl.sqrt(var + eps) # 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重) - w = tl.load(WQ_ptr + offs).to(tl.float32) + w = tl.load(WQ_ptr + offs) - y = x * rstd * w + x *= rstd + y = x.to(w.dtype) * w # 写回 Q - tl.store(Q_ptr + q_ptr_offset, y.to(Q_ptr.dtype.element_ty)) + tl.store(Q_ptr + q_ptr_offset, y) else: # ------------------ 处理 K ------------------ @@ -133,12 +134,13 @@ def _qk_rms_norm_fused_kernel( rstd = 1 / tl.sqrt(var + eps) # 加载 K 的权重 - w = tl.load(WK_ptr + offs).to(tl.float32) + w = tl.load(WK_ptr + offs) + x *= rstd - y = x * rstd * w + y = x.to(w.dtype) * w # 写回 K - tl.store(K_ptr + k_ptr_offset, y.to(K_ptr.dtype.element_ty)) + tl.store(K_ptr + k_ptr_offset, y) def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6): diff --git a/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py b/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py index 2210972caa..1e7d5080ca 100644 --- a/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py +++ b/unit_tests/common/basemodel/triton_kernel/test_qk_rmsnorm_fused.py @@ -7,6 +7,16 @@ ) +def torch_qk_rmsnorm(q, w_q, eps=1e-6): + input_dtype = q.dtype + head_dim = w_q.shape[0] + q_fp32 = q.to(torch.float32) + q_fp32 = q_fp32.view(-1, head_dim) + variance = q_fp32.pow(2).mean(dim=-1, keepdim=True) + q_fp32 = q_fp32 * torch.rsqrt(variance + eps) + return (q_fp32 * w_q.to(input_dtype)).view_as(q) + + def test_qk_rmsnorm_fused_matches_reference(): """Compare fused QK RMSNorm with separate reference RMSNorm kernels.""" if not torch.cuda.is_available(): @@ -36,6 +46,8 @@ def test_qk_rmsnorm_fused_matches_reference(): # reference: 分别对 Q / K 做 RMSNorm q_ref_out = qk_rmsnorm_forward(q_ref, w_q, eps=1e-6) k_ref_out = qk_rmsnorm_forward(k_ref, w_k, eps=1e-6) + # q_ref_out = torch_qk_rmsnorm(q_ref, w_q, eps=1e-6) + # k_ref_out = torch_qk_rmsnorm(k_ref, w_k, eps=1e-6) # fused 是 in-place 的,返回的 q_out/k_out 应该与 q/k 引用一致 assert q_out.data_ptr() == q.data_ptr() From 33e3def1c01babd4c0853d125ef6a58eb6ba4502 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 27 Feb 2026 06:40:33 +0000 Subject: [PATCH 08/11] fix top_k_val != req_obj.vocab_size --- .../mode_backend/generic_post_process.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 2afa06d38e..81ad96b285 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -16,7 +16,8 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_length_penalty_param, b_mask_eos_reqs, is_all_greedy, - is_all_random, + skip_top_k, + skip_top_p, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -69,7 +70,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) - elif is_all_random: + elif skip_top_k and skip_top_p: batch_next_token_ids = _random_sample(probs) batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) @@ -124,7 +125,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param: List[int] = [] mask_eos_reqs: List[bool] = [] is_all_greedy = True - is_all_random = True + skip_top_k = True + skip_top_p = True for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param @@ -140,8 +142,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks.append(top_k_val) if top_k_val > 1: is_all_greedy = False - if top_k_val != -1 or shm_param.top_p != 1.0: - is_all_random = False + if top_k_val != req_obj.vocab_size: + skip_top_k = False + if shm_param.top_p != 1.0: + skip_top_p = False req_idxes.append(req_obj.req_idx) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) @@ -161,5 +165,6 @@ def _get_post_sample_tensors(reqs: List[InferReq]): length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), is_all_greedy, - is_all_random, + skip_top_k, + skip_top_p, ) From 847a2ccc2d045cee2e8d3e62c8fc7c791445fb08 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 27 Feb 2026 09:20:28 +0000 Subject: [PATCH 09/11] support seed --- lightllm/server/api_models.py | 2 ++ lightllm/server/api_openai.py | 2 ++ lightllm/server/core/objs/py_sampling_params.py | 3 +++ lightllm/server/core/objs/sampling_params.py | 3 +++ lightllm/server/router/model_infer/infer_batch.py | 6 ++++++ .../model_infer/mode_backend/generic_post_process.py | 7 +++++-- 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index f30ecc55fe..e194a24d0f 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -142,6 +142,7 @@ class CompletionRequest(BaseModel): top_k: Optional[int] = -1 repetition_penalty: Optional[float] = 1.0 ignore_eos: Optional[bool] = False + seed: Optional[int] = -1 # Class variables to store loaded default values _loaded_defaults: ClassVar[Dict[str, Any]] = {} @@ -217,6 +218,7 @@ class ChatCompletionRequest(BaseModel): top_k: Optional[int] = -1 repetition_penalty: Optional[float] = 1.0 ignore_eos: Optional[bool] = False + seed: Optional[int] = -1 role_settings: Optional[Dict[str, str]] = None character_settings: Optional[List[Dict[str, str]]] = None diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d91bb1d947..11e24612b0 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -208,6 +208,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "n": request.n, "best_of": request.n, "add_special_tokens": False, + "seed": request.seed, } # Structured output handling @@ -537,6 +538,7 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) -> "n": request.n, "best_of": request.best_of, "add_special_tokens": False, + "seed": request.seed, } if request.response_format: diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 887f360c84..51c6add0a5 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -60,6 +60,7 @@ def __init__( move_kv_to_decode_node: Optional[dict] = None, # suggest dp index, deepseekv2 dp mode, use to suggest used dp_index suggested_dp_index: Optional[int] = None, + seed: Optional[int] = -1, ) -> None: self.best_of = best_of self.n = n @@ -91,6 +92,7 @@ def __init__( self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index + self.seed = seed if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -268,6 +270,7 @@ def to_dict(self): ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node + ret["seed"] = self.seed return ret def to_origin_dict(self): diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index d955aa6a87..b9ad314dd9 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -321,6 +321,7 @@ class SamplingParams(ctypes.Structure): ), # whether to add spaces between special tokens when decoding ("print_eos_token", ctypes.c_bool), # eos_id will be always ignored except the value is set to True ("disable_prompt_cache", ctypes.c_bool), # whether to disable prompt cache + ("seed", ctypes.c_int64), # random seed ] _do_sample: bool = False @@ -356,6 +357,7 @@ def init(self, tokenizer, **kwargs): self.add_special_tokens = kwargs.get("add_special_tokens", True) self.add_spaces_between_special_tokens = kwargs.get("add_spaces_between_special_tokens", True) self.print_eos_token = kwargs.get("print_eos_token", False) + self.seed = kwargs.get("seed", -1) self.exponential_decay_length_penalty = ExponentialDecayLengthPenalty() self.exponential_decay_length_penalty.initialize(kwargs.get("exponential_decay_length_penalty", (1, 1.0))) @@ -497,6 +499,7 @@ def to_dict(self): "add_spaces_between_special_tokens": self.add_spaces_between_special_tokens, "print_eos_token": self.print_eos_token, "disable_prompt_cache": self.disable_prompt_cache, + "seed": self.seed, } def to_origin_dict(self): diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4b8b3c538f..4d6aac22be 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -364,6 +364,12 @@ def __init__( self.decode_need_token_num = self._normal_decode_need_token_num self._init_all_state() + + self.generator = None + if getattr(self.sampling_param.shm_param, "seed", -1) != -1: + self.generator = torch.Generator(device="cuda") + self.generator.manual_seed(self.sampling_param.shm_param.seed) + if init_prefix_cache: self._match_radix_cache() return diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 81ad96b285..0592c104bc 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -71,7 +71,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) elif skip_top_k and skip_top_p: - batch_next_token_ids = _random_sample(probs) + batch_next_token_ids = _random_sample(probs, reqs) batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) @@ -111,9 +111,12 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor return probs_sort, probs_idx -def _random_sample(probs: torch.Tensor): +def _random_sample(probs: torch.Tensor, reqs: List[InferReq]): q = torch.empty_like(probs) q.exponential_() + for i, req in enumerate(reqs): + if req.generator is not None: + q[i].exponential_(generator=req.generator) return probs.div(q).argmax(dim=-1).view(-1) From 604074401ea46a8201eb0bdfae5ab40bdc627546 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Sat, 28 Feb 2026 06:02:38 +0000 Subject: [PATCH 10/11] add topp_topk_sampling --- .../mode_backend/generic_post_process.py | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index 0592c104bc..e26e0911e5 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,5 +1,5 @@ import torch -from typing import List +from typing import List, Tuple from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context @@ -75,7 +75,26 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) - elif get_env_start_args().sampling_backend == "triton": + else: + batch_next_token_ids, batch_next_token_logprobs = _top_p_top_k_sample(probs, b_top_ps, b_top_ks) + return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) + + +def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + + probs_sort[torch.arange(0, probs.shape[-1], device="cuda").view(1, -1) >= top_ks.view(-1, 1)] = 0.0 + + return probs_sort, probs_idx + + +def _top_p_top_k_sample( + probs: torch.Tensor, b_top_ps: torch.Tensor, b_top_ks: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + if get_env_start_args().sampling_backend == "triton": probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks) sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) @@ -97,18 +116,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): batch_next_token_probs = torch.gather(probs, dim=1, index=int64_batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) else: - assert False, "dead path" - - -def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor): - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) - - probs_sum = torch.cumsum(probs_sort, dim=-1) - probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 - - probs_sort[torch.arange(0, probs.shape[-1], device="cuda").view(1, -1) >= top_ks.view(-1, 1)] = 0.0 - - return probs_sort, probs_idx + assert False, "Unsupported sampling backend for top_p_top_k_sample" def _random_sample(probs: torch.Tensor, reqs: List[InferReq]): From 812cba1b64b880e78fe2aa259f7203c016d171b0 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 2 Mar 2026 08:31:52 +0000 Subject: [PATCH 11/11] fix sample use exponential_ --- .../server/router/model_infer/infer_batch.py | 2 +- .../mode_backend/generic_post_process.py | 32 ++++++++++++++----- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 4d6aac22be..1b4a1ca5cb 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -366,7 +366,7 @@ def __init__( self._init_all_state() self.generator = None - if getattr(self.sampling_param.shm_param, "seed", -1) != -1: + if self.sampling_param.shm_param.seed != -1: self.generator = torch.Generator(device="cuda") self.generator.manual_seed(self.sampling_param.shm_param.seed) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index e26e0911e5..f3ad03662e 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -18,6 +18,7 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): is_all_greedy, skip_top_k, skip_top_p, + exist_req_use_random_seed, ) = _get_post_sample_tensors(reqs) eos_ids = g_pin_mem_manager.gen_from_list(key="eos_ids", data=eos_id, dtype=torch.int32).cuda(non_blocking=True) @@ -71,12 +72,15 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) elif skip_top_k and skip_top_p: - batch_next_token_ids = _random_sample(probs, reqs) + # topk 等于整个词表,topp 等于1.0,等价于不进行topk topp过滤,直接进行随机采样,可以提升采样速度 + batch_next_token_ids = _random_sample(probs, reqs, exist_req_use_random_seed) batch_next_token_probs = torch.gather(probs, dim=1, index=batch_next_token_ids.view(-1, 1)) return batch_next_token_ids.view(-1), torch.log(batch_next_token_probs).view(-1) else: - batch_next_token_ids, batch_next_token_logprobs = _top_p_top_k_sample(probs, b_top_ps, b_top_ks) + batch_next_token_ids, batch_next_token_logprobs = _top_p_top_k_sample( + reqs, probs, b_top_ps, b_top_ks, exist_req_use_random_seed + ) return batch_next_token_ids.view(-1), batch_next_token_logprobs.view(-1) @@ -92,11 +96,18 @@ def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor def _top_p_top_k_sample( - probs: torch.Tensor, b_top_ps: torch.Tensor, b_top_ks: torch.Tensor + reqs: List[InferReq], + probs: torch.Tensor, + b_top_ps: torch.Tensor, + b_top_ks: torch.Tensor, + exist_req_use_random_seed: bool, ) -> Tuple[torch.Tensor, torch.Tensor]: if get_env_start_args().sampling_backend == "triton": probs_sort, probs_idx = _top_p_top_k(probs, b_top_ps, b_top_ks) - sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) + if not exist_req_use_random_seed: + sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True) + else: + sampled_index = _random_sample(probs_sort, reqs, exist_req_use_random_seed).view(-1, 1) next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index) next_token_logprobs = torch.log(torch.gather(probs_sort, dim=1, index=sampled_index)) return next_token_ids.view(-1), next_token_logprobs.view(-1) @@ -119,12 +130,13 @@ def _top_p_top_k_sample( assert False, "Unsupported sampling backend for top_p_top_k_sample" -def _random_sample(probs: torch.Tensor, reqs: List[InferReq]): +def _random_sample(probs: torch.Tensor, reqs: List[InferReq], exist_req_use_random_seed: bool): q = torch.empty_like(probs) q.exponential_() - for i, req in enumerate(reqs): - if req.generator is not None: - q[i].exponential_(generator=req.generator) + if exist_req_use_random_seed: + for i, req in enumerate(reqs): + if req.generator is not None: + q[i].exponential_(generator=req.generator) return probs.div(q).argmax(dim=-1).view(-1) @@ -138,6 +150,7 @@ def _get_post_sample_tensors(reqs: List[InferReq]): is_all_greedy = True skip_top_k = True skip_top_p = True + exist_req_use_random_seed = False for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param @@ -157,6 +170,8 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_k = False if shm_param.top_p != 1.0: skip_top_p = False + if req_obj.generator is not None: + exist_req_use_random_seed = True req_idxes.append(req_obj.req_idx) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) @@ -178,4 +193,5 @@ def _get_post_sample_tensors(reqs: List[InferReq]): is_all_greedy, skip_top_k, skip_top_p, + exist_req_use_random_seed, )