Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions custom_ops/gpu_ops/send_cache.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "helper.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "remote_cache_kv_ipc.h"

void SendCacheFunc(const paddle::Tensor& qkv,
const paddle::optional<paddle::Tensor>& kv_signal_data) {
const char* fmt_write_cache_completed_signal_str =
std::getenv("FLAGS_fmt_write_cache_completed_signal");
const char* FLAGS_use_pd_disaggregation_per_chunk =
std::getenv("FLAGS_use_pd_disaggregation_per_chunk");
if (fmt_write_cache_completed_signal_str &&
(std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 ||
std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) {
if (FLAGS_use_pd_disaggregation_per_chunk &&
(std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 ||
std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) {
cudaLaunchHostFunc(
qkv.stream(),
&(RemoteCacheKvIpc::
save_cache_kv_complete_signal_layerwise_per_query),
(void*)nullptr);
} else {
if (kv_signal_data) {
cudaLaunchHostFunc(
qkv.stream(),
&RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise,
(void*)(const_cast<int64_t*>(
kv_signal_data.get().data<int64_t>())));
}
}
}
}

PD_BUILD_STATIC_OP(send_cache)
.Inputs({"qkv", paddle::Optional("kv_signal_data")})
.Outputs({"qkv_out"})
.SetInplaceMap({{"qkv", "qkv_out"}})
.SetKernelFn(PD_KERNEL(SendCacheFunc));
1 change: 1 addition & 0 deletions custom_ops/setup_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def find_end_files(directory, end_str):
"gpu_ops/gelu_tanh.cu",
"gpu_ops/reasoning_phase_token_constraint.cu",
"gpu_ops/get_attn_mask_q.cu",
"gpu_ops/send_cache.cu",
]
sm_versions = get_sm_version(archs)
# Some kernels in this file require SM75+ instructions. Exclude them when building SM70 (V100).
Expand Down
20 changes: 16 additions & 4 deletions fastdeploy/cache_manager/cache_messager.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def __init__(
scale_block_bytes = math.prod(key_cache_scale.shape[1:])
if key_cache_scale.dtype == paddle.bfloat16 or key_cache_scale.dtype == paddle.float16:
scale_block_bytes *= 2
elif key_cache_scale.dtype == paddle.float32:
scale_block_bytes *= 4
logger.info(f"scale_block_bytes: {scale_block_bytes}, dtype: {key_cache_scale.dtype}")
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
Expand Down Expand Up @@ -571,6 +573,8 @@ def __init__(
scale_block_bytes = math.prod(key_cache_scale.shape[1:])
if key_cache_scale.dtype == paddle.bfloat16 or key_cache_scale.dtype == paddle.float16:
scale_block_bytes *= 2
elif key_cache_scale.dtype == paddle.float32:
scale_block_bytes *= 4
logger.info(f"scale_block_bytes: {scale_block_bytes}, dtype: {key_cache_scale.dtype}")
logger.info(
f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, "
Expand Down Expand Up @@ -936,6 +940,8 @@ def main():
if args.value_cache_shape:
value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")]
total_gpu_blocks = key_cache_shape_list[0]
# key_cache_shape: [num_blocks, kv_head_num, block_size, head_dim]
block_size = key_cache_shape_list[2]
num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio)
gpu_cache_kvs = {}
gpu_cache_k_tensors = []
Expand Down Expand Up @@ -973,10 +979,14 @@ def main():
f"key_caches_{i}_rank{rank}.device{device}",
)
if args.cache_dtype == "block_wise_fp8":
_scale_dtype = (
"float32" if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" else paddle.get_default_dtype()
)
_scale_last_dim = 4 if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" else key_cache_shape[2]
gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[num_gpu_blocks, key_cache_shape[1], key_cache_shape[2]],
shape=[num_gpu_blocks, key_cache_shape[1], _scale_last_dim],
fill_value=0,
dtype=paddle.get_default_dtype(),
dtype=_scale_dtype,
)
set_data_ipc(
gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"],
Expand All @@ -996,9 +1006,9 @@ def main():
)
if args.cache_dtype == "block_wise_fp8":
gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full(
shape=[num_gpu_blocks, value_cache_shape[1], value_cache_shape[2]],
shape=[num_gpu_blocks, value_cache_shape[1], _scale_last_dim],
fill_value=0,
dtype=paddle.get_default_dtype(),
dtype=_scale_dtype,
)
set_data_ipc(
gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"],
Expand All @@ -1023,6 +1033,7 @@ def main():
gpu_id=device,
rdma_port=args.rdma_port,
cache_dtype=args.cache_dtype,
block_size=block_size,
)
else:
cache_messager = CacheMessager(
Expand All @@ -1038,6 +1049,7 @@ def main():
gpu_id=device,
rdma_port=args.rdma_port,
cache_dtype=args.cache_dtype,
block_size=block_size,
)

cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
Expand Down
53 changes: 38 additions & 15 deletions fastdeploy/cache_manager/cache_transfer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,16 @@ def __init__(self, args):
# compute cache bytes
self.cache_dtype = args.cache_dtype
self.cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype)
self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype())
self.has_cache_scale = self.cache_dtype == "block_wise_fp8"
if self.has_cache_scale:
self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size]
if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL":
self.scale_item_bytes = CacheConfig.get_cache_bytes("float32")
self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, 4]
else:
self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype())
self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size]
else:
self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype())

