Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions cookbook/transformers/deepseek_v4_flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def train():
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': RESHARD_AFTER_FORWARD,
'expert_parallel': {
'enabled': True,
'router_dtype': 'fp32',
'keep_router_logits': False,
},
'expert_parallel': {
'enabled': True,
'router_dtype': 'fp32',
'keep_router_logits': False,
},
},
)

Expand Down
118 changes: 109 additions & 9 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def wrap_model(self, model, optimizer=None):
device_type=device_type,
expert_shard_specs=_collect_ep_expert_shard_specs(model),
rank_to_ep_rank=_build_rank_to_ep_rank(self.ep_fsdp_device_mesh),
rank_to_ep_fsdp_rank=_build_rank_to_ep_fsdp_rank(self.ep_fsdp_device_mesh),
)
else:
_load_rank0_full_state_dict(model, original_sd or {})
Expand Down Expand Up @@ -462,6 +463,19 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di
return rank_to_ep_rank


def _build_rank_to_ep_fsdp_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Dict[int, int]:
if ep_fsdp_device_mesh is None:
return {}
mesh = ep_fsdp_device_mesh.mesh
if hasattr(mesh, 'detach'):
mesh = mesh.detach().cpu().numpy()
rank_to_ep_fsdp_rank = {}
for ep_fsdp_rank in range(mesh.shape[1]):
for rank in mesh[:, ep_fsdp_rank].flatten().tolist():
rank_to_ep_fsdp_rank[int(rank)] = int(ep_fsdp_rank)
return rank_to_ep_fsdp_rank


