Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def _select_experts(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)
topk_weights.mul_(self.routed_scaling_factor)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if self.redundancy_expert_num > 0:
redundancy_topk_ids_repair(
topk_ids=topk_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def _select_experts(
num_expert_group=num_expert_group,
scoring_func=scoring_func,
)
topk_weights.mul_(self.routed_scaling_factor)
if self.routed_scaling_factor != 1.0:
topk_weights.mul_(self.routed_scaling_factor)
if self.num_fused_shared_experts > 0:
pad_topk_ids = (
torch.arange(
Expand Down
87 changes: 62 additions & 25 deletions lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_dp, get_dp_world_size
from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.layernorm import layernorm_forward
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_forward
from lightllm.common.basemodel.triton_kernel.norm.qk_norm import qk_rmsnorm_fused_forward
from .platform_op import PlatformAwareOp


Expand Down Expand Up @@ -195,47 +195,84 @@ def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
self.weight += 1


class QKRMSNORMWeight(RMSNormWeight):
def __init__(self, dim: int, weight_name: str, data_type: torch.dtype):
super().__init__(dim=dim, weight_name=weight_name, data_type=data_type)
class QKRMSNORMWeight(BaseWeightTpl, PlatformAwareOp):
def __init__(self, dim: int, q_weight_name: str, k_weight_name: str, data_type: torch.dtype):
super().__init__(tp_rank=0, tp_world_size=1)
self.dim = dim
self.q_weight_name = q_weight_name
self.k_weight_name = k_weight_name
self.data_type_ = data_type
self._create_weight()

def _create_weight(self):
self.q_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_)
self.k_weight: torch.Tensor = torch.empty(self.dim, dtype=self.data_type_, device=self.device_id_)
self.q_weight.load_ok = False
self.k_weight.load_ok = False

def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
if self.q_weight_name in weights:
self.q_weight.copy_(weights[self.q_weight_name])
self.q_weight.load_ok = True
if self.k_weight_name in weights:
self.k_weight.copy_(weights[self.k_weight_name])
self.k_weight.load_ok = True

def verify_load(self):
return self.q_weight.load_ok and self.k_weight.load_ok

def _native_forward(
self,
input: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
eps: float,
) -> None:
assert input.ndim == 2 and self.weight.ndim == 1
assert input.shape[-1] == self.dim, f"Expected hidden_size to be {self.dim}, but found: {input.shape[-1]}"
head_dim = self.weight.shape[0]
x = input.to(torch.float32)
x = x.view(-1, head_dim)
x_var = x
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = (x * self.weight).to(self.data_type_)
x = x.view(-1, input.shape[-1])
input.copy_(x)
assert q.ndim == 2 and self.q_weight.ndim == 1
assert k.ndim == 2 and self.k_weight.ndim == 1
assert (
q.shape[-1] % self.dim == 0
), f"Expected hidden_size to be multiple of {self.dim}, but found: {q.shape[-1]}"
assert (
k.shape[-1] % self.dim == 0
), f"Expected hidden_size to be multiple of {self.dim}, but found: {k.shape[-1]}"

head_dim = self.q_weight.shape[0]

def _norm_inplace(t: torch.Tensor, weight: torch.Tensor):
t_fp32 = t.to(torch.float32)
t_fp32 = t_fp32.view(-1, head_dim)
variance = t_fp32.pow(2).mean(dim=-1, keepdim=True)
t_fp32 = t_fp32 * torch.rsqrt(variance + eps)
t_fp32 = (t_fp32 * weight).to(self.data_type_)
t_fp32 = t_fp32.view(-1, t.shape[-1])
t.copy_(t_fp32)

_norm_inplace(q, self.q_weight)
_norm_inplace(k, self.k_weight)
return

def _triton_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor:
assert input.ndim == 2 and self.weight.ndim == 1
return qk_rmsnorm_forward(x=input, weight=self.weight, eps=eps)
def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
assert q.ndim == 2 and self.q_weight.ndim == 1
assert k.ndim == 2 and self.k_weight.ndim == 1
return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps)

def _cuda_forward(
self,
input: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
eps: float,
) -> None:
self._triton_forward(input=input, eps=eps)
self._triton_forward(q=q, k=k, eps=eps)
return

def _musa_forward(self, input: torch.Tensor, eps: float) -> torch.Tensor:
def _musa_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
# musa implementation is supported by musa triton on musa platform
return self._triton_forward(input=input, eps=eps)
return self._triton_forward(q=q, k=k, eps=eps)

def __call__(
self,
input: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
eps: float,
) -> None:
return self._forward(input=input, eps=eps)
return self._forward(q=q, k=k, eps=eps)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fused_topk(
sgl_ops.topk_softmax(
topk_weights,
topk_ids,
gating_output.float(), # TODO(woosuk): Optimize this.
gating_output,
renormalize=renormalize,
)
return topk_weights, topk_ids
Expand Down
138 changes: 138 additions & 0 deletions lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,141 @@ def qk_rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps):
num_warps=4,
)
return x


@triton.jit
def _qk_rms_norm_fused_kernel(
# Q Pointers & Strides
Q_ptr,
WQ_ptr,
stride_q_row,
stride_q_col,
# K Pointers & Strides
K_ptr,
WK_ptr,
stride_k_row,
stride_k_col,
# Dimensions
num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comments within the new _qk_rms_norm_fused_kernel Triton kernel are in Chinese (e.g., "Q 的头数 (用于判断边界)"). For consistency with the rest of the codebase and to ensure it's understandable for all contributors, please translate these comments to English.

head_dim: tl.constexpr,
eps,
BLOCK_SIZE: tl.constexpr,
):
# PID 0: 处理第几个 Token (Row)
row_idx = tl.program_id(0)
# PID 1: 处理第几个 Head (Combo Index)
# 范围是 [0, num_heads_q + num_heads_k)
combo_head_idx = tl.program_id(1)

# 公共的 offset (0 ~ head_dim)
offs = tl.arange(0, BLOCK_SIZE)

# === 分支逻辑:判断是处理 Q 还是 K ===
if combo_head_idx < num_heads_q:
# ------------------ 处理 Q ------------------
# 指针计算
# Q 的实际 head index 就是 combo_head_idx
Q_ptr += row_idx * stride_q_row

# 定位 Q 数据: Base + Row偏移 + Head偏移 + 列偏移
q_ptr_offset = (combo_head_idx * head_dim + offs) * stride_q_col

# 加载 Q 数据
x = tl.load(Q_ptr + q_ptr_offset).to(tl.float32)
# RMSNorm 计算
var = tl.sum(x * x, axis=0) / head_dim
rstd = 1 / tl.sqrt(var + eps)

# 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重)
w = tl.load(WQ_ptr + offs)

x *= rstd
y = x.to(w.dtype) * w

# 写回 Q
tl.store(Q_ptr + q_ptr_offset, y)

else:
# ------------------ 处理 K ------------------
# 重新映射 K 的 head index (从 0 开始)
k_head_idx = combo_head_idx - num_heads_q

# 指针计算
K_ptr += row_idx * stride_k_row
k_ptr_offset = (k_head_idx * head_dim + offs) * stride_k_col

# 加载 K 数据
x = tl.load(K_ptr + k_ptr_offset).to(tl.float32)
# RMSNorm 计算
var = tl.sum(x * x, axis=0) / head_dim
rstd = 1 / tl.sqrt(var + eps)

# 加载 K 的权重
w = tl.load(WK_ptr + offs)
x *= rstd

y = x.to(w.dtype) * w

# 写回 K
tl.store(K_ptr + k_ptr_offset, y)


def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6):
"""
In-place RMSNorm for both Q and K in a single kernel launch.
Supports GQA (different number of heads for Q and K).

Args:
q: (Total_Tokens, Hidden_Q) or (B, S, H_q, D) -> flattend to 2D inside
k: (Total_Tokens, Hidden_K)
w_q: (head_dim,) Scale parameter for Q
w_k: (head_dim,) Scale parameter for K
"""
# 1. 维度与连续性检查
# 将输入统一视为 (Total_Tokens, Hidden_Size) 的 2D 视图
q_view = q.view(-1, q.shape[-1])
k_view = k.view(-1, k.shape[-1])

assert w_q.is_contiguous() and w_k.is_contiguous()

M = q_view.shape[0] # Total Tokens
assert k_view.shape[0] == M, "Q and K must have the same number of tokens"

head_dim = w_q.shape[0]
assert w_k.shape[0] == head_dim, "Head dim of Q and K must match"

# 计算 Head 数量
N_q = q_view.shape[1]
N_k = k_view.shape[1]

assert N_q % head_dim == 0
assert N_k % head_dim == 0

num_heads_q = N_q // head_dim
num_heads_k = N_k // head_dim

# 2. Block Size 设置
BLOCK_SIZE = triton.next_power_of_2(head_dim)
assert BLOCK_SIZE == head_dim, "Currently only supports head_dim power of 2 (e.g., 64, 128)"

# 3. 启动 Kernel
# Grid: (Token数量, Q头数 + K头数)
grid = (M, num_heads_q + num_heads_k)

_qk_rms_norm_fused_kernel[grid](
q_view,
w_q,
q_view.stride(0),
q_view.stride(1),
k_view,
w_k,
k_view.stride(0),
k_view.stride(1),
num_heads_q=num_heads_q,
head_dim=head_dim,
eps=eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=4,
)

return q, k
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
"num_stages": 3,
"num_warps": 4
},
"192": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 32,
Expand All @@ -35,6 +44,15 @@
"num_stages": 2,
"num_warps": 4
},
"384": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
Expand All @@ -53,6 +71,24 @@
"num_stages": 2,
"num_warps": 4
},
"640": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"768": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
Expand All @@ -79,5 +115,23 @@
"NEED_TRANS": false,
"num_stages": 2,
"num_warps": 4
},
"896": {
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 64,
"NEED_TRANS": false,
"num_stages": 3,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"GROUP_SIZE_M": 32,
"NEED_TRANS": false,
"num_stages": 4,
"num_warps": 4
}
}
Loading