From 429845591f9154c4b785751a98b692362c3c96a6 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 6 Feb 2026 18:52:52 +0800 Subject: [PATCH 1/2] fix: fix a memleak --- lightllm/distributed/communication_op.py | 26 +++++++++++++++++++ .../model_infer/mode_backend/base_backend.py | 19 +++++++++----- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index d606d757c..43d376abd 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -189,6 +189,8 @@ def all_reduce( op: ReduceOp = ReduceOp.SUM, async_op: bool = False, ) -> None: + if get_global_world_size() == 1: + return if isinstance(group, CustomProcessGroup): return group.all_reduce(input_) else: @@ -201,6 +203,9 @@ def all_gather_into_tensor( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op: bool = False, ) -> None: + if get_global_world_size() == 1: + output_.copy_(input_) + return if isinstance(group, CustomProcessGroup): return group.all_gather_into_tensor(output_, input_) else: @@ -213,6 +218,10 @@ def all_gather( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op: bool = False, ) -> None: + if get_global_world_size() == 1: + if len(output_) > 0: + output_[0].copy_(input_) + return # todo 目前还没有定制算子的支持。 if isinstance(group, CustomProcessGroup): return dist.all_gather(output_, input_, group.device_group, async_op) @@ -227,6 +236,9 @@ def reduce_scatter_tensor( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op=False, ): + if get_global_world_size() == 1: + output.copy_(input) + return # 目前还没有定制算子实现。 if isinstance(group, CustomProcessGroup): return dist.reduce_scatter_tensor(output, input, op=op, group=group.device_group, async_op=async_op) @@ -234,4 +246,18 @@ def reduce_scatter_tensor( return dist.reduce_scatter_tensor(output, input, op=op, group=group, async_op=async_op) +def broadcast( + tensor: torch.Tensor, + src: int, + group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, + async_op: bool = False, +) -> None: + if get_global_world_size() == 1: + return + if isinstance(group, CustomProcessGroup): + return dist.broadcast(tensor, src=src, group=group.device_group, async_op=async_op) + else: + return dist.broadcast(tensor, src=src, group=group, async_op=async_op) + + dist_group_manager = DistributeGroupManager() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 64310d6b0..8b085c45e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -31,7 +31,12 @@ enable_radix_tree_timer_merge, get_radix_tree_merge_update_delta, ) -from lightllm.distributed import dist_group_manager +from lightllm.distributed.communication_op import ( + dist_group_manager, + all_gather_into_tensor, + all_reduce, + broadcast, +) from lightllm.server.core.objs.shm_objs_io_buffer import ShmObjsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack from lightllm.models.deepseek_mtp.model import Deepseek3MTPModel @@ -368,7 +373,7 @@ def _try_read_new_reqs_normal(self): self.node_broadcast_tensor.fill_(0) src_rank_id = self.args.node_rank * self.node_world_size - dist.broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False) + broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False) new_buffer_is_ready = self.node_broadcast_tensor.detach().item() if new_buffer_is_ready: self._read_reqs_buffer_and_init_reqs() @@ -382,7 +387,7 @@ def _try_read_new_reqs_normal(self): self.node_broadcast_tensor.fill_(0) src_rank_id = self.args.node_rank * self.node_world_size - dist.broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False) + broadcast(self.node_broadcast_tensor, src=src_rank_id, group=self.node_nccl_group, async_op=False) new_buffer_is_ready = self.node_broadcast_tensor.detach().item() if new_buffer_is_ready: self._read_nixl_trans_io_buffer_and_update_req_status() @@ -396,7 +401,7 @@ def _try_read_new_reqs_multinode_tp(self): self.multinode_tp_gather_item_tensor.fill_(1) else: self.multinode_tp_gather_item_tensor.fill_(0) - dist.all_gather_into_tensor( + all_gather_into_tensor( self.multinode_tp_all_gather_tensor, self.multinode_tp_gather_item_tensor, group=self.multinode_tp_nccl_group, @@ -806,12 +811,12 @@ def _dp_all_gather_prefill_and_decode_req_num( """ current_dp_prefill_num = len(prefill_reqs) self.dp_gather_item_tensor.fill_(current_dp_prefill_num) - dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) + all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy() current_dp_decode_num = len(decode_reqs) self.dp_gather_item_tensor.fill_(current_dp_decode_num) - dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) + all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False) dp_decode_req_nums = self.dp_all_gather_tensor.cpu().numpy() return dp_prefill_req_nums, dp_decode_req_nums @@ -822,7 +827,7 @@ def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int: """ current_dp_decode_num = len(decode_reqs) self.dp_reduce_tensor.fill_(current_dp_decode_num) - dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) + all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False) max_decode_num = self.dp_reduce_tensor.item() return max_decode_num From 6f8fd42955b053c2ecf8ef71dd15c78039ec2179 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 9 Feb 2026 02:35:54 +0000 Subject: [PATCH 2/2] add _is_single_group --- lightllm/distributed/communication_op.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index 43d376abd..52d4e61da 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -189,7 +189,7 @@ def all_reduce( op: ReduceOp = ReduceOp.SUM, async_op: bool = False, ) -> None: - if get_global_world_size() == 1: + if _is_single_group(group=group): return if isinstance(group, CustomProcessGroup): return group.all_reduce(input_) @@ -203,7 +203,7 @@ def all_gather_into_tensor( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op: bool = False, ) -> None: - if get_global_world_size() == 1: + if _is_single_group(group=group): output_.copy_(input_) return if isinstance(group, CustomProcessGroup): @@ -218,7 +218,7 @@ def all_gather( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op: bool = False, ) -> None: - if get_global_world_size() == 1: + if _is_single_group(group=group): if len(output_) > 0: output_[0].copy_(input_) return @@ -236,7 +236,7 @@ def reduce_scatter_tensor( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op=False, ): - if get_global_world_size() == 1: + if _is_single_group(group=group): output.copy_(input) return # 目前还没有定制算子实现。 @@ -252,7 +252,7 @@ def broadcast( group: Optional[Union[ProcessGroup, CustomProcessGroup]] = None, async_op: bool = False, ) -> None: - if get_global_world_size() == 1: + if _is_single_group(group=group): return if isinstance(group, CustomProcessGroup): return dist.broadcast(tensor, src=src, group=group.device_group, async_op=async_op) @@ -260,4 +260,11 @@ def broadcast( return dist.broadcast(tensor, src=src, group=group, async_op=async_op) +def _is_single_group(group: Optional[Union[ProcessGroup, CustomProcessGroup]]) -> bool: + if isinstance(group, CustomProcessGroup): + return group.dp_world_size == 1 + else: + return dist.get_world_size(group=group) == 1 + + dist_group_manager = DistributeGroupManager()