def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""Find the experts module inside a decoder layer, if any."""
for module in layer_mod.modules():
Expand Down Expand Up @@ -583,6 +597,7 @@ def _broadcast_sharded_state_dict(
device_type: str = 'cuda',
expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None,
rank_to_ep_rank: Optional[Dict[int, int]] = None,
rank_to_ep_fsdp_rank: Optional[Dict[int, int]] = None,
) -> None:
"""Broadcast rank0 full state dict and materialize local FSDP2/EP shards."""
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
Expand All @@ -592,6 +607,7 @@ def _broadcast_sharded_state_dict(
is_rank0 = (dist.get_rank() == 0)
expert_shard_specs = expert_shard_specs or {}
rank_to_ep_rank = rank_to_ep_rank or {}
rank_to_ep_fsdp_rank = rank_to_ep_fsdp_rank or {}

source_metadata = None
if is_rank0:
Comment on lines 607 to 613
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The PR description states that this change supports using each node's local rank0 as the weight loading source to reduce global rank 0 fan-out pressure. However, the implementation in _broadcast_sharded_state_dict still relies exclusively on global rank 0 (dist.get_rank() == 0) as the source for all metadata and weight broadcasts. This does not appear to implement the hierarchical loading logic mentioned in the description.

Expand Down Expand Up @@ -629,37 +645,118 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
stride=full_tensor.stride(),
)

def _shard_tensor_for_mesh_rank(tensor, placements, mesh_size: int, mesh_rank: int):
local_tensor = tensor
for placement in placements:
if isinstance(placement, Shard):
dim = placement.dim
dim_size = local_tensor.size(dim)
base = dim_size // mesh_size
remainder = dim_size % mesh_size
start = mesh_rank * base + min(mesh_rank, remainder)
length = base + (1 if mesh_rank < remainder else 0)
local_tensor = local_tensor.narrow(dim, start, length).contiguous()
elif isinstance(placement, Replicate):
continue
elif isinstance(placement, Partial):
raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.')
else:
raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
return local_tensor

def _ep_expert_local_shape(logical_shape, placements, mesh_size: int, mesh_rank: int):
shape = list(logical_shape)
for placement in placements:
if isinstance(placement, Shard):
dim = placement.dim
dim_size = shape[dim]
base = dim_size // mesh_size
remainder = dim_size % mesh_size
shape[dim] = base + (1 if mesh_rank < remainder else 0)
elif isinstance(placement, Replicate):
continue
elif isinstance(placement, Partial):
raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.')
else:
raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
return tuple(shape)

def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param):
spec = expert_shard_specs[param_name]
experts_per_rank = spec['experts_per_rank']
num_experts = spec['num_experts']
local_shape = tuple(sharded_param.size())
if param_name not in source_metadata:
raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.")
_, source_dtype = source_metadata[param_name]
if isinstance(sharded_param, DTensor):
logical_shape = tuple(sharded_param.size())
placements = sharded_param.placements
device_mesh = sharded_param.device_mesh
mesh_size = device_mesh.size()
else:
logical_shape = tuple(sharded_param.size())
placements = (Replicate(), )
device_mesh = None
mesh_size = 1

current_rank = dist.get_rank()
current_mesh_rank = rank_to_ep_fsdp_rank.get(current_rank, 0)
local_shape = _ep_expert_local_shape(logical_shape, placements, mesh_size, current_mesh_rank)
local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)

def _rank_local_shape(rank: int):
if rank not in rank_to_ep_fsdp_rank:
raise RuntimeError(f'Missing EP-FSDP rank mapping for global rank {rank}.')
return _ep_expert_local_shape(logical_shape, placements, mesh_size, rank_to_ep_fsdp_rank[rank])

if is_rank0:
if full_tensor.size(0) != num_experts:
raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, "
f'but source state has shape {tuple(full_tensor.shape)}. '
'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
world_size = dist.get_world_size()
for rank in range(world_size):
if 0 not in rank_to_ep_rank or 0 not in rank_to_ep_fsdp_rank:
raise RuntimeError('Missing EP/EP-FSDP rank mapping for global rank 0.')
ep_rank = rank_to_ep_rank[0]
ep_fsdp_rank = rank_to_ep_fsdp_rank[0]
start = ep_rank * experts_per_rank
end = start + experts_per_rank
chunk = full_tensor[start:end].contiguous()
local_chunk = _shard_tensor_for_mesh_rank(chunk, placements, mesh_size, ep_fsdp_rank)
local_tensor.copy_(local_chunk.to(device_type))
del chunk, local_chunk

for rank in range(1, world_size):
if rank not in rank_to_ep_rank:
raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.')
ep_rank = rank_to_ep_rank[rank]
ep_fsdp_rank = rank_to_ep_fsdp_rank[rank]
start = ep_rank * experts_per_rank
end = start + experts_per_rank
chunk = full_tensor[start:end].contiguous()
chunk_gpu = chunk.to(device_type)
if rank == 0:
local_tensor.copy_(chunk_gpu)
else:
dist.send(chunk_gpu, dst=rank)
local_chunk = _shard_tensor_for_mesh_rank(chunk, placements, mesh_size, ep_fsdp_rank)
chunk_gpu = local_chunk.to(device_type)
dist.broadcast(chunk_gpu, src=0)
torch_util.synchronize()
del chunk, local_chunk, chunk_gpu
else:
dist.recv(local_tensor, src=0)
world_size = dist.get_world_size()
for rank in range(1, world_size):
recv_tensor = torch.empty(_rank_local_shape(rank), device=device_type, dtype=source_dtype)
dist.broadcast(recv_tensor, src=0)
torch_util.synchronize()
Comment on lines +740 to +747
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly calling torch_util.synchronize() (which likely performs a device-wide synchronization) inside the loop after every broadcast will significantly degrade performance. Standard dist.broadcast calls on CUDA/NPU tensors are already synchronized with respect to the communication stream. Draining the GPU command pipeline in every iteration will make the initialization process extremely slow, especially for models with many layers.

if current_rank == rank:
local_tensor.copy_(recv_tensor)
del recv_tensor
Comment on lines +744 to +750
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of _scatter_ep_expert_tensor using dist.broadcast in a loop over all ranks is highly inefficient.

  1. Communication Complexity: This approach results in O(N^2) total communication volume (where N is the world size) because every rank participates in every broadcast, receiving data intended for every other rank and discarding it. For large clusters (e.g., 1024+ GPUs), this will be a massive bottleneck.
  2. Memory Overhead: Every rank performs a GPU allocation (torch.empty) in every iteration of the loop to match the shape of the broadcasted chunk. This adds significant allocation/deallocation overhead.

Consider using dist.scatter if the shards can be prepared on rank 0, or stick to point-to-point dist.send/dist.recv which only moves O(Weight) data in total. If send/recv is unstable, a hierarchical broadcast (to local rank 0s first) would be more appropriate.


if isinstance(sharded_param, DTensor):
return DTensor.from_local(
local_tensor,
device_mesh=device_mesh,
placements=placements,
run_check=False,
shape=logical_shape,
)
return local_tensor

for param_name, sharded_param in meta_sharded_sd.items():
Expand Down Expand Up @@ -687,7 +784,10 @@ def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param):
source_shape, device=device_type, dtype=source_dtype)

if is_ep_expert_param:
full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
sharded_sd[param_name] = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
del full_tensor
torch_util.synchronize()
continue
else:
if tuple(sharded_param.size()) != tuple(source_shape):
raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
Expand Down
Loading