diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index c90c6552ad6..d973153b20d 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -542,12 +542,13 @@ def create_deep_ep_buffer(self): class EPPrefillRunner(EPRunner): + + allocate_on_comm_stream = False + """ EPPrefillRunner """ - allocate_on_comm_stream = False - def __init__( self, top_k: int, @@ -646,6 +647,7 @@ def combine( "async_finish": self.ep_engine.async_finish, "topk_weights": recv_topk_weights, "previous_event": event, + "allocate_on_comm_stream": EPPrefillRunner.allocate_on_comm_stream, } fused_moe_out, _, event = buffer.combine(**combine_args) return fused_moe_out, event diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index fb2a991f781..12bbc9cd0d8 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -295,7 +295,6 @@ def apply_ep_prefill( token_all_num, ) assert permute_input.shape[0] == token_all_num - del recv_x permute_scale = permute_scale.transpose([1, 0]).contiguous().transpose([1, 0]) diff --git a/fastdeploy/worker/tbo.py b/fastdeploy/worker/tbo.py index 051d0499ada..bcfec83353f 100644 --- a/fastdeploy/worker/tbo.py +++ b/fastdeploy/worker/tbo.py @@ -16,6 +16,8 @@ import threading +import paddle + from fastdeploy.model_executor.forward_meta import ForwardMeta event0 = threading.Event() @@ -40,31 +42,64 @@ def let_another_thread_run(): GLOBAL_THREAD_INFO[thread_name][0].clear() -def split_batch_decoder_layers(forward_meta: ForwardMeta): - split_num = 2 - real_bs = forward_meta.seq_lens_this_time.shape[0] +def is_last_thread(): + thread_name = threading.current_thread().name - res = [forward_meta] * split_num + return thread_name == "thread1" - if real_bs < split_num or forward_meta.ids_remove_padding.shape[0] == 0: - return res - mc_bs = (real_bs + split_num - 1) // split_num +def creat_empty_forward_meta(forward_meta: ForwardMeta): - for i in range(0, split_num): - start_bs = i * mc_bs + res = ForwardMeta( + ids_remove_padding=forward_meta.ids_remove_padding[0:0], + rotary_embs=forward_meta.rotary_embs, + attn_backend=forward_meta.attn_backend, + caches=forward_meta.caches, + ) - end_bs = start_bs + mc_bs - end_bs = min(end_bs, real_bs) + res.hidden_states = forward_meta.hidden_states[0:0] + res.decode_states = forward_meta.decode_states[0:0] - if start_bs >= end_bs: - continue + return res - start_token_id = forward_meta.cu_seqlens_q[start_bs].item() - end_token_id = forward_meta.cu_seqlens_q[end_bs].item() - if start_token_id >= end_token_id: - continue +def split_batch_decoder_layers(forward_meta: ForwardMeta, fd_config): + split_num = 2 + res = [creat_empty_forward_meta(forward_meta), forward_meta] + res[0].tbo_microbatch_id = 0 + res[1].tbo_microbatch_id = 1 + total_token_num = forward_meta.ids_remove_padding.shape[0] + + if total_token_num < 1024: + return res + + chunk_token_num = (total_token_num + split_num - 1) // split_num + + split_sections = [] + for i in range(0, split_num): + start_token_id = i * chunk_token_num + end_token_id = start_token_id + chunk_token_num + end_token_id = min(total_token_num, end_token_id) + split_sections.append(end_token_id) + + # 由于多模的图片理解,需要将多模拟的token聚集在一起! + # 所以需要将split_sections[0]适当的偏移一下! + + special_tokens = [ + fd_config.model_config.image_patch_id, + ] + + ids_remove_padding_cpu = forward_meta.ids_remove_padding.numpy().tolist() + detect_pos = split_sections[0] + while ids_remove_padding_cpu[detect_pos] in special_tokens: + detect_pos += 1 + if detect_pos >= len(ids_remove_padding_cpu): + return res + split_sections[0] = detect_pos + + for i in range(0, split_num): + start_token_id = 0 if i == 0 else split_sections[i - 1] + end_token_id = split_sections[i] res[i] = ForwardMeta( ids_remove_padding=None, @@ -73,34 +108,53 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta): caches=forward_meta.caches, ) + # 我们需要处理的这一段token位于[start_bs, end_bs)里面! + start_bs = forward_meta.batch_id_per_token[start_token_id] + end_bs = forward_meta.batch_id_per_token[end_token_id - 1] + end_bs += 1 + if len(forward_meta.rotary_embs.shape) == 6: max_bs = forward_meta.rotary_embs.shape[0] assert max_bs == forward_meta.block_tables.shape[0] assert forward_meta.rotary_embs.shape[1:3] == [2, 1] assert forward_meta.rotary_embs.shape[4] == 1 res[i].rotary_embs = forward_meta.rotary_embs[start_bs:end_bs] - + res[i].block_tables = forward_meta.block_tables[start_bs:end_bs] res[i].ids_remove_padding = forward_meta.ids_remove_padding[start_token_id:end_token_id] res[i].batch_id_per_token = forward_meta.batch_id_per_token[start_token_id:end_token_id] - start_bs - res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs] - res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs] - res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs] + # 下面这三个要好好弄,小心出错! + # 我需要记录下 start_bs 他被left chunk 瓜分了多少了! + # 我需要记录下 (end_bs-1) 他被 right chunk 瓜分了多少了! + start_bs_s_token_by_left_chunk = start_token_id - forward_meta.cu_seqlens_q[start_bs].item() + end_bs_s_token_by_right_chunk = forward_meta.cu_seqlens_q[end_bs].item() - end_token_id - res[i].block_tables = forward_meta.block_tables[start_bs:end_bs] + res[i].seq_lens_this_time = forward_meta.seq_lens_this_time[start_bs:end_bs] + 0 + res[i].seq_lens_this_time[0] -= start_bs_s_token_by_left_chunk + res[i].seq_lens_this_time[-1] -= end_bs_s_token_by_right_chunk + + res[i].seq_lens_encoder = forward_meta.seq_lens_encoder[start_bs:end_bs] + 0 + if res[i].seq_lens_encoder[0].item() > 0: + res[i].seq_lens_encoder[0] -= start_bs_s_token_by_left_chunk + if res[i].seq_lens_encoder[-1].item() > 0: + res[i].seq_lens_encoder[-1] -= end_bs_s_token_by_right_chunk + + res[i].seq_lens_decoder = forward_meta.seq_lens_decoder[start_bs:end_bs] + 0 + res[i].seq_lens_decoder[0] += start_bs_s_token_by_left_chunk + + cu_seqlens_q = [0] + paddle.cumsum(res[i].seq_lens_this_time).numpy().tolist() + res[i].cu_seqlens_q = paddle.to_tensor(cu_seqlens_q).cast("int32") - res[i].cu_seqlens_q = forward_meta.cu_seqlens_q[start_bs : end_bs + 1] - start_token_id - res[i].cu_seqlens_k = forward_meta.cu_seqlens_k[start_bs : end_bs + 1] - start_token_id + # res[i].cu_seqlens_k = res[i].cu_seqlens_q for key in GLOBAL_ATTN_BUFFERS[i]: setattr(res[i], key, GLOBAL_ATTN_BUFFERS[i][key]) if forward_meta.attn_mask_offsets is not None: mask_num = forward_meta.attn_mask_offsets.shape[0] - token_num = forward_meta.ids_remove_padding.shape[0] - if mask_num == token_num * 2: + if mask_num == total_token_num * 2: res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id * 2 : end_token_id * 2] - elif mask_num == token_num: + elif mask_num == total_token_num: res[i].attn_mask_offsets = forward_meta.attn_mask_offsets[start_token_id:end_token_id] else: assert False, "Invalid attn_mask_offsets shape" @@ -108,7 +162,8 @@ def split_batch_decoder_layers(forward_meta: ForwardMeta): # This is adapt 5.0 if hasattr(forward_meta, "hidden_states"): res[i].hidden_states = forward_meta.hidden_states[start_token_id:end_token_id] + # 下面这个其实不需要,因为纯文不需要这个! res[i].decode_states = forward_meta.decode_states[start_bs:end_bs] - res[i].attn_backend.init_attention_metadata(res[i]) + res[i].tbo_microbatch_id = i return res