Skip to content

优化ep并行时专家并行权重分发逻辑#200

Draft
meichangsu1 wants to merge 1 commit into
modelscope:mainfrom
meichangsu1:dsv4_fsdp2_ljl
Draft

优化ep并行时专家并行权重分发逻辑#200
meichangsu1 wants to merge 1 commit into
modelscope:mainfrom
meichangsu1:dsv4_fsdp2_ljl

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

优化 EP 并行与 memory_efficient_init=True 同时开启时的权重初始化与分发逻辑。

此前在多机 EP + FSDP2 场景下,权重初始化主要依赖全局 rank0 加载完整权重并向所有 rank 分发 expert 权重。随着机器数和 rank 数增大,这种全局 fan-out 容易触发 HCCL 通信资源耗尽、broadcast/send 超时或初始化卡住。

本 PR 调整了 EP 场景下的 memory efficient init 流程:

  • memory_efficient_init=True 且开启 EP 时,支持每个节点的 local rank0 作为本节点权重加载源。
  • 避免全局 rank0 向所有 rank 分发 expert 权重,降低跨节点通信压力。
  • 保留非 EP 场景下原有的 rank0 broadcast 加载路径。
  • 优化 EP expert 参数在 FSDP2 wrap 后的分片加载逻辑,减少初始化阶段 HCCL fan-out 带来的资源压力。
  • 保持 non-persistent buffer 的同步,确保 meta 初始化后 RoPE 等 buffer 能正确恢复。

该改动主要用于提升 DeepSeek V4 / MoE 大模型在多机 EP + FSDP2 场景下的初始化稳定性。

Experiment results

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the expert parallel FSDP strategy to support DTensor sharding during weight loading. While the changes introduce necessary logic for sharding tensors across mesh ranks, the implementation of _scatter_ep_expert_tensor is highly inefficient. Using dist.broadcast in a loop over all ranks creates O(N^2) communication complexity and excessive memory overhead. Furthermore, the inclusion of explicit device synchronizations within these loops will significantly degrade performance. Finally, there is a discrepancy between the PR description's claim of hierarchical loading and the implementation, which still relies solely on global rank 0.

Comment on lines +744 to +750
for rank in range(1, world_size):
recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype)
dist.broadcast(recv_tensor, src=0)
torch_util.synchronize()
if current_rank == rank:
local_tensor.copy_(recv_tensor)
del recv_tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _scatter_ep_expert_tensor using dist.broadcast in a loop over all ranks is highly inefficient.

  1. Communication Complexity: This approach results in O(N^2) total communication volume (where N is the world size) because every rank participates in every broadcast, receiving data intended for every other rank and discarding it. For large clusters (e.g., 1024+ GPUs), this will be a massive bottleneck.
  2. Memory Overhead: Every rank performs a GPU allocation (torch.empty) in every iteration of the loop to match the shape of the broadcasted chunk. This adds significant allocation/deallocation overhead.

Consider using dist.scatter if the shards can be prepared on rank 0, or stick to point-to-point dist.send/dist.recv which only moves O(Weight) data in total. If send/recv is unstable, a hierarchical broadcast (to local rank 0s first) would be more appropriate.

Comment on lines 607 to 613
is_rank0 = (dist.get_rank() == 0)
expert_shard_specs = expert_shard_specs or {}
rank_to_ep_rank = rank_to_ep_rank or {}
rank_to_ep_fsdp_rank = rank_to_ep_fsdp_rank or {}

source_metadata = None
if is_rank0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description states that this change supports using each node's local rank0 as the weight loading source to reduce global rank 0 fan-out pressure. However, the implementation in _broadcast_sharded_state_dict still relies exclusively on global rank 0 (dist.get_rank() == 0) as the source for all metadata and weight broadcasts. This does not appear to implement the hierarchical loading logic mentioned in the description.

Comment on lines +740 to +747
torch_util.synchronize()
del chunk, local_chunk, chunk_gpu
else:
dist.recv(local_tensor, src=0)
world_size = dist.get_world_size()
for rank in range(1, world_size):
recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype)
dist.broadcast(recv_tensor, src=0)
torch_util.synchronize()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly calling torch_util.synchronize() (which likely performs a device-wide synchronization) inside the loop after every broadcast will significantly degrade performance. Standard dist.broadcast calls on CUDA/NPU tensors are already synchronized with respect to the communication stream. Draining the GPU command pipeline in every iteration will make the initialization process extremely slow, especially for models with many layers.

@meichangsu1 meichangsu1 marked this pull request as draft May 21, 2026 09:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant