Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions lightllm/common/basemodel/attention/base_att.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,7 @@ def _find_layer_index(
self, k: torch.Tensor, v: torch.Tensor, att_state: Union["BasePrefillAttState", "BaseDecodeAttState"]
) -> int:
kv_buffer = att_state.infer_state.mem_manager.kv_buffer
layer_count = len(kv_buffer)
find_dict = {kv_buffer[i].data_ptr(): i for i in range(layer_count)}
key = min(k.data_ptr(), v.data_ptr())
assert key in find_dict
return find_dict[key]
return kv_buffer.find_layer_index(k, v)


@dataclass
Expand Down
12 changes: 12 additions & 0 deletions lightllm/common/kv_cache_mem_manager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .kv_buffer.kv_buffer import KvBuffer
from .kv_buffer.quant_kv_buffer import QuantKvBuffer, PPLInt4QuantKvBuffer, PPLInt8QuantKvBuffer
from .mem_manager import MemoryManager, ReadOnlyStaticsMemoryManager
from .ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from .ppl_int4kv_mem_manager import PPLINT4KVMemoryManager
Expand All @@ -6,8 +8,18 @@
from .fp8_per_token_group_quant_deepseek3_2mem_manager import FP8PerTokenGroupQuantDeepseek3_2MemoryManager
from .fp8_static_per_head_quant_mem_manager import FP8StaticPerHeadQuantMemManager
from .fp8_static_per_tensor_quant_mem_manager import FP8StaticPerTensorQuantMemManager
from .kv_buffer.kv_buffer_adapter import KvBufferAdapter
from .kv_buffer.hybrid_kv_buffer import HybridKvBuffer
from .kv_buffer.hybrid_kv_buffer_adapter import HybridKvBufferAdapter

