Skip to content

[Feature] EP support LoRA SFT#198

Open
kevssim wants to merge 42 commits into
modelscope:mainfrom
meichangsu1:ep_lora
Open

[Feature] EP support LoRA SFT#198
kevssim wants to merge 42 commits into
modelscope:mainfrom
meichangsu1:ep_lora

Conversation

@kevssim
Copy link
Copy Markdown
Collaborator

@kevssim kevssim commented May 20, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

EP support LoRA SFT

meichangsu1 and others added 30 commits May 11, 2026 14:52
- Add DeepseekV4Template class that overrides chat template encoding logic for DeepSeek V4
- Export DeepseekV4Template from template module
- Update documentation to describe the new template class and its purpose
Co-authored-by: Copilot <copilot@github.com>
Remove the `_ep_debug` function and all its calls throughout the expert parallel module to clean up production code. The debug logging was gated behind the `TWINKLE_EP_DEBUG` environment variable but added unnecessary complexity and performance overhead in normal operation. Also remove unused imports (`os`, `time`) and the `_ep_debug_name` attribute assignment on blocks.
The sync_after_backward feature and its associated configuration have been removed as they are no longer needed for maintaining EP/FSDP collective ordering. This simplifies the codebase by eliminating unused synchronization logic.
…ECONDS

Add support for customizing the distributed process group timeout through the environment variable `TWINKLE_DIST_TIMEOUT_SECONDS`, defaulting to 7200 seconds (2 hours). This prevents timeout errors in long-running distributed training jobs.
…ffers

Some Transformers models keep persistent buffers in state_dict that are not DTensors. The original accelerate FSDP2 load function assumed all entries have device_mesh, causing failures. This patch adds a monkey-patch to handle both DTensor parameters and regular tensor buffers during state dict loading.
When loading full state dict in FSDP2, the parameter might not exist in the model (e.g., when using tied weights or shared parameters). This change adds a try-except to handle AttributeError when getting the parameter, and only applies contiguous conversion when the parameter exists and is contiguous.
…e FSDP configuration

Add helper functions for parsing boolean, torch dtype, and lora target modules from environment variables. Introduce new config options for low CPU memory usage, memory efficient init, model dtype, and flexible FSDP/DP/EP sizing. Improve accelerate FSDP2 state dict loading with parameter dtype inference and contiguous casting.
Add debug logging to FSDP2 state dict loading and model preparation, controlled by TWINKLE_FSDP_DEBUG environment variable. Also optimize by returning early with original implementation when running on CUDA devices.
…d state loading traces

- Replace logger.info with print in barrier_if_distributed for immediate output
- Add train_debug function with rank and local_rank info for training step debugging
- Add debug logs for forward_backward and clip_grad_and_step in first 2 steps
- Enhance fsdp2_load_full_state_dict patching with detailed per-parameter state loading traces
- Rename `train_debug` to `debug_print` for clarity
- Add timestamp to debug messages for better tracing
- Support writing debug logs to file via `TWINKLE_DEBUG_DIR` env var
- Improve tensor debug info in FSDP2 state loading
- Remove redundant accelerate FSDP2 delegation path
Co-authored-by: Copilot <copilot@github.com>
- Fix accelerate FSDP2 patch to use `distribute_tensor` for CUDA devices
- Refactor native FSDP to handle rank0-only loading and broadcast
- Add EP expert shard specs collection and rank mapping for state dict broadcast
- Fix non-persistent buffer handling to use broadcast instead of restore
…elism

- Fix indentation error in `base.py` for `device_id` assignment
- Add `_native_fsdp_debug` utility function for FSDP debugging
- Implement pre-EP full state dict capture to avoid redundant state_dict calls
- Add debug logging for expert scatter operations in FSDP
Add `_ep_trace` function and debug logging to MoE expert parallelism operations for better debugging and performance analysis. The trace logs all-to-all and all-gather operations with tensor shapes and split sizes, controlled by `TWINKLE_EP_DEBUG` or `TWINKLE_FSDP_DEBUG` environment variables.
Add a `tag` parameter to `_AllToAll` and `debug_tag` to token permutation functions to enhance traceability of all-to-all operations in both forward and backward passes. This enables better debugging by identifying specific all-to-all calls in trace logs.
- Return `permuted_tokens` instead of `torch.empty_like` in EP forward to maintain backward all-to-all path
- Add `_apply_gate` support for custom gating in MoE expert computation
- Broadcast source metadata (shapes/dtypes) from rank 0 in FSDP to ensure correct dtype for EP expert tensors
- Validate source metadata consistency to prevent silent dtype/shape mismatches during state dict broadcast
Remove the `_ep_trace` debug logging function and all its calls from expert parallelism utilities. This cleanup eliminates verbose debug output controlled by `TWINKLE_EP_DEBUG` environment variable, simplifying the code and removing unnecessary I/O overhead in production. Also removes the `tag` parameter from `_AllToAll` and `all_to_all` functions that was only used for debug tracing.
Move expert parameter tensors to GPU lazily during chunk scattering instead of loading all at once, preventing OOM when handling large expert parameters.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for DeepSeek-V4 and Qwen3.5-MoE models, implementing custom chat-template encoding and Expert Parallel (EP) with LoRA SFT. Key changes include FSDP2 state-dict loading patches in the accelerate strategy, expert parameter broadcasting in NativeFSDPStrategy, and new cookbooks. Feedback highlights high-severity memory pressure risks when using all_gather for expert parameters, suggesting dist.gather instead. Other improvements include addressing the shadowing of the built-in eval function, correcting non-idiomatic static method usage, and making target parameters configurable to support diverse MoE architectures.

Comment thread src/twinkle/model/transformers/strategy/native_fsdp.py
Comment thread cookbook/transformers/deepseek_v4.py
Comment thread src/twinkle/model/transformers/transformers.py Outdated
Comment thread src/twinkle/model/transformers/transformers.py Outdated
Comment thread src/twinkle/model/transformers/transformers.py Outdated
@kevssim kevssim marked this pull request as ready for review May 21, 2026 09:20
@kevssim kevssim changed the title [Feature] EP support LoRA [Feature] EP support LoRA SFT May 21, 2026
@kevssim
Copy link
Copy Markdown
Collaborator Author

kevssim commented May 21, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Expert Parallelism (EP) combined with FSDP2 and LoRA, adding dedicated cookbooks for DeepSeek-V4 and Qwen3.5-MoE. Key architectural changes include refactoring NativeFSDPStrategy and AccelerateStrategy to support EP-aware LoRA weight sharding, gathering, and state-dict broadcasting, alongside a new configurable timeout for distributed initialization. Review feedback identifies inconsistencies in the sharding and gathering dimensions for LoRA weights under EP, suggesting they should align with the expert dimension (dimension 0). Additionally, a robustness improvement was suggested for the AccelerateStrategy state-dict loading logic to prevent potential errors when parsing parameter names.

Comment on lines +361 to +364
def _ep_expert_state_dict_gather_dim(name: str) -> int:
if 'lora_B' in name:
return 1
return 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for gathering lora_B on dimension 1 while gathering lora_A and base weights on dimension 0 is inconsistent for Expert Parallelism (EP). In EP, weights are typically sharded along the expert dimension (dimension 0 for both A and B if they are 3D tensors [num_experts, r, in] and [num_experts, out, r]).

If the intent is instead to implement LoRA Rank Parallelism (sharding the rank r for 2D LoRA weights [r, in] and [out, r]), then an all_reduce (sum) is required in the forward pass to combine the partial results from each rank, which appears to be missing in the MoE forward patch.

If these are per-expert LoRA weights, they should be sharded and gathered along the expert dimension (dim 0).

Suggested change
def _ep_expert_state_dict_gather_dim(name: str) -> int:
if 'lora_B' in name:
return 1
return 0
def _ep_expert_state_dict_gather_dim(name: str) -> int:
# For Expert Parallelism, we typically shard along the expert dimension (dim 0).
# If LoRA weights are per-expert, they should also be gathered on dim 0.
return 0

Comment on lines +722 to +727
if 'lora_A' in model_key:
chunk = value.size(0) // ep_world_size
return value.narrow(0, ep_rank * chunk, chunk).contiguous()
if 'lora_B' in model_key:
chunk = value.size(1) // ep_world_size
return value.narrow(1, ep_rank * chunk, chunk).contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the gathering logic, sharding lora_B on dimension 1 while sharding lora_A on dimension 0 is problematic for Expert Parallelism. If the LoRA weights are per-expert, they should both be sliced along the expert dimension (dimension 0). If they are shared and you are sharding the rank r, please verify that the forward pass includes the necessary synchronization (e.g., all_reduce).

Suggested change
if 'lora_A' in model_key:
chunk = value.size(0) // ep_world_size
return value.narrow(0, ep_rank * chunk, chunk).contiguous()
if 'lora_B' in model_key:
chunk = value.size(1) // ep_world_size
return value.narrow(1, ep_rank * chunk, chunk).contiguous()
if 'lora_A' in model_key or 'lora_B' in model_key:
chunk = value.size(0) // ep_world_size
return value.narrow(0, ep_rank * chunk, chunk).contiguous()

except AttributeError:
# Need this for LoRA, as some params are not registered as
# parameters/buffers but still appear in the state dict.
base_param_name, local_param_name = param_name.rsplit('.', 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The rsplit('.', 1) call assumes that param_name always contains at least one dot. While this is generally true for state dict keys in Transformers models, it's safer to handle cases where it might not, to avoid a ValueError during unpacking.

Suggested change
base_param_name, local_param_name = param_name.rsplit('.', 1)
if '.' in param_name:
base_param_name, local_param_name = param_name.rsplit('.', 1)
submodule = model.get_submodule(base_param_name)
old_param = getattr(submodule, local_param_name)
else:
old_param = getattr(model, param_name)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants