From 6b57108a6300c5163a0c138d7bb78aa2988279d8 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 28 Apr 2026 23:17:57 +0800
Subject: [PATCH 01/40] docs: add DeepseekV4Template for DeepSeek V4 chat
template encoding
- Add DeepseekV4Template class that overrides chat template encoding logic for DeepSeek V4
- Export DeepseekV4Template from template module
- Update documentation to describe the new template class and its purpose
---
cookbook/transformers/deepseek_v4.py | 124 ++++
.../\346\250\241\346\235\277/Template.md" | 1 +
src/twinkle/template/__init__.py | 1 +
src/twinkle/template/deepseek_v4.py | 138 +++++
src/twinkle/template/deepseek_v4_encoding.py | 570 ++++++++++++++++++
5 files changed, 834 insertions(+)
create mode 100644 cookbook/transformers/deepseek_v4.py
create mode 100644 src/twinkle/template/deepseek_v4.py
create mode 100644 src/twinkle/template/deepseek_v4_encoding.py
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
new file mode 100644
index 00000000..b7f03a13
--- /dev/null
+++ b/cookbook/transformers/deepseek_v4.py
@@ -0,0 +1,124 @@
+import os
+
+import twinkle
+from peft import LoraConfig
+from transformers import AutoConfig
+from twinkle import DeviceMesh, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
+OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
+
+_num_layers_env = os.environ.get('NUM_LAYERS')
+NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
+
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '2'))
+GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2'))
+LR = float(os.environ.get('LR', '1e-4'))
+MAX_STEPS = int(os.environ.get('MAX_STEPS', '0'))
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50'))
+USE_LORA = os.environ.get('USE_LORA', '1') == '1'
+IGNORE_MISMATCHED_SIZES = os.environ.get('IGNORE_MISMATCHED_SIZES', '1') == '1'
+LORA_TARGET_MODULES = os.environ.get(
+ 'LORA_TARGET_MODULES',
+ 'wq_a,wq_b,wkv,wgate,gate_proj,up_proj,down_proj',
+)
+
+device_mesh = DeviceMesh.from_sizes(fsdp_size=2)
+
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def create_dataset(data_slice=None):
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000)))
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
+ dataset.encode(batched=True)
+ return dataset
+
+
+def eval(model):
+ dataset = create_dataset(data_slice=range(100))
+ dataloader = DataLoader(dataset=dataset, batch_size=max(1, BATCH_SIZE // 2))
+ for _, batch in enumerate(dataloader):
+ model.forward_only(inputs=batch, adapter_name='default')
+ model.calculate_loss(adapter_name='default')
+ return model.calculate_metric(is_training=False, adapter_name='default')
+
+
+def train():
+ dataset = create_dataset()
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
+
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = NUM_LAYERS
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
+ )
+
+ if USE_LORA:
+ lora_target_modules = [name.strip() for name in LORA_TARGET_MODULES.split(',') if name.strip()]
+ lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=lora_target_modules)
+ model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+
+ model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name='default')
+ model.set_optimizer('AdamW', lr=LR, adapter_name='default')
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ adapter_name='default',
+ )
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs(adapter_name='default'))
+ logger.info(
+ f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, '
+ f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, use_lora={USE_LORA}, '
+ f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, '
+ f'lora_target_modules={LORA_TARGET_MODULES}')
+
+ best_loss = float('inf')
+ for step, batch in enumerate(dataloader):
+ if MAX_STEPS and step >= MAX_STEPS:
+ break
+ model.forward_backward(
+ inputs=batch,
+ adapter_name='default',
+ )
+ model.clip_grad_and_step(
+ adapter_name='default',
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
+ )
+
+ if step % 20 == 0:
+ metric = model.calculate_metric(is_training=True, adapter_name='default')
+ logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
+
+ if step > 0 and step % SAVE_STEPS == 0:
+ metrics = eval(model)
+ logger.info(f'Eval metric: {metrics}')
+ loss = float(metrics['loss'])
+ if loss < best_loss:
+ model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name='default')
+ best_loss = loss
+
+ model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name='default')
+
+
+if __name__ == '__main__':
+ train()
diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
index 6e77b415..364275b6 100644
--- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
+++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md"
@@ -52,4 +52,5 @@ class Template:
目前模板关系较为简单:
- Template类:纯文本模型通用
+- DeepseekV4Template类:DeepSeek V4 使用,重写了 chat template 编码逻辑,`encode_messages` 已内置在 twinkle 中
- Qwen3_5Template类:Qwen3.5多模态模型使用
diff --git a/src/twinkle/template/__init__.py b/src/twinkle/template/__init__.py
index 324ce7ac..6c4bdddd 100644
--- a/src/twinkle/template/__init__.py
+++ b/src/twinkle/template/__init__.py
@@ -1,3 +1,4 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from .base import Template
+from .deepseek_v4 import DeepseekV4Template
from .qwen3_5_vl import Qwen3_5Template
diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py
new file mode 100644
index 00000000..7b33dd8a
--- /dev/null
+++ b/src/twinkle/template/deepseek_v4.py
@@ -0,0 +1,138 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import copy
+from typing import Any, Literal, Optional
+
+import torch
+from transformers import AutoConfig, PreTrainedTokenizerFast
+
+from twinkle.hub import HubOperation
+
+from .base import Template
+from .deepseek_v4_encoding import encode_messages
+
+
+def get_deepseek_v4_tokenizer(tokenizer):
+ """Wrap a HF tokenizer with DeepSeek V4's custom chat-template encoder."""
+ dsv4_tokenizer = copy.copy(tokenizer)
+
+ added_vocab = tokenizer.get_added_vocab()
+ added_vocab_size = len(added_vocab)
+ tokenizer_vocab_size = tokenizer.vocab_size
+
+ class _DeepseekV4Tokenizer(tokenizer.__class__): # type: ignore[misc, valid-type]
+
+ def apply_chat_template(
+ self,
+ messages,
+ tools: list[dict[str, Any]] | None = None,
+ **kwargs,
+ ):
+ thinking = kwargs.get('thinking', False)
+ enable_thinking = kwargs.get('enable_thinking', False)
+ thinking = thinking or enable_thinking
+ thinking_mode = 'thinking' if thinking else 'chat'
+
+ conversation = kwargs.get('conversation', messages)
+ messages = conversation.copy()
+ if tools:
+ messages.insert(0, {'role': 'system'})
+ messages[0]['tools'] = tools
+
+ reasoning_effort = kwargs.get('reasoning_effort')
+ if reasoning_effort not in ('max', 'high'):
+ reasoning_effort = None
+
+ prompt_str = encode_messages(
+ messages,
+ thinking_mode=thinking_mode,
+ drop_thinking=kwargs.get('drop_thinking', True),
+ reasoning_effort=reasoning_effort,
+ )
+
+ tokenize = kwargs.get('tokenize', True)
+ return_dict = kwargs.get('return_dict', False)
+ return_tensors = kwargs.get('return_tensors')
+
+ if not tokenize:
+ return {'prompt': prompt_str} if return_dict else prompt_str
+
+ tokenizer_kwargs = {
+ key: kwargs[key]
+ for key in ('truncation', 'max_length')
+ if key in kwargs
+ }
+ input_ids = self.encode(
+ prompt_str,
+ add_special_tokens=False,
+ **tokenizer_kwargs,
+ )
+
+ if not return_dict and return_tensors is None:
+ return input_ids
+
+ attention_mask = [1] * len(input_ids)
+ if return_tensors == 'pt':
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
+ attention_mask = torch.tensor([attention_mask], dtype=torch.long)
+ elif return_tensors is not None:
+ raise ValueError(f'Unsupported return_tensors: {return_tensors}')
+
+ encoded = {
+ 'input_ids': input_ids,
+ 'attention_mask': attention_mask,
+ }
+ if kwargs.get('return_assistant_tokens_mask', False):
+ # Fall back to round-by-round labeling in Template by omitting
+ # assistant_masks support from this custom tokenizer wrapper.
+ pass
+ if return_dict:
+ return encoded
+ return encoded['input_ids']
+
+ def num_special_tokens_to_add(self) -> int:
+ return len(self.encode(''))
+
+ def __len__(self) -> int:
+ return tokenizer_vocab_size + added_vocab_size
+
+ def get_added_vocab(self) -> dict[str, int]:
+ return added_vocab.copy()
+
+ _DeepseekV4Tokenizer.__name__ = f'DSV4{tokenizer.__class__.__name__}'
+ dsv4_tokenizer.__class__ = _DeepseekV4Tokenizer
+ return dsv4_tokenizer
+
+
+class DeepseekV4Template(Template):
+
+ def __init__(
+ self,
+ model_id: str,
+ use_chat_template: bool = True,
+ max_length: Optional[int] = 8192,
+ truncation_strategy: Literal['raise', 'left', 'right', 'split'] = 'raise',
+ default_system: Optional[str] = None,
+ enable_thinking: bool = True,
+ **kwargs,
+ ):
+ model_id = HubOperation.download_model(model_id, ignore_model=True)
+ base_tokenizer = PreTrainedTokenizerFast.from_pretrained(model_id, **kwargs)
+ self.processor = get_deepseek_v4_tokenizer(base_tokenizer)
+ self.config = AutoConfig.from_pretrained(model_id, **kwargs)
+
+ self.use_chat_template = use_chat_template
+ self.max_length = max_length
+ self.enable_thinking = enable_thinking
+ self.truncation_strategy = truncation_strategy
+ self.default_system = default_system
+ self._test_support_assistant_tokens_mask()
+ self.pre_pipeline = [
+ self._add_default_system,
+ self._to_standard_reasoning_content,
+ self._build_standard_messages,
+ ]
+ self.post_pipeline = [
+ self._check_max_length,
+ self._add_attention_fields,
+ self._roll_labels,
+ ]
diff --git a/src/twinkle/template/deepseek_v4_encoding.py b/src/twinkle/template/deepseek_v4_encoding.py
new file mode 100644
index 00000000..5f2b06ae
--- /dev/null
+++ b/src/twinkle/template/deepseek_v4_encoding.py
@@ -0,0 +1,570 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+# ruff: noqa
+# fmt: off
+
+"""
+DeepSeek-V4 Encoding
+
+A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
+with tool calling, thinking mode, and quick instruction task support.
+"""
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+import copy
+import json
+
+import regex as re
+
+bos_token: str = "<|begin▁of▁sentence|>"
+eos_token: str = "<|end▁of▁sentence|>"
+thinking_start_token: str = ""
+thinking_end_token: str = ""
+dsml_token: str = "|DSML|"
+
+USER_SP_TOKEN = "<|User|>"
+ASSISTANT_SP_TOKEN = "<|Assistant|>"
+LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
+
+DS_TASK_SP_TOKENS = {
+ "action": "<|action|>",
+ "query": "<|query|>",
+ "authority": "<|authority|>",
+ "domain": "<|domain|>",
+ "title": "<|title|>",
+ "read_url": "<|read_url|>",
+}
+VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
+
+system_msg_template: str = "{content}"
+user_msg_template: str = "{content}"
+latest_reminder_msg_template: str = "{content}"
+assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
+assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
+thinking_template: str = "{reasoning}"
+
+response_format_template: str = (
+ "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
+)
+tool_call_template: str = (
+ "<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>"
+)
+tool_calls_template = (
+ "<{dsml_token}{tc_block_name}>\n{tool_calls}\n{dsml_token}{tc_block_name}>"
+)
+tool_calls_block_name: str = "tool_calls"
+
+tool_output_template: str = (
+ "{content}"
+)
+
+REASONING_EFFORT_MAX = (
+ "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
+ "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
+ "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
+)
+
+TOOLS_TEMPLATE = """## Tools
+
+You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
+
+<{dsml_token}tool_calls>
+<{dsml_token}invoke name="$TOOL_NAME">
+<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE{dsml_token}parameter>
+...
+{dsml_token}invoke>
+<{dsml_token}invoke name="$TOOL_NAME2">
+...
+{dsml_token}invoke>
+{dsml_token}tool_calls>
+
+String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
+
+If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
+
+Otherwise, output directly after {thinking_end_token} with tool calls or final response.
+
+### Available Tool Schemas
+
+{tool_schemas}
+
+You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
+"""
+
+
+def to_json(value: Any) -> str:
+ try:
+ return json.dumps(value, ensure_ascii=False)
+ except Exception:
+ return json.dumps(value, ensure_ascii=True)
+
+
+def tools_from_openai_format(tools):
+ return [tool["function"] for tool in tools]
+
+
+def tool_calls_from_openai_format(tool_calls):
+ return [
+ {
+ "name": tool_call["function"]["name"],
+ "arguments": tool_call["function"]["arguments"],
+ }
+ for tool_call in tool_calls
+ ]
+
+
+def tool_calls_to_openai_format(tool_calls):
+ return [
+ {
+ "type": "function",
+ "function": {
+ "name": tool_call["name"],
+ "arguments": tool_call["arguments"],
+ }
+ }
+ for tool_call in tool_calls
+ ]
+
+
+def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
+ p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}{dsml_token}parameter>'
+ p_dsml_strs = []
+
+ if isinstance(tool_call["arguments"], str):
+ arguments = json.loads(tool_call["arguments"])
+ else:
+ arguments = tool_call["arguments"]
+
+ for k, v in arguments.items():
+ p_dsml_str = p_dsml_template.format(
+ dsml_token=dsml_token,
+ key=k,
+ is_str="true" if isinstance(v, str) else "false",
+ value=v if isinstance(v, str) else to_json(v),
+ )
+ p_dsml_strs.append(p_dsml_str)
+
+ return "\n".join(p_dsml_strs)
+
+
+def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
+ def _decode_value(key: str, value: str, string: str):
+ if string == "true":
+ value = to_json(value)
+ return f"{to_json(key)}: {value}"
+
+ tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
+ return dict(name=tool_name, arguments=tool_args_json)
+
+
+def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
+ tools_json = [to_json(t) for t in tools]
+
+ return TOOLS_TEMPLATE.format(
+ tool_schemas="\n".join(tools_json),
+ dsml_token=dsml_token,
+ thinking_start_token=thinking_start_token,
+ thinking_end_token=thinking_end_token,
+ )
+
+
+def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
+ last_user_index = -1
+ for idx in range(len(messages) - 1, -1, -1):
+ if messages[idx].get("role") in ["user", "developer"]:
+ last_user_index = idx
+ break
+ return last_user_index
+
+
+def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
+ assert 0 <= index < len(messages)
+ assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
+
+ prompt = ""
+ msg = messages[index]
+ last_user_idx = find_last_user_index(messages)
+
+ role = msg.get("role")
+ content = msg.get("content")
+ tools = msg.get("tools")
+ response_format = msg.get("response_format")
+ tool_calls = msg.get("tool_calls")
+ reasoning = msg.get("reasoning")
+ wo_eos = msg.get("wo_eos", False)
+
+ if tools:
+ tools = tools_from_openai_format(tools)
+ if tool_calls:
+ tool_calls = tool_calls_from_openai_format(tool_calls)
+
+ assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
+ if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
+ prompt += REASONING_EFFORT_MAX
+
+ if role == "system":
+ prompt += system_msg_template.format(content=content or "")
+ if tools:
+ prompt += "\n\n" + render_tools(tools)
+ if response_format:
+ prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
+
+ elif role == "developer":
+ assert content, f"Invalid message for role `{role}`: {msg}"
+
+ content_developer = USER_SP_TOKEN
+ content_developer += content
+
+ if tools:
+ content_developer += "\n\n" + render_tools(tools)
+ if response_format:
+ content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
+
+ prompt += user_msg_template.format(content=content_developer)
+
+ elif role == "user":
+ prompt += USER_SP_TOKEN
+
+ content_blocks = msg.get("content_blocks")
+ if content_blocks:
+ parts = []
+ for block in content_blocks:
+ block_type = block.get("type")
+ if block_type == "text":
+ parts.append(block.get("text", ""))
+ elif block_type == "tool_result":
+ tool_content = block.get("content", "")
+ if isinstance(tool_content, list):
+ text_parts = []
+ for b in tool_content:
+ if b.get("type") == "text":
+ text_parts.append(b.get("text", ""))
+ else:
+ text_parts.append(f"[Unsupported {b.get('type')}]")
+ tool_content = "\n\n".join(text_parts)
+ parts.append(tool_output_template.format(content=tool_content))
+ else:
+ parts.append(f"[Unsupported {block_type}]")
+ prompt += "\n\n".join(parts)
+ else:
+ prompt += content or ""
+
+ elif role == "latest_reminder":
+ prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
+
+ elif role == "tool":
+ raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
+
+ elif role == "assistant":
+ thinking_part = ""
+ tc_content = ""
+
+ if tool_calls:
+ tc_list = [
+ tool_call_template.format(
+ dsml_token=dsml_token,
+ name=tc.get("name"),
+ arguments=encode_arguments_to_dsml(tc)
+ )
+ for tc in tool_calls
+ ]
+ tc_content += '\n\n' + tool_calls_template.format(
+ dsml_token=dsml_token,
+ tool_calls="\n".join(tc_list),
+ tc_block_name=tool_calls_block_name,
+ )
+
+ summary_content = content or ""
+ reasoning = reasoning or ""
+
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
+
+ if thinking_mode == "thinking" and not prev_has_task:
+ if not drop_thinking or index > last_user_idx:
+ thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token
+ else:
+ thinking_part = ""
+
+ if wo_eos:
+ prompt += assistant_msg_wo_eos_template.format(
+ reasoning=thinking_part,
+ content=summary_content,
+ tool_calls=tc_content,
+ )
+ else:
+ prompt += assistant_msg_template.format(
+ reasoning=thinking_part,
+ content=summary_content,
+ tool_calls=tc_content,
+ )
+ else:
+ raise NotImplementedError(f"Unknown role: {role}")
+
+ if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
+ return prompt
+
+ task = messages[index].get("task")
+ if task is not None:
+ assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
+ task_sp_token = DS_TASK_SP_TOKENS[task]
+
+ if task != "action":
+ prompt += task_sp_token
+ else:
+ prompt += ASSISTANT_SP_TOKEN
+ prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
+ prompt += task_sp_token
+
+ elif messages[index].get("role") in ["user", "developer"]:
+ prompt += ASSISTANT_SP_TOKEN
+ if not drop_thinking and thinking_mode == "thinking":
+ prompt += thinking_start_token
+ elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
+ prompt += thinking_start_token
+ else:
+ prompt += thinking_end_token
+
+ return prompt
+
+
+def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ merged: List[Dict[str, Any]] = []
+
+ for msg in messages:
+ msg = copy.deepcopy(msg)
+ role = msg.get("role")
+
+ if role == "tool":
+ tool_block = {
+ "type": "tool_result",
+ "tool_use_id": msg.get("tool_call_id", ""),
+ "content": msg.get("content", ""),
+ }
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
+ merged[-1]["content_blocks"].append(tool_block)
+ else:
+ merged.append({
+ "role": "user",
+ "content_blocks": [tool_block],
+ })
+ elif role == "user":
+ text_block = {"type": "text", "text": msg.get("content", "")}
+ if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
+ merged[-1]["content_blocks"].append(text_block)
+ else:
+ new_msg = {
+ "role": "user",
+ "content": msg.get("content", ""),
+ "content_blocks": [text_block],
+ }
+ for key in ("task", "wo_eos", "mask"):
+ if key in msg:
+ new_msg[key] = msg[key]
+ merged.append(new_msg)
+ else:
+ merged.append(msg)
+
+ return merged
+
+
+def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ last_tool_call_order: Dict[str, int] = {}
+
+ for msg in messages:
+ role = msg.get("role")
+ if role == "assistant" and msg.get("tool_calls"):
+ last_tool_call_order = {}
+ for idx, tc in enumerate(msg["tool_calls"]):
+ tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
+ if tc_id:
+ last_tool_call_order[tc_id] = idx
+
+ elif role == "user" and msg.get("content_blocks"):
+ tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
+ if len(tool_blocks) > 1 and last_tool_call_order:
+ sorted_blocks = sorted(
+ tool_blocks,
+ key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
+ )
+ sorted_idx = 0
+ new_blocks = []
+ for block in msg["content_blocks"]:
+ if block.get("type") == "tool_result":
+ new_blocks.append(sorted_blocks[sorted_idx])
+ sorted_idx += 1
+ else:
+ new_blocks.append(block)
+ msg["content_blocks"] = new_blocks
+
+ return messages
+
+
+def encode_messages(
+ messages: List[Dict[str, Any]],
+ thinking_mode: str,
+ context: Optional[List[Dict[str, Any]]] = None,
+ drop_thinking: bool = True,
+ add_default_bos_token: bool = True,
+ reasoning_effort: Optional[str] = None,
+) -> str:
+ context = context if context else []
+
+ messages = merge_tool_messages(messages)
+ messages = sort_tool_results_by_call_order(context + messages)[len(context):]
+ if context:
+ context = merge_tool_messages(context)
+ context = sort_tool_results_by_call_order(context)
+
+ full_messages = context + messages
+
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
+
+ effective_drop_thinking = drop_thinking
+ if any(m.get("tools") for m in full_messages):
+ effective_drop_thinking = False
+
+ if thinking_mode == "thinking" and effective_drop_thinking:
+ full_messages = _drop_thinking_messages(full_messages)
+ num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
+ context_len = len(full_messages) - num_to_render
+ else:
+ num_to_render = len(messages)
+ context_len = len(context)
+
+ for idx in range(num_to_render):
+ prompt += render_message(
+ idx + context_len,
+ full_messages,
+ thinking_mode=thinking_mode,
+ drop_thinking=effective_drop_thinking,
+ reasoning_effort=reasoning_effort,
+ )
+
+ return prompt
+
+
+def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ last_user_idx = find_last_user_index(messages)
+ result = []
+ keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
+
+ for idx, msg in enumerate(messages):
+ role = msg.get("role")
+ if role in keep_roles or idx >= last_user_idx:
+ result.append(msg)
+ elif role == "assistant":
+ msg = copy.copy(msg)
+ msg.pop("reasoning", None)
+ result.append(msg)
+
+ return result
+
+
+def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
+ min_pos = len(text)
+ matched_stop = None
+
+ for s in stop:
+ pos = text.find(s, index)
+ if pos != -1 and pos < min_pos:
+ min_pos = pos
+ matched_stop = s
+
+ if matched_stop:
+ content = text[index:min_pos]
+ return min_pos + len(matched_stop), content, matched_stop
+ else:
+ content = text[index:]
+ return len(text), content, None
+
+
+def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
+ tool_calls: List[Dict[str, Any]] = []
+ stop_token = None
+ tool_calls_end_token = f"{dsml_token}{tool_calls_block_name}>"
+
+ while index < len(text):
+ index, content_before, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
+ if content_before != ">\n":
+ raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'")
+
+ if stop_token == tool_calls_end_token:
+ break
+
+ if stop_token is None:
+ raise ValueError("Missing special token in tool calls")
+
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"{dsml_token}invoke"])
+
+ p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
+ if len(p_tool_name) != 1:
+ raise ValueError(f"Tool name format error: '{tool_name_content}'")
+ tool_name = p_tool_name[0]
+
+ tool_args: Dict[str, Tuple[str, str]] = {}
+ while stop_token == f"<{dsml_token}parameter":
+ index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
+
+ param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
+ if len(param_kv) != 1:
+ raise ValueError(f"Parameter format error: '{param_content}'")
+ param_name, string, param_value = param_kv[0]
+
+ if param_name in tool_args:
+ raise ValueError(f"Duplicate parameter name: '{param_name}'")
+ tool_args[param_name] = (param_value, string)
+
+ index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"{dsml_token}invoke"])
+ if content != ">\n":
+ raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
+
+ tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
+ tool_calls.append(tool_call)
+
+ return index, stop_token, tool_calls
+
+
+def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
+ summary_content, reasoning = "", ""
+ tool_calls: List[Dict[str, str]] = []
+ index, stop_token = 0, None
+ tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
+
+ is_thinking = thinking_mode == "thinking"
+ is_tool_calling = False
+
+ if is_thinking:
+ index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
+ reasoning = content_delta
+ if stop_token != thinking_end_token:
+ raise ValueError("Invalid thinking format: missing ")
+
+ index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
+ summary_content = content_delta
+ if stop_token == tool_calls_start_token:
+ is_tool_calling = True
+ else:
+ if stop_token != eos_token:
+ raise ValueError("Invalid format: missing EOS token")
+
+ if is_tool_calling:
+ index, stop_token, tool_calls = parse_tool_calls(index, text)
+
+ index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
+ if tool_ends_text:
+ raise ValueError("Unexpected content after tool calls")
+
+ if len(text) != index or stop_token not in [eos_token, None]:
+ raise ValueError("Unexpected content at end")
+
+ for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
+ if sp_token in summary_content or sp_token in reasoning:
+ raise ValueError(f"Unexpected special token '{sp_token}' in content")
+
+ return {
+ "role": "assistant",
+ "content": summary_content,
+ "reasoning": reasoning,
+ "tool_calls": tool_calls_to_openai_format(tool_calls)
+ }
+
+# fmt: on
From a3d201bd11c7d6fc8f34b3b5a935caf8e904813f Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Wed, 29 Apr 2026 15:14:12 +0800
Subject: [PATCH 02/40] wip
---
cookbook/transformers/deepseek_v4.py | 53 ++++++++++-
src/twinkle/model/base.py | 2 +-
.../model/transformers/moe/ep_utils.py | 51 +++++-----
.../model/transformers/moe/expert_parallel.py | 92 ++++++++++++++++---
.../transformers/strategy/native_fsdp.py | 4 +-
.../model/transformers/transformers.py | 35 +++++++
6 files changed, 194 insertions(+), 43 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index b7f03a13..7edfa5a8 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -3,7 +3,7 @@
import twinkle
from peft import LoraConfig
from transformers import AutoConfig
-from twinkle import DeviceMesh, get_device_placement, get_logger
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
@@ -26,16 +26,47 @@
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50'))
USE_LORA = os.environ.get('USE_LORA', '1') == '1'
IGNORE_MISMATCHED_SIZES = os.environ.get('IGNORE_MISMATCHED_SIZES', '1') == '1'
+GRADIENT_CHECKPOINTING = os.environ.get('GRADIENT_CHECKPOINTING', '1') == '1'
+RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1'
LORA_TARGET_MODULES = os.environ.get(
'LORA_TARGET_MODULES',
'wq_a,wq_b,wkv,wgate,gate_proj,up_proj,down_proj',
)
-device_mesh = DeviceMesh.from_sizes(fsdp_size=2)
+device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=2,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+def log_expert_parallel_status(model):
+ logger.info(
+ f'EP flags: enabled={getattr(model, "_enable_expert_parallel", None)}, '
+ f'applied={getattr(model, "_expert_parallel_applied", None)}')
+ raw_model = model.strategy.unwrap_model(model.model)
+ found = False
+ for name, module in raw_model.named_modules():
+ if not hasattr(module, '_ep_patched'):
+ continue
+ found = True
+ logger.info(
+ 'EP block %s: patched=%s rank=%s/%s local_experts=[%s, %s) experts_per_rank=%s',
+ name,
+ getattr(module, '_ep_patched', None),
+ getattr(module, '_ep_rank', None),
+ getattr(module, '_ep_world_size', None),
+ getattr(module, '_ep_local_start', None),
+ getattr(module, '_ep_local_end', None),
+ getattr(module, '_ep_experts_per_rank', None),
+ )
+ if not found:
+ logger.info('No EP-patched MoE blocks found on the wrapped model.')
+
+
def create_dataset(data_slice=None):
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000)))
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
@@ -67,7 +98,16 @@ def train():
model_id=MODEL_ID,
config=config,
device_mesh=device_mesh,
+ strategy="native_fsdp",
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
+ fsdp_config={
+ 'reshard_after_forward': RESHARD_AFTER_FORWARD,
+ 'expert_parallel': {
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'keep_router_logits': False,
+ }
+ },
)
if USE_LORA:
@@ -75,8 +115,11 @@ def train():
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=lora_target_modules)
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ if not GRADIENT_CHECKPOINTING:
+ model.model.gradient_checkpointing_disable()
+
model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name='default')
- model.set_optimizer('AdamW', lr=LR, adapter_name='default')
+ model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name='default')
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler',
num_warmup_steps=5,
@@ -90,6 +133,8 @@ def train():
f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, '
f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, use_lora={USE_LORA}, '
f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, '
+ f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, '
+ f'reshard_after_forward={RESHARD_AFTER_FORWARD}, '
f'lora_target_modules={LORA_TARGET_MODULES}')
best_loss = float('inf')
@@ -104,6 +149,8 @@ def train():
adapter_name='default',
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
)
+ if step == 0:
+ log_expert_parallel_status(model)
if step % 20 == 0:
metric = model.calculate_metric(is_training=True, adapter_name='default')
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index 4df53e99..e1267eb3 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -143,7 +143,7 @@ def upload_to_hub(self,
HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=hub_token, private=True)
def _should_bind_device_id_for_process_group(self, backend: str) -> bool:
- return backend in ('nccl', 'hccl')
+ return backend =="hccl"
def _try_init_process_group(self):
import torch
diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py
index 94118448..365fa4eb 100644
--- a/src/twinkle/model/transformers/moe/ep_utils.py
+++ b/src/twinkle/model/transformers/moe/ep_utils.py
@@ -96,24 +96,22 @@ def all_to_all_async(group, input, output_split_size, input_split_size):
# ========================== moe_utils ==========================
-def permute(tokens: torch.Tensor, routing_map: torch.Tensor):
+def permute(tokens: torch.Tensor, expert_mask: torch.Tensor):
"""
Permutes the tokens according to the routing map.
Args:
tokens (torch.Tensor): The input token tensor, [num_tokens, hidden_dim].
- routing_map (torch.Tensor): The sparse token to expert mapping, [num_experts, tokens].
+ expert_mask (torch.Tensor): The sparse token to expert mapping, [num_experts, top_k, num_tokens].
"""
num_tokens, _ = tokens.shape
- num_experts = routing_map.shape[0]
-
- # mask [num_tokens, num_experts] -> [num_experts, num_tokens]
- routing_map = routing_map.bool()
+ num_experts = expert_mask.shape[0]
+ expert_mask = expert_mask.bool()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
- token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1)
- sorted_indices = token_indices.masked_select(routing_map)
+ token_indices = torch.arange(num_tokens, device=expert_mask.device).view(1, 1, num_tokens).expand_as(expert_mask)
+ sorted_indices = token_indices.masked_select(expert_mask)
# use the mapping to permute the tokens
permuted_input = tokens.index_select(0, sorted_indices)
@@ -226,6 +224,7 @@ def preprocess(
def token_pre_all2all(
hidden_states: torch.Tensor,
expert_mask: torch.Tensor,
+ routing_weights: torch.Tensor,
num_experts: int,
input_splits: torch.Tensor,
output_splits: torch.Tensor,
@@ -235,9 +234,10 @@ def token_pre_all2all(
hidden_dim = hidden_states.size(-1)
hidden_states = hidden_states.reshape(-1, hidden_dim)
org_hidden_states_shape = hidden_states.shape
- routing_map = expert_mask.sum(dim=1)
+ routing_map = expert_mask.sum(dim=1).bool()
- local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, routing_map)
+ local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, expert_mask)
+ local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool())
global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits)
@@ -250,18 +250,21 @@ def token_pre_all2all(
permute_order,
)
- return global_permuted_hidden_states, routing_map, local_input_permutation_mapping, org_hidden_states_shape
+ return (
+ global_permuted_hidden_states,
+ local_input_permutation_mapping,
+ local_assignment_weights,
+ org_hidden_states_shape,
+ )
def tokens_post_all2all(
expert_outputs: torch.Tensor,
- routing_weights: torch.Tensor,
- selected_experts: torch.Tensor,
+ local_assignment_weights: torch.Tensor,
num_experts: int,
input_splits: torch.Tensor,
output_splits: torch.Tensor,
num_global_tokens_per_local_expert: torch.Tensor,
- routing_map: torch.Tensor,
local_input_permutation_mapping: torch.Tensor,
org_hidden_states_shape: torch.Size,
ep_group: Optional[dist.ProcessGroup] = None,
@@ -276,16 +279,12 @@ def tokens_post_all2all(
)
unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits)
-
- # [tokens, experts]
- weights_idx = generate_weights_idx(routing_weights, selected_experts, num_experts)
-
- unpermute_outputs = unpermute(
- unpermute_outputs,
- weights_idx,
- org_hidden_states_shape,
- local_input_permutation_mapping,
- routing_map,
+ weighted_outputs = unpermute_outputs * local_assignment_weights.unsqueeze(-1)
+ hidden_dim = org_hidden_states_shape[-1]
+ final_outputs = torch.zeros(org_hidden_states_shape, device=weighted_outputs.device, dtype=weighted_outputs.dtype)
+ final_outputs.scatter_add_(
+ 0,
+ local_input_permutation_mapping.unsqueeze(1).expand(-1, hidden_dim),
+ weighted_outputs,
)
-
- return unpermute_outputs
+ return final_outputs
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index c9ee0fa0..dceb0f7a 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -1,6 +1,9 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from __future__ import annotations
+import inspect
+import os
+import time
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -18,6 +21,7 @@ class ExpertParallelConfig:
router_dtype: str = 'fp32'
keep_router_logits: bool = True
ignore_shared_experts: bool = False
+ sync_after_backward: bool = True # consumed by TransformersModel to keep EP/FSDP collectives ordered
ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic
@@ -62,7 +66,8 @@ def apply_expert_parallel(
ep_rank = ep_mesh.get_local_rank()
specs = []
- for block in find_moe_blocks(model):
+ for block_name, block in iter_moe_blocks(model):
+ block._ep_debug_name = block_name
spec = shard_experts(block, ep_world_size, ep_rank, cfg)
patch_forward(block, ep_group, ep_world_size, cfg)
specs.append(spec)
@@ -82,8 +87,12 @@ def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig:
def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
+ return [block for _, block in iter_moe_blocks(model)]
+
+
+def iter_moe_blocks(model: nn.Module) -> Iterable[tuple[str, nn.Module]]:
blocks = []
- for module in model.modules():
+ for name, module in model.named_modules():
experts = getattr(module, 'experts', None)
if experts is None:
continue
@@ -91,7 +100,7 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
continue
if not _get_gate(module):
continue
- blocks.append(module)
+ blocks.append((name, module))
return blocks
@@ -187,6 +196,8 @@ def patch_forward(
raise ValueError('MoE block must define top_k/num_experts_per_tok.')
orig_forward = block.forward
+ return_annotation = inspect.signature(orig_forward).return_annotation
+ returns_router_logits = return_annotation in (tuple, Tuple[torch.Tensor, torch.Tensor | None],)
num_experts = block._ep_num_experts
experts_per_rank = block._ep_experts_per_rank
is_tensor_experts = block._ep_tensor_experts
@@ -198,8 +209,10 @@ def patch_forward(
_install_ep_forward(block.experts, experts_per_rank)
def forward(hidden_states: torch.Tensor, *args, **kwargs):
- if args or kwargs:
- raise RuntimeError('Expert parallel patch only supports forward(hidden_states).')
+ if args:
+ raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.')
+
+ _ep_debug(block, ep_group, 'enter', hidden_states_shape=tuple(hidden_states.shape))
orig_shape = hidden_states.shape
if hidden_states.ndim == 3:
@@ -218,6 +231,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
top_k=top_k,
router_dtype=_get_router_dtype(cfg.router_dtype, hidden_states_2d.dtype),
norm_topk_prob=getattr(block, 'norm_topk_prob', False),
+ **kwargs,
)
# Keep routing weights in activation dtype before unpermute weighting.
if routing_weights.dtype != hidden_states_2d.dtype:
@@ -228,28 +242,52 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens]
# 1. preprocess: compute splits and token counts
+ _ep_debug(
+ block,
+ ep_group,
+ 'before_preprocess',
+ num_tokens=hidden_states_2d.shape[0],
+ selected_shape=tuple(selected_experts.shape),
+ expert_mask_sum=int(expert_mask.sum().item()),
+ )
(
input_splits,
output_splits,
num_global_tokens_per_local_expert,
num_global_sum_tokens_per_local_expert,
) = preprocess(expert_mask, num_experts, ep_group)
+ _ep_debug(
+ block,
+ ep_group,
+ 'after_preprocess',
+ input_splits=input_splits,
+ output_splits=output_splits,
+ local_expert_tokens=num_global_sum_tokens_per_local_expert.tolist(),
+ )
# 2. token_pre_all2all: permute → all_to_all → sort_chunks
+ _ep_debug(block, ep_group, 'before_token_pre_all2all')
(
global_permuted_hidden_states,
- routing_map,
local_input_permutation_mapping,
+ local_assignment_weights,
org_hidden_states_shape,
) = token_pre_all2all(
hidden_states_2d,
expert_mask,
+ routing_weights,
num_experts,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
ep_group,
)
+ _ep_debug(
+ block,
+ ep_group,
+ 'after_token_pre_all2all',
+ permuted_shape=tuple(global_permuted_hidden_states.shape),
+ )
# 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire.
# For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank)
@@ -270,28 +308,38 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
)
# 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight)
+ _ep_debug(block, ep_group, 'before_tokens_post_all2all', expert_outputs_shape=tuple(expert_outputs.shape))
final_hidden = tokens_post_all2all(
expert_outputs,
- routing_weights,
- selected_experts,
+ local_assignment_weights,
num_experts,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
- routing_map,
local_input_permutation_mapping,
org_hidden_states_shape,
ep_group,
)
+ _ep_debug(block, ep_group, 'after_tokens_post_all2all', final_shape=tuple(final_hidden.shape))
+ _ep_debug(block, ep_group, 'before_shared_expert')
+ shared_start = time.perf_counter()
shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg)
if shared_out is not None:
final_hidden = final_hidden + shared_out
+ _ep_debug(
+ block,
+ ep_group,
+ 'after_shared_expert',
+ has_shared=shared_out is not None,
+ elapsed=f'{time.perf_counter() - shared_start:.3f}s',
+ )
if len(orig_shape) == 3:
final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim)
- if cfg.keep_router_logits:
+ _ep_debug(block, ep_group, 'exit', output_shape=tuple(final_hidden.shape))
+ if cfg.keep_router_logits and returns_router_logits:
return final_hidden, router_logits
return final_hidden
@@ -300,6 +348,22 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
block._ep_patched = True
+def _ep_debug(block: nn.Module, ep_group: dist.ProcessGroup, event: str, **kwargs) -> None:
+ if os.environ.get('TWINKLE_EP_DEBUG', '0') != '1':
+ return
+ rank = dist.get_rank() if dist.is_initialized() else -1
+ ep_rank = dist.get_rank(ep_group) if dist.is_initialized() else -1
+ ep_world_size = dist.get_world_size(ep_group) if dist.is_initialized() else -1
+ block_name = getattr(block, '_ep_debug_name', block.__class__.__name__)
+ ts = f'{time.time():.3f}'
+ extras = ' '.join(f'{key}={value}' for key, value in kwargs.items())
+ print(
+ f'[twinkle-ep-debug] ts={ts} rank={rank} ep_rank={ep_rank}/{ep_world_size} '
+ f'block={block_name} event={event} {extras}',
+ flush=True,
+ )
+
+
def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None:
if getattr(experts_mod, '_ep_forward_installed', False):
return
@@ -399,6 +463,8 @@ def _maybe_run_shared_expert(block: nn.Module, hidden_states_2d: torch.Tensor, c
if cfg.ignore_shared_experts:
return None
shared = getattr(block, 'shared_expert', None)
+ if shared is None:
+ shared = getattr(block, 'shared_experts', None)
if shared is None:
return None
return _run_module_with_casting(shared, hidden_states_2d)
@@ -487,8 +553,12 @@ def _run_router(
top_k: int,
router_dtype: torch.dtype,
norm_topk_prob: bool,
+ **kwargs,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- gate_out = gate(hidden_states)
+ gate_kwargs = {}
+ if 'input_ids' in kwargs:
+ gate_kwargs['input_ids'] = kwargs['input_ids']
+ gate_out = gate(hidden_states, **gate_kwargs)
if isinstance(gate_out, tuple) and len(gate_out) >= 3:
router_logits, routing_weights, selected_experts = gate_out[:3]
return router_logits, routing_weights, selected_experts
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index ad675006..9993454b 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -96,7 +96,7 @@ def wrap_model(self, model, optimizer=None):
for layer_mod, experts_mod in layer_pairs:
layer_mod._fsdp_modules = []
- if experts_mod is not None and ep_fsdp_mesh_1d is not None:
+ if experts_mod is not None and ep_fsdp_mesh_1d is not None and ep_fsdp_mesh_1d.size() > 1:
from torch.distributed.tensor import Shard
ep_mp_policy = _build_ep_mp_policy(mp_policy)
@@ -140,7 +140,7 @@ def wrap_model(self, model, optimizer=None):
if hasattr(model, 'tie_weights'):
model.tie_weights()
- if ep_enabled and layer_pairs:
+ if ep_enabled and layer_pairs and self.fsdp_config.get('manual_prefetch', False):
_setup_manual_prefetch([lp[0] for lp in layer_pairs])
if ep_enabled:
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 56cbec17..5de0beb3 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -7,6 +7,7 @@
import random
import re
import threading
+import time
import torch
import torch.distributed as dist
import transformers
@@ -46,6 +47,15 @@
logger = get_logger()
+def _ep_debug_event(event: str, **kwargs) -> None:
+ if os.environ.get('TWINKLE_EP_DEBUG', '0') != '1':
+ return
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else -1
+ ts = f'{time.time():.3f}'
+ extras = ' '.join(f'{key}={value}' for key, value in kwargs.items())
+ print(f'[twinkle-ep-debug] ts={ts} rank={rank} model_event={event} {extras}', flush=True)
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -539,6 +549,7 @@ def backward(self, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
loss_value = optimizer_config.train_status.loss_value
assert loss_value is not None, 'Do forwarding and calculating loss before backward'
+ _ep_debug_event('backward_enter', adapter_name=adapter_name)
scaler = optimizer_config.scaler
if scaler is None and self.mixed_precision == 'fp16':
# Auto set a grad scaler
@@ -559,6 +570,9 @@ def backward(self, **kwargs):
else:
loss_value.backward()
+ _ep_debug_event('backward_after_loss_backward', adapter_name=adapter_name)
+ self._sync_after_backward_if_needed()
+ _ep_debug_event('backward_exit', adapter_name=adapter_name)
optimizer_config.train_status.loss_value = None
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
@@ -576,11 +590,24 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
The output of the model forward.
"""
outputs = self.forward(inputs=inputs, **kwargs)
+ _ep_debug_event('forward_backward_after_forward')
loss = self.calculate_loss(**kwargs)
+ _ep_debug_event('forward_backward_after_calculate_loss', loss=loss)
outputs['loss'] = loss
self.backward(**kwargs)
+ _ep_debug_event('forward_backward_exit')
return outputs
+ def _sync_after_backward_if_needed(self) -> None:
+ if not getattr(self, '_enable_expert_parallel', False):
+ return
+ expert_parallel_config = getattr(self, '_expert_parallel_config', None) or {}
+ if not expert_parallel_config.get('sync_after_backward', True):
+ return
+ torch_util.synchronize()
+ if dist.is_available() and dist.is_initialized():
+ dist.barrier()
+
@remote_function()
def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
""" Clip the gradient norm
@@ -595,7 +622,9 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
"""
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
optimizer_config = self.optimizer_group[adapter_name]
+ _ep_debug_event('clip_grad_norm_enter', adapter_name=adapter_name)
if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')):
+ _ep_debug_event('clip_grad_norm_skip_no_sync', adapter_name=adapter_name)
return
optimizer = optimizer_config.optimizer
@@ -628,14 +657,20 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
)
optimizer_config._last_grad_norm = grad_norm
optimizer_config.train_status.num_tokens = 0
+ _ep_debug_event('clip_grad_norm_exit', adapter_name=adapter_name, grad_norm=grad_norm)
return grad_norm
@remote_function(dispatch='all')
def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
+ _ep_debug_event('clip_grad_and_step_enter')
self.clip_grad_norm(max_grad_norm, norm_type, **kwargs)
+ _ep_debug_event('clip_grad_and_step_after_clip')
self.step(**kwargs)
+ _ep_debug_event('clip_grad_and_step_after_step')
self.zero_grad(**kwargs)
+ _ep_debug_event('clip_grad_and_step_after_zero_grad')
self.lr_step(**kwargs)
+ _ep_debug_event('clip_grad_and_step_exit')
def _create_param_group(self,
adapter_name: str,
From 6b3df6f27e1effb4b90eb96dc3096266f7c8087c Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Wed, 29 Apr 2026 16:30:08 +0800
Subject: [PATCH 03/40] wip
Co-authored-by: Copilot
---
cookbook/transformers/deepseek_v4.py | 10 +++++-----
src/twinkle/model/transformers/moe/expert_parallel.py | 10 +++++++++-
2 files changed, 14 insertions(+), 6 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 7edfa5a8..fb553131 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -16,10 +16,10 @@
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
-_num_layers_env = os.environ.get('NUM_LAYERS')
+_num_layers_env = os.environ.get('NUM_LAYERS','4')
NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '2'))
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2'))
LR = float(os.environ.get('LR', '1e-4'))
MAX_STEPS = int(os.environ.get('MAX_STEPS', '0'))
@@ -34,9 +34,9 @@
)
device_mesh = DeviceMesh.from_sizes(
- fsdp_size=2,
+ fsdp_size=4,
dp_size=1,
- ep_size=2,
+ ep_size=4,
device_type=Platform.get_platform().device_prefix(),
)
@@ -77,7 +77,7 @@ def create_dataset(data_slice=None):
def eval(model):
dataset = create_dataset(data_slice=range(100))
- dataloader = DataLoader(dataset=dataset, batch_size=max(1, BATCH_SIZE // 2))
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
for _, batch in enumerate(dataloader):
model.forward_only(inputs=batch, adapter_name='default')
model.calculate_loss(adapter_name='default')
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index dceb0f7a..d57f63f0 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -556,7 +556,7 @@ def _run_router(
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
gate_kwargs = {}
- if 'input_ids' in kwargs:
+ if 'input_ids' in kwargs and _module_forward_accepts_kwarg(gate, 'input_ids'):
gate_kwargs['input_ids'] = kwargs['input_ids']
gate_out = gate(hidden_states, **gate_kwargs)
if isinstance(gate_out, tuple) and len(gate_out) >= 3:
@@ -569,3 +569,11 @@ def _run_router(
if norm_topk_prob:
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True)
return router_logits, routing_weights, selected_experts
+
+
+def _module_forward_accepts_kwarg(module: nn.Module, kwarg: str) -> bool:
+ signature = inspect.signature(module.forward)
+ for param in signature.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
+ return True
+ return kwarg in signature.parameters
From 229d2e0bb3410ef1b0875dc9384374dde547f3e0 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Thu, 7 May 2026 23:55:12 +0800
Subject: [PATCH 04/40] refactor: remove debug logging from expert parallel
implementation
Remove the `_ep_debug` function and all its calls throughout the expert parallel module to clean up production code. The debug logging was gated behind the `TWINKLE_EP_DEBUG` environment variable but added unnecessary complexity and performance overhead in normal operation. Also remove unused imports (`os`, `time`) and the `_ep_debug_name` attribute assignment on blocks.
---
cookbook/transformers/deepseek_v4.py | 2 +-
src/twinkle/model/base.py | 2 +-
.../model/transformers/moe/ep_utils.py | 2 -
.../model/transformers/moe/expert_parallel.py | 71 +---
.../model/transformers/transformers.py | 24 --
src/twinkle/template/deepseek_v4.py | 10 +-
src/twinkle/template/deepseek_v4_encoding.py | 305 +++++++++---------
7 files changed, 163 insertions(+), 253 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index fb553131..6e6f2075 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -98,7 +98,7 @@ def train():
model_id=MODEL_ID,
config=config,
device_mesh=device_mesh,
- strategy="native_fsdp",
+ strategy='native_fsdp',
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
fsdp_config={
'reshard_after_forward': RESHARD_AFTER_FORWARD,
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index e1267eb3..e9c980b9 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -143,7 +143,7 @@ def upload_to_hub(self,
HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=hub_token, private=True)
def _should_bind_device_id_for_process_group(self, backend: str) -> bool:
- return backend =="hccl"
+ return backend == 'hccl'
def _try_init_process_group(self):
import torch
diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py
index 365fa4eb..f64c1796 100644
--- a/src/twinkle/model/transformers/moe/ep_utils.py
+++ b/src/twinkle/model/transformers/moe/ep_utils.py
@@ -106,7 +106,6 @@ def permute(tokens: torch.Tensor, expert_mask: torch.Tensor):
"""
num_tokens, _ = tokens.shape
- num_experts = expert_mask.shape[0]
expert_mask = expert_mask.bool()
# Create a dense expert-to-token mapping from the sparse token-to-expert mapping
@@ -234,7 +233,6 @@ def token_pre_all2all(
hidden_dim = hidden_states.size(-1)
hidden_states = hidden_states.reshape(-1, hidden_dim)
org_hidden_states_shape = hidden_states.shape
- routing_map = expert_mask.sum(dim=1).bool()
local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, expert_mask)
local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool())
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index d57f63f0..79c256ff 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -2,8 +2,6 @@
from __future__ import annotations
import inspect
-import os
-import time
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -66,8 +64,7 @@ def apply_expert_parallel(
ep_rank = ep_mesh.get_local_rank()
specs = []
- for block_name, block in iter_moe_blocks(model):
- block._ep_debug_name = block_name
+ for block in find_moe_blocks(model):
spec = shard_experts(block, ep_world_size, ep_rank, cfg)
patch_forward(block, ep_group, ep_world_size, cfg)
specs.append(spec)
@@ -87,12 +84,8 @@ def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig:
def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
- return [block for _, block in iter_moe_blocks(model)]
-
-
-def iter_moe_blocks(model: nn.Module) -> Iterable[tuple[str, nn.Module]]:
blocks = []
- for name, module in model.named_modules():
+ for module in model.modules():
experts = getattr(module, 'experts', None)
if experts is None:
continue
@@ -100,7 +93,7 @@ def iter_moe_blocks(model: nn.Module) -> Iterable[tuple[str, nn.Module]]:
continue
if not _get_gate(module):
continue
- blocks.append((name, module))
+ blocks.append(module)
return blocks
@@ -197,7 +190,10 @@ def patch_forward(
orig_forward = block.forward
return_annotation = inspect.signature(orig_forward).return_annotation
- returns_router_logits = return_annotation in (tuple, Tuple[torch.Tensor, torch.Tensor | None],)
+ returns_router_logits = return_annotation in (
+ tuple,
+ Tuple[torch.Tensor, torch.Tensor | None],
+ )
num_experts = block._ep_num_experts
experts_per_rank = block._ep_experts_per_rank
is_tensor_experts = block._ep_tensor_experts
@@ -212,8 +208,6 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
if args:
raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.')
- _ep_debug(block, ep_group, 'enter', hidden_states_shape=tuple(hidden_states.shape))
-
orig_shape = hidden_states.shape
if hidden_states.ndim == 3:
batch_size, seq_len, hidden_dim = hidden_states.shape
@@ -242,31 +236,14 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens]
# 1. preprocess: compute splits and token counts
- _ep_debug(
- block,
- ep_group,
- 'before_preprocess',
- num_tokens=hidden_states_2d.shape[0],
- selected_shape=tuple(selected_experts.shape),
- expert_mask_sum=int(expert_mask.sum().item()),
- )
(
input_splits,
output_splits,
num_global_tokens_per_local_expert,
num_global_sum_tokens_per_local_expert,
) = preprocess(expert_mask, num_experts, ep_group)
- _ep_debug(
- block,
- ep_group,
- 'after_preprocess',
- input_splits=input_splits,
- output_splits=output_splits,
- local_expert_tokens=num_global_sum_tokens_per_local_expert.tolist(),
- )
# 2. token_pre_all2all: permute → all_to_all → sort_chunks
- _ep_debug(block, ep_group, 'before_token_pre_all2all')
(
global_permuted_hidden_states,
local_input_permutation_mapping,
@@ -282,12 +259,6 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
num_global_tokens_per_local_expert,
ep_group,
)
- _ep_debug(
- block,
- ep_group,
- 'after_token_pre_all2all',
- permuted_shape=tuple(global_permuted_hidden_states.shape),
- )
# 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire.
# For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank)
@@ -308,7 +279,6 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
)
# 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight)
- _ep_debug(block, ep_group, 'before_tokens_post_all2all', expert_outputs_shape=tuple(expert_outputs.shape))
final_hidden = tokens_post_all2all(
expert_outputs,
local_assignment_weights,
@@ -320,25 +290,14 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
org_hidden_states_shape,
ep_group,
)
- _ep_debug(block, ep_group, 'after_tokens_post_all2all', final_shape=tuple(final_hidden.shape))
- _ep_debug(block, ep_group, 'before_shared_expert')
- shared_start = time.perf_counter()
shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg)
if shared_out is not None:
final_hidden = final_hidden + shared_out
- _ep_debug(
- block,
- ep_group,
- 'after_shared_expert',
- has_shared=shared_out is not None,
- elapsed=f'{time.perf_counter() - shared_start:.3f}s',
- )
if len(orig_shape) == 3:
final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim)
- _ep_debug(block, ep_group, 'exit', output_shape=tuple(final_hidden.shape))
if cfg.keep_router_logits and returns_router_logits:
return final_hidden, router_logits
return final_hidden
@@ -348,22 +307,6 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
block._ep_patched = True
-def _ep_debug(block: nn.Module, ep_group: dist.ProcessGroup, event: str, **kwargs) -> None:
- if os.environ.get('TWINKLE_EP_DEBUG', '0') != '1':
- return
- rank = dist.get_rank() if dist.is_initialized() else -1
- ep_rank = dist.get_rank(ep_group) if dist.is_initialized() else -1
- ep_world_size = dist.get_world_size(ep_group) if dist.is_initialized() else -1
- block_name = getattr(block, '_ep_debug_name', block.__class__.__name__)
- ts = f'{time.time():.3f}'
- extras = ' '.join(f'{key}={value}' for key, value in kwargs.items())
- print(
- f'[twinkle-ep-debug] ts={ts} rank={rank} ep_rank={ep_rank}/{ep_world_size} '
- f'block={block_name} event={event} {extras}',
- flush=True,
- )
-
-
def _install_ep_forward(experts_mod: nn.Module, experts_per_rank: int) -> None:
if getattr(experts_mod, '_ep_forward_installed', False):
return
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 5de0beb3..108e9bef 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -7,7 +7,6 @@
import random
import re
import threading
-import time
import torch
import torch.distributed as dist
import transformers
@@ -47,15 +46,6 @@
logger = get_logger()
-def _ep_debug_event(event: str, **kwargs) -> None:
- if os.environ.get('TWINKLE_EP_DEBUG', '0') != '1':
- return
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else -1
- ts = f'{time.time():.3f}'
- extras = ' '.join(f'{key}={value}' for key, value in kwargs.items())
- print(f'[twinkle-ep-debug] ts={ts} rank={rank} model_event={event} {extras}', flush=True)
-
-
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -549,7 +539,6 @@ def backward(self, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
loss_value = optimizer_config.train_status.loss_value
assert loss_value is not None, 'Do forwarding and calculating loss before backward'
- _ep_debug_event('backward_enter', adapter_name=adapter_name)
scaler = optimizer_config.scaler
if scaler is None and self.mixed_precision == 'fp16':
# Auto set a grad scaler
@@ -570,9 +559,7 @@ def backward(self, **kwargs):
else:
loss_value.backward()
- _ep_debug_event('backward_after_loss_backward', adapter_name=adapter_name)
self._sync_after_backward_if_needed()
- _ep_debug_event('backward_exit', adapter_name=adapter_name)
optimizer_config.train_status.loss_value = None
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
@@ -590,12 +577,9 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
The output of the model forward.
"""
outputs = self.forward(inputs=inputs, **kwargs)
- _ep_debug_event('forward_backward_after_forward')
loss = self.calculate_loss(**kwargs)
- _ep_debug_event('forward_backward_after_calculate_loss', loss=loss)
outputs['loss'] = loss
self.backward(**kwargs)
- _ep_debug_event('forward_backward_exit')
return outputs
def _sync_after_backward_if_needed(self) -> None:
@@ -622,9 +606,7 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
"""
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
optimizer_config = self.optimizer_group[adapter_name]
- _ep_debug_event('clip_grad_norm_enter', adapter_name=adapter_name)
if not optimizer_config.do_grad_sync(kwargs.get('gradient_accumulation_steps')):
- _ep_debug_event('clip_grad_norm_skip_no_sync', adapter_name=adapter_name)
return
optimizer = optimizer_config.optimizer
@@ -657,20 +639,14 @@ def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
)
optimizer_config._last_grad_norm = grad_norm
optimizer_config.train_status.num_tokens = 0
- _ep_debug_event('clip_grad_norm_exit', adapter_name=adapter_name, grad_norm=grad_norm)
return grad_norm
@remote_function(dispatch='all')
def clip_grad_and_step(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
- _ep_debug_event('clip_grad_and_step_enter')
self.clip_grad_norm(max_grad_norm, norm_type, **kwargs)
- _ep_debug_event('clip_grad_and_step_after_clip')
self.step(**kwargs)
- _ep_debug_event('clip_grad_and_step_after_step')
self.zero_grad(**kwargs)
- _ep_debug_event('clip_grad_and_step_after_zero_grad')
self.lr_step(**kwargs)
- _ep_debug_event('clip_grad_and_step_exit')
def _create_param_group(self,
adapter_name: str,
diff --git a/src/twinkle/template/deepseek_v4.py b/src/twinkle/template/deepseek_v4.py
index 7b33dd8a..3d8a6de0 100644
--- a/src/twinkle/template/deepseek_v4.py
+++ b/src/twinkle/template/deepseek_v4.py
@@ -1,12 +1,10 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import copy
-from typing import Any, Literal, Optional
-
import torch
from transformers import AutoConfig, PreTrainedTokenizerFast
+from typing import Any, Literal, Optional
from twinkle.hub import HubOperation
-
from .base import Template
from .deepseek_v4_encoding import encode_messages
@@ -56,11 +54,7 @@ def apply_chat_template(
if not tokenize:
return {'prompt': prompt_str} if return_dict else prompt_str
- tokenizer_kwargs = {
- key: kwargs[key]
- for key in ('truncation', 'max_length')
- if key in kwargs
- }
+ tokenizer_kwargs = {key: kwargs[key] for key in ('truncation', 'max_length') if key in kwargs}
input_ids = self.encode(
prompt_str,
add_special_tokens=False,
diff --git a/src/twinkle/template/deepseek_v4_encoding.py b/src/twinkle/template/deepseek_v4_encoding.py
index 5f2b06ae..2fa4cf8f 100644
--- a/src/twinkle/template/deepseek_v4_encoding.py
+++ b/src/twinkle/template/deepseek_v4_encoding.py
@@ -9,58 +9,57 @@
with tool calling, thinking mode, and quick instruction task support.
"""
-from typing import Any, Dict, List, Optional, Tuple, Union
import copy
import json
-
import regex as re
+from typing import Any, Dict, List, Optional, Tuple, Union
-bos_token: str = "<|begin▁of▁sentence|>"
-eos_token: str = "<|end▁of▁sentence|>"
-thinking_start_token: str = ""
-thinking_end_token: str = ""
-dsml_token: str = "|DSML|"
+bos_token: str = '<|begin▁of▁sentence|>'
+eos_token: str = '<|end▁of▁sentence|>'
+thinking_start_token: str = ''
+thinking_end_token: str = ''
+dsml_token: str = '|DSML|'
-USER_SP_TOKEN = "<|User|>"
-ASSISTANT_SP_TOKEN = "<|Assistant|>"
-LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
+USER_SP_TOKEN = '<|User|>'
+ASSISTANT_SP_TOKEN = '<|Assistant|>'
+LATEST_REMINDER_SP_TOKEN = '<|latest_reminder|>'
DS_TASK_SP_TOKENS = {
- "action": "<|action|>",
- "query": "<|query|>",
- "authority": "<|authority|>",
- "domain": "<|domain|>",
- "title": "<|title|>",
- "read_url": "<|read_url|>",
+ 'action': '<|action|>',
+ 'query': '<|query|>',
+ 'authority': '<|authority|>',
+ 'domain': '<|domain|>',
+ 'title': '<|title|>',
+ 'read_url': '<|read_url|>',
}
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
-system_msg_template: str = "{content}"
-user_msg_template: str = "{content}"
-latest_reminder_msg_template: str = "{content}"
-assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
-assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
-thinking_template: str = "{reasoning}"
+system_msg_template: str = '{content}'
+user_msg_template: str = '{content}'
+latest_reminder_msg_template: str = '{content}'
+assistant_msg_template: str = '{reasoning}{content}{tool_calls}' + eos_token
+assistant_msg_wo_eos_template: str = '{reasoning}{content}{tool_calls}'
+thinking_template: str = '{reasoning}'
response_format_template: str = (
- "## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
+ '## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}'
)
tool_call_template: str = (
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n{dsml_token}invoke>"
)
tool_calls_template = (
- "<{dsml_token}{tc_block_name}>\n{tool_calls}\n{dsml_token}{tc_block_name}>"
+ '<{dsml_token}{tc_block_name}>\n{tool_calls}\n{dsml_token}{tc_block_name}>'
)
-tool_calls_block_name: str = "tool_calls"
+tool_calls_block_name: str = 'tool_calls'
tool_output_template: str = (
- "{content}"
+ '{content}'
)
REASONING_EFFORT_MAX = (
- "Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
- "You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
- "Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
+ 'Reasoning Effort: Absolute maximum with no shortcuts permitted.\n'
+ 'You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n'
+ 'Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n'
)
TOOLS_TEMPLATE = """## Tools
@@ -99,14 +98,14 @@ def to_json(value: Any) -> str:
def tools_from_openai_format(tools):
- return [tool["function"] for tool in tools]
+ return [tool['function'] for tool in tools]
def tool_calls_from_openai_format(tool_calls):
return [
{
- "name": tool_call["function"]["name"],
- "arguments": tool_call["function"]["arguments"],
+ 'name': tool_call['function']['name'],
+ 'arguments': tool_call['function']['arguments'],
}
for tool_call in tool_calls
]
@@ -115,10 +114,10 @@ def tool_calls_from_openai_format(tool_calls):
def tool_calls_to_openai_format(tool_calls):
return [
{
- "type": "function",
- "function": {
- "name": tool_call["name"],
- "arguments": tool_call["arguments"],
+ 'type': 'function',
+ 'function': {
+ 'name': tool_call['name'],
+ 'arguments': tool_call['arguments'],
}
}
for tool_call in tool_calls
@@ -129,30 +128,30 @@ def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}{dsml_token}parameter>'
p_dsml_strs = []
- if isinstance(tool_call["arguments"], str):
- arguments = json.loads(tool_call["arguments"])
+ if isinstance(tool_call['arguments'], str):
+ arguments = json.loads(tool_call['arguments'])
else:
- arguments = tool_call["arguments"]
+ arguments = tool_call['arguments']
for k, v in arguments.items():
p_dsml_str = p_dsml_template.format(
dsml_token=dsml_token,
key=k,
- is_str="true" if isinstance(v, str) else "false",
+ is_str='true' if isinstance(v, str) else 'false',
value=v if isinstance(v, str) else to_json(v),
)
p_dsml_strs.append(p_dsml_str)
- return "\n".join(p_dsml_strs)
+ return '\n'.join(p_dsml_strs)
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
def _decode_value(key: str, value: str, string: str):
- if string == "true":
+ if string == 'true':
value = to_json(value)
- return f"{to_json(key)}: {value}"
+ return f'{to_json(key)}: {value}'
- tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
+ tool_args_json = '{' + ', '.join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + '}'
return dict(name=tool_name, arguments=tool_args_json)
@@ -160,7 +159,7 @@ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
tools_json = [to_json(t) for t in tools]
return TOOLS_TEMPLATE.format(
- tool_schemas="\n".join(tools_json),
+ tool_schemas='\n'.join(tools_json),
dsml_token=dsml_token,
thinking_start_token=thinking_start_token,
thinking_end_token=thinking_end_token,
@@ -170,7 +169,7 @@ def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
last_user_index = -1
for idx in range(len(messages) - 1, -1, -1):
- if messages[idx].get("role") in ["user", "developer"]:
+ if messages[idx].get('role') in ['user', 'developer']:
last_user_index = idx
break
return last_user_index
@@ -178,111 +177,111 @@ def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
assert 0 <= index < len(messages)
- assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
+ assert thinking_mode in ['chat', 'thinking'], f'Invalid thinking_mode `{thinking_mode}`'
- prompt = ""
+ prompt = ''
msg = messages[index]
last_user_idx = find_last_user_index(messages)
- role = msg.get("role")
- content = msg.get("content")
- tools = msg.get("tools")
- response_format = msg.get("response_format")
- tool_calls = msg.get("tool_calls")
- reasoning = msg.get("reasoning")
- wo_eos = msg.get("wo_eos", False)
+ role = msg.get('role')
+ content = msg.get('content')
+ tools = msg.get('tools')
+ response_format = msg.get('response_format')
+ tool_calls = msg.get('tool_calls')
+ reasoning = msg.get('reasoning')
+ wo_eos = msg.get('wo_eos', False)
if tools:
tools = tools_from_openai_format(tools)
if tool_calls:
tool_calls = tool_calls_from_openai_format(tool_calls)
- assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
- if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
+ assert reasoning_effort in ['max', None, 'high'], f'Invalid reasoning effort: {reasoning_effort}'
+ if index == 0 and thinking_mode == 'thinking' and reasoning_effort == 'max':
prompt += REASONING_EFFORT_MAX
- if role == "system":
- prompt += system_msg_template.format(content=content or "")
+ if role == 'system':
+ prompt += system_msg_template.format(content=content or '')
if tools:
- prompt += "\n\n" + render_tools(tools)
+ prompt += '\n\n' + render_tools(tools)
if response_format:
- prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
+ prompt += '\n\n' + response_format_template.format(schema=to_json(response_format))
- elif role == "developer":
- assert content, f"Invalid message for role `{role}`: {msg}"
+ elif role == 'developer':
+ assert content, f'Invalid message for role `{role}`: {msg}'
content_developer = USER_SP_TOKEN
content_developer += content
if tools:
- content_developer += "\n\n" + render_tools(tools)
+ content_developer += '\n\n' + render_tools(tools)
if response_format:
- content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
+ content_developer += '\n\n' + response_format_template.format(schema=to_json(response_format))
prompt += user_msg_template.format(content=content_developer)
- elif role == "user":
+ elif role == 'user':
prompt += USER_SP_TOKEN
- content_blocks = msg.get("content_blocks")
+ content_blocks = msg.get('content_blocks')
if content_blocks:
parts = []
for block in content_blocks:
- block_type = block.get("type")
- if block_type == "text":
- parts.append(block.get("text", ""))
- elif block_type == "tool_result":
- tool_content = block.get("content", "")
+ block_type = block.get('type')
+ if block_type == 'text':
+ parts.append(block.get('text', ''))
+ elif block_type == 'tool_result':
+ tool_content = block.get('content', '')
if isinstance(tool_content, list):
text_parts = []
for b in tool_content:
- if b.get("type") == "text":
- text_parts.append(b.get("text", ""))
+ if b.get('type') == 'text':
+ text_parts.append(b.get('text', ''))
else:
text_parts.append(f"[Unsupported {b.get('type')}]")
- tool_content = "\n\n".join(text_parts)
+ tool_content = '\n\n'.join(text_parts)
parts.append(tool_output_template.format(content=tool_content))
else:
- parts.append(f"[Unsupported {block_type}]")
- prompt += "\n\n".join(parts)
+ parts.append(f'[Unsupported {block_type}]')
+ prompt += '\n\n'.join(parts)
else:
- prompt += content or ""
+ prompt += content or ''
- elif role == "latest_reminder":
+ elif role == 'latest_reminder':
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
- elif role == "tool":
- raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
+ elif role == 'tool':
+ raise NotImplementedError('deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()')
- elif role == "assistant":
- thinking_part = ""
- tc_content = ""
+ elif role == 'assistant':
+ thinking_part = ''
+ tc_content = ''
if tool_calls:
tc_list = [
tool_call_template.format(
dsml_token=dsml_token,
- name=tc.get("name"),
+ name=tc.get('name'),
arguments=encode_arguments_to_dsml(tc)
)
for tc in tool_calls
]
tc_content += '\n\n' + tool_calls_template.format(
dsml_token=dsml_token,
- tool_calls="\n".join(tc_list),
+ tool_calls='\n'.join(tc_list),
tc_block_name=tool_calls_block_name,
)
- summary_content = content or ""
- reasoning = reasoning or ""
+ summary_content = content or ''
+ reasoning = reasoning or ''
- prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
+ prev_has_task = index - 1 >= 0 and messages[index - 1].get('task') is not None
- if thinking_mode == "thinking" and not prev_has_task:
+ if thinking_mode == 'thinking' and not prev_has_task:
if not drop_thinking or index > last_user_idx:
thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token
else:
- thinking_part = ""
+ thinking_part = ''
if wo_eos:
prompt += assistant_msg_wo_eos_template.format(
@@ -297,28 +296,28 @@ def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: st
tool_calls=tc_content,
)
else:
- raise NotImplementedError(f"Unknown role: {role}")
+ raise NotImplementedError(f'Unknown role: {role}')
- if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
+ if index + 1 < len(messages) and messages[index + 1].get('role') not in ['assistant', 'latest_reminder']:
return prompt
- task = messages[index].get("task")
+ task = messages[index].get('task')
if task is not None:
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
task_sp_token = DS_TASK_SP_TOKENS[task]
- if task != "action":
+ if task != 'action':
prompt += task_sp_token
else:
prompt += ASSISTANT_SP_TOKEN
- prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
+ prompt += thinking_end_token if thinking_mode != 'thinking' else thinking_start_token
prompt += task_sp_token
- elif messages[index].get("role") in ["user", "developer"]:
+ elif messages[index].get('role') in ['user', 'developer']:
prompt += ASSISTANT_SP_TOKEN
- if not drop_thinking and thinking_mode == "thinking":
+ if not drop_thinking and thinking_mode == 'thinking':
prompt += thinking_start_token
- elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
+ elif drop_thinking and thinking_mode == 'thinking' and index >= last_user_idx:
prompt += thinking_start_token
else:
prompt += thinking_end_token
@@ -331,32 +330,32 @@ def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
for msg in messages:
msg = copy.deepcopy(msg)
- role = msg.get("role")
+ role = msg.get('role')
- if role == "tool":
+ if role == 'tool':
tool_block = {
- "type": "tool_result",
- "tool_use_id": msg.get("tool_call_id", ""),
- "content": msg.get("content", ""),
+ 'type': 'tool_result',
+ 'tool_use_id': msg.get('tool_call_id', ''),
+ 'content': msg.get('content', ''),
}
- if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
- merged[-1]["content_blocks"].append(tool_block)
+ if merged and merged[-1].get('role') == 'user' and 'content_blocks' in merged[-1]:
+ merged[-1]['content_blocks'].append(tool_block)
else:
merged.append({
- "role": "user",
- "content_blocks": [tool_block],
+ 'role': 'user',
+ 'content_blocks': [tool_block],
})
- elif role == "user":
- text_block = {"type": "text", "text": msg.get("content", "")}
- if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
- merged[-1]["content_blocks"].append(text_block)
+ elif role == 'user':
+ text_block = {'type': 'text', 'text': msg.get('content', '')}
+ if merged and merged[-1].get('role') == 'user' and 'content_blocks' in merged[-1] and merged[-1].get('task') is None:
+ merged[-1]['content_blocks'].append(text_block)
else:
new_msg = {
- "role": "user",
- "content": msg.get("content", ""),
- "content_blocks": [text_block],
+ 'role': 'user',
+ 'content': msg.get('content', ''),
+ 'content_blocks': [text_block],
}
- for key in ("task", "wo_eos", "mask"):
+ for key in ('task', 'wo_eos', 'mask'):
if key in msg:
new_msg[key] = msg[key]
merged.append(new_msg)
@@ -370,30 +369,30 @@ def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict
last_tool_call_order: Dict[str, int] = {}
for msg in messages:
- role = msg.get("role")
- if role == "assistant" and msg.get("tool_calls"):
+ role = msg.get('role')
+ if role == 'assistant' and msg.get('tool_calls'):
last_tool_call_order = {}
- for idx, tc in enumerate(msg["tool_calls"]):
- tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
+ for idx, tc in enumerate(msg['tool_calls']):
+ tc_id = tc.get('id') or tc.get('function', {}).get('id', '')
if tc_id:
last_tool_call_order[tc_id] = idx
- elif role == "user" and msg.get("content_blocks"):
- tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
+ elif role == 'user' and msg.get('content_blocks'):
+ tool_blocks = [b for b in msg['content_blocks'] if b.get('type') == 'tool_result']
if len(tool_blocks) > 1 and last_tool_call_order:
sorted_blocks = sorted(
tool_blocks,
- key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
+ key=lambda b: last_tool_call_order.get(b.get('tool_use_id', ''), 0)
)
sorted_idx = 0
new_blocks = []
- for block in msg["content_blocks"]:
- if block.get("type") == "tool_result":
+ for block in msg['content_blocks']:
+ if block.get('type') == 'tool_result':
new_blocks.append(sorted_blocks[sorted_idx])
sorted_idx += 1
else:
new_blocks.append(block)
- msg["content_blocks"] = new_blocks
+ msg['content_blocks'] = new_blocks
return messages
@@ -416,13 +415,13 @@ def encode_messages(
full_messages = context + messages
- prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
+ prompt = bos_token if add_default_bos_token and len(context) == 0 else ''
effective_drop_thinking = drop_thinking
- if any(m.get("tools") for m in full_messages):
+ if any(m.get('tools') for m in full_messages):
effective_drop_thinking = False
- if thinking_mode == "thinking" and effective_drop_thinking:
+ if thinking_mode == 'thinking' and effective_drop_thinking:
full_messages = _drop_thinking_messages(full_messages)
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
context_len = len(full_messages) - num_to_render
@@ -445,15 +444,15 @@ def encode_messages(
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
last_user_idx = find_last_user_index(messages)
result = []
- keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
+ keep_roles = {'user', 'system', 'tool', 'latest_reminder', 'direct_search_results'}
for idx, msg in enumerate(messages):
- role = msg.get("role")
+ role = msg.get('role')
if role in keep_roles or idx >= last_user_idx:
result.append(msg)
- elif role == "assistant":
+ elif role == 'assistant':
msg = copy.copy(msg)
- msg.pop("reasoning", None)
+ msg.pop('reasoning', None)
result.append(msg)
return result
@@ -480,20 +479,20 @@ def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str,
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
tool_calls: List[Dict[str, Any]] = []
stop_token = None
- tool_calls_end_token = f"{dsml_token}{tool_calls_block_name}>"
+ tool_calls_end_token = f'{dsml_token}{tool_calls_block_name}>'
while index < len(text):
- index, content_before, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
- if content_before != ">\n":
+ index, content_before, stop_token = _read_until_stop(index, text, [f'<{dsml_token}invoke', tool_calls_end_token])
+ if content_before != '>\n':
raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'")
if stop_token == tool_calls_end_token:
break
if stop_token is None:
- raise ValueError("Missing special token in tool calls")
+ raise ValueError('Missing special token in tool calls')
- index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"{dsml_token}invoke"])
+ index, tool_name_content, stop_token = _read_until_stop(index, text, [f'<{dsml_token}parameter', f'{dsml_token}invoke'])
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
if len(p_tool_name) != 1:
@@ -501,8 +500,8 @@ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Di
tool_name = p_tool_name[0]
tool_args: Dict[str, Tuple[str, str]] = {}
- while stop_token == f"<{dsml_token}parameter":
- index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
+ while stop_token == f'<{dsml_token}parameter':
+ index, param_content, stop_token = _read_until_stop(index, text, [f'/{dsml_token}parameter'])
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
if len(param_kv) != 1:
@@ -513,8 +512,8 @@ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Di
raise ValueError(f"Duplicate parameter name: '{param_name}'")
tool_args[param_name] = (param_value, string)
- index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"{dsml_token}invoke"])
- if content != ">\n":
+ index, content, stop_token = _read_until_stop(index, text, [f'<{dsml_token}parameter', f'{dsml_token}invoke'])
+ if content != '>\n':
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
@@ -524,19 +523,19 @@ def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Di
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
- summary_content, reasoning = "", ""
+ summary_content, reasoning = '', ''
tool_calls: List[Dict[str, str]] = []
index, stop_token = 0, None
- tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
+ tool_calls_start_token = f'\n\n<{dsml_token}{tool_calls_block_name}'
- is_thinking = thinking_mode == "thinking"
+ is_thinking = thinking_mode == 'thinking'
is_tool_calling = False
if is_thinking:
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
reasoning = content_delta
if stop_token != thinking_end_token:
- raise ValueError("Invalid thinking format: missing ")
+ raise ValueError('Invalid thinking format: missing ')
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
summary_content = content_delta
@@ -544,27 +543,27 @@ def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[st
is_tool_calling = True
else:
if stop_token != eos_token:
- raise ValueError("Invalid format: missing EOS token")
+ raise ValueError('Invalid format: missing EOS token')
if is_tool_calling:
index, stop_token, tool_calls = parse_tool_calls(index, text)
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
if tool_ends_text:
- raise ValueError("Unexpected content after tool calls")
+ raise ValueError('Unexpected content after tool calls')
if len(text) != index or stop_token not in [eos_token, None]:
- raise ValueError("Unexpected content at end")
+ raise ValueError('Unexpected content at end')
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
if sp_token in summary_content or sp_token in reasoning:
raise ValueError(f"Unexpected special token '{sp_token}' in content")
return {
- "role": "assistant",
- "content": summary_content,
- "reasoning": reasoning,
- "tool_calls": tool_calls_to_openai_format(tool_calls)
+ 'role': 'assistant',
+ 'content': summary_content,
+ 'reasoning': reasoning,
+ 'tool_calls': tool_calls_to_openai_format(tool_calls)
}
# fmt: on
From 0e00035bf1ff7bca6d904e2e87cdebf808d585ec Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Fri, 8 May 2026 10:18:14 +0800
Subject: [PATCH 05/40] refactor: remove sync_after_backward mechanism from
expert parallel
The sync_after_backward feature and its associated configuration have been removed as they are no longer needed for maintaining EP/FSDP collective ordering. This simplifies the codebase by eliminating unused synchronization logic.
---
.../model/transformers/moe/expert_parallel.py | 2 +-
.../model/transformers/transformers.py | 20 +++++++++----------
2 files changed, 11 insertions(+), 11 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index 79c256ff..a0c3cc78 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -19,7 +19,7 @@ class ExpertParallelConfig:
router_dtype: str = 'fp32'
keep_router_logits: bool = True
ignore_shared_experts: bool = False
- sync_after_backward: bool = True # consumed by TransformersModel to keep EP/FSDP collectives ordered
+ # sync_after_backward: bool = True # consumed by TransformersModel to keep EP/FSDP collectives ordered
ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 108e9bef..4d062384 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -559,7 +559,7 @@ def backward(self, **kwargs):
else:
loss_value.backward()
- self._sync_after_backward_if_needed()
+ # self._sync_after_backward_if_needed()
optimizer_config.train_status.loss_value = None
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
@@ -582,15 +582,15 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
self.backward(**kwargs)
return outputs
- def _sync_after_backward_if_needed(self) -> None:
- if not getattr(self, '_enable_expert_parallel', False):
- return
- expert_parallel_config = getattr(self, '_expert_parallel_config', None) or {}
- if not expert_parallel_config.get('sync_after_backward', True):
- return
- torch_util.synchronize()
- if dist.is_available() and dist.is_initialized():
- dist.barrier()
+ # def _sync_after_backward_if_needed(self) -> None:
+ # if not getattr(self, '_enable_expert_parallel', False):
+ # return
+ # expert_parallel_config = getattr(self, '_expert_parallel_config', None) or {}
+ # if not expert_parallel_config.get('sync_after_backward', True):
+ # return
+ # torch_util.synchronize()
+ # if dist.is_available() and dist.is_initialized():
+ # dist.barrier()
@remote_function()
def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
From 03289b53b1206cdb7405016b0b8c6e861aae9e7c Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 09:54:13 +0800
Subject: [PATCH 06/40] feat: add configurable distributed timeout via
TWINKLE_DIST_TIMEOUT_SECONDS
Add support for customizing the distributed process group timeout through the environment variable `TWINKLE_DIST_TIMEOUT_SECONDS`, defaulting to 7200 seconds (2 hours). This prevents timeout errors in long-running distributed training jobs.
---
src/twinkle/model/base.py | 2 ++
src/twinkle/model/transformers/strategy/accelerate.py | 6 ++++++
2 files changed, 8 insertions(+)
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index e9c980b9..ec7b0228 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -1,6 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from abc import ABC, abstractmethod
+from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from twinkle import Platform, torch_util
@@ -164,6 +165,7 @@ def _try_init_process_group(self):
'init_method': 'env://',
'rank': Platform.get_rank(),
'world_size': Platform.get_world_size(),
+ 'timeout': timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))),
}
if self._should_bind_device_id_for_process_group(backend):
init_kwargs['device_id'] = torch.device(Platform.get_local_device())
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 8d31291a..1e4161a9 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -1,4 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
+import os
+from datetime import timedelta
from typing import Any, Dict, Literal, Optional
from twinkle import DeviceMesh
@@ -24,6 +26,7 @@ def __init__(
memory_efficient_init: bool = False,
):
from accelerate import Accelerator
+ from accelerate.utils import InitProcessGroupKwargs
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
@@ -32,6 +35,9 @@ def __init__(
fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init)
kwargs_handlers = []
+ kwargs_handlers.append(
+ InitProcessGroupKwargs(timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))
+ )
if ddp_config is not None:
from accelerate import DistributedDataParallelKwargs
ddp_config = DistributedDataParallelKwargs(**ddp_config)
From 26433b8404f064cd57e3b1473df71538a4c2312f Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 11:21:19 +0800
Subject: [PATCH 07/40] fix: patch accelerate FSDP2 state dict loading to
handle unsharded buffers
Some Transformers models keep persistent buffers in state_dict that are not DTensors. The original accelerate FSDP2 load function assumed all entries have device_mesh, causing failures. This patch adds a monkey-patch to handle both DTensor parameters and regular tensor buffers during state dict loading.
---
.../model/transformers/strategy/accelerate.py | 76 +++++++++++++++++++
1 file changed, 76 insertions(+)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 1e4161a9..3d3dbd9e 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -7,6 +7,80 @@
from .load_context import fsdp_pretrained_load_context
+def _patch_accelerate_fsdp2_load_full_state_dict():
+ """Allow Accelerate FSDP2 state-dict loading to handle unsharded buffers.
+
+ Some Transformers models keep persistent buffers in `state_dict`. FSDP2
+ shards parameters as DTensors, but those buffers can remain ordinary
+ tensors; older Accelerate versions assume every state-dict entry has
+ `device_mesh` and fail on such buffers.
+ """
+ import torch
+ import torch.distributed as dist
+ import accelerate.utils.fsdp_utils as fsdp_utils
+ from torch.distributed.tensor import DTensor, distribute_tensor
+
+ if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False):
+ return
+
+ original = fsdp_utils.fsdp2_load_full_state_dict
+
+ def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False):
+ meta_sharded_sd = model.state_dict()
+ sharded_sd = {}
+
+ for param_name, sharded_param in meta_sharded_sd.items():
+ if isinstance(sharded_param, DTensor):
+ device_mesh = sharded_param.device_mesh
+ current_placement = sharded_param.placements
+
+ if accelerator.is_main_process:
+ full_param = full_sd[param_name].detach().to(accelerator.device)
+ if full_param.is_floating_point():
+ old_param = model.get_parameter_or_buffer(param_name)
+ full_param = full_param.to(old_param.dtype)
+ if old_param.is_contiguous():
+ full_param = full_param.contiguous()
+ else:
+ full_param = torch.empty(
+ sharded_param.size(),
+ device=accelerator.device,
+ dtype=sharded_param.dtype,
+ )
+
+ dist.broadcast(full_param, src=0)
+ sharded_param = distribute_tensor(full_param, device_mesh, current_placement)
+ if cpu_offload:
+ sharded_param = sharded_param.cpu()
+ sharded_sd[param_name] = sharded_param
+ continue
+
+ if accelerator.is_main_process:
+ full_value = full_sd[param_name]
+ if isinstance(full_value, DTensor):
+ full_value = full_value.to_local()
+ full_value = full_value.detach().to(accelerator.device)
+ if full_value.is_floating_point():
+ full_value = full_value.to(sharded_param.dtype)
+ else:
+ full_value = torch.empty(
+ sharded_param.size(),
+ device=accelerator.device,
+ dtype=sharded_param.dtype,
+ )
+
+ dist.broadcast(full_value, src=0)
+ if cpu_offload:
+ full_value = full_value.cpu()
+ sharded_sd[param_name] = full_value
+
+ model.load_state_dict(sharded_sd, assign=True)
+
+ patched_fsdp2_load_full_state_dict._twinkle_patched = True
+ patched_fsdp2_load_full_state_dict._twinkle_original = original
+ fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict
+
+
class AccelerateStrategy:
"""A training strategy that uses `accelerate` to wrap models.
@@ -28,6 +102,8 @@ def __init__(
from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs
+ _patch_accelerate_fsdp2_load_full_state_dict()
+
self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._memory_efficient_init = memory_efficient_init
From 9e7efb0f97eeb824b67f03cbf2bdf2896dd5be22 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 11:43:47 +0800
Subject: [PATCH 08/40] fix: handle missing parameter in FSDP2 load state dict
patch
When loading full state dict in FSDP2, the parameter might not exist in the model (e.g., when using tied weights or shared parameters). This change adds a try-except to handle AttributeError when getting the parameter, and only applies contiguous conversion when the parameter exists and is contiguous.
---
src/twinkle/model/transformers/strategy/accelerate.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 3d3dbd9e..4049ab95 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -37,9 +37,12 @@ def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=
if accelerator.is_main_process:
full_param = full_sd[param_name].detach().to(accelerator.device)
if full_param.is_floating_point():
- old_param = model.get_parameter_or_buffer(param_name)
- full_param = full_param.to(old_param.dtype)
- if old_param.is_contiguous():
+ full_param = full_param.to(sharded_param.dtype)
+ try:
+ old_param = model.get_parameter_or_buffer(param_name)
+ except AttributeError:
+ old_param = None
+ if old_param is not None and old_param.is_contiguous():
full_param = full_param.contiguous()
else:
full_param = torch.empty(
From 4464251937572e6e65b1c8727f8059922281f0de Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 12:47:01 +0800
Subject: [PATCH 09/40] feat(cookbook): add environment variable parsing
utilities and enhance FSDP configuration
Add helper functions for parsing boolean, torch dtype, and lora target modules from environment variables. Introduce new config options for low CPU memory usage, memory efficient init, model dtype, and flexible FSDP/DP/EP sizing. Improve accelerate FSDP2 state dict loading with parameter dtype inference and contiguous casting.
---
.../model/transformers/strategy/accelerate.py | 104 +++++++++++++-----
1 file changed, 79 insertions(+), 25 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 4049ab95..de6b68e3 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -18,7 +18,7 @@ def _patch_accelerate_fsdp2_load_full_state_dict():
import torch
import torch.distributed as dist
import accelerate.utils.fsdp_utils as fsdp_utils
- from torch.distributed.tensor import DTensor, distribute_tensor
+ from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False):
return
@@ -29,42 +29,93 @@ def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=
meta_sharded_sd = model.state_dict()
sharded_sd = {}
+ def _infer_parameter_dtype(model, param_name, empty_param):
+ try:
+ old_param = model.get_parameter_or_buffer(param_name)
+ except AttributeError:
+ # Need this for LoRA, as some params are not registered as
+ # parameters/buffers but still appear in the state dict.
+ base_param_name, local_param_name = param_name.rsplit('.', 1)
+ submodule = model.get_submodule(base_param_name)
+ old_param = getattr(submodule, local_param_name)
+
+ is_torch_e4m3fn_available = hasattr(torch, 'float8_e4m3fn')
+ is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
+ casting_dtype = None
+ if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
+ casting_dtype = old_param.dtype
+ return old_param is not None and old_param.is_contiguous(), casting_dtype
+
+ def _cast_and_contiguous(tensor, to_contiguous, dtype):
+ if dtype is not None:
+ tensor = tensor.to(dtype=dtype)
+ if to_contiguous:
+ tensor = tensor.contiguous()
+ return tensor
+
+ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
+ local_tensor = full_tensor
+ for mesh_dim, placement in enumerate(placements):
+ if isinstance(placement, Shard):
+ # All ranks already received the full tensor via broadcast.
+ # Split locally to avoid distribute_tensor's scatter path,
+ # which is fragile on some torch_npu/HCCL versions.
+ local_tensor = placement._shard_tensor(
+ local_tensor,
+ device_mesh,
+ mesh_dim,
+ src_data_rank=None,
+ )
+ elif isinstance(placement, Replicate):
+ continue
+ elif isinstance(placement, Partial):
+ raise NotImplementedError('FSDP2 full-state loading does not support Partial placements.')
+ else:
+ raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
+ return DTensor.from_local(
+ local_tensor,
+ device_mesh=device_mesh,
+ placements=placements,
+ run_check=False,
+ shape=full_tensor.shape,
+ stride=full_tensor.stride(),
+ )
+
+ def _load_full_value(param_name, sharded_param):
+ if param_name not in full_sd:
+ raise KeyError(
+ f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. "
+ f'Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys.')
+ full_value = full_sd[param_name].detach()
+ if isinstance(full_value, DTensor):
+ full_value = full_value.to_local()
+ device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device
+ return full_value.to(device).contiguous()
+
for param_name, sharded_param in meta_sharded_sd.items():
if isinstance(sharded_param, DTensor):
device_mesh = sharded_param.device_mesh
- current_placement = sharded_param.placements
-
+ placements = sharded_param.placements
if accelerator.is_main_process:
- full_param = full_sd[param_name].detach().to(accelerator.device)
- if full_param.is_floating_point():
- full_param = full_param.to(sharded_param.dtype)
- try:
- old_param = model.get_parameter_or_buffer(param_name)
- except AttributeError:
- old_param = None
- if old_param is not None and old_param.is_contiguous():
- full_param = full_param.contiguous()
+ full_param = _load_full_value(param_name, sharded_param)
else:
full_param = torch.empty(
sharded_param.size(),
- device=accelerator.device,
+ device=device_mesh.device_type,
dtype=sharded_param.dtype,
)
- dist.broadcast(full_param, src=0)
- sharded_param = distribute_tensor(full_param, device_mesh, current_placement)
+ dist.broadcast(full_param, src=0, group=dist.group.WORLD)
+ sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements)
+ to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param)
+ sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
if cpu_offload:
- sharded_param = sharded_param.cpu()
- sharded_sd[param_name] = sharded_param
+ sharded_tensor = sharded_tensor.to('cpu')
+ sharded_sd[param_name] = sharded_tensor
continue
if accelerator.is_main_process:
- full_value = full_sd[param_name]
- if isinstance(full_value, DTensor):
- full_value = full_value.to_local()
- full_value = full_value.detach().to(accelerator.device)
- if full_value.is_floating_point():
- full_value = full_value.to(sharded_param.dtype)
+ full_value = _load_full_value(param_name, sharded_param)
else:
full_value = torch.empty(
sharded_param.size(),
@@ -72,12 +123,15 @@ def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=
dtype=sharded_param.dtype,
)
- dist.broadcast(full_value, src=0)
+ dist.broadcast(full_value, src=0, group=dist.group.WORLD)
+ to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value)
+ full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype)
if cpu_offload:
- full_value = full_value.cpu()
+ full_value = full_value.to('cpu')
sharded_sd[param_name] = full_value
model.load_state_dict(sharded_sd, assign=True)
+ return model
patched_fsdp2_load_full_state_dict._twinkle_patched = True
patched_fsdp2_load_full_state_dict._twinkle_original = original
From 715df925e3aafae4fb61bb14f35f5d29e2340a5d Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 14:36:57 +0800
Subject: [PATCH 10/40] feat: add FSDP debug logging and early return for CUDA
devices
Add debug logging to FSDP2 state dict loading and model preparation, controlled by TWINKLE_FSDP_DEBUG environment variable. Also optimize by returning early with original implementation when running on CUDA devices.
---
.../model/transformers/strategy/accelerate.py | 23 ++++++++++++++++++-
1 file changed, 22 insertions(+), 1 deletion(-)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index de6b68e3..34198ef7 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -26,6 +26,12 @@ def _patch_accelerate_fsdp2_load_full_state_dict():
original = fsdp_utils.fsdp2_load_full_state_dict
def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False):
+ _fsdp_debug(f'enter fsdp2_load_full_state_dict device={accelerator.device}')
+ if accelerator.device.type == 'cuda':
+ result = original(accelerator, model, full_sd, cpu_offload=cpu_offload)
+ _fsdp_debug('exit original fsdp2_load_full_state_dict')
+ return result
+
meta_sharded_sd = model.state_dict()
sharded_sd = {}
@@ -131,6 +137,7 @@ def _load_full_value(param_name, sharded_param):
sharded_sd[param_name] = full_value
model.load_state_dict(sharded_sd, assign=True)
+ _fsdp_debug('exit patched fsdp2_load_full_state_dict')
return model
patched_fsdp2_load_full_state_dict._twinkle_patched = True
@@ -138,6 +145,17 @@ def _load_full_value(param_name, sharded_param):
fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict
+def _fsdp_debug(message: str) -> None:
+ if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
+ return
+ try:
+ import torch.distributed as dist
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
+ except Exception:
+ rank = 0
+ print(f'[twinkle-fsdp-debug][rank{rank}] {message}', flush=True)
+
+
class AccelerateStrategy:
"""A training strategy that uses `accelerate` to wrap models.
@@ -258,7 +276,10 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
return fsdp_plugin
def wrap_model(self, model, *args):
- return self.accelerator.prepare(model, *args)
+ _fsdp_debug('enter accelerator.prepare')
+ result = self.accelerator.prepare(model, *args)
+ _fsdp_debug('exit accelerator.prepare')
+ return result
def unwrap_model(self, model):
return self.accelerator.unwrap_model(model, keep_torch_compile=False)
From 7ef149905ec30ea5bcd280a346599666d90f36f9 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 14:51:52 +0800
Subject: [PATCH 11/40] wip
---
cookbook/transformers/deepseek_v4.py | 34 ++++++++++++++++++++--------
1 file changed, 24 insertions(+), 10 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 6e6f2075..13f75721 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -1,5 +1,6 @@
import os
+import torch.distributed as dist
import twinkle
from peft import LoraConfig
from transformers import AutoConfig
@@ -16,7 +17,7 @@
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
-_num_layers_env = os.environ.get('NUM_LAYERS','4')
+_num_layers_env = os.environ.get('NUM_LAYERS','1')
NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
@@ -43,6 +44,16 @@
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+def barrier_if_distributed(stage: str):
+ if not (dist.is_available() and dist.is_initialized()):
+ return
+ if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
+ logger.info(f'[rank{dist.get_rank()}] before barrier: {stage}')
+ dist.barrier()
+ if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
+ logger.info(f'[rank{dist.get_rank()}] after barrier: {stage}')
+
+
def log_expert_parallel_status(model):
logger.info(
f'EP flags: enabled={getattr(model, "_enable_expert_parallel", None)}, '
@@ -98,16 +109,17 @@ def train():
model_id=MODEL_ID,
config=config,
device_mesh=device_mesh,
- strategy='native_fsdp',
+ strategy="accelerate",
+ memory_efficient_init=True,
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
- fsdp_config={
- 'reshard_after_forward': RESHARD_AFTER_FORWARD,
- 'expert_parallel': {
- 'enabled': True,
- 'router_dtype': 'fp32',
- 'keep_router_logits': False,
- }
- },
+ # fsdp_config={
+ # 'reshard_after_forward': RESHARD_AFTER_FORWARD,
+ # 'expert_parallel': {
+ # 'enabled': True,
+ # 'router_dtype': 'fp32',
+ # 'keep_router_logits': False,
+ # }
+ # },
)
if USE_LORA:
@@ -137,6 +149,8 @@ def train():
f'reshard_after_forward={RESHARD_AFTER_FORWARD}, '
f'lora_target_modules={LORA_TARGET_MODULES}')
+ barrier_if_distributed('before first train step')
+
best_loss = float('inf')
for step, batch in enumerate(dataloader):
if MAX_STEPS and step >= MAX_STEPS:
From d087a0e54035c51e6f5011042b5479487b01f3d4 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 15:23:16 +0800
Subject: [PATCH 12/40] feat: enhance FSDP debug logging with train_debug
utility and detailed state loading traces
- Replace logger.info with print in barrier_if_distributed for immediate output
- Add train_debug function with rank and local_rank info for training step debugging
- Add debug logs for forward_backward and clip_grad_and_step in first 2 steps
- Enhance fsdp2_load_full_state_dict patching with detailed per-parameter state loading traces
---
cookbook/transformers/deepseek_v4.py | 18 ++++++++++++++++--
.../model/transformers/strategy/accelerate.py | 15 +++++++++++++--
2 files changed, 29 insertions(+), 4 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 13f75721..c0eb4934 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -48,10 +48,18 @@ def barrier_if_distributed(stage: str):
if not (dist.is_available() and dist.is_initialized()):
return
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
- logger.info(f'[rank{dist.get_rank()}] before barrier: {stage}')
+ print(f'[twinkle-train-debug][rank{dist.get_rank()}] before barrier: {stage}', flush=True)
dist.barrier()
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
- logger.info(f'[rank{dist.get_rank()}] after barrier: {stage}')
+ print(f'[twinkle-train-debug][rank{dist.get_rank()}] after barrier: {stage}', flush=True)
+
+
+def train_debug(message: str):
+ if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
+ return
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else Platform.get_rank()
+ local_rank = Platform.get_local_rank()
+ print(f'[twinkle-train-debug][rank{rank} local_rank={local_rank}] {message}', flush=True)
def log_expert_parallel_status(model):
@@ -155,14 +163,20 @@ def train():
for step, batch in enumerate(dataloader):
if MAX_STEPS and step >= MAX_STEPS:
break
+ if step < 2:
+ train_debug(f'step={step} before forward_backward batch_keys={list(batch.keys())}')
model.forward_backward(
inputs=batch,
adapter_name='default',
)
+ if step < 2:
+ train_debug(f'step={step} after forward_backward')
model.clip_grad_and_step(
adapter_name='default',
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
)
+ if step < 2:
+ train_debug(f'step={step} after clip_grad_and_step')
if step == 0:
log_expert_parallel_status(model)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 34198ef7..9b1aef8a 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -26,14 +26,18 @@ def _patch_accelerate_fsdp2_load_full_state_dict():
original = fsdp_utils.fsdp2_load_full_state_dict
def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False):
- _fsdp_debug(f'enter fsdp2_load_full_state_dict device={accelerator.device}')
+ _fsdp_debug(
+ f'enter fsdp2_load_full_state_dict device={accelerator.device} '
+ f'full_sd_keys={len(full_sd) if full_sd is not None else "None"}')
if accelerator.device.type == 'cuda':
+ _fsdp_debug('delegating to original accelerate fsdp2_load_full_state_dict')
result = original(accelerator, model, full_sd, cpu_offload=cpu_offload)
_fsdp_debug('exit original fsdp2_load_full_state_dict')
return result
meta_sharded_sd = model.state_dict()
sharded_sd = {}
+ _fsdp_debug(f'patched fsdp2 meta_sharded_keys={len(meta_sharded_sd)}')
def _infer_parameter_dtype(model, param_name, empty_param):
try:
@@ -99,6 +103,7 @@ def _load_full_value(param_name, sharded_param):
return full_value.to(device).contiguous()
for param_name, sharded_param in meta_sharded_sd.items():
+ _fsdp_debug(f'load state entry start: {param_name}')
if isinstance(sharded_param, DTensor):
device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
@@ -112,7 +117,9 @@ def _load_full_value(param_name, sharded_param):
)
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
+ _fsdp_debug(f'broadcast done: {param_name}')
sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements)
+ _fsdp_debug(f'local shard done: {param_name}')
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
if cpu_offload:
@@ -130,6 +137,7 @@ def _load_full_value(param_name, sharded_param):
)
dist.broadcast(full_value, src=0, group=dist.group.WORLD)
+ _fsdp_debug(f'broadcast done: {param_name}')
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value)
full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype)
if cpu_offload:
@@ -151,9 +159,12 @@ def _fsdp_debug(message: str) -> None:
try:
import torch.distributed as dist
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
except Exception:
rank = 0
- print(f'[twinkle-fsdp-debug][rank{rank}] {message}', flush=True)
+ world_size = 1
+ local_rank = os.environ.get('LOCAL_RANK', '?')
+ print(f'[twinkle-fsdp-debug][rank{rank}/{world_size} local_rank={local_rank}] {message}', flush=True)
class AccelerateStrategy:
From d206828ca8a7a05a287e07fc4279ffbc4e57b9f3 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 15:44:48 +0800
Subject: [PATCH 13/40] wip
---
cookbook/transformers/deepseek_v4.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index c0eb4934..10514c4b 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -62,6 +62,14 @@ def train_debug(message: str):
print(f'[twinkle-train-debug][rank{rank} local_rank={local_rank}] {message}', flush=True)
+def describe_batch(batch):
+ if isinstance(batch, dict):
+ return f'dict_keys={list(batch.keys())}'
+ if isinstance(batch, (list, tuple)):
+ return f'{type(batch).__name__}[len={len(batch)}]'
+ return type(batch).__name__
+
+
def log_expert_parallel_status(model):
logger.info(
f'EP flags: enabled={getattr(model, "_enable_expert_parallel", None)}, '
@@ -164,7 +172,7 @@ def train():
if MAX_STEPS and step >= MAX_STEPS:
break
if step < 2:
- train_debug(f'step={step} before forward_backward batch_keys={list(batch.keys())}')
+ train_debug(f'step={step} before forward_backward batch={describe_batch(batch)}')
model.forward_backward(
inputs=batch,
adapter_name='default',
From cb86c20d4d560e56855cf49888de2d7b8a8772dc Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 15:50:08 +0800
Subject: [PATCH 14/40] wip
---
cookbook/transformers/deepseek_v4.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 10514c4b..105f8b26 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -49,7 +49,10 @@ def barrier_if_distributed(stage: str):
return
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
print(f'[twinkle-train-debug][rank{dist.get_rank()}] before barrier: {stage}', flush=True)
- dist.barrier()
+ if dist.get_backend() == 'nccl':
+ dist.barrier(device_ids=[Platform.get_local_rank()])
+ else:
+ dist.barrier()
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
print(f'[twinkle-train-debug][rank{dist.get_rank()}] after barrier: {stage}', flush=True)
From 694379e5d94ff3f411eb9d669e613839ccb4546a Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 16:09:00 +0800
Subject: [PATCH 15/40] refactor(debug): enhance debug logging with timestamps
and file output
- Rename `train_debug` to `debug_print` for clarity
- Add timestamp to debug messages for better tracing
- Support writing debug logs to file via `TWINKLE_DEBUG_DIR` env var
- Improve tensor debug info in FSDP2 state loading
- Remove redundant accelerate FSDP2 delegation path
---
cookbook/transformers/deepseek_v4.py | 22 +++++++++----
.../model/transformers/strategy/accelerate.py | 32 +++++++++++++++----
2 files changed, 40 insertions(+), 14 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 105f8b26..4f438c08 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -1,4 +1,5 @@
import os
+import time
import torch.distributed as dist
import twinkle
@@ -48,21 +49,28 @@ def barrier_if_distributed(stage: str):
if not (dist.is_available() and dist.is_initialized()):
return
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
- print(f'[twinkle-train-debug][rank{dist.get_rank()}] before barrier: {stage}', flush=True)
+ debug_print(f'before barrier: {stage}')
if dist.get_backend() == 'nccl':
dist.barrier(device_ids=[Platform.get_local_rank()])
else:
dist.barrier()
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
- print(f'[twinkle-train-debug][rank{dist.get_rank()}] after barrier: {stage}', flush=True)
+ debug_print(f'after barrier: {stage}')
-def train_debug(message: str):
+def debug_print(message: str):
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
return
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else Platform.get_rank()
local_rank = Platform.get_local_rank()
- print(f'[twinkle-train-debug][rank{rank} local_rank={local_rank}] {message}', flush=True)
+ timestamp = time.time()
+ text = f'[twinkle-train-debug][time={timestamp:.6f} rank{rank} local_rank={local_rank}] {message}'
+ print(text, flush=True)
+ debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
+ if debug_dir:
+ os.makedirs(debug_dir, exist_ok=True)
+ with open(os.path.join(debug_dir, f'train_rank{rank}.log'), 'a', encoding='utf-8') as f:
+ f.write(text + '\n')
def describe_batch(batch):
@@ -175,19 +183,19 @@ def train():
if MAX_STEPS and step >= MAX_STEPS:
break
if step < 2:
- train_debug(f'step={step} before forward_backward batch={describe_batch(batch)}')
+ debug_print(f'step={step} before forward_backward batch={describe_batch(batch)}')
model.forward_backward(
inputs=batch,
adapter_name='default',
)
if step < 2:
- train_debug(f'step={step} after forward_backward')
+ debug_print(f'step={step} after forward_backward')
model.clip_grad_and_step(
adapter_name='default',
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
)
if step < 2:
- train_debug(f'step={step} after clip_grad_and_step')
+ debug_print(f'step={step} after clip_grad_and_step')
if step == 0:
log_expert_parallel_status(model)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 9b1aef8a..d4ec28ac 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -1,5 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
+import time
from datetime import timedelta
from typing import Any, Dict, Literal, Optional
@@ -29,11 +30,6 @@ def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=
_fsdp_debug(
f'enter fsdp2_load_full_state_dict device={accelerator.device} '
f'full_sd_keys={len(full_sd) if full_sd is not None else "None"}')
- if accelerator.device.type == 'cuda':
- _fsdp_debug('delegating to original accelerate fsdp2_load_full_state_dict')
- result = original(accelerator, model, full_sd, cpu_offload=cpu_offload)
- _fsdp_debug('exit original fsdp2_load_full_state_dict')
- return result
meta_sharded_sd = model.state_dict()
sharded_sd = {}
@@ -102,8 +98,23 @@ def _load_full_value(param_name, sharded_param):
device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device
return full_value.to(device).contiguous()
+ def _tensor_debug(tensor):
+ if isinstance(tensor, DTensor):
+ return (
+ f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} '
+ f'placements={tensor.placements} mesh={tensor.device_mesh}')
+ if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'):
+ return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}'
+ return f'type={type(tensor).__name__}'
+
for param_name, sharded_param in meta_sharded_sd.items():
- _fsdp_debug(f'load state entry start: {param_name}')
+ _fsdp_debug(f'load state entry start: {param_name} {_tensor_debug(sharded_param)}')
+ if accelerator.is_main_process:
+ full_value = full_sd.get(param_name)
+ if full_value is None:
+ _fsdp_debug(f'full state entry missing: {param_name}')
+ else:
+ _fsdp_debug(f'full state entry: {param_name} {_tensor_debug(full_value)}')
if isinstance(sharded_param, DTensor):
device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
@@ -164,7 +175,14 @@ def _fsdp_debug(message: str) -> None:
rank = 0
world_size = 1
local_rank = os.environ.get('LOCAL_RANK', '?')
- print(f'[twinkle-fsdp-debug][rank{rank}/{world_size} local_rank={local_rank}] {message}', flush=True)
+ timestamp = time.time()
+ text = f'[twinkle-fsdp-debug][time={timestamp:.6f} rank{rank}/{world_size} local_rank={local_rank}] {message}'
+ print(text, flush=True)
+ debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
+ if debug_dir:
+ os.makedirs(debug_dir, exist_ok=True)
+ with open(os.path.join(debug_dir, f'fsdp_rank{rank}.log'), 'a', encoding='utf-8') as f:
+ f.write(text + '\n')
class AccelerateStrategy:
From 43bad86d58b37987c5f81b7c3199524af391837c Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 18:27:22 +0800
Subject: [PATCH 16/40] wip
Co-authored-by: Copilot
---
cookbook/transformers/deepseek_v4.py | 15 ------
src/twinkle/model/base.py | 2 +-
.../model/transformers/transformers.py | 54 +++++++++++++++++++
3 files changed, 55 insertions(+), 16 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 4f438c08..1f38b9fc 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -45,19 +45,6 @@
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-def barrier_if_distributed(stage: str):
- if not (dist.is_available() and dist.is_initialized()):
- return
- if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
- debug_print(f'before barrier: {stage}')
- if dist.get_backend() == 'nccl':
- dist.barrier(device_ids=[Platform.get_local_rank()])
- else:
- dist.barrier()
- if os.environ.get('TWINKLE_FSDP_DEBUG', '0') == '1':
- debug_print(f'after barrier: {stage}')
-
-
def debug_print(message: str):
if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
return
@@ -176,8 +163,6 @@ def train():
f'reshard_after_forward={RESHARD_AFTER_FORWARD}, '
f'lora_target_modules={LORA_TARGET_MODULES}')
- barrier_if_distributed('before first train step')
-
best_loss = float('inf')
for step, batch in enumerate(dataloader):
if MAX_STEPS and step >= MAX_STEPS:
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index ec7b0228..303eaf74 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -144,7 +144,7 @@ def upload_to_hub(self,
HubOperation.push_to_hub(repo_id=hub_model_id, folder_path=checkpoint_dir, token=hub_token, private=True)
def _should_bind_device_id_for_process_group(self, backend: str) -> bool:
- return backend == 'hccl'
+ return backend in ('nccl', 'hccl')
def _try_init_process_group(self):
import torch
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 4d062384..7e64856c 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -7,6 +7,7 @@
import random
import re
import threading
+import time
import torch
import torch.distributed as dist
import transformers
@@ -46,6 +47,25 @@
logger = get_logger()
+def _twinkle_fsdp_debug(message: str) -> None:
+ if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
+ return
+ try:
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
+ except Exception:
+ rank = int(os.environ.get('RANK', 0))
+ world_size = 1
+ local_rank = os.environ.get('LOCAL_RANK', '?')
+ text = f'[twinkle-model-debug][time={time.time():.6f} rank{rank}/{world_size} local_rank={local_rank}] {message}'
+ print(text, flush=True)
+ debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
+ if debug_dir:
+ os.makedirs(debug_dir, exist_ok=True)
+ with open(os.path.join(debug_dir, f'model_rank{rank}.log'), 'a', encoding='utf-8') as f:
+ f.write(text + '\n')
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -163,7 +183,9 @@ def __init__(
memory_efficient_init: bool = False,
**kwargs):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
+ _twinkle_fsdp_debug('TransformersModel init before process_group')
self._try_init_process_group()
+ _twinkle_fsdp_debug('TransformersModel init after process_group')
super(PreTrainedModel, self).__init__()
# The Default tokenizer will be used to save with a model if no template was set.
self._default_tokenizer = None
@@ -173,9 +195,14 @@ def __init__(
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
+ _twinkle_fsdp_debug(
+ f'TransformersModel strategy decided strategy={strategy} '
+ f'memory_efficient_init={memory_efficient_init}')
self.grad_scaler_config = grad_scaler_config
if model_id is not None:
+ _twinkle_fsdp_debug(f'before HubOperation.download_model model_id={model_id}')
model_id = HubOperation.download_model(model_id)
+ _twinkle_fsdp_debug(f'after HubOperation.download_model model_id={model_id}')
self.model_id = model_id
self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
if config is None:
@@ -190,12 +217,20 @@ def __init__(
if isinstance(model_cls, str):
model_cls = getattr(transformers, model_cls)
if model_id is None:
+ _twinkle_fsdp_debug('before model_cls.from_config')
self.model = model_cls.from_config(self.hf_config, **kwargs)
+ _twinkle_fsdp_debug('after model_cls.from_config')
else:
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
+ _twinkle_fsdp_debug('before pretrained_load_context')
with self.strategy.pretrained_load_context():
+ _twinkle_fsdp_debug('before model_cls.from_pretrained')
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
+ _twinkle_fsdp_debug('after model_cls.from_pretrained')
+ _twinkle_fsdp_debug('after pretrained_load_context')
+ _twinkle_fsdp_debug('before gradient_checkpointing_enable')
self.model.gradient_checkpointing_enable()
+ _twinkle_fsdp_debug('after gradient_checkpointing_enable')
self.sp_strategy = None
self._model_wrapped = False
self.optimizer_group: Dict[str, OptimizerGroup] = {
@@ -265,27 +300,40 @@ def _not_encoded(inputs):
def _lazy_wrap_model(self):
if not self._model_wrapped:
+ _twinkle_fsdp_debug('enter _lazy_wrap_model')
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
+ _twinkle_fsdp_debug(f'_lazy_wrap_model optimizer_groups={len(optimizer_groups)}')
+ _twinkle_fsdp_debug('before _maybe_apply_expert_parallel')
self._maybe_apply_expert_parallel()
+ _twinkle_fsdp_debug('after _maybe_apply_expert_parallel')
+ _twinkle_fsdp_debug('before _ensure_sp_strategy')
self._ensure_sp_strategy()
+ _twinkle_fsdp_debug('after _ensure_sp_strategy')
if self.sp_strategy is not None:
+ _twinkle_fsdp_debug('before sp_strategy.initialize')
self.sp_strategy.initialize()
+ _twinkle_fsdp_debug('after sp_strategy.initialize')
if len(optimizer_groups) == 1:
optimizer_group = optimizer_groups[0]
optimizer = optimizer_group.optimizer
assert optimizer is not None
+ _twinkle_fsdp_debug('before strategy.wrap_model with optimizer')
self.model, optimizer = self.strategy.wrap_model(self.model, optimizer)
+ _twinkle_fsdp_debug('after strategy.wrap_model with optimizer')
optimizer_group.optimizer = optimizer
self.register_mm_forward_hook(optimizer_group)
else:
# maybe forward_only, no optimizer_group available
+ _twinkle_fsdp_debug('before strategy.wrap_model without optimizer')
result = self.strategy.wrap_model(self.model)
+ _twinkle_fsdp_debug('after strategy.wrap_model without optimizer')
if isinstance(result, tuple):
self.model = result[0]
else:
self.model = result
self._model_wrapped = True
+ _twinkle_fsdp_debug('exit _lazy_wrap_model')
def register_mm_forward_hook(self, optimizer_group: OptimizerGroup):
model = self.strategy.unwrap_model(self.model)
@@ -355,7 +403,9 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
temperature = float(kwargs.pop('temperature', 1.0))
return_logits = kwargs.pop('return_logits', False)
optimizer_config = self.optimizer_group[adapter_name]
+ _twinkle_fsdp_debug('forward before _lazy_wrap_model')
self._lazy_wrap_model()
+ _twinkle_fsdp_debug('forward after _lazy_wrap_model')
if not inputs:
raise ValueError('inputs empty, check your DataLoader outputs')
self.model.train()
@@ -576,10 +626,14 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
Returns:
The output of the model forward.
"""
+ _twinkle_fsdp_debug('forward_backward enter')
outputs = self.forward(inputs=inputs, **kwargs)
+ _twinkle_fsdp_debug('forward_backward after forward')
loss = self.calculate_loss(**kwargs)
+ _twinkle_fsdp_debug('forward_backward after calculate_loss')
outputs['loss'] = loss
self.backward(**kwargs)
+ _twinkle_fsdp_debug('forward_backward after backward')
return outputs
# def _sync_after_backward_if_needed(self) -> None:
From 860516942117db452775eec07138f4739e6c829f Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 19:38:28 +0800
Subject: [PATCH 17/40] fix: improve FSDP2 and native FSDP pretrained loading
with EP support
- Fix accelerate FSDP2 patch to use `distribute_tensor` for CUDA devices
- Refactor native FSDP to handle rank0-only loading and broadcast
- Add EP expert shard specs collection and rank mapping for state dict broadcast
- Fix non-persistent buffer handling to use broadcast instead of restore
---
cookbook/transformers/deepseek_v4.py | 18 +-
src/twinkle/model/base.py | 2 +-
.../model/transformers/strategy/accelerate.py | 5 +-
.../transformers/strategy/native_fsdp.py | 174 +++++++++++++++---
.../model/transformers/transformers.py | 49 +++++
5 files changed, 214 insertions(+), 34 deletions(-)
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
index 1f38b9fc..3e10d44c 100644
--- a/cookbook/transformers/deepseek_v4.py
+++ b/cookbook/transformers/deepseek_v4.py
@@ -123,17 +123,17 @@ def train():
model_id=MODEL_ID,
config=config,
device_mesh=device_mesh,
- strategy="accelerate",
+ strategy='native_fsdp',
memory_efficient_init=True,
ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
- # fsdp_config={
- # 'reshard_after_forward': RESHARD_AFTER_FORWARD,
- # 'expert_parallel': {
- # 'enabled': True,
- # 'router_dtype': 'fp32',
- # 'keep_router_logits': False,
- # }
- # },
+ fsdp_config={
+ 'reshard_after_forward': RESHARD_AFTER_FORWARD,
+ 'expert_parallel': {
+ 'enabled': True,
+ 'router_dtype': 'fp32',
+ 'keep_router_logits': False,
+ },
+ },
)
if USE_LORA:
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index 303eaf74..37073e39 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -168,7 +168,7 @@ def _try_init_process_group(self):
'timeout': timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))),
}
if self._should_bind_device_id_for_process_group(backend):
- init_kwargs['device_id'] = torch.device(Platform.get_local_device())
+ ['device_id'] = torch.device(Platform.get_local_device())
dist.init_process_group(**init_kwargs)
if backend == 'hccl':
default_pg = dist.distributed_c10d._get_default_group()
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index d4ec28ac..353661fb 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -19,7 +19,7 @@ def _patch_accelerate_fsdp2_load_full_state_dict():
import torch
import torch.distributed as dist
import accelerate.utils.fsdp_utils as fsdp_utils
- from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
+ from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor
if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False):
return
@@ -60,6 +60,9 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
return tensor
def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
+ if device_mesh.device_type == 'cuda':
+ return distribute_tensor(full_tensor, device_mesh, placements)
+
local_tensor = full_tensor
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Shard):
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 9993454b..856d5dfa 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -30,7 +30,13 @@ def __init__(self,
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
def pretrained_load_context(self):
- return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)
+ # Native FSDP handles rank0-load itself. Do not enable Transformers'
+ # FSDP efficient-loading env here, because some versions have an
+ # unmatched non-rank0 barrier in from_pretrained().
+ return fsdp_pretrained_load_context(False)
+
+ def use_rank0_pretrained_broadcast(self) -> bool:
+ return self._memory_efficient_init and self.device_mesh is not None
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
@@ -60,16 +66,18 @@ def wrap_model(self, model, optimizer=None):
_unbind_optimizer_params(optimizer)
# EP path requires experts on a real device, incompatible with meta-device flow.
- use_meta = self._memory_efficient_init and not ep_enabled
+ use_meta = self.use_rank0_pretrained_broadcast()
original_sd = None
saved_buffers = None
if use_meta:
- original_sd = model.state_dict()
- saved_buffers = _get_non_persistent_buffers(model)
- model = model.to(torch.device('meta'))
- if hasattr(model, 'tie_weights'):
- model.tie_weights()
+ is_rank0 = (dist.get_rank() == 0)
+ original_sd = model.state_dict() if is_rank0 else {}
+ saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
+ if is_rank0:
+ model = model.to(torch.device('meta'))
+ if hasattr(model, 'tie_weights'):
+ model.tie_weights()
if ep_enabled:
_ensure_moe_patched_if_needed(model, self.ep_fsdp_device_mesh)
@@ -129,14 +137,17 @@ def wrap_model(self, model, optimizer=None):
if use_meta:
device_type = self.device_mesh.device_type or 'cuda'
- is_rank0 = (dist.get_rank() == 0)
+ expert_shard_specs = _collect_ep_expert_shard_specs(model) if ep_enabled else {}
+ rank_to_ep_rank = _build_rank_to_ep_rank(self.ep_fsdp_device_mesh) if ep_enabled else {}
_broadcast_sharded_state_dict(
model,
- original_sd if is_rank0 else {},
+ original_sd,
device_type=device_type,
+ expert_shard_specs=expert_shard_specs,
+ rank_to_ep_rank=rank_to_ep_rank,
)
target_device = torch.device(device_type)
- _restore_non_persistent_buffers(model, saved_buffers, device=target_device)
+ _broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device)
if hasattr(model, 'tie_weights'):
model.tie_weights()
@@ -383,6 +394,37 @@ def _collect_ep_experts_map(model: nn.Module) -> Dict[str, nn.Module]:
return experts_map
+def _collect_ep_expert_shard_specs(model: nn.Module) -> Dict[str, Dict[str, int]]:
+ """Collect state-dict names that are sharded by EP and their local expert range."""
+ specs = {}
+ for fqn, module in model.named_modules():
+ if not getattr(module, '_ep_patched', False):
+ continue
+ experts = getattr(module, 'experts', None)
+ if experts is None:
+ continue
+ experts_prefix = f'{fqn}.experts.' if fqn else 'experts.'
+ for pname, _ in experts.named_parameters():
+ specs[experts_prefix + pname] = {
+ 'num_experts': int(module._ep_num_experts),
+ 'experts_per_rank': int(module._ep_experts_per_rank),
+ }
+ return specs
+
+
+def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Dict[int, int]:
+ if ep_fsdp_device_mesh is None:
+ return {}
+ mesh = ep_fsdp_device_mesh.mesh
+ if hasattr(mesh, 'detach'):
+ mesh = mesh.detach().cpu().numpy()
+ rank_to_ep_rank = {}
+ for ep_rank in range(mesh.shape[0]):
+ for rank in mesh[ep_rank].flatten().tolist():
+ rank_to_ep_rank[int(rank)] = int(ep_rank)
+ return rank_to_ep_rank
+
+
def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""Find the experts module inside a decoder layer, if any."""
for module in layer_mod.modules():
@@ -416,11 +458,18 @@ def _place_ep_experts_on_local_device(model: nn.Module, ep_fsdp_device_mesh: Opt
continue
experts = getattr(module, 'experts', None)
if experts is not None:
- experts.to(local_device)
+ _move_module_to_device_or_empty(experts, local_device)
if getattr(module, '_ep_ignore_shared_experts', False):
shared = getattr(module, 'shared_expert', None)
if shared is not None:
- shared.to(local_device)
+ _move_module_to_device_or_empty(shared, local_device)
+
+
+def _move_module_to_device_or_empty(module: nn.Module, device: torch.device) -> None:
+ if any(param.is_meta for param in module.parameters(recurse=True)):
+ module.to_empty(device=device)
+ else:
+ module.to(device)
def _ensure_moe_patched_if_needed(model: nn.Module, ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> None:
@@ -477,32 +526,99 @@ def _broadcast_sharded_state_dict(
model: nn.Module,
full_sd: dict,
device_type: str = 'cuda',
+ expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None,
+ rank_to_ep_rank: Optional[Dict[int, int]] = None,
) -> None:
- """Broadcast full state dict from rank 0 and materialise local shards via distribute_tensor."""
- from torch.distributed.tensor import DTensor, distribute_tensor
+ """Broadcast full state dict from rank 0 and materialize local FSDP2 shards."""
+ from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
meta_sharded_sd = model.state_dict()
sharded_sd = {}
is_rank0 = (dist.get_rank() == 0)
+ expert_shard_specs = expert_shard_specs or {}
+ rank_to_ep_rank = rank_to_ep_rank or {}
+
+ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
+ local_tensor = full_tensor
+ for mesh_dim, placement in enumerate(placements):
+ if isinstance(placement, Shard):
+ local_tensor = placement._shard_tensor(
+ local_tensor,
+ device_mesh,
+ mesh_dim,
+ src_data_rank=None,
+ )
+ elif isinstance(placement, Replicate):
+ continue
+ elif isinstance(placement, Partial):
+ raise NotImplementedError('Native FSDP2 full-state loading does not support Partial placements.')
+ else:
+ raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
+ return DTensor.from_local(
+ local_tensor,
+ device_mesh=device_mesh,
+ placements=placements,
+ run_check=False,
+ shape=full_tensor.shape,
+ stride=full_tensor.stride(),
+ )
+
+ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
+ spec = expert_shard_specs[param_name]
+ experts_per_rank = spec['experts_per_rank']
+ num_experts = spec['num_experts']
+ local_shape = tuple(sharded_param.size())
+ local_tensor = torch.empty(local_shape, device=device_type, dtype=sharded_param.dtype)
+
+ scatter_list = None
+ if is_rank0:
+ if full_tensor.size(0) != num_experts:
+ raise RuntimeError(
+ f"EP expert parameter '{param_name}' expects {num_experts} experts, "
+ f'but full state has shape {tuple(full_tensor.shape)}.')
+ scatter_list = []
+ world_size = dist.get_world_size()
+ for rank in range(world_size):
+ if rank not in rank_to_ep_rank:
+ raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.')
+ ep_rank = rank_to_ep_rank[rank]
+ start = ep_rank * experts_per_rank
+ end = start + experts_per_rank
+ scatter_list.append(full_tensor[start:end].contiguous())
+
+ dist.scatter(local_tensor, scatter_list=scatter_list, src=0)
+ return local_tensor
for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
dtype = sharded_param.dtype
+ is_ep_expert_param = param_name in expert_shard_specs
if is_rank0:
+ if param_name not in full_sd:
+ raise KeyError(
+ f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.")
full_param = full_sd[param_name]
full_tensor = full_param.detach().to(device_type)
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
else:
- full_tensor = torch.empty(shape, device=device_type, dtype=dtype)
+ full_tensor = None if is_ep_expert_param else torch.empty(shape, device=device_type, dtype=dtype)
- dist.broadcast(full_tensor, src=0)
+ if is_ep_expert_param:
+ full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
+ else:
+ dist.broadcast(full_tensor, src=0)
torch_util.synchronize()
- device_mesh = sharded_param.device_mesh
- placements = sharded_param.placements
- sharded_tensor = distribute_tensor(full_tensor, device_mesh, placements)
+ if isinstance(sharded_param, DTensor):
+ sharded_tensor = _dtensor_from_replicated_full_tensor(
+ full_tensor,
+ sharded_param.device_mesh,
+ sharded_param.placements,
+ )
+ else:
+ sharded_tensor = full_tensor
del full_tensor
sharded_sd[param_name] = sharded_tensor
@@ -529,14 +645,26 @@ def _unbind_optimizer_params(optimizer: torch.optim.Optimizer) -> None:
group['params'][i] = torch.empty(1, dtype=param.dtype, device=param.device)
-def _restore_non_persistent_buffers(
+def _broadcast_non_persistent_buffers(
model: nn.Module,
saved_buffers: Dict[str, torch.Tensor],
device: torch.device,
) -> None:
- """Re-register non-persistent buffers saved before to('meta')."""
- for fqn, buf_tensor in saved_buffers.items():
- buf_tensor = buf_tensor.to(device)
+ """Broadcast rank0 non-persistent buffers and re-register them on all ranks."""
+ is_rank0 = (dist.get_rank() == 0)
+ metadata = None
+ if is_rank0:
+ metadata = [(name, tuple(tensor.shape), tensor.dtype) for name, tensor in saved_buffers.items()]
+ metadata_holder = [metadata]
+ dist.broadcast_object_list(metadata_holder, src=0)
+ metadata = metadata_holder[0] or []
+
+ for fqn, shape, dtype in metadata:
+ if is_rank0:
+ buf_tensor = saved_buffers[fqn].to(device)
+ else:
+ buf_tensor = torch.empty(shape, device=device, dtype=dtype)
+ dist.broadcast(buf_tensor, src=0)
if '.' in fqn:
parent_fqn, local_name = fqn.rsplit('.', 1)
parent = model.get_submodule(parent_fqn)
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 7e64856c..351c6543 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -66,6 +66,33 @@ def _twinkle_fsdp_debug(message: str) -> None:
f.write(text + '\n')
+def _filter_from_config_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ load_only_keys = {
+ 'cache_dir',
+ 'device_map',
+ 'force_download',
+ 'ignore_mismatched_sizes',
+ 'local_files_only',
+ 'low_cpu_mem_usage',
+ 'max_memory',
+ 'offload_buffers',
+ 'offload_folder',
+ 'offload_state_dict',
+ 'output_loading_info',
+ 'proxies',
+ 'resume_download',
+ 'revision',
+ 'state_dict',
+ 'subfolder',
+ 'token',
+ 'tokenizer_id',
+ 'trust_remote_code',
+ 'use_safetensors',
+ 'weights_only',
+ }
+ return {key: value for key, value in kwargs.items() if key not in load_only_keys}
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -220,6 +247,10 @@ def __init__(
_twinkle_fsdp_debug('before model_cls.from_config')
self.model = model_cls.from_config(self.hf_config, **kwargs)
_twinkle_fsdp_debug('after model_cls.from_config')
+ elif self._should_init_empty_pretrained_model_on_this_rank():
+ _twinkle_fsdp_debug('before empty model_cls.from_config for rank0 broadcast')
+ self.model = self._init_empty_model_from_config(model_cls, **kwargs)
+ _twinkle_fsdp_debug('after empty model_cls.from_config for rank0 broadcast')
else:
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
_twinkle_fsdp_debug('before pretrained_load_context')
@@ -239,6 +270,24 @@ def __init__(
self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name
self.active_group = _default_adapter_name
+ def _should_init_empty_pretrained_model_on_this_rank(self) -> bool:
+ use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
+ return bool(
+ use_rank0_broadcast()
+ and dist.is_available()
+ and dist.is_initialized()
+ and dist.get_rank() != 0)
+
+ def _init_empty_model_from_config(self, model_cls, **kwargs):
+ from accelerate import init_empty_weights
+
+ config_kwargs = _filter_from_config_kwargs(kwargs)
+ with init_empty_weights(include_buffers=False):
+ model = model_cls.from_config(self.hf_config, **config_kwargs)
+ if hasattr(model, 'tie_weights'):
+ model.tie_weights()
+ return model
+
def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None)
self._enable_expert_parallel = self._should_enable_expert_parallel(self._expert_parallel_config,
From f7f6179a7fe4320f73bbb99931c952e55f904d27 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 20:09:09 +0800
Subject: [PATCH 18/40] fix: correct indentation and add debug logging for FSDP
expert parallelism
- Fix indentation error in `base.py` for `device_id` assignment
- Add `_native_fsdp_debug` utility function for FSDP debugging
- Implement pre-EP full state dict capture to avoid redundant state_dict calls
- Add debug logging for expert scatter operations in FSDP
---
src/twinkle/model/base.py | 2 +-
.../transformers/strategy/native_fsdp.py | 39 ++++++++++++++++++-
.../model/transformers/transformers.py | 22 ++++++++++-
3 files changed, 59 insertions(+), 4 deletions(-)
diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py
index 37073e39..303eaf74 100644
--- a/src/twinkle/model/base.py
+++ b/src/twinkle/model/base.py
@@ -168,7 +168,7 @@ def _try_init_process_group(self):
'timeout': timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))),
}
if self._should_bind_device_id_for_process_group(backend):
- ['device_id'] = torch.device(Platform.get_local_device())
+ init_kwargs['device_id'] = torch.device(Platform.get_local_device())
dist.init_process_group(**init_kwargs)
if backend == 'hccl':
default_pg = dist.distributed_c10d._get_default_group()
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 856d5dfa..80728caa 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -1,6 +1,8 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import torch.distributed as dist
+import os
+import time
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
@@ -13,6 +15,21 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
+def _native_fsdp_debug(message: str) -> None:
+ if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
+ return
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
+ local_rank = os.environ.get('LOCAL_RANK', '?')
+ text = f'[twinkle-native-fsdp-debug][time={time.time():.6f} rank{rank}/{world_size} local_rank={local_rank}] {message}'
+ print(text, flush=True)
+ debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
+ if debug_dir:
+ os.makedirs(debug_dir, exist_ok=True)
+ with open(os.path.join(debug_dir, f'native_fsdp_rank{rank}.log'), 'a', encoding='utf-8') as f:
+ f.write(text + '\n')
+
+
class NativeFSDPStrategy:
def __init__(self,
@@ -28,6 +45,7 @@ def __init__(self,
self._memory_efficient_init = memory_efficient_init
self.enable_ep = enable_ep
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
+ self._rank0_pre_ep_full_state_dict = None
def pretrained_load_context(self):
# Native FSDP handles rank0-load itself. Do not enable Transformers'
@@ -38,6 +56,9 @@ def pretrained_load_context(self):
def use_rank0_pretrained_broadcast(self) -> bool:
return self._memory_efficient_init and self.device_mesh is not None
+ def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
+ self._rank0_pre_ep_full_state_dict = state_dict
+
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
return None
@@ -72,7 +93,13 @@ def wrap_model(self, model, optimizer=None):
saved_buffers = None
if use_meta:
is_rank0 = (dist.get_rank() == 0)
- original_sd = model.state_dict() if is_rank0 else {}
+ if ep_enabled and self._rank0_pre_ep_full_state_dict is not None:
+ _native_fsdp_debug(
+ f'use captured pre-EP full state_dict keys={len(self._rank0_pre_ep_full_state_dict)}')
+ original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {}
+ else:
+ _native_fsdp_debug('use current model state_dict as rank0 broadcast source')
+ original_sd = model.state_dict() if is_rank0 else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
if is_rank0:
model = model.to(torch.device('meta'))
@@ -569,13 +596,20 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
num_experts = spec['num_experts']
local_shape = tuple(sharded_param.size())
local_tensor = torch.empty(local_shape, device=device_type, dtype=sharded_param.dtype)
+ _native_fsdp_debug(
+ f'EP expert scatter start name={param_name} local_shape={local_shape} '
+ f'num_experts={num_experts} experts_per_rank={experts_per_rank}')
scatter_list = None
if is_rank0:
+ _native_fsdp_debug(
+ f'EP expert scatter source name={param_name} source_shape={tuple(full_tensor.shape)} '
+ f'source_dtype={full_tensor.dtype}')
if full_tensor.size(0) != num_experts:
raise RuntimeError(
f"EP expert parameter '{param_name}' expects {num_experts} experts, "
- f'but full state has shape {tuple(full_tensor.shape)}.')
+ f'but source state has shape {tuple(full_tensor.shape)}. '
+ 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
scatter_list = []
world_size = dist.get_world_size()
for rank in range(world_size):
@@ -587,6 +621,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
scatter_list.append(full_tensor[start:end].contiguous())
dist.scatter(local_tensor, scatter_list=scatter_list, src=0)
+ _native_fsdp_debug(f'EP expert scatter done name={param_name}')
return local_tensor
for param_name, sharded_param in meta_sharded_sd.items():
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 351c6543..40233403 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -93,6 +93,16 @@ def _filter_from_config_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
return {key: value for key, value in kwargs.items() if key not in load_only_keys}
+def _clone_state_dict_to_cpu(state_dict: Dict[str, Any]) -> Dict[str, Any]:
+ cloned = {}
+ for key, value in state_dict.items():
+ if hasattr(value, 'detach'):
+ cloned[key] = value.detach().cpu().clone()
+ else:
+ cloned[key] = value
+ return cloned
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -283,7 +293,10 @@ def _init_empty_model_from_config(self, model_cls, **kwargs):
config_kwargs = _filter_from_config_kwargs(kwargs)
with init_empty_weights(include_buffers=False):
- model = model_cls.from_config(self.hf_config, **config_kwargs)
+ if hasattr(model_cls, 'from_config'):
+ model = model_cls.from_config(self.hf_config, **config_kwargs)
+ else:
+ model = model_cls._from_config(self.hf_config, **config_kwargs)
if hasattr(model, 'tie_weights'):
model.tie_weights()
return model
@@ -352,6 +365,13 @@ def _lazy_wrap_model(self):
_twinkle_fsdp_debug('enter _lazy_wrap_model')
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
_twinkle_fsdp_debug(f'_lazy_wrap_model optimizer_groups={len(optimizer_groups)}')
+ use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
+ set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None)
+ if self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None:
+ is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
+ _twinkle_fsdp_debug('before capture pre-EP full state_dict for rank0 broadcast')
+ set_pre_ep_state(_clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {})
+ _twinkle_fsdp_debug('after capture pre-EP full state_dict for rank0 broadcast')
_twinkle_fsdp_debug('before _maybe_apply_expert_parallel')
self._maybe_apply_expert_parallel()
_twinkle_fsdp_debug('after _maybe_apply_expert_parallel')
From 291f34215284c0806267cdcb41e88ced4bff00ae Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 21:25:38 +0800
Subject: [PATCH 19/40] feat(moe): add expert parallelism debug tracing
Add `_ep_trace` function and debug logging to MoE expert parallelism operations for better debugging and performance analysis. The trace logs all-to-all and all-gather operations with tensor shapes and split sizes, controlled by `TWINKLE_EP_DEBUG` or `TWINKLE_FSDP_DEBUG` environment variables.
---
.../model/transformers/moe/ep_utils.py | 68 +++++++++++++++++++
.../model/transformers/moe/expert_parallel.py | 57 +++++++++++++++-
.../model/transformers/transformers.py | 14 ++++
3 files changed, 136 insertions(+), 3 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py
index f64c1796..1afd378e 100644
--- a/src/twinkle/model/transformers/moe/ep_utils.py
+++ b/src/twinkle/model/transformers/moe/ep_utils.py
@@ -6,9 +6,39 @@
import torch
import torch.distributed as dist
+import os
+import time
from typing import Optional
+def _ep_trace(message: str, group: Optional[dist.ProcessGroup] = None) -> None:
+ if os.environ.get('TWINKLE_EP_DEBUG', os.environ.get('TWINKLE_FSDP_DEBUG', '0')) != '1':
+ return
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
+ if group is not None:
+ try:
+ ep_rank = dist.get_rank(group)
+ ep_world_size = dist.get_world_size(group)
+ except Exception:
+ ep_rank = '?'
+ ep_world_size = '?'
+ else:
+ ep_rank = '?'
+ ep_world_size = '?'
+ local_rank = os.environ.get('LOCAL_RANK', '?')
+ text = (
+ f'[twinkle-ep-trace][time={time.time():.6f} rank{rank}/{world_size} '
+ f'local_rank={local_rank} ep_rank={ep_rank}/{ep_world_size}] {message}'
+ )
+ print(text, flush=True)
+ debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
+ if debug_dir:
+ os.makedirs(debug_dir, exist_ok=True)
+ with open(os.path.join(debug_dir, f'ep_rank{rank}.log'), 'a', encoding='utf-8') as f:
+ f.write(text + '\n')
+
+
# ========================== comm ==========================
class _AllToAll(torch.autograd.Function):
@@ -29,6 +59,11 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
output = torch.empty_like(input)
else:
output = torch.empty(size=(sum(output_split_sizes), input.size(1)), dtype=input.dtype, device=input.device)
+ _ep_trace(
+ f'all_to_all forward before input_shape={tuple(input.shape)} output_shape={tuple(output.shape)} '
+ f'input_splits={input_split_sizes} output_splits={output_split_sizes}',
+ group,
+ )
dist.all_to_all_single(
output,
input,
@@ -36,10 +71,16 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
input_split_sizes=input_split_sizes,
group=group,
)
+ _ep_trace('all_to_all forward after', group)
return output
@staticmethod
def backward(ctx, *grad_output):
+ _ep_trace(
+ f'all_to_all backward before grad_shape={tuple(grad_output[0].shape)} '
+ f'input_splits={ctx.output_split_sizes} output_splits={ctx.input_split_sizes}',
+ ctx.group,
+ )
return (
None,
_AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
@@ -201,7 +242,16 @@ def preprocess(
dtype=num_local_tokens_per_expert.dtype,
device=num_local_tokens_per_expert.device,
)
+ _ep_trace(
+ f'preprocess before all_gather local_tokens_shape={tuple(num_local_tokens_per_expert.shape)} '
+ f'input_splits={input_splits}',
+ ep_group,
+ )
dist.all_gather_into_tensor(num_global_tokens_per_expert, num_local_tokens_per_expert, group=ep_group)
+ _ep_trace(
+ f'preprocess after all_gather global_tokens_shape={tuple(num_global_tokens_per_expert.shape)}',
+ ep_group,
+ )
# [ep_size, num_local_experts]
start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts
@@ -237,7 +287,16 @@ def token_pre_all2all(
local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, expert_mask)
local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool())
+ _ep_trace(
+ f'token_pre_all2all before all_to_all local_permuted_shape={tuple(local_permuted_hidden_states.shape)} '
+ f'input_splits={input_splits} output_splits={output_splits}',
+ ep_group,
+ )
global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits)
+ _ep_trace(
+ f'token_pre_all2all after all_to_all global_permuted_shape={tuple(global_permuted_hidden_states.shape)}',
+ ep_group,
+ )
# group tokens together by expert
num_local_experts = num_experts // ep_group.size()
@@ -276,7 +335,16 @@ def tokens_post_all2all(
unpermute_order,
)
+ _ep_trace(
+ f'tokens_post_all2all before all_to_all expert_outputs_shape={tuple(expert_outputs.shape)} '
+ f'input_splits={input_splits} output_splits={output_splits}',
+ ep_group,
+ )
unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits)
+ _ep_trace(
+ f'tokens_post_all2all after all_to_all unpermute_outputs_shape={tuple(unpermute_outputs.shape)}',
+ ep_group,
+ )
weighted_outputs = unpermute_outputs * local_assignment_weights.unsqueeze(-1)
hidden_dim = org_hidden_states_shape[-1]
final_outputs = torch.zeros(org_hidden_states_shape, device=weighted_outputs.device, dtype=weighted_outputs.dtype)
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index a0c3cc78..8d188e86 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -2,6 +2,8 @@
from __future__ import annotations
import inspect
+import os
+import time
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -13,6 +15,27 @@
from twinkle.utils import DeviceMesh
+def _ep_block_trace(block: nn.Module, message: str) -> None:
+ if os.environ.get('TWINKLE_EP_DEBUG', os.environ.get('TWINKLE_FSDP_DEBUG', '0')) != '1':
+ return
+ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
+ world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
+ local_rank = os.environ.get('LOCAL_RANK', '?')
+ block_name = getattr(block, '_ep_debug_name', type(block).__name__)
+ ep_rank = getattr(block, '_ep_rank', '?')
+ ep_world_size = getattr(block, '_ep_world_size', '?')
+ text = (
+ f'[twinkle-ep-block][time={time.time():.6f} rank{rank}/{world_size} local_rank={local_rank} '
+ f'ep_rank={ep_rank}/{ep_world_size} block={block_name}] {message}'
+ )
+ print(text, flush=True)
+ debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
+ if debug_dir:
+ os.makedirs(debug_dir, exist_ok=True)
+ with open(os.path.join(debug_dir, f'ep_block_rank{rank}.log'), 'a', encoding='utf-8') as f:
+ f.write(text + '\n')
+
+
@dataclass
class ExpertParallelConfig:
enabled: bool = True
@@ -64,7 +87,8 @@ def apply_expert_parallel(
ep_rank = ep_mesh.get_local_rank()
specs = []
- for block in find_moe_blocks(model):
+ for block_name, block in find_moe_blocks_with_names(model):
+ block._ep_debug_name = block_name
spec = shard_experts(block, ep_world_size, ep_rank, cfg)
patch_forward(block, ep_group, ep_world_size, cfg)
specs.append(spec)
@@ -84,8 +108,12 @@ def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig:
def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
+ return [block for _, block in find_moe_blocks_with_names(model)]
+
+
+def find_moe_blocks_with_names(model: nn.Module) -> Iterable[tuple[str, nn.Module]]:
blocks = []
- for module in model.modules():
+ for name, module in model.named_modules():
experts = getattr(module, 'experts', None)
if experts is None:
continue
@@ -93,7 +121,7 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]:
continue
if not _get_gate(module):
continue
- blocks.append(module)
+ blocks.append((name, module))
return blocks
@@ -205,6 +233,7 @@ def patch_forward(
_install_ep_forward(block.experts, experts_per_rank)
def forward(hidden_states: torch.Tensor, *args, **kwargs):
+ _ep_block_trace(block, f'forward enter hidden_shape={tuple(hidden_states.shape)}')
if args:
raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.')
@@ -230,20 +259,32 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
# Keep routing weights in activation dtype before unpermute weighting.
if routing_weights.dtype != hidden_states_2d.dtype:
routing_weights = routing_weights.to(hidden_states_2d.dtype)
+ _ep_block_trace(
+ block,
+ f'after router hidden_2d_shape={tuple(hidden_states_2d.shape)} '
+ f'routing_shape={tuple(routing_weights.shape)} selected_shape={tuple(selected_experts.shape)}',
+ )
# Build expert_mask: [num_experts, top_k, num_tokens]
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens]
# 1. preprocess: compute splits and token counts
+ _ep_block_trace(block, f'before preprocess expert_mask_sum={int(expert_mask.sum().item())}')
(
input_splits,
output_splits,
num_global_tokens_per_local_expert,
num_global_sum_tokens_per_local_expert,
) = preprocess(expert_mask, num_experts, ep_group)
+ _ep_block_trace(
+ block,
+ f'after preprocess input_splits={input_splits} output_splits={output_splits} '
+ f'local_expert_tokens={num_global_sum_tokens_per_local_expert.tolist()}',
+ )
# 2. token_pre_all2all: permute → all_to_all → sort_chunks
+ _ep_block_trace(block, 'before token_pre_all2all')
(
global_permuted_hidden_states,
local_input_permutation_mapping,
@@ -259,26 +300,32 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
num_global_tokens_per_local_expert,
ep_group,
)
+ _ep_block_trace(block, f'after token_pre_all2all global_shape={tuple(global_permuted_hidden_states.shape)}')
# 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire.
# For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank)
# → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard
# For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__.
if is_tensor_experts:
+ _ep_block_trace(block, 'before tensor expert compute')
expert_outputs = block.experts(
global_permuted_hidden_states,
num_global_sum_tokens_per_local_expert,
experts_per_rank,
)
+ _ep_block_trace(block, f'after tensor expert compute output_shape={tuple(expert_outputs.shape)}')
else:
+ _ep_block_trace(block, 'before modulelist expert compute')
expert_outputs = _run_local_experts(
block,
global_permuted_hidden_states,
num_global_sum_tokens_per_local_expert,
experts_per_rank,
)
+ _ep_block_trace(block, f'after modulelist expert compute output_shape={tuple(expert_outputs.shape)}')
# 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight)
+ _ep_block_trace(block, 'before tokens_post_all2all')
final_hidden = tokens_post_all2all(
expert_outputs,
local_assignment_weights,
@@ -290,16 +337,20 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
org_hidden_states_shape,
ep_group,
)
+ _ep_block_trace(block, f'after tokens_post_all2all final_2d_shape={tuple(final_hidden.shape)}')
shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg)
if shared_out is not None:
final_hidden = final_hidden + shared_out
+ _ep_block_trace(block, 'after shared expert')
if len(orig_shape) == 3:
final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim)
if cfg.keep_router_logits and returns_router_logits:
+ _ep_block_trace(block, 'forward exit with router logits')
return final_hidden, router_logits
+ _ep_block_trace(block, f'forward exit final_shape={tuple(final_hidden.shape)}')
return final_hidden
block._ep_original_forward = orig_forward
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 40233403..8a8e6e91 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -658,28 +658,42 @@ def backward(self, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
loss_value = optimizer_config.train_status.loss_value
assert loss_value is not None, 'Do forwarding and calculating loss before backward'
+ _twinkle_fsdp_debug(
+ f'backward enter adapter={adapter_name} loss_shape={tuple(loss_value.shape)} '
+ f'loss_dtype={loss_value.dtype} loss_device={loss_value.device}')
scaler = optimizer_config.scaler
if scaler is None and self.mixed_precision == 'fp16':
# Auto set a grad scaler
+ _twinkle_fsdp_debug('backward before set_grad_scaler')
self.set_grad_scaler(adapter_name=adapter_name)
scaler = optimizer_config.scaler
+ _twinkle_fsdp_debug('backward after set_grad_scaler')
optimizer_config.cur_step += 1
should_sync = optimizer_config.do_grad_sync()
+ _twinkle_fsdp_debug(f'backward cur_step={optimizer_config.cur_step} should_sync={should_sync}')
import contextlib
no_sync_ctx = contextlib.nullcontext()
if not should_sync and hasattr(self.model, 'no_sync'):
+ _twinkle_fsdp_debug('backward using model.no_sync')
no_sync_ctx = self.model.no_sync()
+ _twinkle_fsdp_debug('backward before no_sync_ctx')
with no_sync_ctx:
if scaler is not None:
+ _twinkle_fsdp_debug('backward before scaler backward')
scaler.scale(loss_value).backward()
+ _twinkle_fsdp_debug('backward after scaler backward')
else:
+ _twinkle_fsdp_debug('backward before loss.backward')
loss_value.backward()
+ _twinkle_fsdp_debug('backward after loss.backward')
+ _twinkle_fsdp_debug('backward after no_sync_ctx')
# self._sync_after_backward_if_needed()
optimizer_config.train_status.loss_value = None
+ _twinkle_fsdp_debug('backward exit')
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
From e44eb8e1ffa198e36301a7fe3b5daaf32231bc9a Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 21:44:50 +0800
Subject: [PATCH 20/40] feat(moe): add debug tag to all-to-all operations for
improved tracing
Add a `tag` parameter to `_AllToAll` and `debug_tag` to token permutation functions to enhance traceability of all-to-all operations in both forward and backward passes. This enables better debugging by identifying specific all-to-all calls in trace logs.
---
.../model/transformers/moe/ep_utils.py | 52 ++++++++++++++-----
.../model/transformers/moe/expert_parallel.py | 2 +
2 files changed, 41 insertions(+), 13 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py
index 1afd378e..7a36666f 100644
--- a/src/twinkle/model/transformers/moe/ep_utils.py
+++ b/src/twinkle/model/transformers/moe/ep_utils.py
@@ -43,10 +43,11 @@ def _ep_trace(message: str, group: Optional[dist.ProcessGroup] = None) -> None:
class _AllToAll(torch.autograd.Function):
@staticmethod
- def forward(ctx, group, input, output_split_sizes, input_split_sizes):
+ def forward(ctx, group, input, output_split_sizes, input_split_sizes, tag):
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
+ ctx.tag = tag
world_size = dist.get_world_size(group=group)
@@ -60,7 +61,7 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
else:
output = torch.empty(size=(sum(output_split_sizes), input.size(1)), dtype=input.dtype, device=input.device)
_ep_trace(
- f'all_to_all forward before input_shape={tuple(input.shape)} output_shape={tuple(output.shape)} '
+ f'all_to_all forward before tag={tag} input_shape={tuple(input.shape)} output_shape={tuple(output.shape)} '
f'input_splits={input_split_sizes} output_splits={output_split_sizes}',
group,
)
@@ -71,19 +72,26 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
input_split_sizes=input_split_sizes,
group=group,
)
- _ep_trace('all_to_all forward after', group)
+ _ep_trace(f'all_to_all forward after tag={tag}', group)
return output
@staticmethod
def backward(ctx, *grad_output):
_ep_trace(
- f'all_to_all backward before grad_shape={tuple(grad_output[0].shape)} '
+ f'all_to_all backward before tag={ctx.tag} grad_shape={tuple(grad_output[0].shape)} '
f'input_splits={ctx.output_split_sizes} output_splits={ctx.input_split_sizes}',
ctx.group,
)
return (
None,
- _AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
+ _AllToAll.apply(
+ ctx.group,
+ *grad_output,
+ ctx.input_split_sizes,
+ ctx.output_split_sizes,
+ f'{ctx.tag}.backward',
+ ),
+ None,
None,
None,
)
@@ -128,8 +136,8 @@ def backward(ctx, grad_output, grad_async_handle):
)
-def all_to_all(group, input, output_split_size=None, input_split_size=None):
- return _AllToAll.apply(group, input, output_split_size, input_split_size)
+def all_to_all(group, input, output_split_size=None, input_split_size=None, tag: str = ''):
+ return _AllToAll.apply(group, input, output_split_size, input_split_size, tag)
def all_to_all_async(group, input, output_split_size, input_split_size):
@@ -279,6 +287,7 @@ def token_pre_all2all(
output_splits: torch.Tensor,
num_global_tokens_per_local_expert: torch.Tensor,
ep_group: Optional[dist.ProcessGroup] = None,
+ debug_tag: str = '',
) -> torch.Tensor:
hidden_dim = hidden_states.size(-1)
hidden_states = hidden_states.reshape(-1, hidden_dim)
@@ -288,13 +297,21 @@ def token_pre_all2all(
local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool())
_ep_trace(
- f'token_pre_all2all before all_to_all local_permuted_shape={tuple(local_permuted_hidden_states.shape)} '
+ f'token_pre_all2all before all_to_all tag={debug_tag} '
+ f'local_permuted_shape={tuple(local_permuted_hidden_states.shape)} '
f'input_splits={input_splits} output_splits={output_splits}',
ep_group,
)
- global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits)
+ global_permuted_hidden_states = all_to_all(
+ ep_group,
+ local_permuted_hidden_states,
+ output_splits,
+ input_splits,
+ tag=f'{debug_tag}.token_pre_all2all',
+ )
_ep_trace(
- f'token_pre_all2all after all_to_all global_permuted_shape={tuple(global_permuted_hidden_states.shape)}',
+ f'token_pre_all2all after all_to_all tag={debug_tag} '
+ f'global_permuted_shape={tuple(global_permuted_hidden_states.shape)}',
ep_group,
)
@@ -325,6 +342,7 @@ def tokens_post_all2all(
local_input_permutation_mapping: torch.Tensor,
org_hidden_states_shape: torch.Size,
ep_group: Optional[dist.ProcessGroup] = None,
+ debug_tag: str = '',
) -> torch.Tensor:
# group tokens together by expert
num_local_experts = num_experts // ep_group.size()
@@ -336,13 +354,21 @@ def tokens_post_all2all(
)
_ep_trace(
- f'tokens_post_all2all before all_to_all expert_outputs_shape={tuple(expert_outputs.shape)} '
+ f'tokens_post_all2all before all_to_all tag={debug_tag} '
+ f'expert_outputs_shape={tuple(expert_outputs.shape)} '
f'input_splits={input_splits} output_splits={output_splits}',
ep_group,
)
- unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits)
+ unpermute_outputs = all_to_all(
+ ep_group,
+ expert_outputs,
+ input_splits,
+ output_splits,
+ tag=f'{debug_tag}.tokens_post_all2all',
+ )
_ep_trace(
- f'tokens_post_all2all after all_to_all unpermute_outputs_shape={tuple(unpermute_outputs.shape)}',
+ f'tokens_post_all2all after all_to_all tag={debug_tag} '
+ f'unpermute_outputs_shape={tuple(unpermute_outputs.shape)}',
ep_group,
)
weighted_outputs = unpermute_outputs * local_assignment_weights.unsqueeze(-1)
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index 8d188e86..6bdb88fb 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -299,6 +299,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
output_splits,
num_global_tokens_per_local_expert,
ep_group,
+ debug_tag=getattr(block, '_ep_debug_name', block.__class__.__name__),
)
_ep_block_trace(block, f'after token_pre_all2all global_shape={tuple(global_permuted_hidden_states.shape)}')
@@ -336,6 +337,7 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
local_input_permutation_mapping,
org_hidden_states_shape,
ep_group,
+ debug_tag=getattr(block, '_ep_debug_name', block.__class__.__name__),
)
_ep_block_trace(block, f'after tokens_post_all2all final_2d_shape={tuple(final_hidden.shape)}')
From c4a9bd8623bcc403815f1428ec127566cb9b44d2 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Mon, 11 May 2026 22:16:13 +0800
Subject: [PATCH 21/40] fix: preserve autograd graph and fix dtype mismatch in
EP and FSDP
- Return `permuted_tokens` instead of `torch.empty_like` in EP forward to maintain backward all-to-all path
- Add `_apply_gate` support for custom gating in MoE expert computation
- Broadcast source metadata (shapes/dtypes) from rank 0 in FSDP to ensure correct dtype for EP expert tensors
- Validate source metadata consistency to prevent silent dtype/shape mismatches during state dict broadcast
---
.../model/transformers/moe/expert_parallel.py | 17 ++++++++---
.../transformers/strategy/native_fsdp.py | 30 +++++++++++++++++--
2 files changed, 40 insertions(+), 7 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index 6bdb88fb..86e3c603 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -371,7 +371,10 @@ def ep_forward(
experts_per_rank: int,
) -> torch.Tensor:
if permuted_tokens.numel() == 0:
- return torch.empty_like(permuted_tokens)
+ # Preserve the autograd edge to token_pre_all2all. Returning a new
+ # empty tensor can make this rank skip the matching backward
+ # all-to-all, causing EP collective order divergence.
+ return permuted_tokens
input_dtype = permuted_tokens.dtype
@@ -393,8 +396,12 @@ def ep_forward(
compute_dtype = gate_up.dtype
if expert_in.dtype != compute_dtype:
expert_in = expert_in.to(compute_dtype)
- gate, up = F.linear(expert_in, gate_up).chunk(2, dim=-1)
- out = self.act_fn(gate) * up
+ gate_up_out = F.linear(expert_in, gate_up)
+ if hasattr(self, '_apply_gate'):
+ out = self._apply_gate(gate_up_out)
+ else:
+ gate, up = gate_up_out.chunk(2, dim=-1)
+ out = self.act_fn(gate) * up
out = F.linear(out, down)
if out.dtype != input_dtype:
@@ -493,7 +500,9 @@ def _run_local_experts(
that happens in unpermute.
"""
if permuted_tokens.numel() == 0:
- return torch.empty_like(permuted_tokens)
+ # Keep the backward path through token_pre_all2all even when this EP
+ # rank owns no routed tokens for the current block.
+ return permuted_tokens
input_dtype = permuted_tokens.dtype
experts = block.experts
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 80728caa..e7806fe7 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -564,6 +564,16 @@ def _broadcast_sharded_state_dict(
is_rank0 = (dist.get_rank() == 0)
expert_shard_specs = expert_shard_specs or {}
rank_to_ep_rank = rank_to_ep_rank or {}
+ source_metadata = None
+ if is_rank0:
+ source_metadata = {
+ name: (tuple(tensor.shape), tensor.dtype)
+ for name, tensor in full_sd.items()
+ if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ }
+ metadata_holder = [source_metadata]
+ dist.broadcast_object_list(metadata_holder, src=0)
+ source_metadata = metadata_holder[0] or {}
def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
local_tensor = full_tensor
@@ -595,7 +605,10 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
experts_per_rank = spec['experts_per_rank']
num_experts = spec['num_experts']
local_shape = tuple(sharded_param.size())
- local_tensor = torch.empty(local_shape, device=device_type, dtype=sharded_param.dtype)
+ if param_name not in source_metadata:
+ raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.")
+ _, source_dtype = source_metadata[param_name]
+ local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)
_native_fsdp_debug(
f'EP expert scatter start name={param_name} local_shape={local_shape} '
f'num_experts={num_experts} experts_per_rank={experts_per_rank}')
@@ -626,8 +639,10 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
- dtype = sharded_param.dtype
is_ep_expert_param = param_name in expert_shard_specs
+ if param_name not in source_metadata:
+ raise KeyError(f"Missing source metadata for parameter '{param_name}'.")
+ source_shape, source_dtype = source_metadata[param_name]
if is_rank0:
if param_name not in full_sd:
@@ -637,12 +652,21 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
full_tensor = full_param.detach().to(device_type)
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
+ if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype:
+ raise RuntimeError(
+ f"Source metadata mismatch for '{param_name}': "
+ f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, '
+ f'expected shape={source_shape} dtype={source_dtype}.')
else:
- full_tensor = None if is_ep_expert_param else torch.empty(shape, device=device_type, dtype=dtype)
+ full_tensor = None if is_ep_expert_param else torch.empty(source_shape, device=device_type, dtype=source_dtype)
if is_ep_expert_param:
full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
else:
+ if tuple(shape) != tuple(source_shape):
+ raise RuntimeError(
+ f"Parameter '{param_name}' shape mismatch before broadcast: "
+ f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
dist.broadcast(full_tensor, src=0)
torch_util.synchronize()
From dce440ab2656eeb05a84f90f6c941a7219728378 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 00:27:56 +0800
Subject: [PATCH 22/40] wi p
---
src/twinkle/model/transformers/transformers.py | 15 ---------------
1 file changed, 15 deletions(-)
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 8a8e6e91..010d18cb 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -362,47 +362,32 @@ def _not_encoded(inputs):
def _lazy_wrap_model(self):
if not self._model_wrapped:
- _twinkle_fsdp_debug('enter _lazy_wrap_model')
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
- _twinkle_fsdp_debug(f'_lazy_wrap_model optimizer_groups={len(optimizer_groups)}')
use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None)
if self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None:
is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
- _twinkle_fsdp_debug('before capture pre-EP full state_dict for rank0 broadcast')
set_pre_ep_state(_clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {})
- _twinkle_fsdp_debug('after capture pre-EP full state_dict for rank0 broadcast')
- _twinkle_fsdp_debug('before _maybe_apply_expert_parallel')
self._maybe_apply_expert_parallel()
- _twinkle_fsdp_debug('after _maybe_apply_expert_parallel')
- _twinkle_fsdp_debug('before _ensure_sp_strategy')
self._ensure_sp_strategy()
- _twinkle_fsdp_debug('after _ensure_sp_strategy')
if self.sp_strategy is not None:
- _twinkle_fsdp_debug('before sp_strategy.initialize')
self.sp_strategy.initialize()
- _twinkle_fsdp_debug('after sp_strategy.initialize')
if len(optimizer_groups) == 1:
optimizer_group = optimizer_groups[0]
optimizer = optimizer_group.optimizer
assert optimizer is not None
- _twinkle_fsdp_debug('before strategy.wrap_model with optimizer')
self.model, optimizer = self.strategy.wrap_model(self.model, optimizer)
- _twinkle_fsdp_debug('after strategy.wrap_model with optimizer')
optimizer_group.optimizer = optimizer
self.register_mm_forward_hook(optimizer_group)
else:
# maybe forward_only, no optimizer_group available
- _twinkle_fsdp_debug('before strategy.wrap_model without optimizer')
result = self.strategy.wrap_model(self.model)
- _twinkle_fsdp_debug('after strategy.wrap_model without optimizer')
if isinstance(result, tuple):
self.model = result[0]
else:
self.model = result
self._model_wrapped = True
- _twinkle_fsdp_debug('exit _lazy_wrap_model')
def register_mm_forward_hook(self, optimizer_group: OptimizerGroup):
model = self.strategy.unwrap_model(self.model)
From ee339b3ec4cbf70d0ad813ca619306a388d12660 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 09:22:48 +0800
Subject: [PATCH 23/40] refactor(moe): remove debug tracing from EP utilities
Remove the `_ep_trace` debug logging function and all its calls from expert parallelism utilities. This cleanup eliminates verbose debug output controlled by `TWINKLE_EP_DEBUG` environment variable, simplifying the code and removing unnecessary I/O overhead in production. Also removes the `tag` parameter from `_AllToAll` and `all_to_all` functions that was only used for debug tracing.
---
.../model/transformers/moe/ep_utils.py | 106 +-----------------
.../model/transformers/moe/expert_parallel.py | 52 +--------
.../transformers/strategy/native_fsdp.py | 31 +----
3 files changed, 11 insertions(+), 178 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/ep_utils.py b/src/twinkle/model/transformers/moe/ep_utils.py
index 7a36666f..f64c1796 100644
--- a/src/twinkle/model/transformers/moe/ep_utils.py
+++ b/src/twinkle/model/transformers/moe/ep_utils.py
@@ -6,48 +6,17 @@
import torch
import torch.distributed as dist
-import os
-import time
from typing import Optional
-def _ep_trace(message: str, group: Optional[dist.ProcessGroup] = None) -> None:
- if os.environ.get('TWINKLE_EP_DEBUG', os.environ.get('TWINKLE_FSDP_DEBUG', '0')) != '1':
- return
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
- world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
- if group is not None:
- try:
- ep_rank = dist.get_rank(group)
- ep_world_size = dist.get_world_size(group)
- except Exception:
- ep_rank = '?'
- ep_world_size = '?'
- else:
- ep_rank = '?'
- ep_world_size = '?'
- local_rank = os.environ.get('LOCAL_RANK', '?')
- text = (
- f'[twinkle-ep-trace][time={time.time():.6f} rank{rank}/{world_size} '
- f'local_rank={local_rank} ep_rank={ep_rank}/{ep_world_size}] {message}'
- )
- print(text, flush=True)
- debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
- if debug_dir:
- os.makedirs(debug_dir, exist_ok=True)
- with open(os.path.join(debug_dir, f'ep_rank{rank}.log'), 'a', encoding='utf-8') as f:
- f.write(text + '\n')
-
-
# ========================== comm ==========================
class _AllToAll(torch.autograd.Function):
@staticmethod
- def forward(ctx, group, input, output_split_sizes, input_split_sizes, tag):
+ def forward(ctx, group, input, output_split_sizes, input_split_sizes):
ctx.group = group
ctx.output_split_sizes = output_split_sizes
ctx.input_split_sizes = input_split_sizes
- ctx.tag = tag
world_size = dist.get_world_size(group=group)
@@ -60,11 +29,6 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes, tag):
output = torch.empty_like(input)
else:
output = torch.empty(size=(sum(output_split_sizes), input.size(1)), dtype=input.dtype, device=input.device)
- _ep_trace(
- f'all_to_all forward before tag={tag} input_shape={tuple(input.shape)} output_shape={tuple(output.shape)} '
- f'input_splits={input_split_sizes} output_splits={output_split_sizes}',
- group,
- )
dist.all_to_all_single(
output,
input,
@@ -72,26 +36,13 @@ def forward(ctx, group, input, output_split_sizes, input_split_sizes, tag):
input_split_sizes=input_split_sizes,
group=group,
)
- _ep_trace(f'all_to_all forward after tag={tag}', group)
return output
@staticmethod
def backward(ctx, *grad_output):
- _ep_trace(
- f'all_to_all backward before tag={ctx.tag} grad_shape={tuple(grad_output[0].shape)} '
- f'input_splits={ctx.output_split_sizes} output_splits={ctx.input_split_sizes}',
- ctx.group,
- )
return (
None,
- _AllToAll.apply(
- ctx.group,
- *grad_output,
- ctx.input_split_sizes,
- ctx.output_split_sizes,
- f'{ctx.tag}.backward',
- ),
- None,
+ _AllToAll.apply(ctx.group, *grad_output, ctx.input_split_sizes, ctx.output_split_sizes),
None,
None,
)
@@ -136,8 +87,8 @@ def backward(ctx, grad_output, grad_async_handle):
)
-def all_to_all(group, input, output_split_size=None, input_split_size=None, tag: str = ''):
- return _AllToAll.apply(group, input, output_split_size, input_split_size, tag)
+def all_to_all(group, input, output_split_size=None, input_split_size=None):
+ return _AllToAll.apply(group, input, output_split_size, input_split_size)
def all_to_all_async(group, input, output_split_size, input_split_size):
@@ -250,16 +201,7 @@ def preprocess(
dtype=num_local_tokens_per_expert.dtype,
device=num_local_tokens_per_expert.device,
)
- _ep_trace(
- f'preprocess before all_gather local_tokens_shape={tuple(num_local_tokens_per_expert.shape)} '
- f'input_splits={input_splits}',
- ep_group,
- )
dist.all_gather_into_tensor(num_global_tokens_per_expert, num_local_tokens_per_expert, group=ep_group)
- _ep_trace(
- f'preprocess after all_gather global_tokens_shape={tuple(num_global_tokens_per_expert.shape)}',
- ep_group,
- )
# [ep_size, num_local_experts]
start_idx, end_idx = rank * num_local_experts, (rank + 1) * num_local_experts
@@ -287,7 +229,6 @@ def token_pre_all2all(
output_splits: torch.Tensor,
num_global_tokens_per_local_expert: torch.Tensor,
ep_group: Optional[dist.ProcessGroup] = None,
- debug_tag: str = '',
) -> torch.Tensor:
hidden_dim = hidden_states.size(-1)
hidden_states = hidden_states.reshape(-1, hidden_dim)
@@ -296,24 +237,7 @@ def token_pre_all2all(
local_permuted_hidden_states, local_input_permutation_mapping = permute(hidden_states, expert_mask)
local_assignment_weights = routing_weights.T.contiguous().masked_select(expert_mask.bool())
- _ep_trace(
- f'token_pre_all2all before all_to_all tag={debug_tag} '
- f'local_permuted_shape={tuple(local_permuted_hidden_states.shape)} '
- f'input_splits={input_splits} output_splits={output_splits}',
- ep_group,
- )
- global_permuted_hidden_states = all_to_all(
- ep_group,
- local_permuted_hidden_states,
- output_splits,
- input_splits,
- tag=f'{debug_tag}.token_pre_all2all',
- )
- _ep_trace(
- f'token_pre_all2all after all_to_all tag={debug_tag} '
- f'global_permuted_shape={tuple(global_permuted_hidden_states.shape)}',
- ep_group,
- )
+ global_permuted_hidden_states = all_to_all(ep_group, local_permuted_hidden_states, output_splits, input_splits)
# group tokens together by expert
num_local_experts = num_experts // ep_group.size()
@@ -342,7 +266,6 @@ def tokens_post_all2all(
local_input_permutation_mapping: torch.Tensor,
org_hidden_states_shape: torch.Size,
ep_group: Optional[dist.ProcessGroup] = None,
- debug_tag: str = '',
) -> torch.Tensor:
# group tokens together by expert
num_local_experts = num_experts // ep_group.size()
@@ -353,24 +276,7 @@ def tokens_post_all2all(
unpermute_order,
)
- _ep_trace(
- f'tokens_post_all2all before all_to_all tag={debug_tag} '
- f'expert_outputs_shape={tuple(expert_outputs.shape)} '
- f'input_splits={input_splits} output_splits={output_splits}',
- ep_group,
- )
- unpermute_outputs = all_to_all(
- ep_group,
- expert_outputs,
- input_splits,
- output_splits,
- tag=f'{debug_tag}.tokens_post_all2all',
- )
- _ep_trace(
- f'tokens_post_all2all after all_to_all tag={debug_tag} '
- f'unpermute_outputs_shape={tuple(unpermute_outputs.shape)}',
- ep_group,
- )
+ unpermute_outputs = all_to_all(ep_group, expert_outputs, input_splits, output_splits)
weighted_outputs = unpermute_outputs * local_assignment_weights.unsqueeze(-1)
hidden_dim = org_hidden_states_shape[-1]
final_outputs = torch.zeros(org_hidden_states_shape, device=weighted_outputs.device, dtype=weighted_outputs.dtype)
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index 86e3c603..29158bfa 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -2,8 +2,6 @@
from __future__ import annotations
import inspect
-import os
-import time
import torch
import torch.distributed as dist
import torch.nn.functional as F
@@ -15,27 +13,6 @@
from twinkle.utils import DeviceMesh
-def _ep_block_trace(block: nn.Module, message: str) -> None:
- if os.environ.get('TWINKLE_EP_DEBUG', os.environ.get('TWINKLE_FSDP_DEBUG', '0')) != '1':
- return
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
- world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
- local_rank = os.environ.get('LOCAL_RANK', '?')
- block_name = getattr(block, '_ep_debug_name', type(block).__name__)
- ep_rank = getattr(block, '_ep_rank', '?')
- ep_world_size = getattr(block, '_ep_world_size', '?')
- text = (
- f'[twinkle-ep-block][time={time.time():.6f} rank{rank}/{world_size} local_rank={local_rank} '
- f'ep_rank={ep_rank}/{ep_world_size} block={block_name}] {message}'
- )
- print(text, flush=True)
- debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
- if debug_dir:
- os.makedirs(debug_dir, exist_ok=True)
- with open(os.path.join(debug_dir, f'ep_block_rank{rank}.log'), 'a', encoding='utf-8') as f:
- f.write(text + '\n')
-
-
@dataclass
class ExpertParallelConfig:
enabled: bool = True
@@ -87,8 +64,7 @@ def apply_expert_parallel(
ep_rank = ep_mesh.get_local_rank()
specs = []
- for block_name, block in find_moe_blocks_with_names(model):
- block._ep_debug_name = block_name
+ for _, block in find_moe_blocks_with_names(model):
spec = shard_experts(block, ep_world_size, ep_rank, cfg)
patch_forward(block, ep_group, ep_world_size, cfg)
specs.append(spec)
@@ -233,7 +209,6 @@ def patch_forward(
_install_ep_forward(block.experts, experts_per_rank)
def forward(hidden_states: torch.Tensor, *args, **kwargs):
- _ep_block_trace(block, f'forward enter hidden_shape={tuple(hidden_states.shape)}')
if args:
raise RuntimeError('Expert parallel patch only supports keyword-only extra args for MoE blocks.')
@@ -259,32 +234,19 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
# Keep routing weights in activation dtype before unpermute weighting.
if routing_weights.dtype != hidden_states_2d.dtype:
routing_weights = routing_weights.to(hidden_states_2d.dtype)
- _ep_block_trace(
- block,
- f'after router hidden_2d_shape={tuple(hidden_states_2d.shape)} '
- f'routing_shape={tuple(routing_weights.shape)} selected_shape={tuple(selected_experts.shape)}',
- )
-
# Build expert_mask: [num_experts, top_k, num_tokens]
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=num_experts).permute(2, 1, 0) # [num_experts, top_k, num_tokens]
# 1. preprocess: compute splits and token counts
- _ep_block_trace(block, f'before preprocess expert_mask_sum={int(expert_mask.sum().item())}')
(
input_splits,
output_splits,
num_global_tokens_per_local_expert,
num_global_sum_tokens_per_local_expert,
) = preprocess(expert_mask, num_experts, ep_group)
- _ep_block_trace(
- block,
- f'after preprocess input_splits={input_splits} output_splits={output_splits} '
- f'local_expert_tokens={num_global_sum_tokens_per_local_expert.tolist()}',
- )
# 2. token_pre_all2all: permute → all_to_all → sort_chunks
- _ep_block_trace(block, 'before token_pre_all2all')
(
global_permuted_hidden_states,
local_input_permutation_mapping,
@@ -299,34 +261,27 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
output_splits,
num_global_tokens_per_local_expert,
ep_group,
- debug_tag=getattr(block, '_ep_debug_name', block.__class__.__name__),
)
- _ep_block_trace(block, f'after token_pre_all2all global_shape={tuple(global_permuted_hidden_states.shape)}')
# 3. expert_compute: call experts via nn.Module.__call__ so FSDP2 hooks fire.
# For tensor experts: block.experts(permuted_tokens, counts, experts_per_rank)
# → FSDP2 pre-forward unshard → ep_forward → FSDP2 post-forward reshard
# For ModuleList experts: _run_local_experts calls each expert[i](...) via __call__.
if is_tensor_experts:
- _ep_block_trace(block, 'before tensor expert compute')
expert_outputs = block.experts(
global_permuted_hidden_states,
num_global_sum_tokens_per_local_expert,
experts_per_rank,
)
- _ep_block_trace(block, f'after tensor expert compute output_shape={tuple(expert_outputs.shape)}')
else:
- _ep_block_trace(block, 'before modulelist expert compute')
expert_outputs = _run_local_experts(
block,
global_permuted_hidden_states,
num_global_sum_tokens_per_local_expert,
experts_per_rank,
)
- _ep_block_trace(block, f'after modulelist expert compute output_shape={tuple(expert_outputs.shape)}')
# 4. tokens_post_all2all: sort_chunks → all_to_all → unpermute (with routing weight)
- _ep_block_trace(block, 'before tokens_post_all2all')
final_hidden = tokens_post_all2all(
expert_outputs,
local_assignment_weights,
@@ -337,22 +292,17 @@ def forward(hidden_states: torch.Tensor, *args, **kwargs):
local_input_permutation_mapping,
org_hidden_states_shape,
ep_group,
- debug_tag=getattr(block, '_ep_debug_name', block.__class__.__name__),
)
- _ep_block_trace(block, f'after tokens_post_all2all final_2d_shape={tuple(final_hidden.shape)}')
shared_out = _maybe_run_shared_expert(block, hidden_states_2d, cfg)
if shared_out is not None:
final_hidden = final_hidden + shared_out
- _ep_block_trace(block, 'after shared expert')
if len(orig_shape) == 3:
final_hidden = final_hidden.view(batch_size, seq_len, hidden_dim)
if cfg.keep_router_logits and returns_router_logits:
- _ep_block_trace(block, 'forward exit with router logits')
return final_hidden, router_logits
- _ep_block_trace(block, f'forward exit final_shape={tuple(final_hidden.shape)}')
return final_hidden
block._ep_original_forward = orig_forward
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index e7806fe7..d6b6242a 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -1,8 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch
import torch.distributed as dist
-import os
-import time
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
@@ -15,21 +13,6 @@
from torch.distributed.fsdp import MixedPrecisionPolicy
-def _native_fsdp_debug(message: str) -> None:
- if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
- return
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
- world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
- local_rank = os.environ.get('LOCAL_RANK', '?')
- text = f'[twinkle-native-fsdp-debug][time={time.time():.6f} rank{rank}/{world_size} local_rank={local_rank}] {message}'
- print(text, flush=True)
- debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
- if debug_dir:
- os.makedirs(debug_dir, exist_ok=True)
- with open(os.path.join(debug_dir, f'native_fsdp_rank{rank}.log'), 'a', encoding='utf-8') as f:
- f.write(text + '\n')
-
-
class NativeFSDPStrategy:
def __init__(self,
@@ -94,11 +77,8 @@ def wrap_model(self, model, optimizer=None):
if use_meta:
is_rank0 = (dist.get_rank() == 0)
if ep_enabled and self._rank0_pre_ep_full_state_dict is not None:
- _native_fsdp_debug(
- f'use captured pre-EP full state_dict keys={len(self._rank0_pre_ep_full_state_dict)}')
original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {}
else:
- _native_fsdp_debug('use current model state_dict as rank0 broadcast source')
original_sd = model.state_dict() if is_rank0 else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
if is_rank0:
@@ -585,6 +565,10 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
mesh_dim,
src_data_rank=None,
)
+ # _shard_tensor may return a view into the replicated full
+ # tensor. Clone it so the final DTensor shard does not keep
+ # the full parameter storage alive after loading.
+ local_tensor = local_tensor.contiguous().clone()
elif isinstance(placement, Replicate):
continue
elif isinstance(placement, Partial):
@@ -609,15 +593,9 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
raise KeyError(f"Missing source metadata for EP expert parameter '{param_name}'.")
_, source_dtype = source_metadata[param_name]
local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)
- _native_fsdp_debug(
- f'EP expert scatter start name={param_name} local_shape={local_shape} '
- f'num_experts={num_experts} experts_per_rank={experts_per_rank}')
scatter_list = None
if is_rank0:
- _native_fsdp_debug(
- f'EP expert scatter source name={param_name} source_shape={tuple(full_tensor.shape)} '
- f'source_dtype={full_tensor.dtype}')
if full_tensor.size(0) != num_experts:
raise RuntimeError(
f"EP expert parameter '{param_name}' expects {num_experts} experts, "
@@ -634,7 +612,6 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
scatter_list.append(full_tensor[start:end].contiguous())
dist.scatter(local_tensor, scatter_list=scatter_list, src=0)
- _native_fsdp_debug(f'EP expert scatter done name={param_name}')
return local_tensor
for param_name, sharded_param in meta_sharded_sd.items():
From b0ef503de1e2a0adf9d6e2ca62048e81b1ac2767 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 11:49:56 +0800
Subject: [PATCH 24/40] wip
---
.../model/transformers/strategy/native_fsdp.py | 11 +++++++----
1 file changed, 7 insertions(+), 4 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index d6b6242a..6977fc21 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -594,14 +594,12 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
_, source_dtype = source_metadata[param_name]
local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)
- scatter_list = None
if is_rank0:
if full_tensor.size(0) != num_experts:
raise RuntimeError(
f"EP expert parameter '{param_name}' expects {num_experts} experts, "
f'but source state has shape {tuple(full_tensor.shape)}. '
'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
- scatter_list = []
world_size = dist.get_world_size()
for rank in range(world_size):
if rank not in rank_to_ep_rank:
@@ -609,9 +607,14 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
ep_rank = rank_to_ep_rank[rank]
start = ep_rank * experts_per_rank
end = start + experts_per_rank
- scatter_list.append(full_tensor[start:end].contiguous())
+ chunk = full_tensor[start:end].contiguous()
+ if rank == 0:
+ local_tensor.copy_(chunk)
+ else:
+ dist.send(chunk, dst=rank)
+ else:
+ dist.recv(local_tensor, src=0)
- dist.scatter(local_tensor, scatter_list=scatter_list, src=0)
return local_tensor
for param_name, sharded_param in meta_sharded_sd.items():
From fcfea943a045583bc6ec81274b657918e3cd8dc8 Mon Sep 17 00:00:00 2001
From: meichangsu1 <1484603386@qq.com>
Date: Tue, 12 May 2026 14:07:04 +0800
Subject: [PATCH 25/40] fix: optimize memory usage in FSDP expert parameter
broadcasting
Move expert parameter tensors to GPU lazily during chunk scattering instead of loading all at once, preventing OOM when handling large expert parameters.
---
src/twinkle/model/transformers/strategy/native_fsdp.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 6977fc21..bffc43a6 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -608,10 +608,11 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
start = ep_rank * experts_per_rank
end = start + experts_per_rank
chunk = full_tensor[start:end].contiguous()
+ chunk_gpu = chunk.to(device_type)
if rank == 0:
- local_tensor.copy_(chunk)
+ local_tensor.copy_(chunk_gpu)
else:
- dist.send(chunk, dst=rank)
+ dist.send(chunk_gpu, dst=rank)
else:
dist.recv(local_tensor, src=0)
@@ -629,9 +630,12 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
raise KeyError(
f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.")
full_param = full_sd[param_name]
- full_tensor = full_param.detach().to(device_type)
+ full_tensor = full_param.detach()
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
+ # EP expert params: keep on CPU to avoid OOM; move chunks lazily in _scatter_ep_expert_tensor.
+ if not is_ep_expert_param:
+ full_tensor = full_tensor.to(device_type)
if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype:
raise RuntimeError(
f"Source metadata mismatch for '{param_name}': "
From b616a6378f88491f435276bf1d43f7c3267b2215 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Tue, 12 May 2026 16:06:43 +0800
Subject: [PATCH 26/40] fix native fsdp
---
.../transformers/strategy/native_fsdp.py | 58 +++++++++++++------
1 file changed, 41 insertions(+), 17 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index bffc43a6..ef893b6a 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -4,7 +4,7 @@
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
-from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context
@@ -340,16 +340,43 @@ def _build_fsdp_mesh(device_mesh: DeviceMesh) -> Optional[TorchDeviceMesh]:
return TorchDeviceMesh(device_mesh.device_type, flat_mesh, mesh_dim_names=('fsdp', ))
-def _get_decoder_layers(model: nn.Module) -> Optional[nn.ModuleList]:
+def _get_decoder_layers(model: nn.Module) -> Optional[List[nn.Module]]:
+ no_split_modules = _get_no_split_module_names(model)
+ if no_split_modules:
+ layers = [
+ module for module in model.modules()
+ if module is not model and module.__class__.__name__ in no_split_modules
+ ]
+ if layers:
+ return layers
+
inner_model = getattr(model, 'model', None)
if inner_model is not None:
inner_layers = getattr(inner_model, 'layers', None)
if isinstance(inner_layers, nn.ModuleList):
- return inner_layers
+ return list(inner_layers)
return None
+def _get_no_split_module_names(model: nn.Module) -> Set[str]:
+ names = _normalize_no_split_modules(getattr(model, '_no_split_modules', None))
+ if names:
+ return names
+
+ for module in model.modules():
+ names.update(_normalize_no_split_modules(getattr(module, '_no_split_modules', None)))
+ return names
+
+
+def _normalize_no_split_modules(value) -> Set[str]:
+ if value is None:
+ return set()
+ if isinstance(value, str):
+ return {value}
+ return set(value)
+
+
def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
ignored: Set[nn.Parameter] = set()
ep_patched = False
@@ -548,8 +575,7 @@ def _broadcast_sharded_state_dict(
if is_rank0:
source_metadata = {
name: (tuple(tensor.shape), tensor.dtype)
- for name, tensor in full_sd.items()
- if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
}
metadata_holder = [source_metadata]
dist.broadcast_object_list(metadata_holder, src=0)
@@ -596,10 +622,9 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
if is_rank0:
if full_tensor.size(0) != num_experts:
- raise RuntimeError(
- f"EP expert parameter '{param_name}' expects {num_experts} experts, "
- f'but source state has shape {tuple(full_tensor.shape)}. '
- 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
+ raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, "
+ f'but source state has shape {tuple(full_tensor.shape)}. '
+ 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
world_size = dist.get_world_size()
for rank in range(world_size):
if rank not in rank_to_ep_rank:
@@ -637,20 +662,19 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
if not is_ep_expert_param:
full_tensor = full_tensor.to(device_type)
if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype:
- raise RuntimeError(
- f"Source metadata mismatch for '{param_name}': "
- f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, '
- f'expected shape={source_shape} dtype={source_dtype}.')
+ raise RuntimeError(f"Source metadata mismatch for '{param_name}': "
+ f'actual shape={tuple(full_tensor.shape)} dtype={full_tensor.dtype}, '
+ f'expected shape={source_shape} dtype={source_dtype}.')
else:
- full_tensor = None if is_ep_expert_param else torch.empty(source_shape, device=device_type, dtype=source_dtype)
+ full_tensor = None if is_ep_expert_param else torch.empty(
+ source_shape, device=device_type, dtype=source_dtype)
if is_ep_expert_param:
full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
else:
if tuple(shape) != tuple(source_shape):
- raise RuntimeError(
- f"Parameter '{param_name}' shape mismatch before broadcast: "
- f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
+ raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
+ f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
dist.broadcast(full_tensor, src=0)
torch_util.synchronize()
From 9494caa6d4d3a1845fb8bc5dd982b449cf5a9fa7 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Fri, 15 May 2026 14:27:12 +0800
Subject: [PATCH 27/40] fix(ep): fall back to plural shared_experts for dsv4
compatibility
---
.../transformers/strategy/native_fsdp.py | 4 ++
.../ep_lora/test_shared_experts_fallback.py | 44 +++++++++++++++++++
2 files changed, 48 insertions(+)
create mode 100644 tests/scripts/ep_lora/test_shared_experts_fallback.py
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index ef893b6a..40864d2d 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -393,6 +393,8 @@ def _collect_expert_params(model: nn.Module) -> Optional[Set[nn.Parameter]]:
if getattr(module, '_ep_ignore_shared_experts', False) and getattr(module, '_ep_patched', False):
ep_patched = True
shared = getattr(module, 'shared_expert', None)
+ if shared is None:
+ shared = getattr(module, 'shared_experts', None)
if shared is not None:
ignored.update(shared.parameters())
@@ -495,6 +497,8 @@ def _place_ep_experts_on_local_device(model: nn.Module, ep_fsdp_device_mesh: Opt
_move_module_to_device_or_empty(experts, local_device)
if getattr(module, '_ep_ignore_shared_experts', False):
shared = getattr(module, 'shared_expert', None)
+ if shared is None:
+ shared = getattr(module, 'shared_experts', None)
if shared is not None:
_move_module_to_device_or_empty(shared, local_device)
diff --git a/tests/scripts/ep_lora/test_shared_experts_fallback.py b/tests/scripts/ep_lora/test_shared_experts_fallback.py
new file mode 100644
index 00000000..350976bf
--- /dev/null
+++ b/tests/scripts/ep_lora/test_shared_experts_fallback.py
@@ -0,0 +1,44 @@
+"""Lightweight unit tests for dsv4 shared_experts fallback."""
+import torch
+import torch.nn as nn
+
+from twinkle.model.transformers.strategy.native_fsdp import _collect_expert_params
+
+
+def _make_block(use_plural_shared: bool, ignore_shared: bool) -> nn.Module:
+ block = nn.Module()
+ experts = nn.Module()
+ experts.gate_up_proj = nn.Parameter(torch.randn(2, 4, 8))
+ experts.down_proj = nn.Parameter(torch.randn(2, 4, 4))
+ block.experts = experts
+ shared = nn.Linear(4, 4)
+ if use_plural_shared:
+ block.shared_experts = shared
+ else:
+ block.shared_expert = shared
+ block._ep_patched = True
+ block._ep_ignore_shared_experts = ignore_shared
+ parent = nn.Module()
+ parent.block = block
+ return parent
+
+
+def test_singular_shared_expert_collected():
+ parent = _make_block(use_plural_shared=False, ignore_shared=True)
+ ignored = _collect_expert_params(parent)
+ expected_count = 2 + 2
+ assert ignored is not None and len(ignored) == expected_count
+
+
+def test_plural_shared_experts_collected():
+ parent = _make_block(use_plural_shared=True, ignore_shared=True)
+ ignored = _collect_expert_params(parent)
+ expected_count = 2 + 2
+ assert ignored is not None and len(ignored) == expected_count, (
+ '_collect_expert_params should fall back to shared_experts (plural) for dsv4')
+
+
+def test_no_ignore_shared_only_collects_experts():
+ parent = _make_block(use_plural_shared=True, ignore_shared=False)
+ ignored = _collect_expert_params(parent)
+ assert ignored is not None and len(ignored) == 2
From c71a238f8e0087ac6e8b2f0c65db5db7444edaf9 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Fri, 15 May 2026 14:31:33 +0800
Subject: [PATCH 28/40] test(ep-lora): add P0 spike for ParamWrapper + FSDP2
DTensor compatibility
---
.../ep_lora/spike_register_parametrization.py | 99 +++++++++++++++++++
1 file changed, 99 insertions(+)
create mode 100644 tests/scripts/ep_lora/spike_register_parametrization.py
diff --git a/tests/scripts/ep_lora/spike_register_parametrization.py b/tests/scripts/ep_lora/spike_register_parametrization.py
new file mode 100644
index 00000000..41042f4a
--- /dev/null
+++ b/tests/scripts/ep_lora/spike_register_parametrization.py
@@ -0,0 +1,99 @@
+"""P0 spike: verify PEFT ParamWrapper + FSDP2 DTensor compatibility on Qwen3.5-MoE.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 tests/scripts/ep_lora/spike_register_parametrization.py
+"""
+import os
+
+import torch
+import torch.distributed as dist
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_logger
+from twinkle.model import TransformersModel
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+
+
+def main():
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = 2
+ config.hidden_size = 128
+ config.intermediate_size = 256
+ config.moe_intermediate_size = 64
+ config.num_experts = 4
+ config.num_experts_per_tok = 2
+ config.use_cache = False
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+
+ lora_cfg = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ model.add_adapter_to_model('default', lora_cfg)
+ model.set_optimizer('AdamW', lr=1e-4, foreach=False)
+
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ torch.manual_seed(42 + rank)
+ batch = {
+ 'input_ids': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'attention_mask': torch.ones(2, 16, dtype=torch.long, device=Platform.get_local_device()),
+ }
+
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
+ logger.info(f'spike loss (rank {rank}): {loss}')
+ assert torch.is_tensor(loss) or isinstance(loss, float), 'loss should be a scalar'
+ loss_val = float(loss)
+ assert torch.isfinite(torch.tensor(loss_val)), f'loss not finite: {loss_val}'
+
+ unwrapped = model.strategy.unwrap_model(model.model)
+ lora_a_seen = 0
+ lora_b_seen = 0
+ base_grads = []
+ for name, param in unwrapped.named_parameters():
+ if 'experts' not in name:
+ continue
+ if 'lora_A' in name and param.grad is not None:
+ lora_a_seen += 1
+ assert param.grad.abs().sum().item() > 0, f'{name} grad is zero'
+ if 'lora_B' in name and param.grad is not None:
+ lora_b_seen += 1
+ if 'base_layer.gate_up_proj' in name or 'base_layer.down_proj' in name:
+ base_grads.append((name, param.grad))
+
+ logger.info(f'lora_A grads seen: {lora_a_seen}, lora_B grads seen: {lora_b_seen}')
+ assert lora_a_seen > 0, 'no lora_A grads observed under experts subtree'
+ for name, grad in base_grads:
+ assert grad is None, f'{name} should be frozen but has grad'
+
+ if rank == 0:
+ logger.info('SPIKE PASSED: PEFT ParamWrapper works with FSDP2 DTensor for routing experts.')
+
+
+if __name__ == '__main__':
+ main()
From 05e95ec7a7bb2124cd99c810196267c5c85b482b Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Fri, 15 May 2026 14:34:24 +0800
Subject: [PATCH 29/40] feat(ep-lora): trigger EP slicing before PEFT adapter
patching
---
.../model/transformers/transformers.py | 67 +++++++++++++++++--
.../ep_lora/test_patch_adapter_assertions.py | 39 +++++++++++
.../ep_lora/test_patch_adapter_ep_ordering.py | 61 +++++++++++++++++
3 files changed, 162 insertions(+), 5 deletions(-)
create mode 100644 tests/scripts/ep_lora/test_patch_adapter_assertions.py
create mode 100644 tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 010d18cb..bfe12ade 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -360,14 +360,22 @@ def _not_encoded(inputs):
assert isinstance(inputs, dict)
return 'input_ids' not in inputs and 'input_embedding' not in inputs
+ def _capture_rank0_pre_ep_state_if_needed(self):
+ """Capture rank0 pre-EP full state_dict for memory_efficient_init broadcast."""
+ if getattr(self, '_pre_ep_state_captured', False):
+ return
+ use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
+ set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None)
+ if not (self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None):
+ return
+ is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
+ set_pre_ep_state(_clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {})
+ self._pre_ep_state_captured = True
+
def _lazy_wrap_model(self):
if not self._model_wrapped:
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
- use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
- set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None)
- if self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None:
- is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
- set_pre_ep_state(_clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {})
+ self._capture_rank0_pre_ep_state_if_needed()
self._maybe_apply_expert_parallel()
self._ensure_sp_strategy()
if self.sp_strategy is not None:
@@ -1293,9 +1301,58 @@ def calculate_metric(self, is_training, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
return optimizer_config.calculate_metrics(is_training)
+ @staticmethod
+ def _validate_ep_lora_config(self, lora_config) -> None:
+ from peft import LoraConfig
+
+ if not getattr(self, '_enable_expert_parallel', False):
+ return
+ if not isinstance(self.strategy, NativeFSDPStrategy):
+ raise RuntimeError(
+ 'EP + LoRA requires strategy=native_fsdp; '
+ f'got {type(self.strategy).__name__}.')
+ if not isinstance(lora_config, LoraConfig):
+ return
+ target_params = getattr(lora_config, 'target_parameters', None) or []
+ if target_params:
+ if getattr(lora_config, 'use_dora', False):
+ raise ValueError(
+ 'PEFT ParamWrapper does not support use_dora=True with target_parameters; '
+ 'disable DoRA when training expert parameters.')
+ if getattr(lora_config, 'lora_bias', False):
+ raise ValueError(
+ 'PEFT ParamWrapper does not support lora_bias=True with target_parameters.')
+ if float(getattr(lora_config, 'lora_dropout', 0.0)) > 0.0:
+ raise ValueError(
+ 'PEFT ParamWrapper does not support lora_dropout>0 with target_parameters.')
+
+ @staticmethod
+ def _maybe_autofill_target_parameters(lora_config, enable_ep: bool):
+ from peft import LoraConfig
+
+ if not enable_ep or not isinstance(lora_config, LoraConfig):
+ return lora_config
+ target_params = getattr(lora_config, 'target_parameters', None) or []
+ if not target_params:
+ lora_config.target_parameters = ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']
+ logger.info(
+ "EP+LoRA auto-filled target_parameters with "
+ "['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'].")
+ return lora_config
+
def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):
assert adapter_name, 'Use a different adapter_name, current is empty.'
unwrapped_model = self.strategy.unwrap_model(self.model)
+ if not isinstance(config_or_dir, str):
+ self._validate_ep_lora_config(self, config_or_dir)
+ config_or_dir = self._maybe_autofill_target_parameters(
+ config_or_dir, enable_ep=getattr(self, '_enable_expert_parallel', False))
+
+ if getattr(self, '_enable_expert_parallel', False):
+ self._capture_rank0_pre_ep_state_if_needed()
+ self._maybe_apply_expert_parallel()
+ unwrapped_model = self.strategy.unwrap_model(self.model)
+
if isinstance(config_or_dir, str):
config_or_dir = HubOperation.download_model(config_or_dir)
_adapted_model = PeftModel.from_pretrained(
diff --git a/tests/scripts/ep_lora/test_patch_adapter_assertions.py b/tests/scripts/ep_lora/test_patch_adapter_assertions.py
new file mode 100644
index 00000000..ed778f0c
--- /dev/null
+++ b/tests/scripts/ep_lora/test_patch_adapter_assertions.py
@@ -0,0 +1,39 @@
+"""Unit tests for _patch_adapter entry assertions and default target_parameters."""
+from unittest import mock
+
+import pytest
+from peft import LoraConfig
+
+
+def test_lora_dropout_with_target_parameters_raises():
+ from twinkle.model.transformers.transformers import TransformersModel
+ from twinkle.model.transformers.strategy.native_fsdp import NativeFSDPStrategy
+
+ cfg = LoraConfig(
+ r=4,
+ target_modules=['q_proj'],
+ target_parameters=['mlp.experts.gate_up_proj'],
+ lora_dropout=0.1,
+ )
+ instance = mock.MagicMock(spec=TransformersModel)
+ instance._enable_expert_parallel = True
+ instance.strategy = NativeFSDPStrategy(enable_ep=True)
+
+ with pytest.raises(ValueError, match='ParamWrapper'):
+ TransformersModel._validate_ep_lora_config(instance, cfg)
+
+
+def test_no_target_parameters_autofilled():
+ from twinkle.model.transformers.transformers import TransformersModel
+
+ cfg = LoraConfig(r=4, target_modules='all-linear')
+ cfg = TransformersModel._maybe_autofill_target_parameters(cfg, enable_ep=True)
+ assert cfg.target_parameters == ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']
+
+
+def test_target_parameters_passthrough_when_set():
+ from twinkle.model.transformers.transformers import TransformersModel
+
+ cfg = LoraConfig(r=4, target_modules='all-linear', target_parameters=['custom.param'])
+ cfg2 = TransformersModel._maybe_autofill_target_parameters(cfg, enable_ep=True)
+ assert cfg2.target_parameters == ['custom.param']
diff --git a/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py b/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
new file mode 100644
index 00000000..e69f7a98
--- /dev/null
+++ b/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
@@ -0,0 +1,61 @@
+"""Verify _patch_adapter triggers EP slicing before get_peft_model.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
+"""
+import os
+
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_logger
+from twinkle.model import TransformersModel
+
+logger = get_logger()
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+
+
+def main():
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = 2
+ config.num_experts = 4
+ config.use_cache = False
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True}},
+ )
+ assert not getattr(model, '_expert_parallel_applied', False)
+
+ model.add_adapter_to_model(
+ 'default',
+ LoraConfig(
+ r=4,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ ),
+ )
+ assert getattr(model, '_expert_parallel_applied', False), (
+ '_patch_adapter must trigger EP slicing before PEFT wrap')
+
+ unwrapped = model.strategy.unwrap_model(model.model)
+ lora_a_shapes = [(n, p.shape) for n, p in unwrapped.named_parameters() if 'experts' in n and 'lora_A' in n]
+ assert lora_a_shapes, 'no lora_A weight under experts subtree'
+ for name, shape in lora_a_shapes:
+ assert shape[0] == 4 * 2, f'{name} shape {tuple(shape)}; expected r*E_local=8'
+
+ logger.info('ORDERING TEST PASSED')
+
+
+if __name__ == '__main__':
+ main()
From 494198ddf506bacb00ad142f9020a5fa7f665c43 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Fri, 15 May 2026 14:36:31 +0800
Subject: [PATCH 30/40] feat(ep-lora): add EP-aware PEFT save and load handling
---
.../model/transformers/transformers.py | 60 +++++++++++++++++--
1 file changed, 55 insertions(+), 5 deletions(-)
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index bfe12ade..92354c6b 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -103,6 +103,45 @@ def _clone_state_dict_to_cpu(state_dict: Dict[str, Any]) -> Dict[str, Any]:
return cloned
+def _get_named_child(module, name: str):
+ if hasattr(module, name):
+ return getattr(module, name)
+ if name.isdigit() and hasattr(module, '__getitem__'):
+ try:
+ return module[int(name)]
+ except (IndexError, TypeError, KeyError):
+ return None
+ return None
+
+
+def _split_for_ep_pre_distribute(model, model_key: str, value: torch.Tensor, ep_world_size: int,
+ ep_rank: int) -> torch.Tensor:
+ """Slice saved LoRA expert weights by EP rank before DTensor/FSDP placement."""
+ if ep_world_size <= 1:
+ return value
+
+ parts = model_key.split('.')
+ parent = model
+ matched = False
+ for i, part in enumerate(parts[:-1]):
+ parent = _get_named_child(parent, part)
+ if parent is None:
+ return value
+ next_seg = parts[i + 1] if i + 1 < len(parts) else None
+ if getattr(parent, '_ep_patched', False) and next_seg == 'experts':
+ matched = True
+
+ if not matched:
+ return value
+ if 'lora_A' in model_key:
+ chunk = value.size(0) // ep_world_size
+ return value.narrow(0, ep_rank * chunk, chunk).contiguous()
+ if 'lora_B' in model_key:
+ chunk = value.size(1) // ep_world_size
+ return value.narrow(1, ep_rank * chunk, chunk).contiguous()
+ return value
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -1017,11 +1056,17 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int
# Full model save
processed_state_dict = self.strategy.get_full_state_dict(self.model)
else:
- # LoRA adapter save
- state_dict = self.get_state_dict(adapter_name=adapter_name, **kwargs)
- for key, value in state_dict.items():
- key = key.replace(f'.{adapter_name}.', '.')
- processed_state_dict[key] = torch_util.to_local_tensor(value).cpu()
+ # LoRA adapter save (EP-aware via strategy.get_full_state_dict)
+ full_state = self.strategy.get_full_state_dict(self.model)
+ adapter_marker = '.lora_'
+ adapter_suffix = f'.{adapter_name}.'
+ for key, value in full_state.items():
+ if adapter_marker not in key:
+ continue
+ if adapter_suffix not in key:
+ continue
+ normalized = key.replace(adapter_suffix, '.')
+ processed_state_dict[normalized] = value
if isinstance(model, PeftModel):
if Platform.is_master():
@@ -1124,6 +1169,10 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
from torch.distributed.tensor import DTensor, distribute_tensor
+ ep_fsdp_mesh = getattr(self.strategy, 'ep_fsdp_device_mesh', None)
+ ep_world_size = ep_fsdp_mesh['ep'].size() if ep_fsdp_mesh is not None else 1
+ ep_rank = ep_fsdp_mesh['ep'].get_local_rank() if ep_world_size > 1 else 0
+
model_sd = model.state_dict()
converted_weights = {}
for key, value in adapter_weights.items():
@@ -1132,6 +1181,7 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
model_key = model_key.replace('.weight', f'.{adapter_name}.weight')
if model_key in model_sd:
param = model_sd[model_key]
+ value = _split_for_ep_pre_distribute(model, model_key, value, ep_world_size, ep_rank)
if isinstance(param, DTensor) and not isinstance(value, DTensor):
value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements)
converted_weights[key] = value
From 3f7f4347d07b41fd6a21c4e1ec0b545699162df3 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Fri, 15 May 2026 14:40:00 +0800
Subject: [PATCH 31/40] test(ep-lora): add EP LoRA cookbooks and validation
scripts
---
cookbook/transformers/ep_lora_deepseek_v4.py | 95 ++++++++++++++++
cookbook/transformers/ep_lora_qwen3_5_moe.py | 101 +++++++++++++++++
.../scripts/ep_lora/loss_curve_qwen3_5_moe.py | 87 ++++++++++++++
tests/scripts/ep_lora/save_load_dsv4.py | 107 ++++++++++++++++++
.../scripts/ep_lora/save_load_qwen3_5_moe.py | 103 +++++++++++++++++
.../spike_register_parametrization_dsv4.py | 92 +++++++++++++++
6 files changed, 585 insertions(+)
create mode 100644 cookbook/transformers/ep_lora_deepseek_v4.py
create mode 100644 cookbook/transformers/ep_lora_qwen3_5_moe.py
create mode 100644 tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
create mode 100644 tests/scripts/ep_lora/save_load_dsv4.py
create mode 100644 tests/scripts/ep_lora/save_load_qwen3_5_moe.py
create mode 100644 tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
diff --git a/cookbook/transformers/ep_lora_deepseek_v4.py b/cookbook/transformers/ep_lora_deepseek_v4.py
new file mode 100644
index 00000000..3edac2af
--- /dev/null
+++ b/cookbook/transformers/ep_lora_deepseek_v4.py
@@ -0,0 +1,95 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""EP + LoRA SFT cookbook for DeepSeek-V4.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 cookbook/transformers/ep_lora_deepseek_v4.py
+"""
+import os
+
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
+NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '2'))
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '2'))
+GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
+LR = float(os.environ.get('LR', '1e-4'))
+MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
+LORA_R = int(os.environ.get('LORA_R', '8'))
+LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
+NUM_STEPS_LIMIT = int(os.environ.get('NUM_STEPS_LIMIT', '0'))
+
+device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+)
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def train():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = NUM_LAYERS
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(500)))
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset.encode(batched=True)
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+ lora_cfg = LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ model.add_adapter_to_model('default', lora_cfg)
+ model.set_optimizer('AdamW', lr=LR, foreach=False)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ )
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+
+ for step, batch in enumerate(dataloader):
+ if NUM_STEPS_LIMIT and step >= NUM_STEPS_LIMIT:
+ break
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ if (step + 1) % GRAD_ACCUM_STEPS == 0:
+ optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ logger.info(f'optimizer_step {optimizer_step}, metric: {metric}')
+
+ model.save(name='checkpoint-final', output_dir='./output_dsv4')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/ep_lora_qwen3_5_moe.py b/cookbook/transformers/ep_lora_qwen3_5_moe.py
new file mode 100644
index 00000000..e10db392
--- /dev/null
+++ b/cookbook/transformers/ep_lora_qwen3_5_moe.py
@@ -0,0 +1,101 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""EP + LoRA SFT cookbook for Qwen3.5-MoE.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 cookbook/transformers/ep_lora_qwen3_5_moe.py
+"""
+import os
+
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template')
+_num_layers_env = os.environ.get('NUM_LAYERS')
+NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
+GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
+LR = float(os.environ.get('LR', '1e-4'))
+MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
+LORA_R = int(os.environ.get('LORA_R', '8'))
+LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
+NUM_STEPS_LIMIT = int(os.environ.get('NUM_STEPS_LIMIT', '0'))
+
+device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+)
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def train():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
+ config.num_hidden_layers = NUM_LAYERS
+ if hasattr(config, 'use_cache'):
+ config.use_cache = False
+
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
+ try:
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ except ValueError:
+ dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset.encode(batched=True)
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+ lora_cfg = LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ model.add_adapter_to_model('default', lora_cfg)
+ model.set_optimizer('AdamW', lr=LR, foreach=False)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ )
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+
+ for step, batch in enumerate(dataloader):
+ if NUM_STEPS_LIMIT and step >= NUM_STEPS_LIMIT:
+ break
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ if (step + 1) % GRAD_ACCUM_STEPS == 0:
+ optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ logger.info(f'optimizer_step {optimizer_step}, metric: {metric}')
+
+ model.save(name='checkpoint-final', output_dir='./output')
+ logger.info('Saved final adapter to ./output/checkpoint-final')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py b/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
new file mode 100644
index 00000000..29216460
--- /dev/null
+++ b/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
@@ -0,0 +1,87 @@
+"""P2: EP+LoRA loss should decrease over 200 steps on self-cognition SFT.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
+"""
+import os
+
+import torch.distributed as dist
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TARGET_RATIO = float(os.environ.get('TARGET_RATIO', '0.7'))
+NUM_STEPS = int(os.environ.get('NUM_STEPS', '200'))
+
+
+def main():
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = 4
+ config.use_cache = False
+
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
+ dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset.encode(batched=True)
+ dataloader = DataLoader(dataset=dataset, batch_size=4, device_mesh=device_mesh)
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+ model.add_adapter_to_model(
+ 'default',
+ LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ ),
+ )
+ model.set_optimizer('AdamW', lr=1e-4, foreach=False)
+ model.set_lr_scheduler('CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=NUM_STEPS)
+
+ losses = []
+ for step, batch in enumerate(dataloader):
+ if step >= NUM_STEPS:
+ break
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ model.clip_grad_and_step(max_grad_norm=1.0, gradient_accumulation_steps=1)
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
+ losses.append(float(loss))
+
+ if dist.get_rank() == 0:
+ head = sum(losses[:10]) / 10
+ tail = sum(losses[-10:]) / 10
+ ratio = tail / head if head > 0 else float('inf')
+ logger.info(f'head_avg={head:.4f}, tail_avg={tail:.4f}, ratio={ratio:.3f}')
+ assert ratio < TARGET_RATIO, (
+ f'loss did not decrease enough; tail/head ratio {ratio:.3f} >= {TARGET_RATIO}')
+ logger.info('LOSS CURVE TEST PASSED')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/scripts/ep_lora/save_load_dsv4.py b/tests/scripts/ep_lora/save_load_dsv4.py
new file mode 100644
index 00000000..8a6af0c8
--- /dev/null
+++ b/tests/scripts/ep_lora/save_load_dsv4.py
@@ -0,0 +1,107 @@
+"""P3: EP+LoRA save-reload numeric consistency for DeepSeek-V4.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 tests/scripts/ep_lora/save_load_dsv4.py
+"""
+import os
+import shutil
+import tempfile
+
+import torch
+import torch.distributed as dist
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_logger
+from twinkle.model import TransformersModel
+
+logger = get_logger()
+MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
+TOL = float(os.environ.get('TOL', '1e-3'))
+
+
+def build_model():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = 2
+ if hasattr(config, 'n_routed_experts'):
+ config.n_routed_experts = 4
+ if hasattr(config, 'num_experts_per_tok'):
+ config.num_experts_per_tok = 2
+ config.use_cache = False
+ device_mesh = DeviceMesh.current()
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+ lora_cfg = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ model.add_adapter_to_model('default', lora_cfg)
+ model.set_optimizer('AdamW', lr=1e-4, foreach=False)
+ return model, config
+
+
+def fixed_batch(config):
+ torch.manual_seed(0)
+ return {
+ 'input_ids': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'attention_mask': torch.ones(2, 32, dtype=torch.long, device=Platform.get_local_device()),
+ }
+
+
+def compute_loss(model, batch):
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
+ return float(loss)
+
+
+def main():
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+ model_a, config = build_model()
+ batch = fixed_batch(config)
+ model_a.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ model_a.clip_grad_and_step(max_grad_norm=1.0, gradient_accumulation_steps=1)
+ loss_before = compute_loss(model_a, batch)
+
+ tmp_root = tempfile.mkdtemp(prefix='ep_lora_dsv4_save_')
+ if dist.get_rank() == 0:
+ logger.info(f'Save root: {tmp_root}')
+ model_a.save(name='ckpt', output_dir=tmp_root)
+ dist.barrier()
+
+ del model_a
+ torch.cuda.empty_cache()
+
+ model_b, _ = build_model()
+ model_b.load(name='ckpt', output_dir=tmp_root)
+ loss_after = compute_loss(model_b, batch)
+
+ diff = abs(loss_after - loss_before)
+ if dist.get_rank() == 0:
+ logger.info(f'loss_before={loss_before:.6f}, loss_after={loss_after:.6f}, diff={diff:.6e}')
+ assert diff < TOL, f'save/reload loss drift {diff:.3e} > tol {TOL:.1e}'
+
+ if dist.get_rank() == 0:
+ shutil.rmtree(tmp_root, ignore_errors=True)
+ logger.info('SAVE/LOAD TEST PASSED')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/scripts/ep_lora/save_load_qwen3_5_moe.py b/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
new file mode 100644
index 00000000..fb2d03ae
--- /dev/null
+++ b/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
@@ -0,0 +1,103 @@
+"""P1: EP+LoRA save-reload numeric consistency for Qwen3.5-MoE.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 tests/scripts/ep_lora/save_load_qwen3_5_moe.py
+"""
+import os
+import shutil
+import tempfile
+
+import torch
+import torch.distributed as dist
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_logger
+from twinkle.model import TransformersModel
+
+logger = get_logger()
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+TOL = float(os.environ.get('TOL', '1e-3'))
+
+
+def build_model():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = 2
+ config.use_cache = False
+ device_mesh = DeviceMesh.current()
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+ lora_cfg = LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ model.add_adapter_to_model('default', lora_cfg)
+ model.set_optimizer('AdamW', lr=1e-4, foreach=False)
+ return model, config
+
+
+def fixed_batch(config):
+ torch.manual_seed(0)
+ return {
+ 'input_ids': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'attention_mask': torch.ones(2, 32, dtype=torch.long, device=Platform.get_local_device()),
+ }
+
+
+def compute_loss(model, batch):
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
+ return float(loss)
+
+
+def main():
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+ model_a, config = build_model()
+ batch = fixed_batch(config)
+ model_a.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ model_a.clip_grad_and_step(max_grad_norm=1.0, gradient_accumulation_steps=1)
+ loss_before = compute_loss(model_a, batch)
+
+ tmp_root = tempfile.mkdtemp(prefix='ep_lora_save_')
+ if dist.get_rank() == 0:
+ logger.info(f'Save root: {tmp_root}')
+ model_a.save(name='ckpt', output_dir=tmp_root)
+ dist.barrier()
+
+ del model_a
+ torch.cuda.empty_cache()
+
+ model_b, _ = build_model()
+ model_b.load(name='ckpt', output_dir=tmp_root)
+ loss_after = compute_loss(model_b, batch)
+
+ diff = abs(loss_after - loss_before)
+ if dist.get_rank() == 0:
+ logger.info(f'loss_before={loss_before:.6f}, loss_after={loss_after:.6f}, diff={diff:.6e}')
+ assert diff < TOL, f'save/reload loss drift {diff:.3e} > tol {TOL:.1e}'
+
+ if dist.get_rank() == 0:
+ shutil.rmtree(tmp_root, ignore_errors=True)
+ logger.info('SAVE/LOAD TEST PASSED')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py b/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
new file mode 100644
index 00000000..dd4ea03d
--- /dev/null
+++ b/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
@@ -0,0 +1,92 @@
+"""P3 spike: verify PEFT ParamWrapper + FSDP2 DTensor compatibility on DeepSeek-V4.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
+"""
+import os
+
+import torch
+import torch.distributed as dist
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_logger
+from twinkle.model import TransformersModel
+
+logger = get_logger()
+MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
+
+
+def main():
+ device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+ )
+ twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ config.num_hidden_layers = 2
+ if hasattr(config, 'n_routed_experts'):
+ config.n_routed_experts = 4
+ if hasattr(config, 'num_experts_per_tok'):
+ config.num_experts_per_tok = 2
+ config.use_cache = False
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
+ )
+ model.add_adapter_to_model(
+ 'default',
+ LoraConfig(
+ r=8,
+ lora_alpha=32,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ ),
+ )
+ model.set_optimizer('AdamW', lr=1e-4, foreach=False)
+
+ rank = dist.get_rank() if dist.is_initialized() else 0
+ torch.manual_seed(42 + rank)
+ batch = {
+ 'input_ids': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'attention_mask': torch.ones(2, 16, dtype=torch.long, device=Platform.get_local_device()),
+ }
+
+ model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
+ loss_val = float(loss)
+ assert torch.isfinite(torch.tensor(loss_val)), f'loss not finite: {loss_val}'
+
+ unwrapped = model.strategy.unwrap_model(model.model)
+ lora_a_seen = 0
+ base_grads = []
+ for name, param in unwrapped.named_parameters():
+ if 'experts' not in name:
+ continue
+ if 'lora_A' in name and param.grad is not None:
+ lora_a_seen += 1
+ assert param.grad.abs().sum().item() > 0, f'{name} grad is zero'
+ if 'base_layer.gate_up_proj' in name or 'base_layer.down_proj' in name:
+ base_grads.append((name, param.grad))
+
+ assert lora_a_seen > 0, 'no lora_A grads observed under experts subtree'
+ for name, grad in base_grads:
+ assert grad is None, f'{name} should be frozen but has grad'
+
+ if rank == 0:
+ logger.info('SPIKE PASSED: PEFT ParamWrapper works with FSDP2 DTensor for dsv4 routing experts.')
+
+
+if __name__ == '__main__':
+ main()
From 0cbb0fca18d471a62b0d675929ac98cf0acfcd92 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Fri, 15 May 2026 15:33:00 +0800
Subject: [PATCH 32/40] test(ep-lora): handle nested text configs in validation
scripts
---
cookbook/transformers/ep_lora_deepseek_v4.py | 11 ++++++--
cookbook/transformers/ep_lora_qwen3_5_moe.py | 13 ++++++---
.../scripts/ep_lora/ep_lora_config_helpers.py | 12 ++++++++
.../scripts/ep_lora/loss_curve_qwen3_5_moe.py | 5 ++--
tests/scripts/ep_lora/save_load_dsv4.py | 20 +++++++------
.../scripts/ep_lora/save_load_qwen3_5_moe.py | 10 ++++---
.../ep_lora/spike_register_parametrization.py | 24 ++++++++++------
.../spike_register_parametrization_dsv4.py | 20 +++++++------
tests/scripts/ep_lora/test_config_helpers.py | 28 +++++++++++++++++++
.../ep_lora/test_patch_adapter_ep_ordering.py | 6 ++--
10 files changed, 108 insertions(+), 41 deletions(-)
create mode 100644 tests/scripts/ep_lora/ep_lora_config_helpers.py
create mode 100644 tests/scripts/ep_lora/test_config_helpers.py
diff --git a/cookbook/transformers/ep_lora_deepseek_v4.py b/cookbook/transformers/ep_lora_deepseek_v4.py
index 3edac2af..9d035a48 100644
--- a/cookbook/transformers/ep_lora_deepseek_v4.py
+++ b/cookbook/transformers/ep_lora_deepseek_v4.py
@@ -39,11 +39,16 @@
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+def _get_text_config(config):
+ return getattr(config, 'text_config', config)
+
+
def train():
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = NUM_LAYERS
- if hasattr(config, 'use_cache'):
- config.use_cache = False
+ text_config = _get_text_config(config)
+ text_config.num_hidden_layers = NUM_LAYERS
+ if hasattr(text_config, 'use_cache'):
+ text_config.use_cache = False
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(500)))
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
diff --git a/cookbook/transformers/ep_lora_qwen3_5_moe.py b/cookbook/transformers/ep_lora_qwen3_5_moe.py
index e10db392..e8bfebc2 100644
--- a/cookbook/transformers/ep_lora_qwen3_5_moe.py
+++ b/cookbook/transformers/ep_lora_qwen3_5_moe.py
@@ -40,12 +40,17 @@
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+def _get_text_config(config):
+ return getattr(config, 'text_config', config)
+
+
def train():
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
- config.num_hidden_layers = NUM_LAYERS
- if hasattr(config, 'use_cache'):
- config.use_cache = False
+ text_config = _get_text_config(config)
+ if NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'):
+ text_config.num_hidden_layers = NUM_LAYERS
+ if hasattr(text_config, 'use_cache'):
+ text_config.use_cache = False
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
try:
diff --git a/tests/scripts/ep_lora/ep_lora_config_helpers.py b/tests/scripts/ep_lora/ep_lora_config_helpers.py
new file mode 100644
index 00000000..d36d9ae1
--- /dev/null
+++ b/tests/scripts/ep_lora/ep_lora_config_helpers.py
@@ -0,0 +1,12 @@
+def get_text_config(config):
+ return getattr(config, 'text_config', config)
+
+
+def get_vocab_size(config):
+ return get_text_config(config).vocab_size
+
+
+def set_text_config_attrs(config, **attrs):
+ text_config = get_text_config(config)
+ for name, value in attrs.items():
+ setattr(text_config, name, value)
diff --git a/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py b/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
index 29216460..358ac335 100644
--- a/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
+++ b/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
@@ -16,6 +16,8 @@
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor
+from ep_lora_config_helpers import set_text_config_attrs
+
logger = get_logger()
MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
@@ -32,8 +34,7 @@ def main():
)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = 4
- config.use_cache = False
+ set_text_config_attrs(config, num_hidden_layers=4, use_cache=False)
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
diff --git a/tests/scripts/ep_lora/save_load_dsv4.py b/tests/scripts/ep_lora/save_load_dsv4.py
index 8a6af0c8..a76ac4db 100644
--- a/tests/scripts/ep_lora/save_load_dsv4.py
+++ b/tests/scripts/ep_lora/save_load_dsv4.py
@@ -16,6 +16,8 @@
from twinkle import DeviceMesh, Platform, get_logger
from twinkle.model import TransformersModel
+from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
+
logger = get_logger()
MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
TOL = float(os.environ.get('TOL', '1e-3'))
@@ -23,12 +25,13 @@
def build_model():
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = 2
- if hasattr(config, 'n_routed_experts'):
- config.n_routed_experts = 4
- if hasattr(config, 'num_experts_per_tok'):
- config.num_experts_per_tok = 2
- config.use_cache = False
+ text_config = getattr(config, 'text_config', config)
+ attrs = {'num_hidden_layers': 2, 'use_cache': False}
+ if hasattr(text_config, 'n_routed_experts'):
+ attrs['n_routed_experts'] = 4
+ if hasattr(text_config, 'num_experts_per_tok'):
+ attrs['num_experts_per_tok'] = 2
+ set_text_config_attrs(config, **attrs)
device_mesh = DeviceMesh.current()
model = TransformersModel(
model_id=MODEL_ID,
@@ -49,9 +52,10 @@ def build_model():
def fixed_batch(config):
torch.manual_seed(0)
+ vocab_size = get_vocab_size(config)
return {
- 'input_ids': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
- 'labels': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'input_ids': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
'attention_mask': torch.ones(2, 32, dtype=torch.long, device=Platform.get_local_device()),
}
diff --git a/tests/scripts/ep_lora/save_load_qwen3_5_moe.py b/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
index fb2d03ae..dbec27ae 100644
--- a/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
+++ b/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
@@ -16,6 +16,8 @@
from twinkle import DeviceMesh, Platform, get_logger
from twinkle.model import TransformersModel
+from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
+
logger = get_logger()
MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
TOL = float(os.environ.get('TOL', '1e-3'))
@@ -23,8 +25,7 @@
def build_model():
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = 2
- config.use_cache = False
+ set_text_config_attrs(config, num_hidden_layers=2, use_cache=False)
device_mesh = DeviceMesh.current()
model = TransformersModel(
model_id=MODEL_ID,
@@ -45,9 +46,10 @@ def build_model():
def fixed_batch(config):
torch.manual_seed(0)
+ vocab_size = get_vocab_size(config)
return {
- 'input_ids': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
- 'labels': torch.randint(0, config.vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'input_ids': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
'attention_mask': torch.ones(2, 32, dtype=torch.long, device=Platform.get_local_device()),
}
diff --git a/tests/scripts/ep_lora/spike_register_parametrization.py b/tests/scripts/ep_lora/spike_register_parametrization.py
index 41042f4a..de0aa0fe 100644
--- a/tests/scripts/ep_lora/spike_register_parametrization.py
+++ b/tests/scripts/ep_lora/spike_register_parametrization.py
@@ -14,6 +14,8 @@
from twinkle import DeviceMesh, Platform, get_logger
from twinkle.model import TransformersModel
+from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
+
logger = get_logger()
MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
@@ -29,13 +31,16 @@ def main():
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = 2
- config.hidden_size = 128
- config.intermediate_size = 256
- config.moe_intermediate_size = 64
- config.num_experts = 4
- config.num_experts_per_tok = 2
- config.use_cache = False
+ set_text_config_attrs(
+ config,
+ num_hidden_layers=2,
+ hidden_size=128,
+ intermediate_size=256,
+ moe_intermediate_size=64,
+ num_experts=4,
+ num_experts_per_tok=2,
+ use_cache=False,
+ )
model = TransformersModel(
model_id=MODEL_ID,
@@ -55,9 +60,10 @@ def main():
rank = dist.get_rank() if dist.is_initialized() else 0
torch.manual_seed(42 + rank)
+ vocab_size = get_vocab_size(config)
batch = {
- 'input_ids': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
- 'labels': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'input_ids': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
'attention_mask': torch.ones(2, 16, dtype=torch.long, device=Platform.get_local_device()),
}
diff --git a/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py b/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
index dd4ea03d..dc9bf707 100644
--- a/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
+++ b/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
@@ -14,6 +14,8 @@
from twinkle import DeviceMesh, Platform, get_logger
from twinkle.model import TransformersModel
+from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
+
logger = get_logger()
MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
@@ -28,12 +30,13 @@ def main():
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = 2
- if hasattr(config, 'n_routed_experts'):
- config.n_routed_experts = 4
- if hasattr(config, 'num_experts_per_tok'):
- config.num_experts_per_tok = 2
- config.use_cache = False
+ text_config = getattr(config, 'text_config', config)
+ attrs = {'num_hidden_layers': 2, 'use_cache': False}
+ if hasattr(text_config, 'n_routed_experts'):
+ attrs['n_routed_experts'] = 4
+ if hasattr(text_config, 'num_experts_per_tok'):
+ attrs['num_experts_per_tok'] = 2
+ set_text_config_attrs(config, **attrs)
model = TransformersModel(
model_id=MODEL_ID,
@@ -54,9 +57,10 @@ def main():
rank = dist.get_rank() if dist.is_initialized() else 0
torch.manual_seed(42 + rank)
+ vocab_size = get_vocab_size(config)
batch = {
- 'input_ids': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
- 'labels': torch.randint(0, config.vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'input_ids': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
+ 'labels': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
'attention_mask': torch.ones(2, 16, dtype=torch.long, device=Platform.get_local_device()),
}
diff --git a/tests/scripts/ep_lora/test_config_helpers.py b/tests/scripts/ep_lora/test_config_helpers.py
new file mode 100644
index 00000000..528cb125
--- /dev/null
+++ b/tests/scripts/ep_lora/test_config_helpers.py
@@ -0,0 +1,28 @@
+from types import SimpleNamespace
+
+from ep_lora_config_helpers import get_text_config, get_vocab_size, set_text_config_attrs
+
+
+def test_get_vocab_size_prefers_nested_text_config():
+ config = SimpleNamespace(text_config=SimpleNamespace(vocab_size=248320))
+
+ assert get_vocab_size(config) == 248320
+
+
+def test_set_text_config_attrs_updates_nested_text_config():
+ text_config = SimpleNamespace(num_hidden_layers=40, use_cache=True)
+ config = SimpleNamespace(text_config=text_config)
+
+ set_text_config_attrs(config, num_hidden_layers=2, use_cache=False)
+
+ assert get_text_config(config).num_hidden_layers == 2
+ assert get_text_config(config).use_cache is False
+
+
+def test_set_text_config_attrs_falls_back_to_top_level_config():
+ config = SimpleNamespace(vocab_size=1024)
+
+ set_text_config_attrs(config, num_hidden_layers=2, use_cache=False)
+
+ assert config.num_hidden_layers == 2
+ assert config.use_cache is False
diff --git a/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py b/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
index e69f7a98..3011d6ed 100644
--- a/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
+++ b/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
@@ -12,6 +12,8 @@
from twinkle import DeviceMesh, Platform, get_logger
from twinkle.model import TransformersModel
+from ep_lora_config_helpers import set_text_config_attrs
+
logger = get_logger()
MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
@@ -25,9 +27,7 @@ def main():
)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- config.num_hidden_layers = 2
- config.num_experts = 4
- config.use_cache = False
+ set_text_config_attrs(config, num_hidden_layers=2, num_experts=4, use_cache=False)
model = TransformersModel(
model_id=MODEL_ID,
From 6ac03b2e586e1ac0d31ad6a6ae8e0b3fe35ec1ce Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Wed, 20 May 2026 11:10:24 +0800
Subject: [PATCH 33/40] delete
---
.../scripts/ep_lora/ep_lora_config_helpers.py | 12 --
.../scripts/ep_lora/loss_curve_qwen3_5_moe.py | 88 --------------
tests/scripts/ep_lora/save_load_dsv4.py | 111 ------------------
.../scripts/ep_lora/save_load_qwen3_5_moe.py | 105 -----------------
.../ep_lora/spike_register_parametrization.py | 105 -----------------
.../spike_register_parametrization_dsv4.py | 96 ---------------
tests/scripts/ep_lora/test_config_helpers.py | 28 -----
.../ep_lora/test_patch_adapter_assertions.py | 39 ------
.../ep_lora/test_patch_adapter_ep_ordering.py | 61 ----------
.../ep_lora/test_shared_experts_fallback.py | 44 -------
10 files changed, 689 deletions(-)
delete mode 100644 tests/scripts/ep_lora/ep_lora_config_helpers.py
delete mode 100644 tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
delete mode 100644 tests/scripts/ep_lora/save_load_dsv4.py
delete mode 100644 tests/scripts/ep_lora/save_load_qwen3_5_moe.py
delete mode 100644 tests/scripts/ep_lora/spike_register_parametrization.py
delete mode 100644 tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
delete mode 100644 tests/scripts/ep_lora/test_config_helpers.py
delete mode 100644 tests/scripts/ep_lora/test_patch_adapter_assertions.py
delete mode 100644 tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
delete mode 100644 tests/scripts/ep_lora/test_shared_experts_fallback.py
diff --git a/tests/scripts/ep_lora/ep_lora_config_helpers.py b/tests/scripts/ep_lora/ep_lora_config_helpers.py
deleted file mode 100644
index d36d9ae1..00000000
--- a/tests/scripts/ep_lora/ep_lora_config_helpers.py
+++ /dev/null
@@ -1,12 +0,0 @@
-def get_text_config(config):
- return getattr(config, 'text_config', config)
-
-
-def get_vocab_size(config):
- return get_text_config(config).vocab_size
-
-
-def set_text_config_attrs(config, **attrs):
- text_config = get_text_config(config)
- for name, value in attrs.items():
- setattr(text_config, name, value)
diff --git a/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py b/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
deleted file mode 100644
index 358ac335..00000000
--- a/tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
+++ /dev/null
@@ -1,88 +0,0 @@
-"""P2: EP+LoRA loss should decrease over 200 steps on self-cognition SFT.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 tests/scripts/ep_lora/loss_curve_qwen3_5_moe.py
-"""
-import os
-
-import torch.distributed as dist
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_logger
-from twinkle.dataloader import DataLoader
-from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.model import TransformersModel
-from twinkle.preprocessor import SelfCognitionProcessor
-
-from ep_lora_config_helpers import set_text_config_attrs
-
-logger = get_logger()
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TARGET_RATIO = float(os.environ.get('TARGET_RATIO', '0.7'))
-NUM_STEPS = int(os.environ.get('NUM_STEPS', '200'))
-
-
-def main():
- device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
- )
- twinkle.initialize(mode='local', global_device_mesh=device_mesh)
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- set_text_config_attrs(config, num_hidden_layers=4, use_cache=False)
-
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
- dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
- dataset.encode(batched=True)
- dataloader = DataLoader(dataset=dataset, batch_size=4, device_mesh=device_mesh)
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
- model.add_adapter_to_model(
- 'default',
- LoraConfig(
- r=8,
- lora_alpha=32,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- ),
- )
- model.set_optimizer('AdamW', lr=1e-4, foreach=False)
- model.set_lr_scheduler('CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=NUM_STEPS)
-
- losses = []
- for step, batch in enumerate(dataloader):
- if step >= NUM_STEPS:
- break
- if callable(batch):
- batch = batch()
- model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- model.clip_grad_and_step(max_grad_norm=1.0, gradient_accumulation_steps=1)
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
- losses.append(float(loss))
-
- if dist.get_rank() == 0:
- head = sum(losses[:10]) / 10
- tail = sum(losses[-10:]) / 10
- ratio = tail / head if head > 0 else float('inf')
- logger.info(f'head_avg={head:.4f}, tail_avg={tail:.4f}, ratio={ratio:.3f}')
- assert ratio < TARGET_RATIO, (
- f'loss did not decrease enough; tail/head ratio {ratio:.3f} >= {TARGET_RATIO}')
- logger.info('LOSS CURVE TEST PASSED')
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/scripts/ep_lora/save_load_dsv4.py b/tests/scripts/ep_lora/save_load_dsv4.py
deleted file mode 100644
index a76ac4db..00000000
--- a/tests/scripts/ep_lora/save_load_dsv4.py
+++ /dev/null
@@ -1,111 +0,0 @@
-"""P3: EP+LoRA save-reload numeric consistency for DeepSeek-V4.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 tests/scripts/ep_lora/save_load_dsv4.py
-"""
-import os
-import shutil
-import tempfile
-
-import torch
-import torch.distributed as dist
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_logger
-from twinkle.model import TransformersModel
-
-from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
-
-logger = get_logger()
-MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
-TOL = float(os.environ.get('TOL', '1e-3'))
-
-
-def build_model():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- text_config = getattr(config, 'text_config', config)
- attrs = {'num_hidden_layers': 2, 'use_cache': False}
- if hasattr(text_config, 'n_routed_experts'):
- attrs['n_routed_experts'] = 4
- if hasattr(text_config, 'num_experts_per_tok'):
- attrs['num_experts_per_tok'] = 2
- set_text_config_attrs(config, **attrs)
- device_mesh = DeviceMesh.current()
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
- lora_cfg = LoraConfig(
- r=8,
- lora_alpha=32,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- )
- model.add_adapter_to_model('default', lora_cfg)
- model.set_optimizer('AdamW', lr=1e-4, foreach=False)
- return model, config
-
-
-def fixed_batch(config):
- torch.manual_seed(0)
- vocab_size = get_vocab_size(config)
- return {
- 'input_ids': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
- 'labels': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
- 'attention_mask': torch.ones(2, 32, dtype=torch.long, device=Platform.get_local_device()),
- }
-
-
-def compute_loss(model, batch):
- model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
- return float(loss)
-
-
-def main():
- device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
- )
- twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
- model_a, config = build_model()
- batch = fixed_batch(config)
- model_a.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- model_a.clip_grad_and_step(max_grad_norm=1.0, gradient_accumulation_steps=1)
- loss_before = compute_loss(model_a, batch)
-
- tmp_root = tempfile.mkdtemp(prefix='ep_lora_dsv4_save_')
- if dist.get_rank() == 0:
- logger.info(f'Save root: {tmp_root}')
- model_a.save(name='ckpt', output_dir=tmp_root)
- dist.barrier()
-
- del model_a
- torch.cuda.empty_cache()
-
- model_b, _ = build_model()
- model_b.load(name='ckpt', output_dir=tmp_root)
- loss_after = compute_loss(model_b, batch)
-
- diff = abs(loss_after - loss_before)
- if dist.get_rank() == 0:
- logger.info(f'loss_before={loss_before:.6f}, loss_after={loss_after:.6f}, diff={diff:.6e}')
- assert diff < TOL, f'save/reload loss drift {diff:.3e} > tol {TOL:.1e}'
-
- if dist.get_rank() == 0:
- shutil.rmtree(tmp_root, ignore_errors=True)
- logger.info('SAVE/LOAD TEST PASSED')
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/scripts/ep_lora/save_load_qwen3_5_moe.py b/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
deleted file mode 100644
index dbec27ae..00000000
--- a/tests/scripts/ep_lora/save_load_qwen3_5_moe.py
+++ /dev/null
@@ -1,105 +0,0 @@
-"""P1: EP+LoRA save-reload numeric consistency for Qwen3.5-MoE.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 tests/scripts/ep_lora/save_load_qwen3_5_moe.py
-"""
-import os
-import shutil
-import tempfile
-
-import torch
-import torch.distributed as dist
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_logger
-from twinkle.model import TransformersModel
-
-from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
-
-logger = get_logger()
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-TOL = float(os.environ.get('TOL', '1e-3'))
-
-
-def build_model():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- set_text_config_attrs(config, num_hidden_layers=2, use_cache=False)
- device_mesh = DeviceMesh.current()
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
- lora_cfg = LoraConfig(
- r=8,
- lora_alpha=32,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- )
- model.add_adapter_to_model('default', lora_cfg)
- model.set_optimizer('AdamW', lr=1e-4, foreach=False)
- return model, config
-
-
-def fixed_batch(config):
- torch.manual_seed(0)
- vocab_size = get_vocab_size(config)
- return {
- 'input_ids': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
- 'labels': torch.randint(0, vocab_size, (2, 32), device=Platform.get_local_device()),
- 'attention_mask': torch.ones(2, 32, dtype=torch.long, device=Platform.get_local_device()),
- }
-
-
-def compute_loss(model, batch):
- model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
- return float(loss)
-
-
-def main():
- device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
- )
- twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
- model_a, config = build_model()
- batch = fixed_batch(config)
- model_a.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- model_a.clip_grad_and_step(max_grad_norm=1.0, gradient_accumulation_steps=1)
- loss_before = compute_loss(model_a, batch)
-
- tmp_root = tempfile.mkdtemp(prefix='ep_lora_save_')
- if dist.get_rank() == 0:
- logger.info(f'Save root: {tmp_root}')
- model_a.save(name='ckpt', output_dir=tmp_root)
- dist.barrier()
-
- del model_a
- torch.cuda.empty_cache()
-
- model_b, _ = build_model()
- model_b.load(name='ckpt', output_dir=tmp_root)
- loss_after = compute_loss(model_b, batch)
-
- diff = abs(loss_after - loss_before)
- if dist.get_rank() == 0:
- logger.info(f'loss_before={loss_before:.6f}, loss_after={loss_after:.6f}, diff={diff:.6e}')
- assert diff < TOL, f'save/reload loss drift {diff:.3e} > tol {TOL:.1e}'
-
- if dist.get_rank() == 0:
- shutil.rmtree(tmp_root, ignore_errors=True)
- logger.info('SAVE/LOAD TEST PASSED')
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/scripts/ep_lora/spike_register_parametrization.py b/tests/scripts/ep_lora/spike_register_parametrization.py
deleted file mode 100644
index de0aa0fe..00000000
--- a/tests/scripts/ep_lora/spike_register_parametrization.py
+++ /dev/null
@@ -1,105 +0,0 @@
-"""P0 spike: verify PEFT ParamWrapper + FSDP2 DTensor compatibility on Qwen3.5-MoE.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 tests/scripts/ep_lora/spike_register_parametrization.py
-"""
-import os
-
-import torch
-import torch.distributed as dist
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_logger
-from twinkle.model import TransformersModel
-
-from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
-
-logger = get_logger()
-
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-
-
-def main():
- device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
- )
- twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- set_text_config_attrs(
- config,
- num_hidden_layers=2,
- hidden_size=128,
- intermediate_size=256,
- moe_intermediate_size=64,
- num_experts=4,
- num_experts_per_tok=2,
- use_cache=False,
- )
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
-
- lora_cfg = LoraConfig(
- r=8,
- lora_alpha=32,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- )
- model.add_adapter_to_model('default', lora_cfg)
- model.set_optimizer('AdamW', lr=1e-4, foreach=False)
-
- rank = dist.get_rank() if dist.is_initialized() else 0
- torch.manual_seed(42 + rank)
- vocab_size = get_vocab_size(config)
- batch = {
- 'input_ids': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
- 'labels': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
- 'attention_mask': torch.ones(2, 16, dtype=torch.long, device=Platform.get_local_device()),
- }
-
- model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
- logger.info(f'spike loss (rank {rank}): {loss}')
- assert torch.is_tensor(loss) or isinstance(loss, float), 'loss should be a scalar'
- loss_val = float(loss)
- assert torch.isfinite(torch.tensor(loss_val)), f'loss not finite: {loss_val}'
-
- unwrapped = model.strategy.unwrap_model(model.model)
- lora_a_seen = 0
- lora_b_seen = 0
- base_grads = []
- for name, param in unwrapped.named_parameters():
- if 'experts' not in name:
- continue
- if 'lora_A' in name and param.grad is not None:
- lora_a_seen += 1
- assert param.grad.abs().sum().item() > 0, f'{name} grad is zero'
- if 'lora_B' in name and param.grad is not None:
- lora_b_seen += 1
- if 'base_layer.gate_up_proj' in name or 'base_layer.down_proj' in name:
- base_grads.append((name, param.grad))
-
- logger.info(f'lora_A grads seen: {lora_a_seen}, lora_B grads seen: {lora_b_seen}')
- assert lora_a_seen > 0, 'no lora_A grads observed under experts subtree'
- for name, grad in base_grads:
- assert grad is None, f'{name} should be frozen but has grad'
-
- if rank == 0:
- logger.info('SPIKE PASSED: PEFT ParamWrapper works with FSDP2 DTensor for routing experts.')
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py b/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
deleted file mode 100644
index dc9bf707..00000000
--- a/tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
+++ /dev/null
@@ -1,96 +0,0 @@
-"""P3 spike: verify PEFT ParamWrapper + FSDP2 DTensor compatibility on DeepSeek-V4.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 tests/scripts/ep_lora/spike_register_parametrization_dsv4.py
-"""
-import os
-
-import torch
-import torch.distributed as dist
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_logger
-from twinkle.model import TransformersModel
-
-from ep_lora_config_helpers import get_vocab_size, set_text_config_attrs
-
-logger = get_logger()
-MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
-
-
-def main():
- device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
- )
- twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- text_config = getattr(config, 'text_config', config)
- attrs = {'num_hidden_layers': 2, 'use_cache': False}
- if hasattr(text_config, 'n_routed_experts'):
- attrs['n_routed_experts'] = 4
- if hasattr(text_config, 'num_experts_per_tok'):
- attrs['num_experts_per_tok'] = 2
- set_text_config_attrs(config, **attrs)
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
- model.add_adapter_to_model(
- 'default',
- LoraConfig(
- r=8,
- lora_alpha=32,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- ),
- )
- model.set_optimizer('AdamW', lr=1e-4, foreach=False)
-
- rank = dist.get_rank() if dist.is_initialized() else 0
- torch.manual_seed(42 + rank)
- vocab_size = get_vocab_size(config)
- batch = {
- 'input_ids': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
- 'labels': torch.randint(0, vocab_size, (2, 16), device=Platform.get_local_device()),
- 'attention_mask': torch.ones(2, 16, dtype=torch.long, device=Platform.get_local_device()),
- }
-
- model.forward_backward(inputs=batch, gradient_accumulation_steps=1)
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- loss = metric['loss'] if isinstance(metric, dict) and 'loss' in metric else metric
- loss_val = float(loss)
- assert torch.isfinite(torch.tensor(loss_val)), f'loss not finite: {loss_val}'
-
- unwrapped = model.strategy.unwrap_model(model.model)
- lora_a_seen = 0
- base_grads = []
- for name, param in unwrapped.named_parameters():
- if 'experts' not in name:
- continue
- if 'lora_A' in name and param.grad is not None:
- lora_a_seen += 1
- assert param.grad.abs().sum().item() > 0, f'{name} grad is zero'
- if 'base_layer.gate_up_proj' in name or 'base_layer.down_proj' in name:
- base_grads.append((name, param.grad))
-
- assert lora_a_seen > 0, 'no lora_A grads observed under experts subtree'
- for name, grad in base_grads:
- assert grad is None, f'{name} should be frozen but has grad'
-
- if rank == 0:
- logger.info('SPIKE PASSED: PEFT ParamWrapper works with FSDP2 DTensor for dsv4 routing experts.')
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/scripts/ep_lora/test_config_helpers.py b/tests/scripts/ep_lora/test_config_helpers.py
deleted file mode 100644
index 528cb125..00000000
--- a/tests/scripts/ep_lora/test_config_helpers.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from types import SimpleNamespace
-
-from ep_lora_config_helpers import get_text_config, get_vocab_size, set_text_config_attrs
-
-
-def test_get_vocab_size_prefers_nested_text_config():
- config = SimpleNamespace(text_config=SimpleNamespace(vocab_size=248320))
-
- assert get_vocab_size(config) == 248320
-
-
-def test_set_text_config_attrs_updates_nested_text_config():
- text_config = SimpleNamespace(num_hidden_layers=40, use_cache=True)
- config = SimpleNamespace(text_config=text_config)
-
- set_text_config_attrs(config, num_hidden_layers=2, use_cache=False)
-
- assert get_text_config(config).num_hidden_layers == 2
- assert get_text_config(config).use_cache is False
-
-
-def test_set_text_config_attrs_falls_back_to_top_level_config():
- config = SimpleNamespace(vocab_size=1024)
-
- set_text_config_attrs(config, num_hidden_layers=2, use_cache=False)
-
- assert config.num_hidden_layers == 2
- assert config.use_cache is False
diff --git a/tests/scripts/ep_lora/test_patch_adapter_assertions.py b/tests/scripts/ep_lora/test_patch_adapter_assertions.py
deleted file mode 100644
index ed778f0c..00000000
--- a/tests/scripts/ep_lora/test_patch_adapter_assertions.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""Unit tests for _patch_adapter entry assertions and default target_parameters."""
-from unittest import mock
-
-import pytest
-from peft import LoraConfig
-
-
-def test_lora_dropout_with_target_parameters_raises():
- from twinkle.model.transformers.transformers import TransformersModel
- from twinkle.model.transformers.strategy.native_fsdp import NativeFSDPStrategy
-
- cfg = LoraConfig(
- r=4,
- target_modules=['q_proj'],
- target_parameters=['mlp.experts.gate_up_proj'],
- lora_dropout=0.1,
- )
- instance = mock.MagicMock(spec=TransformersModel)
- instance._enable_expert_parallel = True
- instance.strategy = NativeFSDPStrategy(enable_ep=True)
-
- with pytest.raises(ValueError, match='ParamWrapper'):
- TransformersModel._validate_ep_lora_config(instance, cfg)
-
-
-def test_no_target_parameters_autofilled():
- from twinkle.model.transformers.transformers import TransformersModel
-
- cfg = LoraConfig(r=4, target_modules='all-linear')
- cfg = TransformersModel._maybe_autofill_target_parameters(cfg, enable_ep=True)
- assert cfg.target_parameters == ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']
-
-
-def test_target_parameters_passthrough_when_set():
- from twinkle.model.transformers.transformers import TransformersModel
-
- cfg = LoraConfig(r=4, target_modules='all-linear', target_parameters=['custom.param'])
- cfg2 = TransformersModel._maybe_autofill_target_parameters(cfg, enable_ep=True)
- assert cfg2.target_parameters == ['custom.param']
diff --git a/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py b/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
deleted file mode 100644
index 3011d6ed..00000000
--- a/tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
+++ /dev/null
@@ -1,61 +0,0 @@
-"""Verify _patch_adapter triggers EP slicing before get_peft_model.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 tests/scripts/ep_lora/test_patch_adapter_ep_ordering.py
-"""
-import os
-
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_logger
-from twinkle.model import TransformersModel
-
-from ep_lora_config_helpers import set_text_config_attrs
-
-logger = get_logger()
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-
-
-def main():
- device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
- )
- twinkle.initialize(mode='local', global_device_mesh=device_mesh)
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- set_text_config_attrs(config, num_hidden_layers=2, num_experts=4, use_cache=False)
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True}},
- )
- assert not getattr(model, '_expert_parallel_applied', False)
-
- model.add_adapter_to_model(
- 'default',
- LoraConfig(
- r=4,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- ),
- )
- assert getattr(model, '_expert_parallel_applied', False), (
- '_patch_adapter must trigger EP slicing before PEFT wrap')
-
- unwrapped = model.strategy.unwrap_model(model.model)
- lora_a_shapes = [(n, p.shape) for n, p in unwrapped.named_parameters() if 'experts' in n and 'lora_A' in n]
- assert lora_a_shapes, 'no lora_A weight under experts subtree'
- for name, shape in lora_a_shapes:
- assert shape[0] == 4 * 2, f'{name} shape {tuple(shape)}; expected r*E_local=8'
-
- logger.info('ORDERING TEST PASSED')
-
-
-if __name__ == '__main__':
- main()
diff --git a/tests/scripts/ep_lora/test_shared_experts_fallback.py b/tests/scripts/ep_lora/test_shared_experts_fallback.py
deleted file mode 100644
index 350976bf..00000000
--- a/tests/scripts/ep_lora/test_shared_experts_fallback.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Lightweight unit tests for dsv4 shared_experts fallback."""
-import torch
-import torch.nn as nn
-
-from twinkle.model.transformers.strategy.native_fsdp import _collect_expert_params
-
-
-def _make_block(use_plural_shared: bool, ignore_shared: bool) -> nn.Module:
- block = nn.Module()
- experts = nn.Module()
- experts.gate_up_proj = nn.Parameter(torch.randn(2, 4, 8))
- experts.down_proj = nn.Parameter(torch.randn(2, 4, 4))
- block.experts = experts
- shared = nn.Linear(4, 4)
- if use_plural_shared:
- block.shared_experts = shared
- else:
- block.shared_expert = shared
- block._ep_patched = True
- block._ep_ignore_shared_experts = ignore_shared
- parent = nn.Module()
- parent.block = block
- return parent
-
-
-def test_singular_shared_expert_collected():
- parent = _make_block(use_plural_shared=False, ignore_shared=True)
- ignored = _collect_expert_params(parent)
- expected_count = 2 + 2
- assert ignored is not None and len(ignored) == expected_count
-
-
-def test_plural_shared_experts_collected():
- parent = _make_block(use_plural_shared=True, ignore_shared=True)
- ignored = _collect_expert_params(parent)
- expected_count = 2 + 2
- assert ignored is not None and len(ignored) == expected_count, (
- '_collect_expert_params should fall back to shared_experts (plural) for dsv4')
-
-
-def test_no_ignore_shared_only_collects_experts():
- parent = _make_block(use_plural_shared=True, ignore_shared=False)
- ignored = _collect_expert_params(parent)
- assert ignored is not None and len(ignored) == 2
From d795c777feaa2935f0fb00471349076df1b6af06 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Wed, 20 May 2026 11:11:21 +0800
Subject: [PATCH 34/40] fix
---
.../transformers/strategy/native_fsdp.py | 8 +++++++-
.../model/transformers/transformers.py | 20 ++++++++++++++++++-
2 files changed, 26 insertions(+), 2 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 40864d2d..df9c3b47 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -271,7 +271,7 @@ def get_full_state_dict(self, model) -> dict:
local_full = local_full.contiguous().to(Platform.get_local_device())
gathered = [torch.empty_like(local_full) for _ in range(ep_world_size)]
dist.all_gather(gathered, local_full, group=ep_group)
- local_full = torch.cat(gathered, dim=0)
+ local_full = torch.cat(gathered, dim=_ep_expert_state_dict_gather_dim(name))
state_dict[name] = local_full.cpu()
del gathered, local_full
else:
@@ -296,6 +296,12 @@ def _detect_ep_expert_names(model: nn.Module) -> Set[str]:
return candidate_names & actual_param_names
+def _ep_expert_state_dict_gather_dim(name: str) -> int:
+ if 'lora_B' in name:
+ return 1
+ return 0
+
+
def _build_mp_policy(mixed_precision: str) -> 'MixedPrecisionPolicy':
from torch.distributed.fsdp import MixedPrecisionPolicy
if mixed_precision == 'bf16':
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 92354c6b..6e2abe60 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -142,6 +142,19 @@ def _split_for_ep_pre_distribute(model, model_key: str, value: torch.Tensor, ep_
return value
+def _has_param_wrapper_without_base_weight(model) -> bool:
+ for module in model.modules():
+ if not hasattr(module, 'parameter_name'):
+ continue
+ get_base_layer = getattr(module, 'get_base_layer', None)
+ if get_base_layer is None:
+ continue
+ base_layer = get_base_layer()
+ if not hasattr(base_layer, 'weight'):
+ return True
+ return False
+
+
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -1175,6 +1188,7 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
model_sd = model.state_dict()
converted_weights = {}
+ direct_weights = {}
for key, value in adapter_weights.items():
model_key = key
if f'.{adapter_name}.weight' not in model_key:
@@ -1184,9 +1198,13 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
value = _split_for_ep_pre_distribute(model, model_key, value, ep_world_size, ep_rank)
if isinstance(param, DTensor) and not isinstance(value, DTensor):
value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements)
+ direct_weights[model_key] = value
converted_weights[key] = value
- set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name)
+ if _has_param_wrapper_without_base_weight(model):
+ model.load_state_dict(direct_weights, strict=False)
+ else:
+ set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name)
if self.device_mesh.fsdp_world_size > 1:
load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name=adapter_name)
From 0c478575fef39535fa2936b6832fd0a9c2346cad Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Thu, 21 May 2026 10:33:43 +0800
Subject: [PATCH 35/40] wip
---
.../transformers/strategy/native_fsdp.py | 128 ++++++++++++++++--
.../strategy/native_fsdp_state.py | 67 +++++++++
.../model/transformers/transformers.py | 6 +
3 files changed, 186 insertions(+), 15 deletions(-)
create mode 100644 src/twinkle/model/transformers/strategy/native_fsdp_state.py
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index df9c3b47..2024be9b 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -8,6 +8,12 @@
from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context
+from .native_fsdp_state import (
+ _collect_adapter_source_state,
+ _collect_state_metadata,
+ _is_lora_state_key,
+ _resolve_full_state_source_key,
+)
if TYPE_CHECKING:
from torch.distributed.fsdp import MixedPrecisionPolicy
@@ -29,6 +35,7 @@ def __init__(self,
self.enable_ep = enable_ep
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
self._rank0_pre_ep_full_state_dict = None
+ self._adapter_full_state_dict = None
def pretrained_load_context(self):
# Native FSDP handles rank0-load itself. Do not enable Transformers'
@@ -42,6 +49,9 @@ def use_rank0_pretrained_broadcast(self) -> bool:
def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self._rank0_pre_ep_full_state_dict = state_dict
+ def set_adapter_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
+ self._adapter_full_state_dict = state_dict
+
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
return None
@@ -74,12 +84,16 @@ def wrap_model(self, model, optimizer=None):
original_sd = None
saved_buffers = None
+ adapter_source_sd = {}
+ adapter_full_sd = {}
if use_meta:
is_rank0 = (dist.get_rank() == 0)
if ep_enabled and self._rank0_pre_ep_full_state_dict is not None:
original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {}
else:
original_sd = model.state_dict() if is_rank0 else {}
+ adapter_source_sd = _collect_adapter_source_state(model.state_dict())
+ adapter_full_sd = self._adapter_full_state_dict if is_rank0 and self._adapter_full_state_dict else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
if is_rank0:
model = model.to(torch.device('meta'))
@@ -152,7 +166,10 @@ def wrap_model(self, model, optimizer=None):
device_type=device_type,
expert_shard_specs=expert_shard_specs,
rank_to_ep_rank=rank_to_ep_rank,
+ adapter_source_sd=adapter_source_sd,
+ adapter_full_sd=adapter_full_sd,
)
+ self._adapter_full_state_dict = None
target_device = torch.device(device_type)
_broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device)
if hasattr(model, 'tie_weights'):
@@ -572,6 +589,8 @@ def _broadcast_sharded_state_dict(
device_type: str = 'cuda',
expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None,
rank_to_ep_rank: Optional[Dict[int, int]] = None,
+ adapter_source_sd: Optional[Dict[str, torch.Tensor]] = None,
+ adapter_full_sd: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Broadcast full state dict from rank 0 and materialize local FSDP2 shards."""
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard
@@ -581,15 +600,28 @@ def _broadcast_sharded_state_dict(
is_rank0 = (dist.get_rank() == 0)
expert_shard_specs = expert_shard_specs or {}
rank_to_ep_rank = rank_to_ep_rank or {}
+ adapter_source_sd = adapter_source_sd or {}
+ adapter_full_sd = adapter_full_sd or {}
source_metadata = None
+ source_keys = None
+ adapter_metadata = None
if is_rank0:
- source_metadata = {
- name: (tuple(tensor.shape), tensor.dtype)
- for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
- }
- metadata_holder = [source_metadata]
+ source_metadata = {}
+ source_keys = {}
+ for param_name in meta_sharded_sd:
+ if _is_lora_state_key(param_name):
+ continue
+ source_key = _resolve_full_state_source_key(param_name, full_sd)
+ source_tensor = full_sd[source_key]
+ if hasattr(source_tensor, 'shape') and hasattr(source_tensor, 'dtype'):
+ source_metadata[param_name] = (tuple(source_tensor.shape), source_tensor.dtype)
+ source_keys[param_name] = source_key
+ adapter_metadata = _collect_state_metadata({**adapter_source_sd, **adapter_full_sd})
+ metadata_holder = [source_metadata, source_keys, adapter_metadata]
dist.broadcast_object_list(metadata_holder, src=0)
source_metadata = metadata_holder[0] or {}
+ source_keys = metadata_holder[1] or {}
+ adapter_metadata = metadata_holder[2] or {}
def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
local_tensor = full_tensor
@@ -620,6 +652,51 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
stride=full_tensor.stride(),
)
+ def _broadcast_adapter_source_tensor(full_tensor, sharded_param):
+ if not isinstance(sharded_param, DTensor):
+ dist.broadcast(full_tensor, src=0)
+ return full_tensor
+ mesh = sharded_param.device_mesh.mesh
+ source_rank = int(mesh.flatten()[0].item())
+ dist.broadcast(full_tensor, src=source_rank, group=sharded_param.device_mesh.get_group())
+ return full_tensor
+
+ def _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param):
+ local_shape = tuple(sharded_param.size())
+ _, source_dtype = adapter_metadata[param_name]
+ local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)
+
+ if is_rank0:
+ shard_dim = _ep_expert_state_dict_gather_dim(param_name)
+ local_dim = local_shape[shard_dim]
+ world_size = dist.get_world_size()
+ for rank in range(world_size):
+ if rank not in rank_to_ep_rank:
+ raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.')
+ ep_rank = rank_to_ep_rank[rank]
+ start = ep_rank * local_dim
+ chunk = full_tensor.narrow(shard_dim, start, local_dim).contiguous().to(device_type)
+ if rank == 0:
+ local_tensor.copy_(chunk)
+ else:
+ dist.send(chunk, dst=rank)
+ else:
+ dist.recv(local_tensor, src=0)
+
+ return local_tensor
+
+ def _get_adapter_source(param_name):
+ if param_name in adapter_full_sd:
+ adapter_tensor = adapter_full_sd[param_name]
+ return adapter_tensor, tuple(adapter_tensor.shape), adapter_tensor.dtype
+ if param_name in adapter_source_sd:
+ adapter_tensor = adapter_source_sd[param_name]
+ return adapter_tensor, tuple(adapter_tensor.shape), adapter_tensor.dtype
+ if param_name in adapter_metadata:
+ source_shape, source_dtype = adapter_metadata[param_name]
+ return None, source_shape, source_dtype
+ raise KeyError(f"Missing adapter source state for parameter '{param_name}'.")
+
def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
spec = expert_shard_specs[param_name]
experts_per_rank = spec['experts_per_rank']
@@ -655,16 +732,34 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
for param_name, sharded_param in meta_sharded_sd.items():
shape = sharded_param.size()
- is_ep_expert_param = param_name in expert_shard_specs
- if param_name not in source_metadata:
- raise KeyError(f"Missing source metadata for parameter '{param_name}'.")
- source_shape, source_dtype = source_metadata[param_name]
-
- if is_rank0:
- if param_name not in full_sd:
+ is_adapter_param = _is_lora_state_key(param_name)
+ is_ep_expert_param = param_name in expert_shard_specs and not is_adapter_param
+ if is_adapter_param:
+ adapter_tensor, source_shape, source_dtype = _get_adapter_source(param_name)
+ else:
+ if param_name not in source_metadata:
+ raise KeyError(f"Missing source metadata for parameter '{param_name}'.")
+ source_shape, source_dtype = source_metadata[param_name]
+ is_ep_adapter_param = (
+ is_adapter_param and param_name in expert_shard_specs and tuple(source_shape) != tuple(shape))
+
+ if is_adapter_param:
+ if adapter_tensor is not None:
+ full_tensor = adapter_tensor.detach()
+ if isinstance(full_tensor, DTensor):
+ full_tensor = full_tensor.to_local()
+ if not is_ep_adapter_param:
+ full_tensor = full_tensor.to(device_type)
+ else:
+ full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype)
+ if not is_ep_adapter_param:
+ full_tensor = _broadcast_adapter_source_tensor(full_tensor, sharded_param)
+ elif is_rank0:
+ source_key = source_keys[param_name]
+ if source_key not in full_sd:
raise KeyError(
f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.")
- full_param = full_sd[param_name]
+ full_param = full_sd[source_key]
full_tensor = full_param.detach()
if isinstance(full_tensor, DTensor):
full_tensor = full_tensor.to_local()
@@ -679,13 +774,16 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
full_tensor = None if is_ep_expert_param else torch.empty(
source_shape, device=device_type, dtype=source_dtype)
- if is_ep_expert_param:
+ if is_ep_adapter_param:
+ full_tensor = _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param)
+ elif is_ep_expert_param:
full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param)
else:
if tuple(shape) != tuple(source_shape):
raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
- dist.broadcast(full_tensor, src=0)
+ if not is_adapter_param:
+ dist.broadcast(full_tensor, src=0)
torch_util.synchronize()
if isinstance(sharded_param, DTensor):
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp_state.py b/src/twinkle/model/transformers/strategy/native_fsdp_state.py
new file mode 100644
index 00000000..dc62d951
--- /dev/null
+++ b/src/twinkle/model/transformers/strategy/native_fsdp_state.py
@@ -0,0 +1,67 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+from typing import Any, Dict, Mapping
+
+LORA_STATE_KEY_MARKERS = ('lora_A', 'lora_B', 'lora_embedding')
+PEFT_BASE_PREFIX = 'base_model.model.'
+PEFT_BASE_LAYER_SEGMENT = 'base_layer'
+
+
+def _is_lora_state_key(name: str) -> bool:
+ return any(marker in name for marker in LORA_STATE_KEY_MARKERS)
+
+
+def _strip_peft_base_prefix(name: str) -> str:
+ while name.startswith(PEFT_BASE_PREFIX):
+ name = name[len(PEFT_BASE_PREFIX):]
+ return name
+
+
+def _strip_base_layer_segments(name: str) -> str:
+ return '.'.join(segment for segment in name.split('.') if segment != PEFT_BASE_LAYER_SEGMENT)
+
+
+def _source_key_candidates(param_name: str) -> list[str]:
+ stripped_prefix = _strip_peft_base_prefix(param_name)
+ candidates = [
+ param_name,
+ stripped_prefix,
+ _strip_base_layer_segments(param_name),
+ _strip_base_layer_segments(stripped_prefix),
+ ]
+ deduped = []
+ for candidate in candidates:
+ if candidate not in deduped:
+ deduped.append(candidate)
+ return deduped
+
+
+def _resolve_full_state_source_key(param_name: str, source_state: Mapping[str, Any]) -> str:
+ if _is_lora_state_key(param_name):
+ raise KeyError(f"LoRA parameter '{param_name}' must be loaded from adapter source state.")
+
+ candidates = _source_key_candidates(param_name)
+ for candidate in candidates:
+ if candidate in source_state:
+ return candidate
+ raise KeyError(
+ f"Missing source metadata for parameter '{param_name}'. "
+ f'Tried source keys: {", ".join(candidates)}.')
+
+
+def _collect_adapter_source_state(state_dict: Mapping[str, Any]) -> Dict[str, Any]:
+ adapter_state = {}
+ for name, tensor in state_dict.items():
+ if not _is_lora_state_key(name) or not hasattr(tensor, 'detach'):
+ continue
+ if getattr(tensor, 'is_meta', False):
+ continue
+ adapter_state[name] = tensor.detach().cpu().clone()
+ return adapter_state
+
+
+def _collect_state_metadata(state_dict: Mapping[str, Any]) -> Dict[str, tuple[tuple[int, ...], Any]]:
+ return {
+ name: (tuple(tensor.shape), tensor.dtype)
+ for name, tensor in state_dict.items()
+ if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ }
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 6e2abe60..1e6ba4bd 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -1189,18 +1189,24 @@ def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
model_sd = model.state_dict()
converted_weights = {}
direct_weights = {}
+ full_adapter_source = {}
for key, value in adapter_weights.items():
model_key = key
if f'.{adapter_name}.weight' not in model_key:
model_key = model_key.replace('.weight', f'.{adapter_name}.weight')
if model_key in model_sd:
param = model_sd[model_key]
+ full_adapter_source[model_key] = value.detach().cpu().clone()
value = _split_for_ep_pre_distribute(model, model_key, value, ep_world_size, ep_rank)
if isinstance(param, DTensor) and not isinstance(value, DTensor):
value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements)
direct_weights[model_key] = value
converted_weights[key] = value
+ set_adapter_full_state = getattr(self.strategy, 'set_adapter_full_state_dict', None)
+ if set_adapter_full_state is not None and full_adapter_source:
+ set_adapter_full_state(full_adapter_source)
+
if _has_param_wrapper_without_base_weight(model):
model.load_state_dict(direct_weights, strict=False)
else:
From 77dc0991da774b3a071a8009fd7cafd250e38f49 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Thu, 21 May 2026 14:25:23 +0800
Subject: [PATCH 36/40] cookboook
---
README.md | 3 +-
cookbook/transformers/deepseek_v4.py | 203 ------------------
.../transformers/ep_fsdp2_lora_deepseek_v4.py | 166 ++++++++++++++
.../transformers/ep_fsdp2_lora_deepseek_v4.sh | 18 ++
.../transformers/ep_fsdp2_lora_qwen3_5_moe.py | 158 ++++++++++++++
.../transformers/ep_fsdp2_lora_qwen3_5_moe.sh | 18 ++
cookbook/transformers/ep_fsdp_qwen3_moe.py | 111 ----------
cookbook/transformers/ep_fsdp_qwen3_moe.sh | 7 -
cookbook/transformers/ep_lora_deepseek_v4.py | 100 ---------
cookbook/transformers/ep_lora_qwen3_5_moe.py | 106 ---------
cookbook/transformers/fsdp2_moe.py | 95 --------
cookbook/transformers/fsdp2_moe.sh | 1 -
cookbook/transformers/fsdp2_moe_npu.sh | 6 -
13 files changed, 361 insertions(+), 631 deletions(-)
delete mode 100644 cookbook/transformers/deepseek_v4.py
create mode 100644 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
create mode 100644 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
create mode 100644 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
create mode 100644 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
delete mode 100644 cookbook/transformers/ep_fsdp_qwen3_moe.py
delete mode 100644 cookbook/transformers/ep_fsdp_qwen3_moe.sh
delete mode 100644 cookbook/transformers/ep_lora_deepseek_v4.py
delete mode 100644 cookbook/transformers/ep_lora_qwen3_5_moe.py
delete mode 100644 cookbook/transformers/fsdp2_moe.py
delete mode 100644 cookbook/transformers/fsdp2_moe.sh
delete mode 100644 cookbook/transformers/fsdp2_moe_npu.sh
diff --git a/README.md b/README.md
index 5adaa24d..d5a2a1b3 100644
--- a/README.md
+++ b/README.md
@@ -92,8 +92,7 @@ sh INSTALL_MEGATRON.sh
| Training Type | Model Framework | Cookbook Path |
| ------------------------------------ | --------------- | ----------------------------------------------------- |
| FSDP finetuning | transformers | [Script](cookbook/transformers/fsdp2.py) |
-| FSDP MoE finetuning | transformers | [Script](cookbook/transformers/fsdp2_moe.py) |
-| EP FSDP MoE finetuning | transformers | [Script](cookbook/transformers/ep_fsdp_qwen3_moe.py) |
+| EP FSDP2 LoRA finetuning | transformers | [Script](cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py) |
| SP FSDP finetuning | transformers | [Script](cookbook/transformers/sp_fsdp_dense.py) |
| pp/tp/cp finetuning | megatron | [Script](cookbook/megatron/tp.py) |
| pp/tp/cp MoE finetuning | megatron | [Script](cookbook/megatron/tp_moe.py) |
diff --git a/cookbook/transformers/deepseek_v4.py b/cookbook/transformers/deepseek_v4.py
deleted file mode 100644
index 3e10d44c..00000000
--- a/cookbook/transformers/deepseek_v4.py
+++ /dev/null
@@ -1,203 +0,0 @@
-import os
-import time
-
-import torch.distributed as dist
-import twinkle
-from peft import LoraConfig
-from transformers import AutoConfig
-from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
-from twinkle.dataloader import DataLoader
-from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.model import TransformersModel
-from twinkle.preprocessor import SelfCognitionProcessor
-
-logger = get_logger()
-
-MODEL_ID = os.environ.get('MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4-flash-bfa16')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
-OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
-
-_num_layers_env = os.environ.get('NUM_LAYERS','1')
-NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
-
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
-GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '2'))
-LR = float(os.environ.get('LR', '1e-4'))
-MAX_STEPS = int(os.environ.get('MAX_STEPS', '0'))
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '50'))
-USE_LORA = os.environ.get('USE_LORA', '1') == '1'
-IGNORE_MISMATCHED_SIZES = os.environ.get('IGNORE_MISMATCHED_SIZES', '1') == '1'
-GRADIENT_CHECKPOINTING = os.environ.get('GRADIENT_CHECKPOINTING', '1') == '1'
-RESHARD_AFTER_FORWARD = os.environ.get('RESHARD_AFTER_FORWARD', '1') == '1'
-LORA_TARGET_MODULES = os.environ.get(
- 'LORA_TARGET_MODULES',
- 'wq_a,wq_b,wkv,wgate,gate_proj,up_proj,down_proj',
-)
-
-device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=4,
- device_type=Platform.get_platform().device_prefix(),
-)
-
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
-
-def debug_print(message: str):
- if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
- return
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else Platform.get_rank()
- local_rank = Platform.get_local_rank()
- timestamp = time.time()
- text = f'[twinkle-train-debug][time={timestamp:.6f} rank{rank} local_rank={local_rank}] {message}'
- print(text, flush=True)
- debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
- if debug_dir:
- os.makedirs(debug_dir, exist_ok=True)
- with open(os.path.join(debug_dir, f'train_rank{rank}.log'), 'a', encoding='utf-8') as f:
- f.write(text + '\n')
-
-
-def describe_batch(batch):
- if isinstance(batch, dict):
- return f'dict_keys={list(batch.keys())}'
- if isinstance(batch, (list, tuple)):
- return f'{type(batch).__name__}[len={len(batch)}]'
- return type(batch).__name__
-
-
-def log_expert_parallel_status(model):
- logger.info(
- f'EP flags: enabled={getattr(model, "_enable_expert_parallel", None)}, '
- f'applied={getattr(model, "_expert_parallel_applied", None)}')
- raw_model = model.strategy.unwrap_model(model.model)
- found = False
- for name, module in raw_model.named_modules():
- if not hasattr(module, '_ep_patched'):
- continue
- found = True
- logger.info(
- 'EP block %s: patched=%s rank=%s/%s local_experts=[%s, %s) experts_per_rank=%s',
- name,
- getattr(module, '_ep_patched', None),
- getattr(module, '_ep_rank', None),
- getattr(module, '_ep_world_size', None),
- getattr(module, '_ep_local_start', None),
- getattr(module, '_ep_local_end', None),
- getattr(module, '_ep_experts_per_rank', None),
- )
- if not found:
- logger.info('No EP-patched MoE blocks found on the wrapped model.')
-
-
-def create_dataset(data_slice=None):
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=data_slice or range(1000)))
- dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
- dataset.encode(batched=True)
- return dataset
-
-
-def eval(model):
- dataset = create_dataset(data_slice=range(100))
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
- for _, batch in enumerate(dataloader):
- model.forward_only(inputs=batch, adapter_name='default')
- model.calculate_loss(adapter_name='default')
- return model.calculate_metric(is_training=False, adapter_name='default')
-
-
-def train():
- dataset = create_dataset()
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE)
-
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
- config.num_hidden_layers = NUM_LAYERS
- if hasattr(config, 'use_cache'):
- config.use_cache = False
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- strategy='native_fsdp',
- memory_efficient_init=True,
- ignore_mismatched_sizes=IGNORE_MISMATCHED_SIZES,
- fsdp_config={
- 'reshard_after_forward': RESHARD_AFTER_FORWARD,
- 'expert_parallel': {
- 'enabled': True,
- 'router_dtype': 'fp32',
- 'keep_router_logits': False,
- },
- },
- )
-
- if USE_LORA:
- lora_target_modules = [name.strip() for name in LORA_TARGET_MODULES.split(',') if name.strip()]
- lora_config = LoraConfig(r=8, lora_alpha=32, target_modules=lora_target_modules)
- model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
-
- if not GRADIENT_CHECKPOINTING:
- model.model.gradient_checkpointing_disable()
-
- model.set_template(TEMPLATE_ID, model_id=MODEL_ID, adapter_name='default')
- model.set_optimizer('AdamW', lr=LR, foreach=False, adapter_name='default')
- model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
- num_training_steps=len(dataloader),
- adapter_name='default',
- )
-
- logger.info(get_device_placement())
- logger.info(model.get_train_configs(adapter_name='default'))
- logger.info(
- f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, '
- f'grad_accum={GRAD_ACCUM_STEPS}, lr={LR:.2e}, use_lora={USE_LORA}, '
- f'num_layers={NUM_LAYERS}, ignore_mismatched_sizes={IGNORE_MISMATCHED_SIZES}, '
- f'gradient_checkpointing={GRADIENT_CHECKPOINTING}, '
- f'reshard_after_forward={RESHARD_AFTER_FORWARD}, '
- f'lora_target_modules={LORA_TARGET_MODULES}')
-
- best_loss = float('inf')
- for step, batch in enumerate(dataloader):
- if MAX_STEPS and step >= MAX_STEPS:
- break
- if step < 2:
- debug_print(f'step={step} before forward_backward batch={describe_batch(batch)}')
- model.forward_backward(
- inputs=batch,
- adapter_name='default',
- )
- if step < 2:
- debug_print(f'step={step} after forward_backward')
- model.clip_grad_and_step(
- adapter_name='default',
- gradient_accumulation_steps=GRAD_ACCUM_STEPS,
- )
- if step < 2:
- debug_print(f'step={step} after clip_grad_and_step')
- if step == 0:
- log_expert_parallel_status(model)
-
- if step % 20 == 0:
- metric = model.calculate_metric(is_training=True, adapter_name='default')
- logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
-
- if step > 0 and step % SAVE_STEPS == 0:
- metrics = eval(model)
- logger.info(f'Eval metric: {metrics}')
- loss = float(metrics['loss'])
- if loss < best_loss:
- model.save(name=f'checkpoint-{step}', output_dir=OUTPUT_DIR, adapter_name='default')
- best_loss = loss
-
- model.save(name='last-checkpoint', output_dir=OUTPUT_DIR, adapter_name='default')
-
-
-if __name__ == '__main__':
- train()
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
new file mode 100644
index 00000000..93819bef
--- /dev/null
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
@@ -0,0 +1,166 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""EP + FSDP2 + LoRA SFT cookbook for DeepSeek-V4.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
+"""
+import os
+from pathlib import Path
+
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
+NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '2'))
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
+GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
+LR = float(os.environ.get('LR', '1e-4'))
+MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
+LORA_R = int(os.environ.get('LORA_R', '8'))
+LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
+ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1'
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '0'))
+OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4')
+RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None
+RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
+IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1'
+ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default')
+
+device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+)
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def _get_text_config(config):
+ return getattr(config, 'text_config', config)
+
+
+def _configure_smoke_config(config):
+ text_config = _get_text_config(config)
+ old_num_hidden_layers = getattr(text_config, 'num_hidden_layers', NUM_LAYERS)
+ text_config.num_hidden_layers = NUM_LAYERS
+ if hasattr(text_config, 'use_cache'):
+ text_config.use_cache = False
+ if hasattr(text_config, 'num_hash_layers'):
+ text_config.num_hash_layers = min(text_config.num_hash_layers, NUM_LAYERS)
+ if hasattr(text_config, 'compress_ratios'):
+ extra_entries = max(len(text_config.compress_ratios) - old_num_hidden_layers, 0)
+ keep = min(len(text_config.compress_ratios), NUM_LAYERS + extra_entries)
+ text_config.compress_ratios = list(text_config.compress_ratios[:keep])
+
+
+def _build_lora_config(enable_ep: bool):
+ if enable_ep:
+ return LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules='all-linear',
+ exclude_modules=['o_a_proj'],
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ # Expert weights are bare nn.Parameters. PEFT trains them through
+ # target_parameters/ParamWrapper, which dynamically parametrizes weights
+ # during forward. That is not stable with plain FSDP2, so non-EP mode uses
+ # regular module LoRA and does not train expert parameters.
+ return LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules='all-linear',
+ )
+
+
+def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
+ model.save(
+ name=checkpoint_name,
+ output_dir=OUTPUT_DIR,
+ adapter_name=ADAPTER_NAME,
+ save_optimizer=True,
+ consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
+ )
+
+
+def train():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ _configure_smoke_config(config)
+
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(500)))
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset.encode(batched=True)
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ strategy='native_fsdp',
+ memory_efficient_init=True,
+ fsdp_config={
+ 'expert_parallel': {
+ 'enabled': ENABLE_EP,
+ 'router_dtype': 'fp32',
+ 'keep_router_logits': False,
+ }
+ },
+ )
+ lora_cfg = _build_lora_config(ENABLE_EP)
+ model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ model.set_optimizer('AdamW', lr=LR, foreach=False)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ )
+
+ if RESUME_FROM_CHECKPOINT:
+ checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
+ kwargs = {}
+ if ADAPTER_NAME:
+ kwargs['adapter_name'] = ADAPTER_NAME
+ progress = model.resume_from_checkpoint(
+ str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
+ if not IGNORE_DATA_SKIP:
+ dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+ logger.info(
+ f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
+ f'num_layers={NUM_LAYERS}, enable_ep={ENABLE_EP}, save_steps={SAVE_STEPS}, output_dir={OUTPUT_DIR}')
+
+ optimizer_group = model.optimizer_group[ADAPTER_NAME]
+ for step, batch in enumerate(dataloader):
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(inputs=batch)
+ model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ cur_step = optimizer_group.cur_step
+ if cur_step > 0 and (step + 1) % GRAD_ACCUM_STEPS == 0:
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ logger.info(f'optimizer_step {cur_step}, metric: {metric}')
+ if SAVE_STEPS and cur_step % SAVE_STEPS == 0:
+ save_checkpoint(model, f'checkpoint-{cur_step}', dataloader)
+
+ save_checkpoint(model, 'checkpoint-final', dataloader)
+ logger.info(f'Saved final adapter to {OUTPUT_DIR}/checkpoint-final')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
new file mode 100644
index 00000000..ce64b176
--- /dev/null
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# EP + FSDP2 + LoRA training for DeepSeek-V4.
+# ENABLE_EP=1 trains expert LoRA with target_parameters.
+# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters.
+
+export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
+export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
+export ENABLE_EP="${ENABLE_EP:-1}"
+export NUM_LAYERS="${NUM_LAYERS:-2}"
+export BATCH_SIZE="${BATCH_SIZE:-4}"
+export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}"
+export SAVE_STEPS="${SAVE_STEPS:-0}"
+export OUTPUT_DIR="${OUTPUT_DIR:-./output_dsv4}"
+
+torchrun --nproc-per-node="${NPROC_PER_NODE}" \
+ cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
new file mode 100644
index 00000000..0ab64ae6
--- /dev/null
+++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
@@ -0,0 +1,158 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+"""EP + FSDP2 + LoRA SFT cookbook for Qwen3.5-MoE.
+
+Run on 4 GPUs:
+ torchrun --nproc-per-node=4 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
+"""
+import os
+from pathlib import Path
+
+from peft import LoraConfig
+from transformers import AutoConfig
+
+import twinkle
+from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
+from twinkle.dataloader import DataLoader
+from twinkle.dataset import Dataset, DatasetMeta
+from twinkle.model import TransformersModel
+from twinkle.preprocessor import SelfCognitionProcessor
+
+logger = get_logger()
+
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
+TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template')
+_num_layers_env = os.environ.get('NUM_LAYERS')
+NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
+BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
+GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
+LR = float(os.environ.get('LR', '1e-4'))
+MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
+LORA_R = int(os.environ.get('LORA_R', '8'))
+LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
+ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1'
+SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '0'))
+OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
+RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None
+RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
+IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1'
+ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default')
+
+device_mesh = DeviceMesh.from_sizes(
+ fsdp_size=4,
+ dp_size=1,
+ ep_size=2,
+ device_type=Platform.get_platform().device_prefix(),
+)
+twinkle.initialize(mode='local', global_device_mesh=device_mesh)
+
+
+def _get_text_config(config):
+ return getattr(config, 'text_config', config)
+
+
+def _build_lora_config(enable_ep: bool):
+ if enable_ep:
+ return LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules='all-linear',
+ target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
+ )
+ # Expert weights are bare nn.Parameters. PEFT trains them through
+ # target_parameters/ParamWrapper, which dynamically parametrizes weights
+ # during forward. That is not stable with plain FSDP2, so non-EP mode uses
+ # regular module LoRA and does not train expert parameters.
+ return LoraConfig(
+ r=LORA_R,
+ lora_alpha=LORA_ALPHA,
+ target_modules='all-linear',
+ )
+
+
+def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
+ model.save(
+ name=checkpoint_name,
+ output_dir=OUTPUT_DIR,
+ adapter_name=ADAPTER_NAME,
+ save_optimizer=True,
+ consumed_train_samples=dataloader.get_state()['consumed_train_samples'],
+ )
+
+
+def train():
+ config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
+ text_config = _get_text_config(config)
+ if NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'):
+ text_config.num_hidden_layers = NUM_LAYERS
+ if hasattr(text_config, 'use_cache'):
+ text_config.use_cache = False
+
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
+ try:
+ dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
+ except ValueError:
+ dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
+ dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
+ dataset.encode(batched=True)
+ dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
+
+ model = TransformersModel(
+ model_id=MODEL_ID,
+ config=config,
+ device_mesh=device_mesh,
+ strategy='native_fsdp',
+ fsdp_config={
+ 'expert_parallel': {
+ 'enabled': ENABLE_EP,
+ 'router_dtype': 'fp32',
+ 'keep_router_logits': False,
+ }
+ },
+ )
+ lora_cfg = _build_lora_config(ENABLE_EP)
+ model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ model.set_optimizer('AdamW', lr=LR, foreach=False)
+ model.set_lr_scheduler(
+ scheduler_cls='CosineWarmupScheduler',
+ num_warmup_steps=5,
+ num_training_steps=len(dataloader),
+ )
+
+ if RESUME_FROM_CHECKPOINT:
+ checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve()
+ kwargs = {}
+ if ADAPTER_NAME:
+ kwargs['adapter_name'] = ADAPTER_NAME
+ progress = model.resume_from_checkpoint(
+ str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs)
+ if not IGNORE_DATA_SKIP:
+ dataloader.resume_from_checkpoint(progress['consumed_train_samples'])
+
+ logger.info(get_device_placement())
+ logger.info(model.get_train_configs())
+ logger.info(
+ f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
+ f'num_layers={NUM_LAYERS}, enable_ep={ENABLE_EP}, save_steps={SAVE_STEPS}, output_dir={OUTPUT_DIR}')
+
+ optimizer_group = model.optimizer_group[ADAPTER_NAME]
+ for step, batch in enumerate(dataloader):
+ if callable(batch):
+ batch = batch()
+ model.forward_backward(inputs=batch)
+ model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
+ cur_step = optimizer_group.cur_step
+ if cur_step > 0 and (step + 1) % GRAD_ACCUM_STEPS == 0:
+ metric = model.calculate_metric(is_training=True)
+ if callable(metric):
+ metric = metric()
+ logger.info(f'optimizer_step {cur_step}, metric: {metric}')
+ if SAVE_STEPS and cur_step % SAVE_STEPS == 0:
+ save_checkpoint(model, f'checkpoint-{cur_step}', dataloader)
+
+ save_checkpoint(model, 'checkpoint-final', dataloader)
+ logger.info(f'Saved final adapter to {OUTPUT_DIR}/checkpoint-final')
+
+
+if __name__ == '__main__':
+ train()
diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
new file mode 100644
index 00000000..760b5e99
--- /dev/null
+++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# EP + FSDP2 + LoRA training for Qwen3.5-MoE.
+# ENABLE_EP=1 trains expert LoRA with target_parameters.
+# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters.
+
+export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
+export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
+export ENABLE_EP="${ENABLE_EP:-1}"
+export NUM_LAYERS="${NUM_LAYERS:-2}"
+export BATCH_SIZE="${BATCH_SIZE:-4}"
+export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}"
+export SAVE_STEPS="${SAVE_STEPS:-0}"
+export OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_5_moe}"
+
+torchrun --nproc-per-node="${NPROC_PER_NODE}" \
+ cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py
deleted file mode 100644
index 11855fae..00000000
--- a/cookbook/transformers/ep_fsdp_qwen3_moe.py
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-import os
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
-from twinkle.dataloader import DataLoader
-from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.model import TransformersModel
-from twinkle.preprocessor import SelfCognitionProcessor
-
-logger = get_logger()
-
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template')
-_num_layers_env = os.environ.get('NUM_LAYERS')
-NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
-GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
-LR = float(os.environ.get('LR', '1e-5'))
-MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
-KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1'
-
-# 8 gpus, dp=1, fsdp=8 (data parallel), ep_size=8 (expert parallel)
-device_mesh = DeviceMesh.from_sizes(
- fsdp_size=8,
- dp_size=1,
- ep_size=8,
- device_type=Platform.get_platform().device_prefix(),
-)
-
-twinkle.initialize(
- mode='local',
- global_device_mesh=device_mesh,
-)
-
-
-def train():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'):
- config.num_hidden_layers = NUM_LAYERS
- if hasattr(config, 'use_cache'):
- config.use_cache = False
-
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
- try:
- dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
- except ValueError:
- dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
-
- dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
- dataset.encode(batched=True)
- dataloader = DataLoader(
- dataset=dataset,
- batch_size=BATCH_SIZE,
- device_mesh=device_mesh,
- )
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={
- 'expert_parallel': {
- 'enabled': True,
- 'router_dtype': 'fp32',
- 'keep_router_logits': KEEP_ROUTER_LOGITS,
- }
- },
- )
- # Disable foreach to avoid DTensor mixed-type errors in EP runs.
- model.set_optimizer('AdamW', lr=LR, foreach=False)
- model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
- num_training_steps=len(dataloader),
- )
-
- logger.info(get_device_placement())
- logger.info(model.get_train_configs())
- logger.info(
- f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
- f'lr={LR:.2e}, max_grad_norm={MAX_GRAD_NORM}, '
- f'keep_router_logits={KEEP_ROUTER_LOGITS}')
-
- for step, batch in enumerate(dataloader):
- if callable(batch):
- batch = batch()
- model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- model.clip_grad_and_step(
- max_grad_norm=MAX_GRAD_NORM,
- gradient_accumulation_steps=GRAD_ACCUM_STEPS,
- )
-
- is_sync_step = ((step + 1) % GRAD_ACCUM_STEPS == 0)
- if is_sync_step:
- optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- logger.info(f'Current optimizer_step {optimizer_step}, metric: {metric}')
- if optimizer_step > 0 and optimizer_step % 50 == 0:
- model.save(name=f'checkpoint-step-{optimizer_step}', output_dir='./output')
-
- model.save(name='checkpoint-final', output_dir='./output')
- logger.info('Saved final checkpoint to ./output/checkpoint-final')
-
-
-if __name__ == '__main__':
- train()
diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.sh b/cookbook/transformers/ep_fsdp_qwen3_moe.sh
deleted file mode 100644
index cfc8a7cf..00000000
--- a/cookbook/transformers/ep_fsdp_qwen3_moe.sh
+++ /dev/null
@@ -1,7 +0,0 @@
-# EP + FSDP2 (Transformers MoE) example.
-# With expert_parallel enabled, expert parameters are sharded across the EP dimension.
-# Non-expert parameters are sharded by FSDP (across world_size).
-# Officially validated scope: qwen3_moe_like models (for example, Qwen3-30B-A3B).
-# Other MoE models may work if their MoE blocks expose: `experts` + `gate/router` + `top_k` (or `num_experts_per_tok`).
-# EP runtime constraints: `num_experts % ep_world_size == 0`.
-CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 ep_fsdp_qwen3_moe.py
diff --git a/cookbook/transformers/ep_lora_deepseek_v4.py b/cookbook/transformers/ep_lora_deepseek_v4.py
deleted file mode 100644
index 9d035a48..00000000
--- a/cookbook/transformers/ep_lora_deepseek_v4.py
+++ /dev/null
@@ -1,100 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-"""EP + LoRA SFT cookbook for DeepSeek-V4.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 cookbook/transformers/ep_lora_deepseek_v4.py
-"""
-import os
-
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
-from twinkle.dataloader import DataLoader
-from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.model import TransformersModel
-from twinkle.preprocessor import SelfCognitionProcessor
-
-logger = get_logger()
-
-MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
-NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '2'))
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '2'))
-GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
-LR = float(os.environ.get('LR', '1e-4'))
-MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
-LORA_R = int(os.environ.get('LORA_R', '8'))
-LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
-NUM_STEPS_LIMIT = int(os.environ.get('NUM_STEPS_LIMIT', '0'))
-
-device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
-)
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
-
-def _get_text_config(config):
- return getattr(config, 'text_config', config)
-
-
-def train():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- text_config = _get_text_config(config)
- text_config.num_hidden_layers = NUM_LAYERS
- if hasattr(text_config, 'use_cache'):
- text_config.use_cache = False
-
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(500)))
- dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
- dataset.encode(batched=True)
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
- lora_cfg = LoraConfig(
- r=LORA_R,
- lora_alpha=LORA_ALPHA,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- )
- model.add_adapter_to_model('default', lora_cfg)
- model.set_optimizer('AdamW', lr=LR, foreach=False)
- model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
- num_training_steps=len(dataloader),
- )
-
- logger.info(get_device_placement())
- logger.info(model.get_train_configs())
-
- for step, batch in enumerate(dataloader):
- if NUM_STEPS_LIMIT and step >= NUM_STEPS_LIMIT:
- break
- if callable(batch):
- batch = batch()
- model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- if (step + 1) % GRAD_ACCUM_STEPS == 0:
- optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- logger.info(f'optimizer_step {optimizer_step}, metric: {metric}')
-
- model.save(name='checkpoint-final', output_dir='./output_dsv4')
-
-
-if __name__ == '__main__':
- train()
diff --git a/cookbook/transformers/ep_lora_qwen3_5_moe.py b/cookbook/transformers/ep_lora_qwen3_5_moe.py
deleted file mode 100644
index e8bfebc2..00000000
--- a/cookbook/transformers/ep_lora_qwen3_5_moe.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-"""EP + LoRA SFT cookbook for Qwen3.5-MoE.
-
-Run on 4 GPUs:
- torchrun --nproc-per-node=4 cookbook/transformers/ep_lora_qwen3_5_moe.py
-"""
-import os
-
-from peft import LoraConfig
-from transformers import AutoConfig
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
-from twinkle.dataloader import DataLoader
-from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.model import TransformersModel
-from twinkle.preprocessor import SelfCognitionProcessor
-
-logger = get_logger()
-
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
-DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
-TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template')
-_num_layers_env = os.environ.get('NUM_LAYERS')
-NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
-BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
-GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
-LR = float(os.environ.get('LR', '1e-4'))
-MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
-LORA_R = int(os.environ.get('LORA_R', '8'))
-LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
-NUM_STEPS_LIMIT = int(os.environ.get('NUM_STEPS_LIMIT', '0'))
-
-device_mesh = DeviceMesh.from_sizes(
- fsdp_size=4,
- dp_size=1,
- ep_size=2,
- device_type=Platform.get_platform().device_prefix(),
-)
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
-
-def _get_text_config(config):
- return getattr(config, 'text_config', config)
-
-
-def train():
- config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- text_config = _get_text_config(config)
- if NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'):
- text_config.num_hidden_layers = NUM_LAYERS
- if hasattr(text_config, 'use_cache'):
- text_config.use_cache = False
-
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
- try:
- dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
- except ValueError:
- dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
- dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
- dataset.encode(batched=True)
- dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
-
- model = TransformersModel(
- model_id=MODEL_ID,
- config=config,
- device_mesh=device_mesh,
- fsdp_config={'expert_parallel': {'enabled': True, 'router_dtype': 'fp32'}},
- )
- lora_cfg = LoraConfig(
- r=LORA_R,
- lora_alpha=LORA_ALPHA,
- target_modules='all-linear',
- target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'],
- )
- model.add_adapter_to_model('default', lora_cfg)
- model.set_optimizer('AdamW', lr=LR, foreach=False)
- model.set_lr_scheduler(
- scheduler_cls='CosineWarmupScheduler',
- num_warmup_steps=5,
- num_training_steps=len(dataloader),
- )
-
- logger.info(get_device_placement())
- logger.info(model.get_train_configs())
-
- for step, batch in enumerate(dataloader):
- if NUM_STEPS_LIMIT and step >= NUM_STEPS_LIMIT:
- break
- if callable(batch):
- batch = batch()
- model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
- if (step + 1) % GRAD_ACCUM_STEPS == 0:
- optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
- metric = model.calculate_metric(is_training=True)
- if callable(metric):
- metric = metric()
- logger.info(f'optimizer_step {optimizer_step}, metric: {metric}')
-
- model.save(name='checkpoint-final', output_dir='./output')
- logger.info('Saved final adapter to ./output/checkpoint-final')
-
-
-if __name__ == '__main__':
- train()
diff --git a/cookbook/transformers/fsdp2_moe.py b/cookbook/transformers/fsdp2_moe.py
deleted file mode 100644
index a2965ee5..00000000
--- a/cookbook/transformers/fsdp2_moe.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import os
-from peft import LoraConfig
-from tqdm import tqdm
-
-import twinkle
-from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
-from twinkle.dataloader import DataLoader
-from twinkle.dataset import Dataset, DatasetMeta
-from twinkle.model import TransformersModel
-from twinkle.preprocessor import SelfCognitionProcessor
-from twinkle.utils.framework import Torch
-from twinkle.kernel import apply_npu_patch
-
-# Construct a device_mesh, fsdp=4, dp=2
-device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2)
-# use torchrun mode
-twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-
-logger = get_logger()
-
-
-# npu patch
-if Torch.is_npu_available():
- apply_npu_patch()
-
-
-def eval(model):
- # 100 Samples
- dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
- dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
- dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
- dataset.encode()
- dataloader = DataLoader(dataset=dataset, batch_size=4)
- for step, batch in tqdm(enumerate(dataloader)):
- model.forward_only(inputs=batch)
- model.calculate_loss()
- metrics = model.calculate_metric(is_training=False)
- return metrics
-
-
-def train():
- # 1000 samples
- dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
- # Set template to prepare encoding
- dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507')
- # Preprocess the dataset to standard format
- dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
- # Encode dataset
- dataset.encode()
- # Global batch size = 4, for GPUs, so 1 sample per GPU
- dataloader = DataLoader(dataset=dataset, batch_size=8)
- # Use a TransformersModel, transformer_cls_names_to_wrap=Qwen3MoeSparseMoeBlock to avoid hang of fsdp2
- model = TransformersModel(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']})
- # Patch MoE model to fix the hang bug, support transformers==4.*
- model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch')
- lora_config = LoraConfig(
- r=8,
- lora_alpha=32,
- target_modules='all-linear'
- )
-
- # Add a lora to model, with name `default`
- # Comment this to use full-parameter training
- model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2)
- # Add Optimizer for lora `default`
- model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
- # Add LRScheduler for lora `default`
- model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
- logger.info(get_device_placement())
- # Print the training config
- logger.info(model.get_train_configs())
- logger.info(f'Total steps: {len(dataloader)}')
- loss_metric = 99.0
- # lora: 34G * 8
- for step, batch in enumerate(dataloader):
- # Do forward and backward
- model.forward_backward(inputs=batch)
- # Step
- model.clip_grad_and_step()
- if step % 20 == 0:
- # Print metric
- metric = model.calculate_metric(is_training=True)
- logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
- if step > 0 and step % 40 == 0:
- metrics = eval(model)
- logger.info(f'Eval metric: {metrics}')
- metrics['step'] = step
- if loss_metric > float(metrics['loss']):
- model.save(f'checkpoint-{step}')
- loss_metric = float(metrics['loss'])
- model.save(f'last-checkpoint')
-
-
-if __name__ == '__main__':
- train()
diff --git a/cookbook/transformers/fsdp2_moe.sh b/cookbook/transformers/fsdp2_moe.sh
deleted file mode 100644
index c496cd1d..00000000
--- a/cookbook/transformers/fsdp2_moe.sh
+++ /dev/null
@@ -1 +0,0 @@
-CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py
diff --git a/cookbook/transformers/fsdp2_moe_npu.sh b/cookbook/transformers/fsdp2_moe_npu.sh
deleted file mode 100644
index 349f9d0d..00000000
--- a/cookbook/transformers/fsdp2_moe_npu.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/usr/bin/env bash
-
-# CANN loading
-source /usr/local/Ascend/ascend-toolkit/set_env.sh
-
-ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py
From cf0b79f10c1498e11c100c72698f3c3ec842ab70 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Thu, 21 May 2026 14:58:51 +0800
Subject: [PATCH 37/40] lint
---
.../model/transformers/strategy/accelerate.py | 51 ++---------
.../transformers/strategy/native_fsdp.py | 8 +-
.../strategy/native_fsdp_state.py | 8 +-
.../model/transformers/transformers.py | 84 ++-----------------
4 files changed, 19 insertions(+), 132 deletions(-)
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 353661fb..3d2a0639 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -1,6 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
-import time
from datetime import timedelta
from typing import Any, Dict, Literal, Optional
@@ -16,9 +15,9 @@ def _patch_accelerate_fsdp2_load_full_state_dict():
tensors; older Accelerate versions assume every state-dict entry has
`device_mesh` and fail on such buffers.
"""
+ import accelerate.utils.fsdp_utils as fsdp_utils
import torch
import torch.distributed as dist
- import accelerate.utils.fsdp_utils as fsdp_utils
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor
if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False):
@@ -27,13 +26,8 @@ def _patch_accelerate_fsdp2_load_full_state_dict():
original = fsdp_utils.fsdp2_load_full_state_dict
def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False):
- _fsdp_debug(
- f'enter fsdp2_load_full_state_dict device={accelerator.device} '
- f'full_sd_keys={len(full_sd) if full_sd is not None else "None"}')
-
meta_sharded_sd = model.state_dict()
sharded_sd = {}
- _fsdp_debug(f'patched fsdp2 meta_sharded_keys={len(meta_sharded_sd)}')
def _infer_parameter_dtype(model, param_name, empty_param):
try:
@@ -103,21 +97,13 @@ def _load_full_value(param_name, sharded_param):
def _tensor_debug(tensor):
if isinstance(tensor, DTensor):
- return (
- f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} '
- f'placements={tensor.placements} mesh={tensor.device_mesh}')
+ return (f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} '
+ f'placements={tensor.placements} mesh={tensor.device_mesh}')
if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'):
return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}'
return f'type={type(tensor).__name__}'
for param_name, sharded_param in meta_sharded_sd.items():
- _fsdp_debug(f'load state entry start: {param_name} {_tensor_debug(sharded_param)}')
- if accelerator.is_main_process:
- full_value = full_sd.get(param_name)
- if full_value is None:
- _fsdp_debug(f'full state entry missing: {param_name}')
- else:
- _fsdp_debug(f'full state entry: {param_name} {_tensor_debug(full_value)}')
if isinstance(sharded_param, DTensor):
device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
@@ -131,9 +117,7 @@ def _tensor_debug(tensor):
)
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
- _fsdp_debug(f'broadcast done: {param_name}')
sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements)
- _fsdp_debug(f'local shard done: {param_name}')
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
if cpu_offload:
@@ -151,7 +135,6 @@ def _tensor_debug(tensor):
)
dist.broadcast(full_value, src=0, group=dist.group.WORLD)
- _fsdp_debug(f'broadcast done: {param_name}')
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value)
full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype)
if cpu_offload:
@@ -159,7 +142,6 @@ def _tensor_debug(tensor):
sharded_sd[param_name] = full_value
model.load_state_dict(sharded_sd, assign=True)
- _fsdp_debug('exit patched fsdp2_load_full_state_dict')
return model
patched_fsdp2_load_full_state_dict._twinkle_patched = True
@@ -167,27 +149,6 @@ def _tensor_debug(tensor):
fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict
-def _fsdp_debug(message: str) -> None:
- if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
- return
- try:
- import torch.distributed as dist
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
- world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
- except Exception:
- rank = 0
- world_size = 1
- local_rank = os.environ.get('LOCAL_RANK', '?')
- timestamp = time.time()
- text = f'[twinkle-fsdp-debug][time={timestamp:.6f} rank{rank}/{world_size} local_rank={local_rank}] {message}'
- print(text, flush=True)
- debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
- if debug_dir:
- os.makedirs(debug_dir, exist_ok=True)
- with open(os.path.join(debug_dir, f'fsdp_rank{rank}.log'), 'a', encoding='utf-8') as f:
- f.write(text + '\n')
-
-
class AccelerateStrategy:
"""A training strategy that uses `accelerate` to wrap models.
@@ -219,8 +180,8 @@ def __init__(
kwargs_handlers = []
kwargs_handlers.append(
- InitProcessGroupKwargs(timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))
- )
+ InitProcessGroupKwargs(
+ timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200')))))
if ddp_config is not None:
from accelerate import DistributedDataParallelKwargs
ddp_config = DistributedDataParallelKwargs(**ddp_config)
@@ -308,9 +269,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
return fsdp_plugin
def wrap_model(self, model, *args):
- _fsdp_debug('enter accelerator.prepare')
result = self.accelerator.prepare(model, *args)
- _fsdp_debug('exit accelerator.prepare')
return result
def unwrap_model(self, model):
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index 669d6970..ee571a06 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -8,12 +8,8 @@
from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context
-from .native_fsdp_state import (
- _collect_adapter_source_state,
- _collect_state_metadata,
- _is_lora_state_key,
- _resolve_full_state_source_key,
-)
+from .native_fsdp_state import (_collect_adapter_source_state, _collect_state_metadata, _is_lora_state_key,
+ _resolve_full_state_source_key)
if TYPE_CHECKING:
from torch.distributed.fsdp import MixedPrecisionPolicy
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp_state.py b/src/twinkle/model/transformers/strategy/native_fsdp_state.py
index dc62d951..bfb30163 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp_state.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp_state.py
@@ -43,9 +43,8 @@ def _resolve_full_state_source_key(param_name: str, source_state: Mapping[str, A
for candidate in candidates:
if candidate in source_state:
return candidate
- raise KeyError(
- f"Missing source metadata for parameter '{param_name}'. "
- f'Tried source keys: {", ".join(candidates)}.')
+ raise KeyError(f"Missing source metadata for parameter '{param_name}'. "
+ f'Tried source keys: {", ".join(candidates)}.')
def _collect_adapter_source_state(state_dict: Mapping[str, Any]) -> Dict[str, Any]:
@@ -62,6 +61,5 @@ def _collect_adapter_source_state(state_dict: Mapping[str, Any]) -> Dict[str, An
def _collect_state_metadata(state_dict: Mapping[str, Any]) -> Dict[str, tuple[tuple[int, ...], Any]]:
return {
name: (tuple(tensor.shape), tensor.dtype)
- for name, tensor in state_dict.items()
- if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ for name, tensor in state_dict.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
}
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 4035083d..a4247858 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -7,7 +7,6 @@
import random
import re
import threading
-import time
import torch
import torch.distributed as dist
import transformers
@@ -49,25 +48,6 @@
logger = get_logger()
-def _twinkle_fsdp_debug(message: str) -> None:
- if os.environ.get('TWINKLE_FSDP_DEBUG', '0') != '1':
- return
- try:
- rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else int(os.environ.get('RANK', 0))
- world_size = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
- except Exception:
- rank = int(os.environ.get('RANK', 0))
- world_size = 1
- local_rank = os.environ.get('LOCAL_RANK', '?')
- text = f'[twinkle-model-debug][time={time.time():.6f} rank{rank}/{world_size} local_rank={local_rank}] {message}'
- print(text, flush=True)
- debug_dir = os.environ.get('TWINKLE_DEBUG_DIR')
- if debug_dir:
- os.makedirs(debug_dir, exist_ok=True)
- with open(os.path.join(debug_dir, f'model_rank{rank}.log'), 'a', encoding='utf-8') as f:
- f.write(text + '\n')
-
-
def _get_named_child(module, name: str):
if hasattr(module, name):
return getattr(module, name)
@@ -237,9 +217,7 @@ def __init__(
memory_efficient_init: bool = False,
**kwargs):
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
- _twinkle_fsdp_debug('TransformersModel init before process_group')
self._try_init_process_group()
- _twinkle_fsdp_debug('TransformersModel init after process_group')
super(PreTrainedModel, self).__init__()
# The Default tokenizer will be used to save with a model if no template was set.
self._default_tokenizer = None
@@ -249,14 +227,9 @@ def __init__(
self._ddp_config = ddp_config or {}
self._memory_efficient_init = memory_efficient_init
self._decide_strategy(strategy)
- _twinkle_fsdp_debug(
- f'TransformersModel strategy decided strategy={strategy} '
- f'memory_efficient_init={memory_efficient_init}')
self.grad_scaler_config = grad_scaler_config
if model_id is not None:
- _twinkle_fsdp_debug(f'before HubOperation.download_model model_id={model_id}')
model_id = HubOperation.download_model(model_id)
- _twinkle_fsdp_debug(f'after HubOperation.download_model model_id={model_id}')
self.model_id = model_id
self.tokenizer_id = kwargs.get('tokenizer_id', self.model_id)
if config is None:
@@ -271,24 +244,14 @@ def __init__(
if isinstance(model_cls, str):
model_cls = getattr(transformers, model_cls)
if model_id is None:
- _twinkle_fsdp_debug('before model_cls.from_config')
self.model = model_cls.from_config(self.hf_config, **kwargs)
- _twinkle_fsdp_debug('after model_cls.from_config')
elif self._should_init_empty_pretrained_model_on_this_rank():
- _twinkle_fsdp_debug('before empty model_cls.from_config for rank0 broadcast')
self.model = self._init_empty_model_from_config(model_cls, **kwargs)
- _twinkle_fsdp_debug('after empty model_cls.from_config for rank0 broadcast')
else:
# Trigger transformers' FSDP-aware loading: meta-device init + rank-0-only weight load.
- _twinkle_fsdp_debug('before pretrained_load_context')
with self.strategy.pretrained_load_context():
- _twinkle_fsdp_debug('before model_cls.from_pretrained')
self.model = model_cls.from_pretrained(model_id, config=self.hf_config, **kwargs)
- _twinkle_fsdp_debug('after model_cls.from_pretrained')
- _twinkle_fsdp_debug('after pretrained_load_context')
- _twinkle_fsdp_debug('before gradient_checkpointing_enable')
self.model.gradient_checkpointing_enable()
- _twinkle_fsdp_debug('after gradient_checkpointing_enable')
self.sp_strategy = None
self._model_wrapped = False
self.optimizer_group: Dict[str, OptimizerGroup] = {
@@ -299,11 +262,7 @@ def __init__(
def _should_init_empty_pretrained_model_on_this_rank(self) -> bool:
use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
- return bool(
- use_rank0_broadcast()
- and dist.is_available()
- and dist.is_initialized()
- and dist.get_rank() != 0)
+ return bool(use_rank0_broadcast() and dist.is_available() and dist.is_initialized() and dist.get_rank() != 0)
def _init_empty_model_from_config(self, model_cls, **kwargs):
from accelerate import init_empty_weights
@@ -482,9 +441,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec
temperature = float(kwargs.pop('temperature', 1.0))
return_logits = kwargs.pop('return_logits', False)
optimizer_config = self.optimizer_group[adapter_name]
- _twinkle_fsdp_debug('forward before _lazy_wrap_model')
self._lazy_wrap_model()
- _twinkle_fsdp_debug('forward after _lazy_wrap_model')
if not inputs:
raise ValueError('inputs empty, check your DataLoader outputs')
self.model.train()
@@ -678,42 +635,28 @@ def backward(self, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
loss_value = optimizer_config.train_status.loss_value
assert loss_value is not None, 'Do forwarding and calculating loss before backward'
- _twinkle_fsdp_debug(
- f'backward enter adapter={adapter_name} loss_shape={tuple(loss_value.shape)} '
- f'loss_dtype={loss_value.dtype} loss_device={loss_value.device}')
scaler = optimizer_config.scaler
if scaler is None and self.mixed_precision == 'fp16':
# Auto set a grad scaler
- _twinkle_fsdp_debug('backward before set_grad_scaler')
self.set_grad_scaler(adapter_name=adapter_name)
scaler = optimizer_config.scaler
- _twinkle_fsdp_debug('backward after set_grad_scaler')
optimizer_config.cur_step += 1
should_sync = optimizer_config.do_grad_sync()
- _twinkle_fsdp_debug(f'backward cur_step={optimizer_config.cur_step} should_sync={should_sync}')
import contextlib
no_sync_ctx = contextlib.nullcontext()
if not should_sync and hasattr(self.model, 'no_sync'):
- _twinkle_fsdp_debug('backward using model.no_sync')
no_sync_ctx = self.model.no_sync()
- _twinkle_fsdp_debug('backward before no_sync_ctx')
with no_sync_ctx:
if scaler is not None:
- _twinkle_fsdp_debug('backward before scaler backward')
scaler.scale(loss_value).backward()
- _twinkle_fsdp_debug('backward after scaler backward')
else:
- _twinkle_fsdp_debug('backward before loss.backward')
loss_value.backward()
- _twinkle_fsdp_debug('backward after loss.backward')
- _twinkle_fsdp_debug('backward after no_sync_ctx')
# self._sync_after_backward_if_needed()
optimizer_config.train_status.loss_value = None
- _twinkle_fsdp_debug('backward exit')
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
@@ -729,14 +672,10 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
Returns:
The output of the model forward.
"""
- _twinkle_fsdp_debug('forward_backward enter')
outputs = self.forward(inputs=inputs, **kwargs)
- _twinkle_fsdp_debug('forward_backward after forward')
loss = self.calculate_loss(**kwargs)
- _twinkle_fsdp_debug('forward_backward after calculate_loss')
outputs['loss'] = loss
self.backward(**kwargs)
- _twinkle_fsdp_debug('forward_backward after backward')
return outputs
# def _sync_after_backward_if_needed(self) -> None:
@@ -1359,23 +1298,19 @@ def _validate_ep_lora_config(self, lora_config) -> None:
if not getattr(self, '_enable_expert_parallel', False):
return
if not isinstance(self.strategy, NativeFSDPStrategy):
- raise RuntimeError(
- 'EP + LoRA requires strategy=native_fsdp; '
- f'got {type(self.strategy).__name__}.')
+ raise RuntimeError('EP + LoRA requires strategy=native_fsdp; '
+ f'got {type(self.strategy).__name__}.')
if not isinstance(lora_config, LoraConfig):
return
target_params = getattr(lora_config, 'target_parameters', None) or []
if target_params:
if getattr(lora_config, 'use_dora', False):
- raise ValueError(
- 'PEFT ParamWrapper does not support use_dora=True with target_parameters; '
- 'disable DoRA when training expert parameters.')
+ raise ValueError('PEFT ParamWrapper does not support use_dora=True with target_parameters; '
+ 'disable DoRA when training expert parameters.')
if getattr(lora_config, 'lora_bias', False):
- raise ValueError(
- 'PEFT ParamWrapper does not support lora_bias=True with target_parameters.')
+ raise ValueError('PEFT ParamWrapper does not support lora_bias=True with target_parameters.')
if float(getattr(lora_config, 'lora_dropout', 0.0)) > 0.0:
- raise ValueError(
- 'PEFT ParamWrapper does not support lora_dropout>0 with target_parameters.')
+ raise ValueError('PEFT ParamWrapper does not support lora_dropout>0 with target_parameters.')
@staticmethod
def _maybe_autofill_target_parameters(lora_config, enable_ep: bool):
@@ -1386,9 +1321,8 @@ def _maybe_autofill_target_parameters(lora_config, enable_ep: bool):
target_params = getattr(lora_config, 'target_parameters', None) or []
if not target_params:
lora_config.target_parameters = ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']
- logger.info(
- "EP+LoRA auto-filled target_parameters with "
- "['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'].")
+ logger.info('EP+LoRA auto-filled target_parameters with '
+ "['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'].")
return lora_config
def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):
From 0ed0e8ff6c5328855ad817041b8009252c177837 Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Thu, 21 May 2026 15:25:59 +0800
Subject: [PATCH 38/40] WIP
---
.../model/transformers/strategy/accelerate.py | 7 +-
.../transformers/strategy/native_fsdp.py | 165 +++++++++++++++++-
.../strategy/native_fsdp_state.py | 65 -------
.../model/transformers/transformers.py | 93 +---------
4 files changed, 170 insertions(+), 160 deletions(-)
delete mode 100644 src/twinkle/model/transformers/strategy/native_fsdp_state.py
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index 3d2a0639..ec1aef0a 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -1,7 +1,7 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
from datetime import timedelta
-from typing import Any, Dict, Literal, Optional
+from typing import Any, Dict, Literal, Mapping, Optional
from twinkle import DeviceMesh
from .load_context import fsdp_pretrained_load_context
@@ -275,6 +275,11 @@ def wrap_model(self, model, *args):
def unwrap_model(self, model):
return self.accelerator.unwrap_model(model, keep_torch_compile=False)
+ def load_peft_weights(self, model, adapter_weights: Mapping[str, Any], adapter_name: str) -> None:
+ from peft.utils import set_peft_model_state_dict
+
+ set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name)
+
def _get_fsdp_plugin(self):
state = self.accelerator.state
return state.fsdp_plugin if hasattr(state, 'fsdp_plugin') else None
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index ee571a06..d57483ed 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -4,16 +4,18 @@
from torch import nn
from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh
from torch.distributed.fsdp import fully_shard
-from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Set
from twinkle.utils import DeviceMesh, Platform, torch_util
from .load_context import fsdp_pretrained_load_context
-from .native_fsdp_state import (_collect_adapter_source_state, _collect_state_metadata, _is_lora_state_key,
- _resolve_full_state_source_key)
if TYPE_CHECKING:
from torch.distributed.fsdp import MixedPrecisionPolicy
+LORA_STATE_KEY_MARKERS = ('lora_A', 'lora_B', 'lora_embedding')
+PEFT_BASE_PREFIX = 'base_model.model.'
+PEFT_BASE_LAYER_SEGMENT = 'base_layer'
+
class NativeFSDPStrategy:
@@ -48,6 +50,15 @@ def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor])
def set_adapter_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self._adapter_full_state_dict = state_dict
+ def load_peft_weights(self, model, adapter_weights: Mapping[str, torch.Tensor], adapter_name: str) -> None:
+ from peft.utils import set_peft_model_state_dict
+
+ fsdp_world_size = self.device_mesh.fsdp_world_size if self.device_mesh is not None else 1
+ if fsdp_world_size > 1:
+ _load_peft_weights_for_native_fsdp2(model, adapter_weights, adapter_name, self)
+ else:
+ set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name)
+
def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]:
if self.device_mesh is None:
return None
@@ -581,6 +592,154 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor
return optimizer
+def _is_lora_state_key(name: str) -> bool:
+ return any(marker in name for marker in LORA_STATE_KEY_MARKERS)
+
+
+def _strip_peft_base_prefix(name: str) -> str:
+ while name.startswith(PEFT_BASE_PREFIX):
+ name = name[len(PEFT_BASE_PREFIX):]
+ return name
+
+
+def _strip_base_layer_segments(name: str) -> str:
+ return '.'.join(segment for segment in name.split('.') if segment != PEFT_BASE_LAYER_SEGMENT)
+
+
+def _source_key_candidates(param_name: str) -> List[str]:
+ stripped_prefix = _strip_peft_base_prefix(param_name)
+ candidates = [
+ param_name,
+ stripped_prefix,
+ _strip_base_layer_segments(param_name),
+ _strip_base_layer_segments(stripped_prefix),
+ ]
+ deduped = []
+ for candidate in candidates:
+ if candidate not in deduped:
+ deduped.append(candidate)
+ return deduped
+
+
+def _resolve_full_state_source_key(param_name: str, source_state: Mapping[str, Any]) -> str:
+ if _is_lora_state_key(param_name):
+ raise KeyError(f"LoRA parameter '{param_name}' must be loaded from adapter source state.")
+
+ candidates = _source_key_candidates(param_name)
+ for candidate in candidates:
+ if candidate in source_state:
+ return candidate
+ raise KeyError(f"Missing source metadata for parameter '{param_name}'. "
+ f'Tried source keys: {", ".join(candidates)}.')
+
+
+def _collect_adapter_source_state(state_dict: Mapping[str, Any]) -> Dict[str, Any]:
+ adapter_state = {}
+ for name, tensor in state_dict.items():
+ if not _is_lora_state_key(name) or not hasattr(tensor, 'detach'):
+ continue
+ if getattr(tensor, 'is_meta', False):
+ continue
+ adapter_state[name] = tensor.detach().cpu().clone()
+ return adapter_state
+
+
+def _collect_state_metadata(state_dict: Mapping[str, Any]) -> Dict[str, tuple[tuple[int, ...], Any]]:
+ return {
+ name: (tuple(tensor.shape), tensor.dtype)
+ for name, tensor in state_dict.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
+ }
+
+
+def _get_named_child(module, name: str):
+ if hasattr(module, name):
+ return getattr(module, name)
+ if name.isdigit() and hasattr(module, '__getitem__'):
+ try:
+ return module[int(name)]
+ except (IndexError, TypeError, KeyError):
+ return None
+ return None
+
+
+def _split_for_ep_pre_distribute(model, model_key: str, value: torch.Tensor, ep_world_size: int,
+ ep_rank: int) -> torch.Tensor:
+ """Slice saved LoRA expert weights by EP rank before DTensor/FSDP placement."""
+ if ep_world_size <= 1:
+ return value
+
+ parent = model
+ matched = False
+ parts = model_key.split('.')
+ for i, part in enumerate(parts[:-1]):
+ parent = _get_named_child(parent, part)
+ if parent is None:
+ return value
+ next_seg = parts[i + 1] if i + 1 < len(parts) else None
+ if getattr(parent, '_ep_patched', False) and next_seg == 'experts':
+ matched = True
+
+ if not matched:
+ return value
+ if 'lora_A' in model_key:
+ chunk = value.size(0) // ep_world_size
+ return value.narrow(0, ep_rank * chunk, chunk).contiguous()
+ if 'lora_B' in model_key:
+ chunk = value.size(1) // ep_world_size
+ return value.narrow(1, ep_rank * chunk, chunk).contiguous()
+ return value
+
+
+def _has_param_wrapper_without_base_weight(model) -> bool:
+ for module in model.modules():
+ if not hasattr(module, 'parameter_name'):
+ continue
+ get_base_layer = getattr(module, 'get_base_layer', None)
+ if get_base_layer is None:
+ continue
+ base_layer = get_base_layer()
+ if not hasattr(base_layer, 'weight'):
+ return True
+ return False
+
+
+def _load_peft_weights_for_native_fsdp2(model, adapter_weights: Mapping[str, torch.Tensor], adapter_name: str,
+ strategy: NativeFSDPStrategy) -> None:
+ """Load PEFT adapter weights into a native FSDP2 model, including EP expert adapters."""
+ from peft.utils import set_peft_model_state_dict
+ from torch.distributed.tensor import DTensor, distribute_tensor
+
+ ep_fsdp_mesh = getattr(strategy, 'ep_fsdp_device_mesh', None)
+ ep_world_size = ep_fsdp_mesh['ep'].size() if ep_fsdp_mesh is not None else 1
+ ep_rank = ep_fsdp_mesh['ep'].get_local_rank() if ep_world_size > 1 else 0
+
+ model_sd = model.state_dict()
+ converted_weights = {}
+ direct_weights = {}
+ full_adapter_source = {}
+ for key, value in adapter_weights.items():
+ model_key = key
+ if f'.{adapter_name}.weight' not in model_key:
+ model_key = model_key.replace('.weight', f'.{adapter_name}.weight')
+ if model_key in model_sd:
+ param = model_sd[model_key]
+ full_adapter_source[model_key] = value.detach().cpu().clone()
+ value = _split_for_ep_pre_distribute(model, model_key, value, ep_world_size, ep_rank)
+ if isinstance(param, DTensor) and not isinstance(value, DTensor):
+ value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements)
+ direct_weights[model_key] = value
+ converted_weights[key] = value
+
+ set_adapter_full_state = getattr(strategy, 'set_adapter_full_state_dict', None)
+ if set_adapter_full_state is not None and full_adapter_source:
+ set_adapter_full_state(full_adapter_source)
+
+ if _has_param_wrapper_without_base_weight(model):
+ model.load_state_dict(direct_weights, strict=False)
+ else:
+ set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name)
+
+
def _broadcast_sharded_state_dict(
model: nn.Module,
full_sd: dict,
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp_state.py b/src/twinkle/model/transformers/strategy/native_fsdp_state.py
deleted file mode 100644
index bfb30163..00000000
--- a/src/twinkle/model/transformers/strategy/native_fsdp_state.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright (c) ModelScope Contributors. All rights reserved.
-from typing import Any, Dict, Mapping
-
-LORA_STATE_KEY_MARKERS = ('lora_A', 'lora_B', 'lora_embedding')
-PEFT_BASE_PREFIX = 'base_model.model.'
-PEFT_BASE_LAYER_SEGMENT = 'base_layer'
-
-
-def _is_lora_state_key(name: str) -> bool:
- return any(marker in name for marker in LORA_STATE_KEY_MARKERS)
-
-
-def _strip_peft_base_prefix(name: str) -> str:
- while name.startswith(PEFT_BASE_PREFIX):
- name = name[len(PEFT_BASE_PREFIX):]
- return name
-
-
-def _strip_base_layer_segments(name: str) -> str:
- return '.'.join(segment for segment in name.split('.') if segment != PEFT_BASE_LAYER_SEGMENT)
-
-
-def _source_key_candidates(param_name: str) -> list[str]:
- stripped_prefix = _strip_peft_base_prefix(param_name)
- candidates = [
- param_name,
- stripped_prefix,
- _strip_base_layer_segments(param_name),
- _strip_base_layer_segments(stripped_prefix),
- ]
- deduped = []
- for candidate in candidates:
- if candidate not in deduped:
- deduped.append(candidate)
- return deduped
-
-
-def _resolve_full_state_source_key(param_name: str, source_state: Mapping[str, Any]) -> str:
- if _is_lora_state_key(param_name):
- raise KeyError(f"LoRA parameter '{param_name}' must be loaded from adapter source state.")
-
- candidates = _source_key_candidates(param_name)
- for candidate in candidates:
- if candidate in source_state:
- return candidate
- raise KeyError(f"Missing source metadata for parameter '{param_name}'. "
- f'Tried source keys: {", ".join(candidates)}.')
-
-
-def _collect_adapter_source_state(state_dict: Mapping[str, Any]) -> Dict[str, Any]:
- adapter_state = {}
- for name, tensor in state_dict.items():
- if not _is_lora_state_key(name) or not hasattr(tensor, 'detach'):
- continue
- if getattr(tensor, 'is_meta', False):
- continue
- adapter_state[name] = tensor.detach().cpu().clone()
- return adapter_state
-
-
-def _collect_state_metadata(state_dict: Mapping[str, Any]) -> Dict[str, tuple[tuple[int, ...], Any]]:
- return {
- name: (tuple(tensor.shape), tensor.dtype)
- for name, tensor in state_dict.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype')
- }
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index a4247858..50f14d6b 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -13,7 +13,7 @@
from copy import copy
from dataclasses import dataclass, field
from peft import PeftConfig, PeftModel, get_peft_model
-from peft.utils import load_peft_weights, set_peft_model_state_dict
+from peft.utils import load_peft_weights
from safetensors.torch import save_file
from torch import GradScaler
from torch.optim import Adam, AdamW, Optimizer
@@ -48,58 +48,6 @@
logger = get_logger()
-def _get_named_child(module, name: str):
- if hasattr(module, name):
- return getattr(module, name)
- if name.isdigit() and hasattr(module, '__getitem__'):
- try:
- return module[int(name)]
- except (IndexError, TypeError, KeyError):
- return None
- return None
-
-
-def _split_for_ep_pre_distribute(model, model_key: str, value: torch.Tensor, ep_world_size: int,
- ep_rank: int) -> torch.Tensor:
- """Slice saved LoRA expert weights by EP rank before DTensor/FSDP placement."""
- if ep_world_size <= 1:
- return value
-
- parts = model_key.split('.')
- parent = model
- matched = False
- for i, part in enumerate(parts[:-1]):
- parent = _get_named_child(parent, part)
- if parent is None:
- return value
- next_seg = parts[i + 1] if i + 1 < len(parts) else None
- if getattr(parent, '_ep_patched', False) and next_seg == 'experts':
- matched = True
-
- if not matched:
- return value
- if 'lora_A' in model_key:
- chunk = value.size(0) // ep_world_size
- return value.narrow(0, ep_rank * chunk, chunk).contiguous()
- if 'lora_B' in model_key:
- chunk = value.size(1) // ep_world_size
- return value.narrow(1, ep_rank * chunk, chunk).contiguous()
- return value
-
-
-def _has_param_wrapper_without_base_weight(model) -> bool:
- for module in model.modules():
- if not hasattr(module, 'parameter_name'):
- continue
- get_base_layer = getattr(module, 'get_base_layer', None)
- if get_base_layer is None:
- continue
- base_layer = get_base_layer()
- if not hasattr(base_layer, 'weight'):
- return True
- return False
-
-
@dataclass
class OptimizerGroup(BaseOptimizerGroup):
"""Optimizer group for Transformers training."""
@@ -1092,44 +1040,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
model = self.strategy.unwrap_model(self.model)
if isinstance(model, PeftModel):
adapter_weights = load_peft_weights(checkpoint_dir, device='cpu')
-
- def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'):
- from torch.distributed.tensor import DTensor, distribute_tensor
-
- ep_fsdp_mesh = getattr(self.strategy, 'ep_fsdp_device_mesh', None)
- ep_world_size = ep_fsdp_mesh['ep'].size() if ep_fsdp_mesh is not None else 1
- ep_rank = ep_fsdp_mesh['ep'].get_local_rank() if ep_world_size > 1 else 0
-
- model_sd = model.state_dict()
- converted_weights = {}
- direct_weights = {}
- full_adapter_source = {}
- for key, value in adapter_weights.items():
- model_key = key
- if f'.{adapter_name}.weight' not in model_key:
- model_key = model_key.replace('.weight', f'.{adapter_name}.weight')
- if model_key in model_sd:
- param = model_sd[model_key]
- full_adapter_source[model_key] = value.detach().cpu().clone()
- value = _split_for_ep_pre_distribute(model, model_key, value, ep_world_size, ep_rank)
- if isinstance(param, DTensor) and not isinstance(value, DTensor):
- value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements)
- direct_weights[model_key] = value
- converted_weights[key] = value
-
- set_adapter_full_state = getattr(self.strategy, 'set_adapter_full_state_dict', None)
- if set_adapter_full_state is not None and full_adapter_source:
- set_adapter_full_state(full_adapter_source)
-
- if _has_param_wrapper_without_base_weight(model):
- model.load_state_dict(direct_weights, strict=False)
- else:
- set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name)
-
- if self.device_mesh.fsdp_world_size > 1:
- load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name=adapter_name)
- else:
- set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name)
+ self.strategy.load_peft_weights(model, adapter_weights, adapter_name)
else:
raise NotImplementedError
From 442262dc1271769d55c71a2727f35c330e7c6d9c Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Thu, 21 May 2026 15:38:57 +0800
Subject: [PATCH 39/40] wip
---
.../model/transformers/moe/expert_parallel.py | 1 -
.../model/transformers/strategy/accelerate.py | 6 ++
.../transformers/strategy/native_fsdp.py | 40 +++++++++++-
.../model/transformers/transformers.py | 65 +------------------
4 files changed, 48 insertions(+), 64 deletions(-)
diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py
index 29158bfa..2b6e45ea 100644
--- a/src/twinkle/model/transformers/moe/expert_parallel.py
+++ b/src/twinkle/model/transformers/moe/expert_parallel.py
@@ -19,7 +19,6 @@ class ExpertParallelConfig:
router_dtype: str = 'fp32'
keep_router_logits: bool = True
ignore_shared_experts: bool = False
- # sync_after_backward: bool = True # consumed by TransformersModel to keep EP/FSDP collectives ordered
ep_size: int | None = None # consumed by TransformersModel, not used in expert_parallel logic
diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py
index ec1aef0a..6fb84530 100644
--- a/src/twinkle/model/transformers/strategy/accelerate.py
+++ b/src/twinkle/model/transformers/strategy/accelerate.py
@@ -197,6 +197,12 @@ def __init__(
def pretrained_load_context(self):
return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None)
+ def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None:
+ return
+
+ def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool):
+ return config_or_dir
+
@staticmethod
def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
# TODO should test with transformers v5.0
diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py
index d57483ed..ef5666ce 100644
--- a/src/twinkle/model/transformers/strategy/native_fsdp.py
+++ b/src/twinkle/model/transformers/strategy/native_fsdp.py
@@ -6,12 +6,15 @@
from torch.distributed.fsdp import fully_shard
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Set
-from twinkle.utils import DeviceMesh, Platform, torch_util
+from twinkle.utils import DeviceMesh, Platform, get_logger, torch_util
+from twinkle.utils.torch_utils import clone_state_dict_to_cpu
from .load_context import fsdp_pretrained_load_context
if TYPE_CHECKING:
from torch.distributed.fsdp import MixedPrecisionPolicy
+logger = get_logger()
+
LORA_STATE_KEY_MARKERS = ('lora_A', 'lora_B', 'lora_embedding')
PEFT_BASE_PREFIX = 'base_model.model.'
PEFT_BASE_LAYER_SEGMENT = 'base_layer'
@@ -34,6 +37,7 @@ def __init__(self,
self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None
self._rank0_pre_ep_full_state_dict = None
self._adapter_full_state_dict = None
+ self._pre_ep_state_captured = False
def pretrained_load_context(self):
# Native FSDP loads pretrained weights via rank0 broadcast during wrap_model().
@@ -44,6 +48,40 @@ def pretrained_load_context(self):
def use_rank0_pretrained_broadcast(self) -> bool:
return self._memory_efficient_init and self.device_mesh is not None
+ def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None:
+ if self._pre_ep_state_captured:
+ return
+ if not (enable_ep and self.use_rank0_pretrained_broadcast()):
+ return
+ is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
+ self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_rank0 else {})
+ self._pre_ep_state_captured = True
+
+ def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool):
+ if not enable_ep:
+ return config_or_dir
+
+ from peft import LoraConfig
+
+ if not isinstance(config_or_dir, LoraConfig):
+ return config_or_dir
+
+ target_params = getattr(config_or_dir, 'target_parameters', None) or []
+ if target_params:
+ if getattr(config_or_dir, 'use_dora', False):
+ raise ValueError('PEFT ParamWrapper does not support use_dora=True with target_parameters; '
+ 'disable DoRA when training expert parameters.')
+ if getattr(config_or_dir, 'lora_bias', False):
+ raise ValueError('PEFT ParamWrapper does not support lora_bias=True with target_parameters.')
+ if float(getattr(config_or_dir, 'lora_dropout', 0.0)) > 0.0:
+ raise ValueError('PEFT ParamWrapper does not support lora_dropout>0 with target_parameters.')
+ return config_or_dir
+
+ config_or_dir.target_parameters = ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']
+ logger.info('EP+LoRA auto-filled target_parameters with '
+ "['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'].")
+ return config_or_dir
+
def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None:
self._rank0_pre_ep_full_state_dict = state_dict
diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py
index 50f14d6b..a9a80b8f 100644
--- a/src/twinkle/model/transformers/transformers.py
+++ b/src/twinkle/model/transformers/transformers.py
@@ -42,7 +42,6 @@
from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util
from twinkle.utils.framework import Torch
from twinkle.utils.grad_clip import normalize_and_clip_grad_norm
-from twinkle.utils.torch_utils import clone_state_dict_to_cpu
from twinkle.utils.transformers_utils import filter_from_config_kwargs
logger = get_logger()
@@ -284,22 +283,10 @@ def _not_encoded(inputs):
assert isinstance(inputs, dict)
return 'input_ids' not in inputs and 'input_embedding' not in inputs
- def _capture_rank0_pre_ep_state_if_needed(self):
- """Capture rank0 pre-EP full state_dict for memory_efficient_init broadcast."""
- if getattr(self, '_pre_ep_state_captured', False):
- return
- use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
- set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None)
- if not (self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None):
- return
- is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
- set_pre_ep_state(clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {})
- self._pre_ep_state_captured = True
-
def _lazy_wrap_model(self):
if not self._model_wrapped:
optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None]
- self._capture_rank0_pre_ep_state_if_needed()
+ self.strategy.capture_pre_ep_state_if_needed(self.model, enable_ep=self._enable_expert_parallel)
self._maybe_apply_expert_parallel()
self._ensure_sp_strategy()
if self.sp_strategy is not None:
@@ -603,7 +590,6 @@ def backward(self, **kwargs):
else:
loss_value.backward()
- # self._sync_after_backward_if_needed()
optimizer_config.train_status.loss_value = None
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
@@ -626,16 +612,6 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr
self.backward(**kwargs)
return outputs
- # def _sync_after_backward_if_needed(self) -> None:
- # if not getattr(self, '_enable_expert_parallel', False):
- # return
- # expert_parallel_config = getattr(self, '_expert_parallel_config', None) or {}
- # if not expert_parallel_config.get('sync_after_backward', True):
- # return
- # torch_util.synchronize()
- # if dist.is_available() and dist.is_initialized():
- # dist.barrier()
-
@remote_function()
def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs):
""" Clip the gradient norm
@@ -1202,50 +1178,15 @@ def calculate_metric(self, is_training, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
return optimizer_config.calculate_metrics(is_training)
- @staticmethod
- def _validate_ep_lora_config(self, lora_config) -> None:
- from peft import LoraConfig
-
- if not getattr(self, '_enable_expert_parallel', False):
- return
- if not isinstance(self.strategy, NativeFSDPStrategy):
- raise RuntimeError('EP + LoRA requires strategy=native_fsdp; '
- f'got {type(self.strategy).__name__}.')
- if not isinstance(lora_config, LoraConfig):
- return
- target_params = getattr(lora_config, 'target_parameters', None) or []
- if target_params:
- if getattr(lora_config, 'use_dora', False):
- raise ValueError('PEFT ParamWrapper does not support use_dora=True with target_parameters; '
- 'disable DoRA when training expert parameters.')
- if getattr(lora_config, 'lora_bias', False):
- raise ValueError('PEFT ParamWrapper does not support lora_bias=True with target_parameters.')
- if float(getattr(lora_config, 'lora_dropout', 0.0)) > 0.0:
- raise ValueError('PEFT ParamWrapper does not support lora_dropout>0 with target_parameters.')
-
- @staticmethod
- def _maybe_autofill_target_parameters(lora_config, enable_ep: bool):
- from peft import LoraConfig
-
- if not enable_ep or not isinstance(lora_config, LoraConfig):
- return lora_config
- target_params = getattr(lora_config, 'target_parameters', None) or []
- if not target_params:
- lora_config.target_parameters = ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj']
- logger.info('EP+LoRA auto-filled target_parameters with '
- "['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'].")
- return lora_config
-
def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):
assert adapter_name, 'Use a different adapter_name, current is empty.'
unwrapped_model = self.strategy.unwrap_model(self.model)
if not isinstance(config_or_dir, str):
- self._validate_ep_lora_config(self, config_or_dir)
- config_or_dir = self._maybe_autofill_target_parameters(
+ config_or_dir = self.strategy.prepare_adapter_config(
config_or_dir, enable_ep=getattr(self, '_enable_expert_parallel', False))
if getattr(self, '_enable_expert_parallel', False):
- self._capture_rank0_pre_ep_state_if_needed()
+ self.strategy.capture_pre_ep_state_if_needed(self.model, enable_ep=self._enable_expert_parallel)
self._maybe_apply_expert_parallel()
unwrapped_model = self.strategy.unwrap_model(self.model)
From 9e70d586053c5517dc386719f1c4d6d4da928d6d Mon Sep 17 00:00:00 2001
From: weikaiwen
Date: Thu, 21 May 2026 17:39:41 +0800
Subject: [PATCH 40/40] cookbook
---
.../transformers/ep_fsdp2_lora_deepseek_v4.py | 46 ++++++-------------
.../transformers/ep_fsdp2_lora_deepseek_v4.sh | 2 -
.../transformers/ep_fsdp2_lora_qwen3_5_moe.py | 34 +++++---------
.../transformers/ep_fsdp2_lora_qwen3_5_moe.sh | 2 -
4 files changed, 26 insertions(+), 58 deletions(-)
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
index 93819bef..0b33f6df 100644
--- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py
@@ -22,15 +22,14 @@
MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template')
-NUM_LAYERS = int(os.environ.get('NUM_LAYERS', '2'))
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
+LOG_INTERVAL = GRAD_ACCUM_STEPS
LR = float(os.environ.get('LR', '1e-4'))
MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
LORA_R = int(os.environ.get('LORA_R', '8'))
LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '0'))
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4')
RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None
RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
@@ -40,30 +39,12 @@
device_mesh = DeviceMesh.from_sizes(
fsdp_size=4,
dp_size=1,
- ep_size=2,
+ ep_size=4,
device_type=Platform.get_platform().device_prefix(),
)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-def _get_text_config(config):
- return getattr(config, 'text_config', config)
-
-
-def _configure_smoke_config(config):
- text_config = _get_text_config(config)
- old_num_hidden_layers = getattr(text_config, 'num_hidden_layers', NUM_LAYERS)
- text_config.num_hidden_layers = NUM_LAYERS
- if hasattr(text_config, 'use_cache'):
- text_config.use_cache = False
- if hasattr(text_config, 'num_hash_layers'):
- text_config.num_hash_layers = min(text_config.num_hash_layers, NUM_LAYERS)
- if hasattr(text_config, 'compress_ratios'):
- extra_entries = max(len(text_config.compress_ratios) - old_num_hidden_layers, 0)
- keep = min(len(text_config.compress_ratios), NUM_LAYERS + extra_entries)
- text_config.compress_ratios = list(text_config.compress_ratios[:keep])
-
-
def _build_lora_config(enable_ep: bool):
if enable_ep:
return LoraConfig(
@@ -80,12 +61,13 @@ def _build_lora_config(enable_ep: bool):
return LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
+ exclude_modules=['o_a_proj'],
target_modules='all-linear',
)
def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
- model.save(
+ return model.save(
name=checkpoint_name,
output_dir=OUTPUT_DIR,
adapter_name=ADAPTER_NAME,
@@ -96,9 +78,11 @@ def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader:
def train():
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- _configure_smoke_config(config)
+ text_config = getattr(config, 'text_config', config)
+ if hasattr(text_config, 'use_cache'):
+ text_config.use_cache = False
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(500)))
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID))
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope'))
dataset.encode(batched=True)
@@ -141,25 +125,23 @@ def train():
logger.info(model.get_train_configs())
logger.info(
f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
- f'num_layers={NUM_LAYERS}, enable_ep={ENABLE_EP}, save_steps={SAVE_STEPS}, output_dir={OUTPUT_DIR}')
+ f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}')
optimizer_group = model.optimizer_group[ADAPTER_NAME]
- for step, batch in enumerate(dataloader):
+ for batch in dataloader:
if callable(batch):
batch = batch()
model.forward_backward(inputs=batch)
model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
cur_step = optimizer_group.cur_step
- if cur_step > 0 and (step + 1) % GRAD_ACCUM_STEPS == 0:
+ if cur_step > 0 and cur_step % LOG_INTERVAL == 0:
metric = model.calculate_metric(is_training=True)
if callable(metric):
metric = metric()
- logger.info(f'optimizer_step {cur_step}, metric: {metric}')
- if SAVE_STEPS and cur_step % SAVE_STEPS == 0:
- save_checkpoint(model, f'checkpoint-{cur_step}', dataloader)
+ logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}')
- save_checkpoint(model, 'checkpoint-final', dataloader)
- logger.info(f'Saved final adapter to {OUTPUT_DIR}/checkpoint-final')
+ final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader)
+ logger.info(f'Saved final adapter to {final_checkpoint}')
if __name__ == '__main__':
diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
index ce64b176..37f0862a 100644
--- a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
+++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh
@@ -8,10 +8,8 @@ set -euo pipefail
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
export ENABLE_EP="${ENABLE_EP:-1}"
-export NUM_LAYERS="${NUM_LAYERS:-2}"
export BATCH_SIZE="${BATCH_SIZE:-4}"
export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}"
-export SAVE_STEPS="${SAVE_STEPS:-0}"
export OUTPUT_DIR="${OUTPUT_DIR:-./output_dsv4}"
torchrun --nproc-per-node="${NPROC_PER_NODE}" \
diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
index 0ab64ae6..82a0e1a0 100644
--- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
+++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py
@@ -19,19 +19,17 @@
logger = get_logger()
-MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B')
+MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B')
DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition')
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template')
-_num_layers_env = os.environ.get('NUM_LAYERS')
-NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
+LOG_INTERVAL = GRAD_ACCUM_STEPS
LR = float(os.environ.get('LR', '1e-4'))
MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
LORA_R = int(os.environ.get('LORA_R', '8'))
LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32'))
ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1'
-SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '0'))
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output')
RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None
RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1'
@@ -41,16 +39,12 @@
device_mesh = DeviceMesh.from_sizes(
fsdp_size=4,
dp_size=1,
- ep_size=2,
+ ep_size=4,
device_type=Platform.get_platform().device_prefix(),
)
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
-def _get_text_config(config):
- return getattr(config, 'text_config', config)
-
-
def _build_lora_config(enable_ep: bool):
if enable_ep:
return LoraConfig(
@@ -71,7 +65,7 @@ def _build_lora_config(enable_ep: bool):
def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader):
- model.save(
+ return model.save(
name=checkpoint_name,
output_dir=OUTPUT_DIR,
adapter_name=ADAPTER_NAME,
@@ -82,13 +76,11 @@ def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader:
def train():
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
- text_config = _get_text_config(config)
- if NUM_LAYERS is not None and hasattr(text_config, 'num_hidden_layers'):
- text_config.num_hidden_layers = NUM_LAYERS
+ text_config = getattr(config, 'text_config', config)
if hasattr(text_config, 'use_cache'):
text_config.use_cache = False
- dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
+ dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID))
try:
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
except ValueError:
@@ -133,25 +125,23 @@ def train():
logger.info(model.get_train_configs())
logger.info(
f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
- f'num_layers={NUM_LAYERS}, enable_ep={ENABLE_EP}, save_steps={SAVE_STEPS}, output_dir={OUTPUT_DIR}')
+ f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}')
optimizer_group = model.optimizer_group[ADAPTER_NAME]
- for step, batch in enumerate(dataloader):
+ for batch in dataloader:
if callable(batch):
batch = batch()
model.forward_backward(inputs=batch)
model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
cur_step = optimizer_group.cur_step
- if cur_step > 0 and (step + 1) % GRAD_ACCUM_STEPS == 0:
+ if cur_step > 0 and cur_step % LOG_INTERVAL == 0:
metric = model.calculate_metric(is_training=True)
if callable(metric):
metric = metric()
- logger.info(f'optimizer_step {cur_step}, metric: {metric}')
- if SAVE_STEPS and cur_step % SAVE_STEPS == 0:
- save_checkpoint(model, f'checkpoint-{cur_step}', dataloader)
+ logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}')
- save_checkpoint(model, 'checkpoint-final', dataloader)
- logger.info(f'Saved final adapter to {OUTPUT_DIR}/checkpoint-final')
+ final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader)
+ logger.info(f'Saved final adapter to {final_checkpoint}')
if __name__ == '__main__':
diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
index 760b5e99..6a3b9574 100644
--- a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
+++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh
@@ -8,10 +8,8 @@ set -euo pipefail
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}"
export NPROC_PER_NODE="${NPROC_PER_NODE:-4}"
export ENABLE_EP="${ENABLE_EP:-1}"
-export NUM_LAYERS="${NUM_LAYERS:-2}"
export BATCH_SIZE="${BATCH_SIZE:-4}"
export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}"
-export SAVE_STEPS="${SAVE_STEPS:-0}"
export OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_5_moe}"
torchrun --nproc-per-node="${NPROC_PER_NODE}" \