diff --git a/custom_ops/gpu_ops/send_cache.cu b/custom_ops/gpu_ops/send_cache.cu new file mode 100644 index 00000000000..7eaeed68dbb --- /dev/null +++ b/custom_ops/gpu_ops/send_cache.cu @@ -0,0 +1,54 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#include "remote_cache_kv_ipc.h" + +void SendCacheFunc(const paddle::Tensor& qkv, + const paddle::optional& kv_signal_data) { + const char* fmt_write_cache_completed_signal_str = + std::getenv("FLAGS_fmt_write_cache_completed_signal"); + const char* FLAGS_use_pd_disaggregation_per_chunk = + std::getenv("FLAGS_use_pd_disaggregation_per_chunk"); + if (fmt_write_cache_completed_signal_str && + (std::strcmp(fmt_write_cache_completed_signal_str, "true") == 0 || + std::strcmp(fmt_write_cache_completed_signal_str, "1") == 0)) { + if (FLAGS_use_pd_disaggregation_per_chunk && + (std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "true") == 0 || + std::strcmp(FLAGS_use_pd_disaggregation_per_chunk, "1") == 0)) { + cudaLaunchHostFunc( + qkv.stream(), + &(RemoteCacheKvIpc:: + save_cache_kv_complete_signal_layerwise_per_query), + (void*)nullptr); + } else { + if (kv_signal_data) { + cudaLaunchHostFunc( + qkv.stream(), + &RemoteCacheKvIpc::save_cache_kv_complete_signal_layerwise, + (void*)(const_cast( + kv_signal_data.get().data()))); + } + } + } +} + +PD_BUILD_STATIC_OP(send_cache) + .Inputs({"qkv", paddle::Optional("kv_signal_data")}) + .Outputs({"qkv_out"}) + .SetInplaceMap({{"qkv", "qkv_out"}}) + .SetKernelFn(PD_KERNEL(SendCacheFunc)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index 1e700f87634..8237870b045 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -339,6 +339,7 @@ def find_end_files(directory, end_str): "gpu_ops/gelu_tanh.cu", "gpu_ops/reasoning_phase_token_constraint.cu", "gpu_ops/get_attn_mask_q.cu", + "gpu_ops/send_cache.cu", ] sm_versions = get_sm_version(archs) # Some kernels in this file require SM75+ instructions. Exclude them when building SM70 (V100). diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index b934c3e74c7..32c8af537e3 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -226,6 +226,8 @@ def __init__( scale_block_bytes = math.prod(key_cache_scale.shape[1:]) if key_cache_scale.dtype == paddle.bfloat16 or key_cache_scale.dtype == paddle.float16: scale_block_bytes *= 2 + elif key_cache_scale.dtype == paddle.float32: + scale_block_bytes *= 4 logger.info(f"scale_block_bytes: {scale_block_bytes}, dtype: {key_cache_scale.dtype}") logger.info( f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " @@ -571,6 +573,8 @@ def __init__( scale_block_bytes = math.prod(key_cache_scale.shape[1:]) if key_cache_scale.dtype == paddle.bfloat16 or key_cache_scale.dtype == paddle.float16: scale_block_bytes *= 2 + elif key_cache_scale.dtype == paddle.float32: + scale_block_bytes *= 4 logger.info(f"scale_block_bytes: {scale_block_bytes}, dtype: {key_cache_scale.dtype}") logger.info( f"layers {num_layers} cache_shape: {cache_shape}, max_block_num: {max_block_num}, " @@ -936,6 +940,8 @@ def main(): if args.value_cache_shape: value_cache_shape_list = [int(i) for i in args.value_cache_shape.split(",")] total_gpu_blocks = key_cache_shape_list[0] + # key_cache_shape: [num_blocks, kv_head_num, block_size, head_dim] + block_size = key_cache_shape_list[2] num_extra_layer_gpu_blocks = int(total_gpu_blocks * speculative_config.num_gpu_block_expand_ratio) gpu_cache_kvs = {} gpu_cache_k_tensors = [] @@ -973,10 +979,14 @@ def main(): f"key_caches_{i}_rank{rank}.device{device}", ) if args.cache_dtype == "block_wise_fp8": + _scale_dtype = ( + "float32" if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" else paddle.get_default_dtype() + ) + _scale_last_dim = 4 if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" else key_cache_shape[2] gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full( - shape=[num_gpu_blocks, key_cache_shape[1], key_cache_shape[2]], + shape=[num_gpu_blocks, key_cache_shape[1], _scale_last_dim], fill_value=0, - dtype=paddle.get_default_dtype(), + dtype=_scale_dtype, ) set_data_ipc( gpu_cache_kvs[f"key_cache_scales_{i}_rank{rank}_device{device}"], @@ -996,9 +1006,9 @@ def main(): ) if args.cache_dtype == "block_wise_fp8": gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"] = paddle.full( - shape=[num_gpu_blocks, value_cache_shape[1], value_cache_shape[2]], + shape=[num_gpu_blocks, value_cache_shape[1], _scale_last_dim], fill_value=0, - dtype=paddle.get_default_dtype(), + dtype=_scale_dtype, ) set_data_ipc( gpu_cache_kvs[f"value_cache_scales_{i}_rank{rank}_device{device}"], @@ -1023,6 +1033,7 @@ def main(): gpu_id=device, rdma_port=args.rdma_port, cache_dtype=args.cache_dtype, + block_size=block_size, ) else: cache_messager = CacheMessager( @@ -1038,6 +1049,7 @@ def main(): gpu_id=device, rdma_port=args.rdma_port, cache_dtype=args.cache_dtype, + block_size=block_size, ) cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32) diff --git a/fastdeploy/cache_manager/cache_transfer_manager.py b/fastdeploy/cache_manager/cache_transfer_manager.py index 85a113adf66..c06bfa20c5f 100644 --- a/fastdeploy/cache_manager/cache_transfer_manager.py +++ b/fastdeploy/cache_manager/cache_transfer_manager.py @@ -175,10 +175,16 @@ def __init__(self, args): # compute cache bytes self.cache_dtype = args.cache_dtype self.cache_item_bytes = CacheConfig.get_cache_bytes(self.cache_dtype) - self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype()) self.has_cache_scale = self.cache_dtype == "block_wise_fp8" if self.has_cache_scale: - self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size] + if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL": + self.scale_item_bytes = CacheConfig.get_cache_bytes("float32") + self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, 4] + else: + self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype()) + self.cache_scale_shape = [self.num_gpu_blocks, self.head_num, self.block_size] + else: + self.scale_item_bytes = CacheConfig.get_cache_bytes(paddle.get_default_dtype()) # kv cache storage self.storage_backend_type = args.kvcache_storage_backend @@ -416,7 +422,9 @@ def _init_storage_buffer(self, args): self.storage_backend.register_buffer(write_buffer, cache_buffer_total_bytes) if self.has_cache_scale: - self.scale_buffer_stride_bytes = layer_num * self.head_num * self.block_size * self.scale_item_bytes + self.scale_buffer_stride_bytes = ( + layer_num * self.head_num * self.cache_scale_shape[-1] * self.scale_item_bytes + ) scale_buffer_total_bytes = block_num * self.scale_buffer_stride_bytes * 2 logger.info( f"Creating scale cpu buffer cache for all layers: {scale_buffer_total_bytes / 1024 ** 3:.2f}GB" @@ -474,10 +482,16 @@ def _init_gpu_cache(self): set_data_ipc(key_cache, key_name) if self.cache_dtype == "block_wise_fp8": + _scale_dtype = ( + "float32" + if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" + else paddle.get_default_dtype() + ) + _scale_last_dim = self.cache_scale_shape[-1] key_cache_scales = paddle.full( - shape=[num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]], + shape=[num_gpu_blocks, self.key_cache_shape[1], _scale_last_dim], fill_value=0, - dtype=paddle.get_default_dtype(), + dtype=_scale_dtype, ) set_data_ipc(key_cache_scales, key_cache_scales_name) if self.value_cache_shape: @@ -486,9 +500,9 @@ def _init_gpu_cache(self): if self.cache_dtype == "block_wise_fp8": value_cache_scales = paddle.full( - shape=[num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]], + shape=[num_gpu_blocks, self.value_cache_shape[1], _scale_last_dim], fill_value=0, - dtype=paddle.get_default_dtype(), + dtype=_scale_dtype, ) set_data_ipc(value_cache_scales, value_cache_scales_name) else: @@ -497,21 +511,27 @@ def _init_gpu_cache(self): val_cache = paddle.empty(shape=[], dtype=cache_type) key_cache = share_external_data_(key_cache, key_name, key_cache_shape, True) if self.cache_dtype == "block_wise_fp8": - key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + _scale_dtype = ( + "float32" + if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" + else paddle.get_default_dtype() + ) + _scale_last_dim = self.cache_scale_shape[-1] + key_cache_scales = paddle.empty(shape=[], dtype=_scale_dtype) key_cache_scales = share_external_data_( key_cache_scales, key_cache_scales_name, - [num_gpu_blocks, self.key_cache_shape[1], self.key_cache_shape[2]], + [num_gpu_blocks, self.key_cache_shape[1], _scale_last_dim], True, ) if self.value_cache_shape: val_cache = share_external_data_(val_cache, val_name, value_cache_shape, True) if self.cache_dtype == "block_wise_fp8": - value_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + value_cache_scales = paddle.empty(shape=[], dtype=_scale_dtype) value_cache_scales = share_external_data_( value_cache_scales, value_cache_scales_name, - [num_gpu_blocks, self.value_cache_shape[1], self.value_cache_shape[2]], + [num_gpu_blocks, self.value_cache_shape[1], _scale_last_dim], True, ) @@ -572,10 +592,13 @@ def _init_cpu_cache(self): value_need_to_allocate_bytes = self.num_cpu_blocks * cache_item_bytes * value_cache_size logger.info("Initializing swap space (cpu cache) for all layers.") if self.cache_dtype == "block_wise_fp8": - cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) - cache_scales_size = self.key_cache_shape[1] * self.key_cache_shape[2] - scales_key_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size - scales_value_need_to_allocate_bytes = self.num_cpu_blocks * cache_scales.element_size() * cache_scales_size + _scale_dtype = ( + "float32" if envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" else paddle.get_default_dtype() + ) + _scale_item_bytes = CacheConfig.get_cache_bytes(_scale_dtype) + cache_scales_size = self.key_cache_shape[1] * self.cache_scale_shape[-1] + scales_key_need_to_allocate_bytes = self.num_cpu_blocks * _scale_item_bytes * cache_scales_size + scales_value_need_to_allocate_bytes = self.num_cpu_blocks * _scale_item_bytes * cache_scales_size self.k_dst_ptrs = [] self.v_dst_ptrs = [] self.k_scales_ptrs = [] diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 44cf528bed3..05b97b53034 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -164,6 +164,8 @@ class ForwardMeta: real_bsz: int = 0 + seq_lens_kv: paddle.Tensor = None + def clear_caches(self): """Safely clean up the caches""" if self.caches: diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index bdb018ec269..abfb4ce5987 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -48,12 +48,27 @@ from fastdeploy.spec_decode import SpecMethod if current_platform.is_cuda(): - from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output + from fastdeploy.model_executor.ops.gpu import ( + merge_prefill_decode_output, + send_cache, + ) else: merge_prefill_decode_output = None + send_cache = None +from fastdeploy import envs from fastdeploy.model_executor.utils import get_sm_version +_use_blackwell_attn = envs.FD_ATTENTION_BACKEND == "FLASH_MASK_ATTN_BLACKWELL" + +if _use_blackwell_attn: + try: + import blackwell_ops + except: + assert False, "FLASH_MASK_ATTN_BLACKWELL requires blackwell_ops, but import failed." +else: + blackwell_ops = None + @dataclass class FlashMaskAttentionMetadata(AttentionMetadata): @@ -143,7 +158,45 @@ def get_kv_cache_shape( value_cache_shape = key_cache_shape return key_cache_shape, value_cache_shape + def get_kv_cache_scale_shape(self, max_num_blocks): + if _use_blackwell_attn: + kv_scale_shape = [max_num_blocks, self.kv_num_heads, 4] + kv_cache_scale_dtype = "float32" + else: + kv_scale_shape = [max_num_blocks, self.kv_num_heads, self.block_size] + kv_cache_scale_dtype = paddle.get_default_dtype() + return kv_scale_shape, kv_cache_scale_dtype + + def init_blackwell_attention_metadata(self, forward_meta: ForwardMeta): + self.actual_cu_seq_k = paddle.ones_like(forward_meta.cu_seqlens_k) + max_token_num = blackwell_ops.flash_attn_get_qk_token( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.cu_seqlens_q, + forward_meta.cu_seqlens_k, + self.actual_cu_seq_k, + forward_meta.seq_lens_kv, + self.kv_num_heads, + )[0] + self.max_enc_len = max_token_num[0] + self.max_dec_len = max_token_num[1] + self.q_token_num = max_token_num[2] + self.kv_token_num = max_token_num[3] + + if self.max_enc_len > 0: + self.q_input_encoder = paddle.zeros( + [self.q_token_num, self.num_heads, self.head_dim], dtype=paddle.bfloat16 + ) + self.k_input_encoder = paddle.zeros( + [self.kv_token_num, self.kv_num_heads, self.head_dim], dtype=paddle.bfloat16 + ) + self.v_input_encoder = paddle.zeros( + [self.kv_token_num, self.kv_num_heads, self.head_dim], dtype=paddle.bfloat16 + ) + def init_attention_metadata(self, forward_meta: ForwardMeta): + if _use_blackwell_attn: + self.init_blackwell_attention_metadata(forward_meta) metadata = FlashMaskAttentionMetadata() # metadata only save pd_disaggregation info. metadata.kv_signal_data_list = [None] * self.num_layers @@ -170,6 +223,97 @@ def init_attention_metadata(self, forward_meta: ForwardMeta): self.attention_metadata = metadata + def forward_mixed_blackwell( + self, + qkv: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + cache_k: paddle.Tensor, + cache_v: paddle.Tensor, + cache_k_scales: paddle.Tensor, + q_norm_weight: paddle.Tensor, + k_norm_weight: paddle.Tensor, + kv_signal_data: paddle.Tensor, + ): + attn_out = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], qkv.dtype) + cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") + if self.max_enc_len > 0: + blackwell_ops.flash_attn_write_cache_kv_encoder( + qkv, + forward_meta.cu_seqlens_q, + self.actual_cu_seq_k, + forward_meta.rotary_embs, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + cache_k, + cache_v, + forward_meta.block_tables, + self.q_input_encoder, + self.k_input_encoder, + self.v_input_encoder, + cache_k_scales, + q_norm_weight, + k_norm_weight, + self.num_heads, + self.kv_num_heads, + self.head_dim, + self.max_enc_len, + self.max_dec_len, + self.max_seq_len, + cache_quant_type_str, + ) + + send_cache(qkv, kv_signal_data) + blackwell_ops.flash_encoder_attn_fwd( + self.q_input_encoder, + self.k_input_encoder, + self.v_input_encoder, + forward_meta.cu_seqlens_q, + self.actual_cu_seq_k, + attn_out, + None, + ) + if self.max_dec_len > 0: + q_output, q_dequant_scale = blackwell_ops.flash_attn_write_cache_kv_decoder( + qkv, + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.rotary_embs, + cache_k, + cache_v, + forward_meta.block_tables, + cache_k_scales, + q_norm_weight, + k_norm_weight, + self.num_heads, + self.kv_num_heads, + self.head_dim, + self.max_seq_len, + cache_quant_type_str, + ) + + blackwell_ops.flash_decoder_attn_fwd( + q_output, + forward_meta.cu_seqlens_q, + forward_meta.seq_lens_kv, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + cache_k, + cache_v, + forward_meta.block_tables, + attn_out, + q_dequant_scale, + cache_k_scales, + self.num_heads, + self.kv_num_heads, + self.head_dim, + self.max_seq_len, + self.speculate_max_draft_token_num, + cache_quant_type_str, + ) + return attn_out + def forward_mixed( self, q: paddle.Tensor, @@ -211,7 +355,18 @@ def forward_mixed( os.environ["FLAGS_fmt_write_cache_completed_signal"] = "0" elif forward_meta.tbo_microbatch_id == 1: os.environ["FLAGS_fmt_write_cache_completed_signal"] = "1" - + if _use_blackwell_attn: + return self.forward_mixed_blackwell( + qkv, + layer, + forward_meta, + cache_k, + cache_v, + cache_k_scales, + q_norm_weight, + k_norm_weight, + metadata.kv_signal_data_list[layer.layer_id], + ) if layer.layer_id == 0: get_block_shape_and_split_kv_block( forward_meta.seq_lens_encoder, @@ -373,7 +528,6 @@ def forward_mixed( self.causal, self.speculative_method is not None, ) - if use_fa_do_prefill: merge_prefill_decode_output( res_encoder, diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index b2eceb0aeb8..79c106b60d7 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -31,6 +31,7 @@ class _Backend(enum.Enum): HPU_ATTN = enum.auto() FLASH_MASK_ATTN = enum.auto() DECODE_UNIFIED_ATTN = enum.auto() + FLASH_MASK_ATTN_BLACKWELL = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index de9e40a249b..6fae987097e 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -76,6 +76,9 @@ def get_attention_backend_cls(cls, selected_backend: _Backend): elif selected_backend == _Backend.DECODE_UNIFIED_ATTN: logger.info("Using DECODE UNIFIED ATTN backend.") return "fastdeploy.model_executor.layers.attention.DecodeUnifiedAttentionBackend" + elif selected_backend == _Backend.FLASH_MASK_ATTN_BLACKWELL: + logger.info("Using FLASH MASK ATTN BLACKWELL backend.") + return "fastdeploy.model_executor.layers.attention.FlashMaskAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 6c5149e253b..d1abbb57827 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -233,6 +233,11 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): ) if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] + kv_cache_scale_dtype = paddle.get_default_dtype() + if hasattr(self.attn_backends[0], "get_kv_cache_scale_shape"): + kv_cache_scale_shape, kv_cache_scale_dtype = self.attn_backends[0].get_kv_cache_scale_shape( + max_num_blocks=self.num_gpu_blocks + ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size cache_ready_signal_data = np.zeros(shape=[self.parallel_config.tensor_parallel_size], dtype=np.int32) @@ -284,13 +289,13 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): if kv_cache_quant_type == "block_wise_fp8": scale_key_cache_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}" scale_val_cache_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}" - key_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + key_scale_cache = paddle.empty(shape=[], dtype=kv_cache_scale_dtype) key_scale_cache = self._share_external_data( key_scale_cache, scale_key_cache_name, kv_cache_scale_shape ) self.cache_kvs_map[scale_key_cache_name] = key_scale_cache cache_kvs_list.append(key_scale_cache) - value_scale_cache = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + value_scale_cache = paddle.empty(shape=[], dtype=kv_cache_scale_dtype) value_scale_cache = self._share_external_data( value_scale_cache, scale_val_cache_name, kv_cache_scale_shape ) @@ -329,7 +334,7 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): key_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, - dtype=paddle.get_default_dtype(), + dtype=kv_cache_scale_dtype, ) key_cache_scales_name = f"key_cache_scales_{i}_rank{local_rank}.device{self.device_id}" set_data_ipc(key_cache_scales, key_cache_scales_name) @@ -339,7 +344,7 @@ def initialize_kv_cache(self, main_model_num_blocks, profile: bool = False): val_cache_scales = paddle.full( shape=kv_cache_scale_shape, fill_value=0, - dtype=paddle.get_default_dtype(), + dtype=kv_cache_scale_dtype, ) val_cache_scales_name = f"value_cache_scales_{i}_rank{local_rank}.device{self.device_id}" set_data_ipc(val_cache_scales, val_cache_scales_name) @@ -702,6 +707,7 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, + seq_lens_kv=self.model_inputs["seq_lens_kv"], ) if "decode_block_indices" in self.model_inputs: self.forward_meta.decode_block_indices = self.model_inputs["decode_block_indices"] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9cbb72636ef..0cc844949ce 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -1420,6 +1420,7 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): kv_num_blocks_x_cpu=self.share_inputs["kv_num_blocks_x_cpu"], attn_mask_offsets=self.share_inputs["attn_mask_offsets"] if self.enable_mm else None, routing_replay_table=routing_replay_table, + seq_lens_kv=self.share_inputs["seq_lens_kv"], ) # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) @@ -1518,6 +1519,11 @@ def initialize_kv_cache(self, profile: bool = False) -> None: indexer_cache_shape = [] if kv_cache_quant_type == "block_wise_fp8": kv_cache_scale_shape = [key_cache_shape[0], key_cache_shape[1], key_cache_shape[2]] + kv_cache_scale_dtype = paddle.get_default_dtype() + if hasattr(self.attn_backends[0], "get_kv_cache_scale_shape"): + kv_cache_scale_shape, kv_cache_scale_dtype = self.attn_backends[0].get_kv_cache_scale_shape( + max_num_blocks=max_block_num + ) local_rank = self.local_rank % self.parallel_config.tensor_parallel_size # Check if gpu runner needs to create kv cache @@ -1569,13 +1575,13 @@ def initialize_kv_cache(self, profile: bool = False) -> None: cache_kvs_list.extend([key_cache]) if kv_cache_quant_type == "block_wise_fp8": key_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + shape=kv_cache_scale_shape, fill_value=0, dtype=kv_cache_scale_dtype ) set_data_ipc(key_cache_scales, key_cache_scales_name) self.cache_kvs_map[key_cache_scales_name] = key_cache_scales if value_cache_shape: val_cache_scales = paddle.full( - shape=kv_cache_scale_shape, fill_value=0, dtype=paddle.get_default_dtype() + shape=kv_cache_scale_shape, fill_value=0, dtype=kv_cache_scale_dtype ) set_data_ipc(val_cache_scales, value_cache_scales_name) self.cache_kvs_map[value_cache_scales_name] = val_cache_scales @@ -1590,7 +1596,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: key_cache = share_external_data(key_cache, key_cache_name, key_cache_shape) self.cache_kvs_map[key_cache_name] = key_cache if kv_cache_quant_type == "block_wise_fp8": - key_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + key_cache_scales = paddle.empty(shape=[], dtype=kv_cache_scale_dtype) key_cache_scales = share_external_data( key_cache_scales, key_cache_scales_name, kv_cache_scale_shape ) @@ -1601,7 +1607,7 @@ def initialize_kv_cache(self, profile: bool = False) -> None: self.cache_kvs_map[val_cache_name] = val_cache cache_kvs_list.extend([key_cache, val_cache]) if kv_cache_quant_type == "block_wise_fp8": - val_cache_scales = paddle.empty(shape=[], dtype=paddle.get_default_dtype()) + val_cache_scales = paddle.empty(shape=[], dtype=kv_cache_scale_dtype) val_cache_scales = share_external_data( val_cache_scales, value_cache_scales_name, kv_cache_scale_shape ) diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 96b1694c895..a203de535dd 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -385,6 +385,7 @@ def init_share_inputs(self): self.mask_rollback = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") self.preempted_idx = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32", device="cpu") self.last_preempted_idx = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32", device="cpu") + self.seq_lens_kv = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") def swap_states(self, i1, i2) -> None: """Swap the data at indices i1 and i2 for all array-like attributes""" @@ -909,6 +910,9 @@ def init_share_inputs(self): self.recompute_token_num = paddle.full( [self.scheduler_config.max_num_seqs, 1], self.num_model_steps - 1, dtype="int32" ) + + self.seq_lens_kv = paddle.full(shape=[self.scheduler_config.max_num_seqs, 1], fill_value=0, dtype="int32") + # attn_mask if self.enable_mm: self.decode_states = paddle.full(