-
Notifications
You must be signed in to change notification settings - Fork 32
优化ep并行时专家并行权重分发逻辑 #200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
优化ep并行时专家并行权重分发逻辑 #200
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -151,6 +151,7 @@ def wrap_model(self, model, optimizer=None): | |
| device_type=device_type, | ||
| expert_shard_specs=_collect_ep_expert_shard_specs(model), | ||
| rank_to_ep_rank=_build_rank_to_ep_rank(self.ep_fsdp_device_mesh), | ||
| rank_to_ep_fsdp_rank=_build_rank_to_ep_fsdp_rank(self.ep_fsdp_device_mesh), | ||
| ) | ||
| else: | ||
| _load_rank0_full_state_dict(model, original_sd or {}) | ||
|
|
@@ -462,6 +463,19 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di | |
| return rank_to_ep_rank | ||
|
|
||
|
|
||
| def _build_rank_to_ep_fsdp_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Dict[int, int]: | ||
| if ep_fsdp_device_mesh is None: | ||
| return {} | ||
| mesh = ep_fsdp_device_mesh.mesh | ||
| if hasattr(mesh, 'detach'): | ||
| mesh = mesh.detach().cpu().numpy() | ||
| rank_to_ep_fsdp_rank = {} | ||
| for ep_fsdp_rank in range(mesh.shape[1]): | ||
| for rank in mesh[:, ep_fsdp_rank].flatten().tolist(): | ||
| rank_to_ep_fsdp_rank[int(rank)] = int(ep_fsdp_rank) | ||
| return rank_to_ep_fsdp_rank | ||
|
|
||
|
|
||
| def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: | ||
| """Find the experts module inside a decoder layer, if any.""" | ||
| for module in layer_mod.modules(): | ||
|
|
@@ -583,6 +597,7 @@ def _broadcast_sharded_state_dict( | |
| device_type: str = 'cuda', | ||
| expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None, | ||
| rank_to_ep_rank: Optional[Dict[int, int]] = None, | ||
| rank_to_ep_fsdp_rank: Optional[Dict[int, int]] = None, | ||
| ) -> None: | ||
| """Broadcast rank0 full state dict and materialize local FSDP2/EP shards.""" | ||
| from torch.distributed.tensor import DTensor, Partial, Replicate, Shard | ||
|
|
@@ -592,6 +607,7 @@ def _broadcast_sharded_state_dict( | |
| is_rank0 = (dist.get_rank() == 0) | ||
| expert_shard_specs = expert_shard_specs or {} | ||
| rank_to_ep_rank = rank_to_ep_rank or {} | ||
| rank_to_ep_fsdp_rank = rank_to_ep_fsdp_rank or {} | ||
|
|
||
| source_metadata = None | ||
| if is_rank0: | ||
|
|
@@ -629,37 +645,118 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): | |
| stride=full_tensor.stride(), | ||
| ) | ||
|
|
||
| def _shard_tensor_for_mesh_rank(tensor, placements, mesh_size: int, mesh_rank: int): | ||
| local_tensor = tensor | ||
| for placement in placements: | ||
| if isinstance(placement, Shard): | ||
| dim = placement.dim | ||
| dim_size = local_tensor.size(dim) | ||
| base = dim_size // mesh_size | ||
| remainder = dim_size % mesh_size | ||
| start = mesh_rank * base + min(mesh_rank, remainder) | ||
| length = base + (1 if mesh_rank < remainder else 0) | ||
| local_tensor = local_tensor.narrow(dim, start, length).contiguous() | ||
| elif isinstance(placement, Replicate): | ||
| continue | ||
| elif isinstance(placement, Partial): | ||
| raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.') | ||
| else: | ||
| raise NotImplementedError(f'Unsupported DTensor placement: {placement}') | ||
| return local_tensor | ||
|
|
||
| def _ep_expert_local_shape(logical_shape, placements, mesh_size: int, mesh_rank: int): | ||
| shape = list(logical_shape) | ||
| for placement in placements: | ||
| if isinstance(placement, Shard): | ||
| dim = placement.dim | ||
| dim_size = shape[dim] | ||
| base = dim_size // mesh_size | ||
| remainder = dim_size % mesh_size | ||
| shape[dim] = base + (1 if mesh_rank < remainder else 0) | ||
| elif isinstance(placement, Replicate): | ||
| continue | ||
| elif isinstance(placement, Partial): | ||
| raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.') | ||
| else: | ||
| raise NotImplementedError(f'Unsupported DTensor placement: {placement}') | ||
| return tuple(shape) | ||
|
|
||
| def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): | ||
| spec = expert_shard_specs[param_name] | ||
| experts_per_rank = spec['experts_per_rank'] | ||
| num_experts = spec['num_experts'] | ||
| local_shape = tuple(sharded_param.size()) | ||
| if param_name not in source_metadata: | ||
| raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.") | ||
| _, source_dtype = source_metadata[param_name] | ||
| if isinstance(sharded_param, DTensor): | ||
| logical_shape = tuple(sharded_param.size()) | ||
| placements = sharded_param.placements | ||
| device_mesh = sharded_param.device_mesh | ||
| mesh_size = device_mesh.size() | ||
| else: | ||
| logical_shape = tuple(sharded_param.size()) | ||
| placements = (Replicate(), ) | ||
| device_mesh = None | ||
| mesh_size = 1 | ||
|
|
||
| current_rank = dist.get_rank() | ||
| current_mesh_rank = rank_to_ep_fsdp_rank.get(current_rank, 0) | ||
| local_shape = _ep_expert_local_shape(logical_shape, placements, mesh_size, current_mesh_rank) | ||
| local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) | ||
|
|
||
| def _rank_local_shape(rank: int): | ||
| if rank not in rank_to_ep_fsdp_rank: | ||
| raise RuntimeError(f'Missing EP-FSDP rank mapping for global rank {rank}.') | ||
| return _ep_expert_local_shape(logical_shape, placements, mesh_size, rank_to_ep_fsdp_rank[rank]) | ||
|
|
||
| if is_rank0: | ||
| if full_tensor.size(0) != num_experts: | ||
| raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, " | ||
| f'but source state has shape {tuple(full_tensor.shape)}. ' | ||
| 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().') | ||
| world_size = dist.get_world_size() | ||
| for rank in range(world_size): | ||
| if 0 not in rank_to_ep_rank or 0 not in rank_to_ep_fsdp_rank: | ||
| raise RuntimeError('Missing EP/EP-FSDP rank mapping for global rank 0.') | ||
| ep_rank = rank_to_ep_rank[0] | ||
| ep_fsdp_rank = rank_to_ep_fsdp_rank[0] | ||
| start = ep_rank * experts_per_rank | ||
| end = start + experts_per_rank | ||
| chunk = full_tensor[start:end].contiguous() | ||
| local_chunk = _shard_tensor_for_mesh_rank(chunk, placements, mesh_size, ep_fsdp_rank) | ||
| local_tensor.copy_(local_chunk.to(device_type)) | ||
| del chunk, local_chunk | ||
|
|
||
| for rank in range(1, world_size): | ||
| if rank not in rank_to_ep_rank: | ||
| raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') | ||
| ep_rank = rank_to_ep_rank[rank] | ||
| ep_fsdp_rank = rank_to_ep_fsdp_rank[rank] | ||
| start = ep_rank * experts_per_rank | ||
| end = start + experts_per_rank | ||
| chunk = full_tensor[start:end].contiguous() | ||
| chunk_gpu = chunk.to(device_type) | ||
| if rank == 0: | ||
| local_tensor.copy_(chunk_gpu) | ||
| else: | ||
| dist.send(chunk_gpu, dst=rank) | ||
| local_chunk = _shard_tensor_for_mesh_rank(chunk, placements, mesh_size, ep_fsdp_rank) | ||
| chunk_gpu = local_chunk.to(device_type) | ||
| dist.broadcast(chunk_gpu, src=0) | ||
| torch_util.synchronize() | ||
| del chunk, local_chunk, chunk_gpu | ||
| else: | ||
| dist.recv(local_tensor, src=0) | ||
| world_size = dist.get_world_size() | ||
| for rank in range(1, world_size): | ||
| recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype) | ||
| dist.broadcast(recv_tensor, src=0) | ||
| torch_util.synchronize() | ||
|
Comment on lines
+740
to
+747
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly calling |
||
| if current_rank == rank: | ||
| local_tensor.copy_(recv_tensor) | ||
| del recv_tensor | ||
|
Comment on lines
+744
to
+750
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of
Consider using |
||
|
|
||
| if isinstance(sharded_param, DTensor): | ||
| return DTensor.from_local( | ||
| local_tensor, | ||
| device_mesh=device_mesh, | ||
| placements=placements, | ||
| run_check=False, | ||
| shape=logical_shape, | ||
| ) | ||
| return local_tensor | ||
|
|
||
| for param_name, sharded_param in meta_sharded_sd.items(): | ||
|
|
@@ -687,7 +784,10 @@ def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): | |
| source_shape, device=device_type, dtype=source_dtype) | ||
|
|
||
| if is_ep_expert_param: | ||
| full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param) | ||
| sharded_sd[param_name] = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param) | ||
| del full_tensor | ||
| torch_util.synchronize() | ||
| continue | ||
| else: | ||
| if tuple(sharded_param.size()) != tuple(source_shape): | ||
| raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR description states that this change supports using each node's
local rank0as the weight loading source to reduce global rank 0 fan-out pressure. However, the implementation in_broadcast_sharded_state_dictstill relies exclusively on global rank 0 (dist.get_rank() == 0) as the source for all metadata and weight broadcasts. This does not appear to implement the hierarchical loading logic mentioned in the description.