From c77c837e0a92a8ebd44d33c962dde064de782c92 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 5 Feb 2026 09:59:11 +0000 Subject: [PATCH 1/2] optimize support qkv repeat --- .../meta_weights/mm_weight/mm_weight.py | 15 ++++-- .../meta_weights/mm_weight/rowmm_weight.py | 47 ++++++++++++++++--- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 3630bc2c0..2e666f088 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -104,6 +104,9 @@ def _create_weight(self): ) return + def _get_param_slicer(self, sub_child_index: int): + return self.param_slicer + # 执行顺序 def _load_weight( self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int @@ -113,7 +116,8 @@ def _load_weight( if quanted_param_name in weights: param_name = quanted_param_name if param_name in weights: - weight = self.param_slicer._slice_weight(weights[param_name]) + slicer = self._get_param_slicer(sub_child_index) + weight = slicer._slice_weight(weights[param_name]) self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) return @@ -121,7 +125,8 @@ def _load_bias( self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int ) -> None: if param_name in weights: - bias = self.param_slicer._slice_bias(weights[param_name]) + slicer = self._get_param_slicer(sub_child_index) + bias = slicer._slice_bias(weights[param_name]) self.bias_list[sub_child_index].copy_(bias) self.bias_list[sub_child_index].load_ok = True return @@ -130,7 +135,8 @@ def _load_weight_scale( self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int ) -> None: if param_name in weights: - weight_scale = self.param_slicer._slice_weight_scale(weights[param_name]) + slicer = self._get_param_slicer(sub_child_index) + weight_scale = slicer._slice_weight_scale(weights[param_name]) self.quant_method.load_weight_scale(weight_scale, self.mm_param_list[sub_child_index]) return @@ -138,7 +144,8 @@ def _load_weight_zero_point( self, param_name: Union[str, List[str]], weights: Dict[str, torch.Tensor], sub_child_index: int ) -> None: if param_name in weights: - weight_zero_point = self.param_slicer._slice_weight_zero_point(weights[param_name]) + slicer = self._get_param_slicer(sub_child_index) + weight_zero_point = slicer._slice_weight_zero_point(weights[param_name]) self.quant_method.load_weight_zero_point(weight_zero_point, self.mm_param_list[sub_child_index]) return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index 0b8b6ac56..fb5039836 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -108,15 +108,18 @@ def __init__( ) -> None: self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() - self.repeat_times = 1 + self.q_repeat_times = 1 + self.kv_repeat_times = 1 assert q_head_num % self.tp_world_size_ == 0, ( f"q_head_num must be divisible by tp_world_size_, " f"but found: {q_head_num} % {self.tp_world_size_}" ) - assert kv_head_num % self.tp_world_size_ == 0, ( - f"kv_head_num must be divisible by tp_world_size_" f"but found: {kv_head_num} % {self.tp_world_size_}" + assert kv_head_num % self.tp_world_size_ == 0 or self.tp_world_size_ % kv_head_num == 0, ( + f"kv_head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by kv_head_num, " + f"but found: {kv_head_num} % {self.tp_world_size_}" ) q_hidden_size = (q_head_num // self.tp_world_size_) * head_dim - kv_hidden_size = (kv_head_num // self.tp_world_size_) * head_dim + kv_hidden_size = self._get_tp_padded_head_num(kv_head_num) * head_dim out_dims = [q_hidden_size, kv_hidden_size, kv_hidden_size] super().__init__( in_dim=in_dim, @@ -128,13 +131,45 @@ def __init__( tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_, ) - self.param_slicer = get_row_slice_mixin( + self.q_param_slicer = get_row_slice_mixin( self.quant_method.method_name, tp_rank=self.tp_rank_, tp_world_size=self.tp_world_size_, - repeat_times=self.repeat_times, + repeat_times=self.q_repeat_times, + ) + self.kv_param_slicer = get_row_slice_mixin( + self.quant_method.method_name, + tp_rank=self.tp_rank_, + tp_world_size=self.tp_world_size_, + repeat_times=self.kv_repeat_times, ) + def _get_param_slicer(self, sub_child_index: int): + """ + sub_child_index: + 0 -> q + 1 -> k + 2 -> v + q 使用 q_param_slicer, k / v 使用 kv_param_slicer. + """ + if sub_child_index == 0: + return self.q_param_slicer + else: + return self.kv_param_slicer + + def _get_tp_padded_head_num(self, head_num: int): + if head_num % self.tp_world_size_ == 0: + return head_num // self.tp_world_size_ + elif self.tp_world_size_ % head_num == 0: + self.kv_repeat_times = self.tp_world_size_ // head_num + return self.kv_repeat_times * head_num // self.tp_world_size_ + else: + raise ValueError( + f"head_num must be divisible by tp_world_size_ or " + f"tp_world_size_ must be divisible by head_num, " + f"but found: {head_num} % {self.tp_world_size_}" + ) + class ROWBMMWeight(BMMWeightTpl): def __init__( From 597da700b36d86a8d36f99145c306cb27df2a75e Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 6 Feb 2026 02:20:41 +0000 Subject: [PATCH 2/2] add comments. --- .../layer_weights/meta_weights/mm_weight/mm_weight.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 2e666f088..502169914 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -105,6 +105,13 @@ def _create_weight(self): return def _get_param_slicer(self, sub_child_index: int): + """ + 在部分子类场景中,可能需要不同的切片器,比如qkv场景 + 这里提供一个接口,子类可以重写,这样不同的组成部分可以使用不同的切片器 + 例如 QKVROWNMMWeight,它的q和kv使用不同的切片器 + 当然,大部分场景下,都是返回同一个切片器 + sub_child_index: 用于区分是第几个weight, 方便子类重写时使用 + """ return self.param_slicer # 执行顺序