__all__ = [
"KvBuffer",
"QuantKvBuffer",
"PPLInt4QuantKvBuffer",
"PPLInt8QuantKvBuffer",
"HybridKvBuffer",
"KvBufferAdapter",
"HybridKvBufferAdapter",
"MemoryManager",
"ReadOnlyStaticsMemoryManager",
"PPLINT4KVMemoryManager",
Expand Down
21 changes: 8 additions & 13 deletions lightllm/common/kv_cache_mem_manager/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import os
import torch.distributed as dist
from lightllm.server.pd_io_struct import KVMoveTask
from .kv_buffer.kv_buffer import KvBuffer
from .mem_manager import MemoryManager
from typing import List, Union, Any
from lightllm.utils.log_utils import init_logger
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node
from lightllm.distributed.pynccl import PyNcclCommunicator
from lightllm.common.kv_trans_kernel.nixl_kv_trans import mla_page_io


logger = init_logger(__name__)
Expand Down Expand Up @@ -45,7 +45,10 @@ def get_cell_size(self):
return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda")
self.kv_buffer = KvBuffer(
torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda"),
head_num=head_num,
)

def alloc_kv_move_buffer(self, max_req_total_len):
self.kv_move_buffer = torch.empty(
Expand Down Expand Up @@ -77,11 +80,8 @@ def write_mem_to_page_kv_move_buffer(
pin_mem_indexes.numpy()[:] = mem_indexes
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
mla_page_io(
mem_indexes=mem_indexes_gpu,
page_tensor=cur_page,
kv_buffer=dp_mems[0].kv_buffer,
mode="write",
dp_mems[0].kv_buffer_adapter.write_to_page_buffer(
mem_indexes=mem_indexes_gpu, page_tensor=cur_page, is_mla=True
)
return

Expand All @@ -99,12 +99,7 @@ def read_page_kv_move_buffer_to_mem(
mem_indexes_gpu = pin_mem_indexes.cuda(non_blocking=True)
dp_mems = mem_managers[(dp_index * dp_world_size) : ((dp_index + 1) * dp_world_size)]
for mem in dp_mems:
mla_page_io(
mem_indexes=mem_indexes_gpu,
page_tensor=cur_page,
kv_buffer=mem.kv_buffer,
mode="read",
)
mem.kv_buffer_adapter.read_from_page_buffer(mem_indexes=mem_indexes_gpu, page_tensor=cur_page, is_mla=True)

def send_to_decode_node(
self,
Expand Down
9 changes: 9 additions & 0 deletions lightllm/common/kv_cache_mem_manager/kv_buffer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .kv_buffer import KvBuffer
from .quant_kv_buffer import QuantKvBuffer, PPLInt4QuantKvBuffer, PPLInt8QuantKvBuffer

__all__ = [
"KvBuffer",
"QuantKvBuffer",
"PPLInt4QuantKvBuffer",
"PPLInt8QuantKvBuffer",
]
95 changes: 95 additions & 0 deletions lightllm/common/kv_cache_mem_manager/kv_buffer/hybrid_kv_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Any, List, Optional

import torch

from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager

from .kv_buffer import KvBuffer


class HybridKvBuffer(KvBuffer):
def __init__(
self,
buffers: List[Optional[torch.Tensor]],
head_num: int,
full_attention_interval: int,
mamba_cache_size: int,
linear_attn_layer_num: int,
conv_state_dtype: torch.dtype,
ssm_state_dtype: torch.dtype,
conv_kernel_size: int,
num_linear_k_heads: int,
num_linear_v_heads: int,
head_linear_k_dim: int,
head_linear_v_dim: int,
):
self._buffers = buffers
self._head_num = head_num
self._full_attention_interval = full_attention_interval
self.mamba_cache_manager = MambaCacheManager(
size=mamba_cache_size,
layer_num=linear_attn_layer_num,
conv_state_dtype=conv_state_dtype,
ssm_state_dtype=ssm_state_dtype,
conv_kernel_size=conv_kernel_size,
num_linear_k_heads=num_linear_k_heads,
num_linear_v_heads=num_linear_v_heads,
head_linear_k_dim=head_linear_k_dim,
head_linear_v_dim=head_linear_v_dim,
)

def create_adapter(self):
from .hybrid_kv_buffer_adapter import HybridKvBufferAdapter

return HybridKvBufferAdapter(self)

def get_mamba_cache(self, layer_idx: int):
layer_idx_in_linear = layer_idx - (layer_idx // self._full_attention_interval)
return self.mamba_cache_manager.get_mamba_cache(layer_idx_in_linear)

def __getitem__(self, item):
return self._buffers[item]
Comment on lines +50 to +51
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

HybridKvBuffer does not support multi-dimensional indexing or item assignment, which are used in MemoryManager (e.g., lines 234 and 271). Because self._buffers is a list, indexing it with a tuple or slice will fail. This will cause crashes in models using HybridKvBuffer (like Qwen3Next) when features like PD separation are enabled.

Suggested change
def __getitem__(self, item):
return self._buffers[item]
def __getitem__(self, item):
if isinstance(item, tuple):
return self._buffers[item[0]][item[1:]]
return self._buffers[item]
def __setitem__(self, key, value):
if isinstance(key, tuple):
self._buffers[key[0]][key[1:]] = value
else:
self._buffers[key] = value


def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None:
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv

layer_buffer = self._buffers[layer_index]
if layer_buffer is None:
raise RuntimeError(f"layer {layer_index} does not have kv cache storage")
destindex_copy_kv(kv, mem_index, layer_buffer)

def get_att_input_params(self, layer_index: int) -> Any:
layer_buffer = self._buffers[layer_index]
if layer_buffer is None:
raise RuntimeError(f"layer {layer_index} does not have kv cache storage")
k = layer_buffer[:, : self._head_num, :]
v = layer_buffer[:, self._head_num :, :]
return k, v

def get_index_kv_buffer(self, index: Any) -> dict:
return {"kv_buffer": [None if layer_buffer is None else layer_buffer[index] for layer_buffer in self._buffers]}

def load_index_kv_buffer(self, index: Any, payload: dict) -> None:
for layer_index, layer_payload in enumerate(payload["kv_buffer"]):
if layer_payload is None:
continue
layer_buffer = self._buffers[layer_index]
if layer_buffer is None:
raise RuntimeError(f"layer {layer_index} does not have kv cache storage")
layer_buffer[index].copy_(layer_payload)

def get_device(self) -> int:
for layer_buffer in self._buffers:
if layer_buffer is not None:
return layer_buffer.get_device()
raise RuntimeError("HybridKvBuffer does not contain any kv cache tensor")

def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {
layer_buffer.data_ptr(): layer_index
for layer_index, layer_buffer in enumerate(self._buffers)
if layer_buffer is not None
}
assert key in find_dict
return find_dict[key]
Comment on lines +87 to +95
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to KvBuffer, the dictionary of layer pointers in HybridKvBuffer.find_layer_index should be cached to improve performance.

Suggested change
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {
layer_buffer.data_ptr(): layer_index
for layer_index, layer_buffer in enumerate(self._buffers)
if layer_buffer is not None
}
assert key in find_dict
return find_dict[key]
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
if not hasattr(self, "_layer_ptr_to_idx"):
self._layer_ptr_to_idx = {
layer_buffer.data_ptr(): layer_index
for layer_index, layer_buffer in enumerate(self._buffers)
if layer_buffer is not None
}
assert key in self._layer_ptr_to_idx
return self._layer_ptr_to_idx[key]

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Optional

import torch

from .hybrid_kv_buffer import HybridKvBuffer
from .kv_buffer_adapter import KvBufferAdapter


class HybridKvBufferAdapter(KvBufferAdapter):
def __init__(self, kv_buffer: HybridKvBuffer):
super().__init__(kv_buffer)

def write_to_page_buffer(
self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor, tp_index: int, tp_world_size: int
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support paged kv write")

def read_from_page_buffer(
self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor, tp_index: int, tp_world_size: int
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support paged kv read")

def write_from_mla_page_buffer(self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support mla paged kv write")

def read_from_mla_page_buffer(self, mem_indexes: torch.Tensor, page_tensor: torch.Tensor) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support mla paged kv read")

def load_from_cpu_cache(
self,
gpu_mem_indexes: torch.Tensor,
cpu_kv_cache: torch.Tensor,
cpu_kv_cache_scale: Optional[torch.Tensor],
page_indexes: torch.Tensor,
tp_index: int,
tp_world_size: int,
grid_num: int,
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support cpu cache load")

def offload_to_cpu_cache(
self,
token_indexes: torch.Tensor,
cpu_kv_cache: torch.Tensor,
cpu_kv_cache_scale: Optional[torch.Tensor],
page_indexes: torch.Tensor,
page_readies: torch.Tensor,
tp_index: int,
tp_world_size: int,
grid_num: int,
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support cpu cache offload")

def copy_kv_from_other_dp_ranks(
self,
mem_managers,
move_token_indexes: torch.Tensor,
token_dp_indexes: torch.Tensor,
mem_indexes: torch.Tensor,
dp_size_in_node: int,
rank_in_dp: int,
) -> None:
raise NotImplementedError(f"{self.__class__.__name__} does not support dp kv copy")
65 changes: 65 additions & 0 deletions lightllm/common/kv_cache_mem_manager/kv_buffer/kv_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any, Optional

import torch


class KvBuffer:
"""KV cache 的数据封装类。

这个类的职责是管理 kv buffer 本身的存储与访问语义,关注点是
"这块缓存里存了什么、怎么按层读写、怎么导入导出"。
因此这里的方法应当主要围绕 kv buffer 自身的数据操作展开,
不承载 page io、cpu cache、dp 传输这类业务流程逻辑。
"""

def __init__(self, buffer: torch.Tensor, head_num: int):
self._buffer = buffer
self._head_num = head_num

def create_adapter(self):
# 业务逻辑由 adapter 承接,KvBuffer 只负责提供底层存储对象。
from .kv_buffer_adapter import KvBufferAdapter

return KvBufferAdapter(self)

def __getitem__(self, item):
return self._buffer[item]
Comment on lines +25 to +26
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The KvBuffer class is missing the __setitem__ method. Since self.kv_buffer in MemoryManager is now a KvBuffer object rather than a raw tensor, direct item assignment (e.g., self.kv_buffer[slice, ...] = value) will raise a TypeError. This is critical for methods like MemoryManager._write_kv_move_data and Deepseek2MemoryManager._write_kv_move_data which rely on this behavior for PD separation.

Suggested change
def __getitem__(self, item):
return self._buffer[item]
def __getitem__(self, item):
return self._buffer[item]
def __setitem__(self, key, value):
self._buffer[key] = value


@property
def shape(self):
return self._buffer.shape

def get_storage_tensor(self) -> torch.Tensor:
return self._buffer

def get_storage_data_ptr(self) -> int:
return self._buffer.data_ptr()

def get_scale_buffer(self) -> Optional[torch.Tensor]:
return None

def copy_kv_to_mem_manager(self, layer_index: int, mem_index: torch.Tensor, kv: torch.Tensor) -> None:
from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv

destindex_copy_kv(kv, mem_index, self._buffer[layer_index])

def get_att_input_params(self, layer_index: int) -> Any:
layer_buffer = self._buffer[layer_index]
k = layer_buffer[:, : self._head_num, :]
v = layer_buffer[:, self._head_num :, :]
return k, v

def get_index_kv_buffer(self, index: Any) -> dict:
return {"kv_buffer": self._buffer[:, index]}

def load_index_kv_buffer(self, index: Any, payload: dict) -> None:
self._buffer[:, index].copy_(payload["kv_buffer"])

def get_device(self) -> int:
return self._buffer.get_device()

def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
assert key in find_dict
return find_dict[key]
Comment on lines +61 to +65
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The find_layer_index method builds a dictionary of layer pointers on every call. Since the KV buffer is pre-allocated and its layer pointers are constant, this dictionary should be cached to avoid unnecessary overhead during inference.

Suggested change
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
find_dict = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
assert key in find_dict
return find_dict[key]
def find_layer_index(self, k: torch.Tensor, v: torch.Tensor) -> int:
key = min(k.data_ptr(), v.data_ptr())
if not hasattr(self, "_layer_ptr_to_idx"):
self._layer_ptr_to_idx = {self._buffer[i].data_ptr(): i for i in range(len(self._buffer))}
assert key in self._layer_ptr_to_idx
return self._layer_ptr_to_idx[key]

Loading
Loading