# kv cache storage
self.storage_backend_type = args.kvcache_storage_backend
Expand Down Expand Up @@ -416,7 +422,9 @@ def _init_storage_buffer(self, args):
self.storage_backend.register_buffer(write_buffer, cache_buffer_total_bytes)

if self.has_cache_scale:
self.scale_buffer_stride_bytes = layer_num * self.head_num * self.block_size * self.scale_item_bytes
self.scale_buffer_stride_bytes = (
layer_num * self.head_num * self.cache_scale_shape[-1] * self.scale_item_bytes
)
scale_buffer_total_bytes = block_num * self.scale_buffer_stride_bytes * 2
logger.info(
f"Creating scale cpu buffer cache for all layers: {scale_buffer_total_bytes / 1024 ** 3:.2f}GB"
Expand Down Expand Up @@ -474,10 +482,16 @@ def _init_gpu_cache(self):
set_data_ipc(key_cache, key_name)

if self.cache_dtype == "block_wise_fp8":
_scale_dtype = (
"float32"
if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL"
else paddle.get_default_dtype()
)
_scale_last_dim = self.cache_scale_shape[-1]
key_cache_scales = paddle.full(
shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
shape=[num_gpu_blocks, self.key_cache_shape[1], _scale_last_dim],
fill_value=0,
dtype=paddle.get_default_dtype(),
dtype=_scale_dtype,
)
set_data_ipc(key_cache_scales, key_cache_scales_name)
if self.value_cache_shape:
Expand All @@ -486,9 +500,9 @@ def _init_gpu_cache(self):

if self.cache_dtype == "block_wise_fp8":
value_cache_scales = paddle.full(
shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
shape=[num_gpu_blocks, self.value_cache_shape[1], _scale_last_dim],
fill_value=0,
dtype=paddle.get_default_dtype(),
dtype=_scale_dtype,
)
set_data_ipc(value_cache_scales, value_cache_scales_name)
else:
Expand All @@ -497,21 +511,27 @@ def _init_gpu_cache(self):
val_cache = paddle.empty(shape=[], dtype=cache_type)
key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True)
if self.cache_dtype == "block_wise_fp8":
key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
_scale_dtype = (
"float32"
if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL"
else paddle.get_default_dtype()
)
_scale_last_dim = self.cache_scale_shape[-1]
key_cache_scales = paddle.empty(shape=[], dtype=_scale_dtype)
key_cache_scales = share_external_data_(
key_cache_scales,
key_cache_scales_name,
[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]],
[num_gpu_blocks, self.key_cache_shape[1], _scale_last_dim],
True,
)
if self.value_cache_shape:
val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True)
if self.cache_dtype == "block_wise_fp8":
value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
value_cache_scales = paddle.empty(shape=[], dtype=_scale_dtype)
value_cache_scales = share_external_data_(
value_cache_scales,
value_cache_scales_name,
[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]],
[num_gpu_blocks, self.value_cache_shape[1], _scale_last_dim],
True,
)

Expand Down Expand Up @@ -572,10 +592,13 @@ def _init_cpu_cache(self):
value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size
logger.info("Initializing swap space (cpu cache) for all layers.")
if self.cache_dtype == "block_wise_fp8":
cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype())
cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2]
scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size
_scale_dtype = (
"float32" if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" else paddle.get_default_dtype()
)
_scale_item_bytes = CacheConfig.get_cache_bytes(_scale_dtype)
cache_scales_size = self.key_cache_shape[1] * self.cache_scale_shape[-1]
scales_key_need_to_allocate_bytes = self.num_cpu_blocks * _scale_item_bytes * cache_scales_size
scales_value_need_to_allocate_bytes = self.num_cpu_blocks * _scale_item_bytes * cache_scales_size
self.k_dst_ptrs = []
self.v_dst_ptrs = []
self.k_scales_ptrs = []
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class ForwardMeta:

real_bsz: int = 0

seq_lens_kv: paddle.Tensor = None

def clear_caches(self):
"""Safely clean up the caches"""
if self.caches:
Expand Down
Loading
Loading