Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def all_reduce(
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False,
) -> None:
if _is_single_group(group=group):
return
if isinstance(group, CustomProcessGroup):
return group.all_reduce(input_)
else:
Expand All @@ -201,6 +203,9 @@ def all_gather_into_tensor(
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
async_op: bool = False,
) -> None:
if _is_single_group(group=group):
output_.copy_(input_)
return
if isinstance(group, CustomProcessGroup):
return group.all_gather_into_tensor(output_, input_)
else:
Expand All @@ -213,6 +218,10 @@ def all_gather(
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
async_op: bool = False,
) -> None:
if _is_single_group(group=group):
if len(output_) > 0:
output_[0].copy_(input_)
return
# todo 目前还没有定制算子的支持。
if isinstance(group, CustomProcessGroup):
return dist.all_gather(output_, input_, group.device_group, async_op)
Expand All @@ -227,11 +236,35 @@ def reduce_scatter_tensor(
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
async_op=False,
):
if _is_single_group(group=group):
output.copy_(input)
return
# 目前还没有定制算子实现。
if isinstance(group, CustomProcessGroup):
return dist.reduce_scatter_tensor(output, input, op=op, group=group.device_group, async_op=async_op)
else:
return dist.reduce_scatter_tensor(output, input, op=op, group=group, async_op=async_op)


def broadcast(
tensor: torch.Tensor,
src: int,
group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None,
async_op: bool = False,
) -> None:
if _is_single_group(group=group):
return
if isinstance(group, CustomProcessGroup):
return dist.broadcast(tensor, src=src, group=group.device_group, async_op=async_op)
else:
return dist.broadcast(tensor, src=src, group=group, async_op=async_op)


def _is_single_group(group: Optional[Union[ProcessGroup, CustomProcessGroup]]) -> bool:
if isinstance(group, CustomProcessGroup):
return group.dp_world_size == 1
else:
return dist.get_world_size(group=group) == 1


dist_group_manager = DistributeGroupManager()
19 changes: 12 additions & 7 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
enable_radix_tree_timer_merge,
get_radix_tree_merge_update_delta,
)
from lightllm.distributed import dist_group_manager
from lightllm.distributed.communication_op import (
dist_group_manager,
all_gather_into_tensor,
all_reduce,
broadcast,
)
from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer
from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack
from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel
Expand Down Expand Up @@ -368,7 +373,7 @@ def _try_read_new_reqs_normal(self):
self.node_broadcast_tensor.fill_(0)

src_rank_id = self.args.node_rank * self.node_world_size
dist.broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
new_buffer_is_ready = self.node_broadcast_tensor.detach().item()
if new_buffer_is_ready:
self._read_reqs_buffer_and_init_reqs()
Expand All @@ -382,7 +387,7 @@ def _try_read_new_reqs_normal(self):
self.node_broadcast_tensor.fill_(0)

src_rank_id = self.args.node_rank * self.node_world_size
dist.broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False)
new_buffer_is_ready = self.node_broadcast_tensor.detach().item()
if new_buffer_is_ready:
self._read_nixl_trans_io_buffer_and_update_req_status()
Expand All @@ -396,7 +401,7 @@ def _try_read_new_reqs_multinode_tp(self):
self.multinode_tp_gather_item_tensor.fill_(1)
else:
self.multinode_tp_gather_item_tensor.fill_(0)
dist.all_gather_into_tensor(
all_gather_into_tensor(
self.multinode_tp_all_gather_tensor,
self.multinode_tp_gather_item_tensor,
group=self.multinode_tp_nccl_group,
Expand Down Expand Up @@ -806,12 +811,12 @@ def _dp_all_gather_prefill_and_decode_req_num(
"""
current_dp_prefill_num = len(prefill_reqs)
self.dp_gather_item_tensor.fill_(current_dp_prefill_num)
dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy()

current_dp_decode_num = len(decode_reqs)
self.dp_gather_item_tensor.fill_(current_dp_decode_num)
dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
dp_decode_req_nums = self.dp_all_gather_tensor.cpu().numpy()

return dp_prefill_req_nums, dp_decode_req_nums
Expand All @@ -822,7 +827,7 @@ def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int:
"""
current_dp_decode_num = len(decode_reqs)
self.dp_reduce_tensor.fill_(current_dp_decode_num)
dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)
all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)
max_decode_num = self.dp_reduce_tensor.item()
return max_decode_num

Expand Down