From afe6dddba96386f5e13ed3bb71c5134e1df7904e Mon Sep 17 00:00:00 2001 From: qq_30035749 Date: Thu, 21 May 2026 16:45:25 +0800 Subject: [PATCH] fix: correct indentation and add rank_to_ep_fsdp_rank mapping for expert parallel FSDP loading --- cookbook/transformers/deepseek_v4_flash.py | 10 +- .../transformers/strategy/native_fsdp.py | 118 ++++++++++++++++-- 2 files changed, 114 insertions(+), 14 deletions(-) diff --git a/cookbook/transformers/deepseek_v4_flash.py b/cookbook/transformers/deepseek_v4_flash.py index 869f4cc8..93d8aefc 100644 --- a/cookbook/transformers/deepseek_v4_flash.py +++ b/cookbook/transformers/deepseek_v4_flash.py @@ -88,11 +88,11 @@ def train(): ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES, fsdp_config={ 'reshard_after_forward': RESHARD_AFTER_FORWARD, - 'expert_parallel': { - 'enabled': True, - 'router_dtype': 'fp32', - 'keep_router_logits': False, - }, + 'expert_parallel': { + 'enabled': True, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + }, }, ) diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 9e3bbad7..bdebc3c5 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -151,6 +151,7 @@ def wrap_model(self, model, optimizer=None): device_type=device_type, expert_shard_specs=_collect_ep_expert_shard_specs(model), rank_to_ep_rank=_build_rank_to_ep_rank(self.ep_fsdp_device_mesh), + rank_to_ep_fsdp_rank=_build_rank_to_ep_fsdp_rank(self.ep_fsdp_device_mesh), ) else: _load_rank0_full_state_dict(model, original_sd or {}) @@ -462,6 +463,19 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di return rank_to_ep_rank +def _build_rank_to_ep_fsdp_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Dict[int, int]: + if ep_fsdp_device_mesh is None: + return {} + mesh = ep_fsdp_device_mesh.mesh + if hasattr(mesh, 'detach'): + mesh = mesh.detach().cpu().numpy() + rank_to_ep_fsdp_rank = {} + for ep_fsdp_rank in range(mesh.shape[1]): + for rank in mesh[:, ep_fsdp_rank].flatten().tolist(): + rank_to_ep_fsdp_rank[int(rank)] = int(ep_fsdp_rank) + return rank_to_ep_fsdp_rank + + def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: """Find the experts module inside a decoder layer, if any.""" for module in layer_mod.modules(): @@ -583,6 +597,7 @@ def _broadcast_sharded_state_dict( device_type: str = 'cuda', expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None, rank_to_ep_rank: Optional[Dict[int, int]] = None, + rank_to_ep_fsdp_rank: Optional[Dict[int, int]] = None, ) -> None: """Broadcast rank0 full state dict and materialize local FSDP2/EP shards.""" from torch.distributed.tensor import DTensor, Partial, Replicate, Shard @@ -592,6 +607,7 @@ def _broadcast_sharded_state_dict( is_rank0 = (dist.get_rank() == 0) expert_shard_specs = expert_shard_specs or {} rank_to_ep_rank = rank_to_ep_rank or {} + rank_to_ep_fsdp_rank = rank_to_ep_fsdp_rank or {} source_metadata = None if is_rank0: @@ -629,37 +645,118 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): stride=full_tensor.stride(), ) + def _shard_tensor_for_mesh_rank(tensor, placements, mesh_size: int, mesh_rank: int): + local_tensor = tensor + for placement in placements: + if isinstance(placement, Shard): + dim = placement.dim + dim_size = local_tensor.size(dim) + base = dim_size // mesh_size + remainder = dim_size % mesh_size + start = mesh_rank * base + min(mesh_rank, remainder) + length = base + (1 if mesh_rank < remainder else 0) + local_tensor = local_tensor.narrow(dim, start, length).contiguous() + elif isinstance(placement, Replicate): + continue + elif isinstance(placement, Partial): + raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.') + else: + raise NotImplementedError(f'Unsupported DTensor placement: {placement}') + return local_tensor + + def _ep_expert_local_shape(logical_shape, placements, mesh_size: int, mesh_rank: int): + shape = list(logical_shape) + for placement in placements: + if isinstance(placement, Shard): + dim = placement.dim + dim_size = shape[dim] + base = dim_size // mesh_size + remainder = dim_size % mesh_size + shape[dim] = base + (1 if mesh_rank < remainder else 0) + elif isinstance(placement, Replicate): + continue + elif isinstance(placement, Partial): + raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.') + else: + raise NotImplementedError(f'Unsupported DTensor placement: {placement}') + return tuple(shape) + def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): spec = expert_shard_specs[param_name] experts_per_rank = spec['experts_per_rank'] num_experts = spec['num_experts'] - local_shape = tuple(sharded_param.size()) if param_name not in source_metadata: raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.") _, source_dtype = source_metadata[param_name] + if isinstance(sharded_param, DTensor): + logical_shape = tuple(sharded_param.size()) + placements = sharded_param.placements + device_mesh = sharded_param.device_mesh + mesh_size = device_mesh.size() + else: + logical_shape = tuple(sharded_param.size()) + placements = (Replicate(), ) + device_mesh = None + mesh_size = 1 + + current_rank = dist.get_rank() + current_mesh_rank = rank_to_ep_fsdp_rank.get(current_rank, 0) + local_shape = _ep_expert_local_shape(logical_shape, placements, mesh_size, current_mesh_rank) local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) + def _rank_local_shape(rank: int): + if rank not in rank_to_ep_fsdp_rank: + raise RuntimeError(f'Missing EP-FSDP rank mapping for global rank {rank}.') + return _ep_expert_local_shape(logical_shape, placements, mesh_size, rank_to_ep_fsdp_rank[rank]) + if is_rank0: if full_tensor.size(0) != num_experts: raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, " f'but source state has shape {tuple(full_tensor.shape)}. ' 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().') world_size = dist.get_world_size() - for rank in range(world_size): + if 0 not in rank_to_ep_rank or 0 not in rank_to_ep_fsdp_rank: + raise RuntimeError('Missing EP/EP-FSDP rank mapping for global rank 0.') + ep_rank = rank_to_ep_rank[0] + ep_fsdp_rank = rank_to_ep_fsdp_rank[0] + start = ep_rank * experts_per_rank + end = start + experts_per_rank + chunk = full_tensor[start:end].contiguous() + local_chunk = _shard_tensor_for_mesh_rank(chunk, placements, mesh_size, ep_fsdp_rank) + local_tensor.copy_(local_chunk.to(device_type)) + del chunk, local_chunk + + for rank in range(1, world_size): if rank not in rank_to_ep_rank: raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') ep_rank = rank_to_ep_rank[rank] + ep_fsdp_rank = rank_to_ep_fsdp_rank[rank] start = ep_rank * experts_per_rank end = start + experts_per_rank chunk = full_tensor[start:end].contiguous() - chunk_gpu = chunk.to(device_type) - if rank == 0: - local_tensor.copy_(chunk_gpu) - else: - dist.send(chunk_gpu, dst=rank) + local_chunk = _shard_tensor_for_mesh_rank(chunk, placements, mesh_size, ep_fsdp_rank) + chunk_gpu = local_chunk.to(device_type) + dist.broadcast(chunk_gpu, src=0) + torch_util.synchronize() + del chunk, local_chunk, chunk_gpu else: - dist.recv(local_tensor, src=0) + world_size = dist.get_world_size() + for rank in range(1, world_size): + recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype) + dist.broadcast(recv_tensor, src=0) + torch_util.synchronize() + if current_rank == rank: + local_tensor.copy_(recv_tensor) + del recv_tensor + if isinstance(sharded_param, DTensor): + return DTensor.from_local( + local_tensor, + device_mesh=device_mesh, + placements=placements, + run_check=False, + shape=logical_shape, + ) return local_tensor for param_name, sharded_param in meta_sharded_sd.items(): @@ -687,7 +784,10 @@ def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): source_shape, device=device_type, dtype=source_dtype) if is_ep_expert_param: - full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param) + sharded_sd[param_name] = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param) + del full_tensor + torch_util.synchronize() + continue else: if tuple(sharded_param.size()) != tuple(source_shape): raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "