diff --git a/lightllm/common/basemodel/attention/base_att.py b/lightllm/common/basemodel/attention/base_att.py index 1286a46ec2..652d6228ca 100644 --- a/lightllm/common/basemodel/attention/base_att.py +++ b/lightllm/common/basemodel/attention/base_att.py @@ -41,11 +41,7 @@ def _find_layer_index( self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"] ) -> int: kv_buffer = att_state.infer_state.mem_manager.kv_buffer - layer_count = len(kv_buffer) - find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)} - key = min(k.data_ptr(), v.data_ptr()) - assert key in find_dict - return find_dict[key] + return kv_buffer.find_layer_index(k, v) @dataclass diff --git a/lightllm/common/kv_cache_mem_manager/__init__.py b/lightllm/common/kv_cache_mem_manager/__init__.py index 79e75b3485..82bce0a805 100644 --- a/lightllm/common/kv_cache_mem_manager/__init__.py +++ b/lightllm/common/kv_cache_mem_manager/__init__.py @@ -1,3 +1,5 @@ +from .kv_buffer.kv_buffer import KvBuffer +from .kv_buffer.quant_kv_buffer import QuantKvBuffer, PPLInt4QuantKvBuffer, PPLInt8QuantKvBuffer from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager @@ -6,8 +8,18 @@ from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager +from .kv_buffer.kv_buffer_adapter import KvBufferAdapter +from .kv_buffer.hybrid_kv_buffer import HybridKvBuffer +from .kv_buffer.hybrid_kv_buffer_adapter import HybridKvBufferAdapter __all__ = [ + "KvBuffer", + "QuantKvBuffer", + "PPLInt4QuantKvBuffer", + "PPLInt8QuantKvBuffer", + "HybridKvBuffer", + "KvBufferAdapter", + "HybridKvBufferAdapter", "MemoryManager", "ReadOnlyStaticsMemoryManager", "PPLINT4KVMemoryManager", diff --git a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py index 3d93e1b070..53bb8c009f 100644 --- a/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py @@ -2,13 +2,13 @@ import os import torch.distributed as dist from lightllm.server.pd_io_struct import KVMoveTask +from .kv_buffer.kv_buffer import KvBuffer from .mem_manager import MemoryManager from typing import List, Union, Any from lightllm.utils.log_utils import init_logger from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator -from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io logger = init_logger(__name__) @@ -45,7 +45,10 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + self.kv_buffer = KvBuffer( + torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda"), + head_num=head_num, + ) def alloc_kv_move_buffer(self, max_req_total_len): self.kv_move_buffer = torch.empty( @@ -77,11 +80,8 @@ def write_mem_to_page_kv_move_buffer( pin_mem_indexes.numpy()[:] = mem_indexes mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] - mla_page_io( - mem_indexes=mem_indexes_gpu, - page_tensor=cur_page, - kv_buffer=dp_mems[0].kv_buffer, - mode="write", + dp_mems[0].kv_buffer_adapter.write_to_page_buffer( + mem_indexes=mem_indexes_gpu, page_tensor=cur_page, is_mla=True ) return @@ -99,12 +99,7 @@ def read_page_kv_move_buffer_to_mem( mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] for mem in dp_mems: - mla_page_io( - mem_indexes=mem_indexes_gpu, - page_tensor=cur_page, - kv_buffer=mem.kv_buffer, - mode="read", - ) + mem.kv_buffer_adapter.read_from_page_buffer(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, is_mla=True) def send_to_decode_node( self, diff --git a/lightllm/common/kv_cache_mem_manager/kv_buffer/__init__.py b/lightllm/common/kv_cache_mem_manager/kv_buffer/__init__.py new file mode 100644 index 0000000000..f6532f8172 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/kv_buffer/__init__.py @@ -0,0 +1,9 @@ +from .kv_buffer import KvBuffer +from .quant_kv_buffer import QuantKvBuffer, PPLInt4QuantKvBuffer, PPLInt8QuantKvBuffer + +__all__ = [ + "KvBuffer", + "QuantKvBuffer", + "PPLInt4QuantKvBuffer", + "PPLInt8QuantKvBuffer", +] diff --git a/lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer.py b/lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer.py new file mode 100644 index 0000000000..0dba722cdb --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer.py @@ -0,0 +1,95 @@ +from typing import Any, List, Optional + +import torch + +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + +from .kv_buffer import KvBuffer + + +class HybridKvBuffer(KvBuffer): + def __init__( + self, + buffers: List[Optional[torch.Tensor]], + head_num: int, + full_attention_interval: int, + mamba_cache_size: int, + linear_attn_layer_num: int, + conv_state_dtype: torch.dtype, + ssm_state_dtype: torch.dtype, + conv_kernel_size: int, + num_linear_k_heads: int, + num_linear_v_heads: int, + head_linear_k_dim: int, + head_linear_v_dim: int, + ): + self._buffers = buffers + self._head_num = head_num + self._full_attention_interval = full_attention_interval + self.mamba_cache_manager = MambaCacheManager( + size=mamba_cache_size, + layer_num=linear_attn_layer_num, + conv_state_dtype=conv_state_dtype, + ssm_state_dtype=ssm_state_dtype, + conv_kernel_size=conv_kernel_size, + num_linear_k_heads=num_linear_k_heads, + num_linear_v_heads=num_linear_v_heads, + head_linear_k_dim=head_linear_k_dim, + head_linear_v_dim=head_linear_v_dim, + ) + + def create_adapter(self): + from .hybrid_kv_buffer_adapter import HybridKvBufferAdapter + + return HybridKvBufferAdapter(self) + + def get_mamba_cache(self, layer_idx: int): + layer_idx_in_linear = layer_idx - (layer_idx // self._full_attention_interval) + return self.mamba_cache_manager.get_mamba_cache(layer_idx_in_linear) + + def __getitem__(self, item): + return self._buffers[item] + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None: + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv + + layer_buffer = self._buffers[layer_index] + if layer_buffer is None: + raise RuntimeError(f"layer {layer_index} does not have kv cache storage") + destindex_copy_kv(kv, mem_index, layer_buffer) + + def get_att_input_params(self, layer_index: int) -> Any: + layer_buffer = self._buffers[layer_index] + if layer_buffer is None: + raise RuntimeError(f"layer {layer_index} does not have kv cache storage") + k = layer_buffer[:, : self._head_num, :] + v = layer_buffer[:, self._head_num :, :] + return k, v + + def get_index_kv_buffer(self, index: Any) -> dict: + return {"kv_buffer": [None if layer_buffer is None else layer_buffer[index] for layer_buffer in self._buffers]} + + def load_index_kv_buffer(self, index: Any, payload: dict) -> None: + for layer_index, layer_payload in enumerate(payload["kv_buffer"]): + if layer_payload is None: + continue + layer_buffer = self._buffers[layer_index] + if layer_buffer is None: + raise RuntimeError(f"layer {layer_index} does not have kv cache storage") + layer_buffer[index].copy_(layer_payload) + + def get_device(self) -> int: + for layer_buffer in self._buffers: + if layer_buffer is not None: + return layer_buffer.get_device() + raise RuntimeError("HybridKvBuffer does not contain any kv cache tensor") + + def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: + key = min(k.data_ptr(), v.data_ptr()) + find_dict = { + layer_buffer.data_ptr(): layer_index + for layer_index, layer_buffer in enumerate(self._buffers) + if layer_buffer is not None + } + assert key in find_dict + return find_dict[key] diff --git a/lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer_adapter.py b/lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer_adapter.py new file mode 100644 index 0000000000..3c469448e1 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer_adapter.py @@ -0,0 +1,63 @@ +from typing import Optional + +import torch + +from .hybrid_kv_buffer import HybridKvBuffer +from .kv_buffer_adapter import KvBufferAdapter + + +class HybridKvBufferAdapter(KvBufferAdapter): + def __init__(self, kv_buffer: HybridKvBuffer): + super().__init__(kv_buffer) + + def write_to_page_buffer( + self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor, tp_index: int, tp_world_size: int + ) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support paged kv write") + + def read_from_page_buffer( + self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor, tp_index: int, tp_world_size: int + ) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support paged kv read") + + def write_from_mla_page_buffer(self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support mla paged kv write") + + def read_from_mla_page_buffer(self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support mla paged kv read") + + def load_from_cpu_cache( + self, + gpu_mem_indexes: torch.Tensor, + cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], + page_indexes: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + ) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support cpu cache load") + + def offload_to_cpu_cache( + self, + token_indexes: torch.Tensor, + cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + ) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support cpu cache offload") + + def copy_kv_from_other_dp_ranks( + self, + mem_managers, + move_token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + mem_indexes: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, + ) -> None: + raise NotImplementedError(f"{self.__class__.__name__} does not support dp kv copy") diff --git a/lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer.py b/lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer.py new file mode 100644 index 0000000000..a8e6963714 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer.py @@ -0,0 +1,65 @@ +from typing import Any, Optional + +import torch + + +class KvBuffer: + """KV cache 的数据封装类。 + + 这个类的职责是管理 kv buffer 本身的存储与访问语义,关注点是 + "这块缓存里存了什么、怎么按层读写、怎么导入导出"。 + 因此这里的方法应当主要围绕 kv buffer 自身的数据操作展开, + 不承载 page io、cpu cache、dp 传输这类业务流程逻辑。 + """ + + def __init__(self, buffer: torch.Tensor, head_num: int): + self._buffer = buffer + self._head_num = head_num + + def create_adapter(self): + # 业务逻辑由 adapter 承接,KvBuffer 只负责提供底层存储对象。 + from .kv_buffer_adapter import KvBufferAdapter + + return KvBufferAdapter(self) + + def __getitem__(self, item): + return self._buffer[item] + + @property + def shape(self): + return self._buffer.shape + + def get_storage_tensor(self) -> torch.Tensor: + return self._buffer + + def get_storage_data_ptr(self) -> int: + return self._buffer.data_ptr() + + def get_scale_buffer(self) -> Optional[torch.Tensor]: + return None + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None: + from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv + + destindex_copy_kv(kv, mem_index, self._buffer[layer_index]) + + def get_att_input_params(self, layer_index: int) -> Any: + layer_buffer = self._buffer[layer_index] + k = layer_buffer[:, : self._head_num, :] + v = layer_buffer[:, self._head_num :, :] + return k, v + + def get_index_kv_buffer(self, index: Any) -> dict: + return {"kv_buffer": self._buffer[:, index]} + + def load_index_kv_buffer(self, index: Any, payload: dict) -> None: + self._buffer[:, index].copy_(payload["kv_buffer"]) + + def get_device(self) -> int: + return self._buffer.get_device() + + def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int: + key = min(k.data_ptr(), v.data_ptr()) + find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))} + assert key in find_dict + return find_dict[key] diff --git a/lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer_adapter.py b/lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer_adapter.py new file mode 100644 index 0000000000..48bb072441 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer_adapter.py @@ -0,0 +1,143 @@ +from typing import Optional + +import torch + +from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu, load_cpu_kv_to_gpu +from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp +from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io, page_io + +from .kv_buffer import KvBuffer + + +class KvBufferAdapter: + """与 kv buffer 相关的业务适配类。 + + 这个类的职责是承接 page io、cpu cache、dp 传输等业务流程。 + 这些能力会使用 kv buffer,但 kv buffer 只是业务函数的输入之一, + 并不是这里唯一的关注对象;方法通常还会组合 page tensor、索引、 + tp/dp 上下文等额外业务参数一起完成操作。 + """ + + def __init__(self, kv_buffer: KvBuffer): + self.kv_buffer = kv_buffer + # dp copy 路径会缓存远端 kv buffer 的底层地址,避免重复构建。 + self._mem_ptrs_tensor = None + + def write_to_page_buffer( + self, + mem_indexes: torch.Tensor, + page_tensor: torch.Tensor, + tp_index: int = 0, + tp_world_size: int = 1, + is_mla: bool = False, + ) -> None: + if is_mla: + mla_page_io( + mem_indexes=mem_indexes, + page_tensor=page_tensor, + kv_buffer=self.kv_buffer.get_storage_tensor(), + mode="write", + ) + else: + page_io( + mem_indexes=mem_indexes, + page_tensor=page_tensor, + kv_buffer=self.kv_buffer.get_storage_tensor(), + tp_index=tp_index, + tp_world_size=tp_world_size, + mode="write", + ) + + def read_from_page_buffer( + self, + mem_indexes: torch.Tensor, + page_tensor: torch.Tensor, + tp_index: int = 0, + tp_world_size: int = 1, + is_mla: bool = False, + ) -> None: + if is_mla: + mla_page_io( + mem_indexes=mem_indexes, + page_tensor=page_tensor, + kv_buffer=self.kv_buffer.get_storage_tensor(), + mode="read", + ) + else: + page_io( + mem_indexes=mem_indexes, + page_tensor=page_tensor, + kv_buffer=self.kv_buffer.get_storage_tensor(), + tp_index=tp_index, + tp_world_size=tp_world_size, + mode="read", + ) + + def copy_kv_from_other_dp_ranks( + self, + mem_managers, + move_token_indexes: torch.Tensor, + token_dp_indexes: torch.Tensor, + mem_indexes: torch.Tensor, + dp_size_in_node: int, + rank_in_dp: int, + ) -> None: + if self._mem_ptrs_tensor is None: + mems_ptr_list = [mem.kv_buffer.get_storage_data_ptr() for mem in mem_managers] + self._mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True) + + kv_trans_for_dp( + input_mems=self._mem_ptrs_tensor.cuda(non_blocking=True), + input_idx=move_token_indexes, + input_dp_idx=token_dp_indexes, + output=self.kv_buffer.get_storage_tensor(), + output_idx=mem_indexes, + dp_size_in_node=dp_size_in_node, + rank_in_dp=rank_in_dp, + ) + + def load_from_cpu_cache( + self, + gpu_mem_indexes: torch.Tensor, + cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], + page_indexes: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + ) -> None: + load_cpu_kv_to_gpu( + gpu_mem_indexes=gpu_mem_indexes, + gpu_kv_cache=self.kv_buffer.get_storage_tensor(), + gpu_kv_cache_scale=self.kv_buffer.get_scale_buffer(), + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, + page_indexes=page_indexes, + tp_index=tp_index, + tp_world_size=tp_world_size, + grid_num=grid_num, + ) + + def offload_to_cpu_cache( + self, + token_indexes: torch.Tensor, + cpu_kv_cache: torch.Tensor, + cpu_kv_cache_scale: Optional[torch.Tensor], + page_indexes: torch.Tensor, + page_readies: torch.Tensor, + tp_index: int, + tp_world_size: int, + grid_num: int, + ) -> None: + offload_gpu_kv_to_cpu( + token_indexes=token_indexes, + gpu_kv_cache=self.kv_buffer.get_storage_tensor(), + gpu_kv_cache_scale=self.kv_buffer.get_scale_buffer(), + cpu_kv_cache=cpu_kv_cache, + cpu_kv_cache_scale=cpu_kv_cache_scale, + page_indexes=page_indexes, + page_readies=page_readies, + tp_index=tp_index, + tp_world_size=tp_world_size, + grid_num=grid_num, + ) diff --git a/lightllm/common/kv_cache_mem_manager/kv_buffer/quant_kv_buffer.py b/lightllm/common/kv_cache_mem_manager/kv_buffer/quant_kv_buffer.py new file mode 100644 index 0000000000..7581754206 --- /dev/null +++ b/lightllm/common/kv_cache_mem_manager/kv_buffer/quant_kv_buffer.py @@ -0,0 +1,69 @@ +from typing import Any, Optional + +import torch + +from .kv_buffer import KvBuffer + + +class QuantKvBuffer(KvBuffer): + def __init__( + self, + buffer: torch.Tensor, + scale_buffer: torch.Tensor, + head_num: int, + quant_group_size: Optional[int] = None, + ): + super().__init__(buffer=buffer, head_num=head_num) + self.scale_buffer = scale_buffer + self.quant_group_size = quant_group_size + + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None: + raise NotImplementedError("QuantKvBuffer subclass must implement quantized token writes") + + def get_scale_buffer(self) -> torch.Tensor: + return self.scale_buffer + + def get_att_input_params(self, layer_index: int) -> Any: + layer_buffer = self._buffer[layer_index] + layer_scale_buffer = self.scale_buffer[layer_index] + k = layer_buffer[:, : self._head_num, :] + k_scale = layer_scale_buffer[:, : self._head_num, :] + v = layer_buffer[:, self._head_num :, :] + v_scale = layer_scale_buffer[:, self._head_num :, :] + return (k, k_scale), (v, v_scale) + + def get_index_kv_buffer(self, index: Any) -> dict: + return { + "kv_buffer": self._buffer[:, index], + "scale_buffer": self.scale_buffer[:, index], + } + + def load_index_kv_buffer(self, index: Any, payload: dict) -> None: + self._buffer[:, index].copy_(payload["kv_buffer"]) + self.scale_buffer[:, index].copy_(payload["scale_buffer"]) + + +class PPLInt4QuantKvBuffer(QuantKvBuffer): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None: + from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv + + destindex_copy_int4kv( + kv, + mem_index, + self[layer_index], + self.scale_buffer[layer_index], + quant_group_size=self.quant_group_size, + ) + + +class PPLInt8QuantKvBuffer(QuantKvBuffer): + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None: + from lightllm.common.basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv + + destindex_copy_quantize_kv( + kv, + mem_index, + self[layer_index], + self.scale_buffer[layer_index], + quant_group_dim=self.quant_group_size, + ) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index 1203cbdec7..500230d448 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -4,6 +4,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from typing import List, Union, Tuple, Any +from .kv_buffer.kv_buffer import KvBuffer from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_for_dp from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -64,21 +65,21 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False head_dim, layer_num, ) + self._init_kv_buffer_adapter() self.HOLD_TOKEN_MEMINDEX = self.size + def _init_kv_buffer_adapter(self): + self.kv_buffer_adapter = self.kv_buffer.create_adapter() + def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): """ 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 """ - from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv - - destindex_copy_kv(kv, mem_index, self.kv_buffer[layer_index]) + self.kv_buffer.copy_kv_to_mem_manager(layer_index, mem_index, kv) return def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: - k = self.kv_buffer[layer_index][:, : self.head_num, :] - v = self.kv_buffer[layer_index][:, self.head_num :, :] - return k, v + return self.kv_buffer.get_att_input_params(layer_index) def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -108,7 +109,10 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + self.kv_buffer = KvBuffer( + torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda"), + head_num=head_num, + ) def alloc_kv_move_buffer(self, max_req_total_len): """ @@ -148,17 +152,15 @@ def write_mem_to_page_kv_move_buffer( pin_mem_indexes = self._buffer_mem_indexes_tensors[page_index][0 : len(mem_indexes)] pin_mem_indexes.numpy()[:] = mem_indexes mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True) - repeat_count = dp_world_size * self.kv_buffer.shape[2] // self.kv_move_buffer.shape[3] + repeat_count = dp_world_size * (2 * self.head_num) // self.kv_move_buffer.shape[3] dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)] for tp_index in range(dp_world_size): if tp_index % repeat_count == 0: - page_io( + dp_mems[tp_index].kv_buffer_adapter.write_to_page_buffer( mem_indexes=mem_indexes_gpu, page_tensor=cur_page, - kv_buffer=dp_mems[tp_index].kv_buffer, tp_index=tp_index, tp_world_size=dp_world_size, - mode="write", ) # keep for debug # logger.info(f"src token tensor {self.kv_buffer[:, mem_indexes[0], 0, 0]}") @@ -182,13 +184,11 @@ def read_page_kv_move_buffer_to_mem( non_blocking=True ) for tp_index in range(dp_world_size): - page_io( + dp_mems[tp_index].kv_buffer_adapter.read_from_page_buffer( mem_indexes=mem_indexes_gpu, page_tensor=cur_page, - kv_buffer=dp_mems[tp_index].kv_buffer, tp_index=tp_index, tp_world_size=dp_world_size, - mode="read", ) # keep for debug # logger.info(f"dst token tensor {self.kv_buffer[:, mem_indexes[0], 0, 0]}") @@ -340,6 +340,7 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None + self.kv_buffer_adapter = None def alloc(self, need_size) -> torch.Tensor: if need_size > self.mark_end - self.mark_start: @@ -414,13 +415,14 @@ def resize_mem(self, new_size): self.shared_can_use_token_num.set_value(self.can_use_mem_size) self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) + self._init_kv_buffer_adapter() return def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index]} + return self.kv_buffer.get_index_kv_buffer(index) def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + self.kv_buffer.load_index_kv_buffer(index, load_tensor_dict) def copy_kv_from_other_dp_ranks( self, @@ -431,20 +433,11 @@ def copy_kv_from_other_dp_ranks( dp_size_in_node: int, rank_in_dp: int, ): - if not hasattr(self, "mem_ptrs_tensor"): - # 构建一个2D tensor,shape为(layer_num, mem_num) - mems_ptr_list = [] - for i in range(0, len(mem_managers)): - mems_ptr_list.append(mem_managers[i].kv_buffer.data_ptr()) - self.mem_ptrs_tensor = torch.tensor(mems_ptr_list, dtype=torch.uint64, device="cpu", pin_memory=True) - - # 一次性传输所有层 - kv_trans_for_dp( - input_mems=self.mem_ptrs_tensor.cuda(non_blocking=True), - input_idx=move_token_indexes, - input_dp_idx=token_dp_indexes, - output=self.kv_buffer, - output_idx=mem_indexes, + self.kv_buffer_adapter.copy_kv_from_other_dp_ranks( + mem_managers=mem_managers, + move_token_indexes=move_token_indexes, + token_dp_indexes=token_dp_indexes, + mem_indexes=mem_indexes, dp_size_in_node=dp_size_in_node, rank_in_dp=rank_in_dp, ) diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py index 559980dc12..f567e444b7 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int4kv_mem_manager.py @@ -1,6 +1,6 @@ import torch -from typing import Tuple, Any from .mem_manager import MemoryManager +from .kv_buffer.quant_kv_buffer import PPLInt4QuantKvBuffer class PPLINT4KVMemoryManager(MemoryManager): @@ -9,28 +9,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from ..basemodel.triton_kernel.kv_copy.ppl_int4kv_copy_kv import destindex_copy_int4kv - - destindex_copy_int4kv( - kv, - mem_index, - self.kv_buffer[layer_index], - self.scale_buffer[layer_index], - quant_group_size=self.group_quant_size, - ) - return - - def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: - k = self.kv_buffer[layer_index][:, : self.head_num, :] - k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] - v = self.kv_buffer[layer_index][:, self.head_num :, :] - v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] - return (k, k_scale), (v, v_scale) - def get_cell_size(self): return 2 * self.head_num * self.head_dim // 2 * self.layer_num * torch._utils._element_size( self.kv_dtype @@ -39,20 +17,11 @@ def get_cell_size(self): ) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty( - (layer_num, size + 1, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda" - ) - self.scale_buffer = torch.empty( - (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda" + self.kv_buffer = PPLInt4QuantKvBuffer( + torch.empty((layer_num, size + 1, 2 * head_num, head_dim // 2), dtype=torch.int8, device="cuda"), + scale_buffer=torch.empty( + (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda" + ), + head_num=head_num, + quant_group_size=self.group_quant_size, ) - - def _free_buffers(self): - self.kv_buffer = None - self.scale_buffer = None - - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"]) diff --git a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py index 951d72e2c8..ad67a3ca0d 100755 --- a/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/ppl_int8kv_mem_manager.py @@ -1,6 +1,6 @@ import torch -from typing import Tuple, Any from .mem_manager import MemoryManager +from .kv_buffer.quant_kv_buffer import PPLInt8QuantKvBuffer class PPLINT8KVMemoryManager(MemoryManager): @@ -9,28 +9,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True, self.group_quant_size = 8 super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=always_copy, mem_fraction=mem_fraction) - def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor): - """ - 将每一层生成的kv拷贝到mem manager对应mem_index 位置中 - """ - from ..basemodel.triton_kernel.kv_copy.ppl_int8kv_copy_kv import destindex_copy_quantize_kv - - destindex_copy_quantize_kv( - kv, - mem_index, - self.kv_buffer[layer_index], - self.scale_buffer[layer_index], - quant_group_dim=self.group_quant_size, - ) - return - - def get_att_input_params(self, layer_index: int) -> Tuple[Any, Any]: - k = self.kv_buffer[layer_index][:, : self.head_num, :] - k_scale = self.scale_buffer[layer_index][:, : self.head_num, :] - v = self.kv_buffer[layer_index][:, self.head_num :, :] - v_scale = self.scale_buffer[layer_index][:, self.head_num :, :] - return (k, k_scale), (v, v_scale) - def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size( self.kv_dtype @@ -39,18 +17,11 @@ def get_cell_size(self): ) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device="cuda") - self.scale_buffer = torch.empty( - (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda" + self.kv_buffer = PPLInt8QuantKvBuffer( + torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=torch.int8, device="cuda"), + scale_buffer=torch.empty( + (layer_num, size + 1, 2 * head_num, head_dim // self.group_quant_size), dtype=dtype, device="cuda" + ), + head_num=head_num, + quant_group_size=self.group_quant_size, ) - - def _free_buffers(self): - self.kv_buffer = None - self.scale_buffer = None - - def get_index_kv_buffer(self, index): - return {"kv_buffer": self.kv_buffer[:, index], "scale_buffer": self.scale_buffer[:, index]} - - def load_index_kv_buffer(self, index, load_tensor_dict): - self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) - self.scale_buffer[:, index].copy_(load_tensor_dict["scale_buffer"]) diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py index 8602f2e67e..97bdc0a4d3 100644 --- a/lightllm/common/mamba_cache_mem_manager/cache_manager.py +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -199,7 +199,7 @@ def profile_size( ): start_args = get_env_start_args() if self.size is not None and not start_args.disable_dynamic_prompt_cache: - assert self.size < start_args.running_max_req_size * 2, ( + assert self.size >= start_args.running_max_req_size * 2, ( f"error mamba_cache_size({self.size}), ", f"mamba_cache_size should be at least running_max_req_size * 2", f"({start_args.running_max_req_size * 2}), ", diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 169e6ac2d8..3fe7c563d4 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -237,11 +237,13 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): class ReqManagerForMamba(ReqManager): def __init__(self, max_request_num, max_sequence_length, mem_manager): + from lightllm.common.kv_cache_mem_manager.kv_buffer.hybrid_kv_buffer import HybridKvBuffer from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager super().__init__(max_request_num, max_sequence_length, mem_manager) self.mtp_step = get_env_start_args().mtp_step - self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager + assert isinstance(self.mem_manager.kv_buffer, HybridKvBuffer) + self.buffer_mem_manager: MambaCacheManager = self.mem_manager.kv_buffer.mamba_cache_manager self.req_to_buffer_index = torch.zeros( (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda" ) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index dce5e96b31..819d62feea 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -230,7 +230,7 @@ def gdn_forward( assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) input = input.view(-1, self.embed_dim_) - conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) + conv_states, ssm_states = infer_state.mem_manager.kv_buffer.get_mamba_cache(self.layer_num_) mixed_qkvzba = layer_weight.linear_in_proj.mm(input) mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 12a6d56b8c..2aff2b2142 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -1,9 +1,10 @@ import torch -from typing import Tuple +import torch.distributed as dist from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.kv_buffer.hybrid_kv_buffer import HybridKvBuffer from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager -from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager -from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory logger = init_logger(__name__) @@ -36,41 +37,93 @@ def __init__( self.layer_num = layer_num self.full_attn_layer_num = layer_num // full_attention_interval self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + self.linear_attn_cache_size = linear_attn_cache_size + self.conv_state_dtype = conv_state_dtype + self.ssm_state_dtype = ssm_state_dtype + self.conv_kernel_size = conv_kernel_size + self.num_linear_k_heads = num_linear_k_heads + self.num_linear_v_heads = num_linear_v_heads + self.head_linear_k_dim = head_linear_k_dim + self.head_linear_v_dim = head_linear_v_dim - self.mamba_cache_mem_manager = MambaCacheManager( - size=linear_attn_cache_size, - layer_num=self.linear_attn_layer_num, - conv_state_dtype=conv_state_dtype, - ssm_state_dtype=ssm_state_dtype, - conv_kernel_size=conv_kernel_size, - num_linear_k_heads=num_linear_k_heads, - num_linear_v_heads=num_linear_v_heads, - head_linear_k_dim=head_linear_k_dim, - head_linear_v_dim=head_linear_v_dim, + super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + + def profile_size(self, mem_fraction): + if self.size is not None: + return + + world_size = dist.get_world_size() + total_memory = get_total_gpu_memory() + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + mamba_cell_size = ( + self.linear_attn_layer_num + * conv_dim + * (self.conv_kernel_size - 1) + * torch._utils._element_size(self.conv_state_dtype) + ) + ( + self.linear_attn_layer_num + * self.num_linear_v_heads + * self.head_linear_k_dim + * self.head_linear_v_dim + * torch._utils._element_size(self.ssm_state_dtype) ) - super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + if self.linear_attn_cache_size is None: + start_args = get_env_start_args() + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 + self.linear_attn_cache_size = int(available_memory * mamba_cache_ratio * 1024 ** 3 / mamba_cell_size) + reserved_mamba_memory = self.linear_attn_cache_size * mamba_cell_size / (1024 ** 3) + available_memory -= reserved_mamba_memory + + cell_size = self.get_cell_size() + self.size = int(available_memory * 1024 ** 3 / cell_size) + if world_size > 1: + tensor = torch.tensor(self.size, dtype=torch.int64, device="cuda") + dist.all_reduce(tensor, op=dist.ReduceOp.MIN) + self.size = tensor.item() + + logger.info( + f"{available_memory} GB space is available for full attention kv cache after reserving " + f"{reserved_mamba_memory} GB for mamba cache\n" + f"{cell_size / 1024 ** 2} MB is the size of one token kv cache\n" + f"{self.size} is the profiled max_total_token_num with the mem_fraction {mem_fraction}\n" + ) + return def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., # None, kv_cache, mtp_kv_cache, mtp_kv_cache] # Only full attention layers have KV cache. - self.kv_buffer = [None for _ in range(self.layer_num)] + kv_buffers = [None for _ in range(self.layer_num)] for layer_id in range(self.full_attn_layer_num): - self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( + kv_buffers[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" ) + self.kv_buffer = HybridKvBuffer( + kv_buffers, + head_num=head_num, + full_attention_interval=self.full_attention_interval, + mamba_cache_size=self.linear_attn_cache_size, + linear_attn_layer_num=self.linear_attn_layer_num, + conv_state_dtype=self.conv_state_dtype, + ssm_state_dtype=self.ssm_state_dtype, + conv_kernel_size=self.conv_kernel_size, + num_linear_k_heads=self.num_linear_k_heads, + num_linear_v_heads=self.num_linear_v_heads, + head_linear_k_dim=self.head_linear_k_dim, + head_linear_v_dim=self.head_linear_v_dim, + ) def free_all(self): super().free_all() - self.mamba_cache_mem_manager.free_all() + self.kv_buffer.mamba_cache_manager.free_all() return def get_cell_size(self): # Only full attention layers and MTP layers have KV cache kv_cache_layer_num = self.full_attn_layer_num return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) - - def get_mamba_cache(self, layer_idx: int): - layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) - return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 44bb269ed8..b7bbbac14f 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -4,6 +4,7 @@ from sortedcontainers import SortedSet from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.common.kv_cache_mem_manager.kv_buffer.hybrid_kv_buffer import HybridKvBuffer from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager from lightllm.utils.log_utils import init_logger @@ -13,8 +14,8 @@ class HybridRadixCache(RadixCache): def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) - assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") - self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager + assert isinstance(kv_cache_mem_manager.kv_buffer, HybridKvBuffer) + self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.kv_buffer.mamba_cache_manager self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) def match_prefix(self, key, update_refs=False): diff --git a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py index 7c4168a937..7849dd74ce 100644 --- a/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/multi_level_kv_cache.py @@ -9,7 +9,6 @@ from lightllm.utils.envs_utils import get_env_start_args from ..infer_batch import InferReq from lightllm.utils.dist_utils import create_new_group_for_current_dp -from lightllm.common.basemodel.triton_kernel.kv_cache_offload import offload_gpu_kv_to_cpu, load_cpu_kv_to_gpu from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.utils.log_utils import init_logger @@ -96,29 +95,26 @@ def load_cpu_cache_to_reqs(self, reqs: List[InferReq]): grid_num = 16 mem_manager = self.backend.model.mem_manager - if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: + kv_scale_buffer = mem_manager.kv_buffer.get_scale_buffer() + if kv_scale_buffer is not None: cpu_cache_meta = self.cpu_cache_client.kv_cache_tensor_meta cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor[ :, :, :, :, 0 : cpu_cache_meta.head_dim ] cpu_kv_cache_scale = self.cpu_cache_client.cpu_kv_cache_tensor[ :, :, :, :, cpu_cache_meta.head_dim : - ].view(mem_manager.scale_buffer.dtype) - gpu_kv_cache_scale = mem_manager.scale_buffer + ].view(kv_scale_buffer.dtype) else: cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor cpu_kv_cache_scale = None - gpu_kv_cache_scale = None mem_indexes_cuda = mem_indexes.cuda(non_blocking=True) page_indexes_cuda = torch.tensor(need_pages, dtype=torch.int32, device="cpu").cuda( non_blocking=True ) # 将 cpu page 的内容拷贝到 gpu 页面中 - load_cpu_kv_to_gpu( + mem_manager.kv_buffer_adapter.load_from_cpu_cache( gpu_mem_indexes=mem_indexes_cuda, - gpu_kv_cache=mem_manager.kv_buffer, - gpu_kv_cache_scale=gpu_kv_cache_scale, cpu_kv_cache=cpu_kv_cache, cpu_kv_cache_scale=cpu_kv_cache_scale, page_indexes=page_indexes_cuda, @@ -260,23 +256,20 @@ def _start_kv_cache_offload_task( grid_num = 16 mem_manager = self.backend.model.mem_manager - if hasattr(mem_manager, "scale_buffer") and mem_manager.scale_buffer is not None: + kv_scale_buffer = mem_manager.kv_buffer.get_scale_buffer() + if kv_scale_buffer is not None: cpu_cache_meta = self.cpu_cache_client.kv_cache_tensor_meta cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor[:, :, :, :, 0 : cpu_cache_meta.head_dim] cpu_kv_cache_scale = self.cpu_cache_client.cpu_kv_cache_tensor[ :, :, :, :, cpu_cache_meta.head_dim : - ].view(mem_manager.scale_buffer.dtype) - gpu_kv_cache_scale = mem_manager.scale_buffer + ].view(kv_scale_buffer.dtype) else: cpu_kv_cache = self.cpu_cache_client.cpu_kv_cache_tensor cpu_kv_cache_scale = None - gpu_kv_cache_scale = None # assert max(page_list) < self.cpu_cache_client.cpu_kv_cache_tensor.shape[0] - offload_gpu_kv_to_cpu( + mem_manager.kv_buffer_adapter.offload_to_cpu_cache( token_indexes=token_indexes, - gpu_kv_cache=mem_manager.kv_buffer, - gpu_kv_cache_scale=gpu_kv_cache_scale, cpu_kv_cache=cpu_kv_cache, cpu_kv_cache_scale=cpu_kv_cache_scale, page_indexes=cuda_page_indexes,