diff --git a/README.md b/README.md index 671ec9b6..f98056bb 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,7 @@ sh INSTALL_MEGATRON.sh | Training Type | Model Framework | Cookbook Path | | ------------------------------------ | --------------- | ----------------------------------------------------- | | FSDP finetuning | transformers | [Script](cookbook/transformers/fsdp2.py) | -| FSDP MoE finetuning | transformers | [Script](cookbook/transformers/fsdp2_moe.py) | -| EP FSDP MoE finetuning | transformers | [Script](cookbook/transformers/ep_fsdp_qwen3_moe.py) | +| EP FSDP2 LoRA finetuning | transformers | [Script](cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py) | | SP FSDP finetuning | transformers | [Script](cookbook/transformers/sp_fsdp_dense.py) | | pp/tp/cp finetuning | megatron | [Script](cookbook/megatron/tp.py) | | pp/tp/cp MoE finetuning | megatron | [Script](cookbook/megatron/tp_moe.py) | diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py new file mode 100644 index 00000000..0b33f6df --- /dev/null +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py @@ -0,0 +1,148 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""EP + FSDP2 + LoRA SFT cookbook for DeepSeek-V4. + +Run on 4 GPUs: + torchrun --nproc-per-node=4 cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py +""" +import os +from pathlib import Path + +from peft import LoraConfig +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = os.environ.get('DSV4_MODEL_ID', 'ms://deepseek-ai/DeepSeek-V4') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'DeepseekV4Template') +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-4')) +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +LORA_R = int(os.environ.get('LORA_R', '8')) +LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) +ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output_dsv4') +RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None +RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' +IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=4, + dp_size=1, + ep_size=4, + device_type=Platform.get_platform().device_prefix(), +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def _build_lora_config(enable_ep: bool): + if enable_ep: + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + exclude_modules=['o_a_proj'], + target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], + ) + # Expert weights are bare nn.Parameters. PEFT trains them through + # target_parameters/ParamWrapper, which dynamically parametrizes weights + # during forward. That is not stable with plain FSDP2, so non-EP mode uses + # regular module LoRA and does not train expert parameters. + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + exclude_modules=['o_a_proj'], + target_modules='all-linear', + ) + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + text_config = getattr(config, 'text_config', config) + if hasattr(text_config, 'use_cache'): + text_config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.encode(batched=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + memory_efficient_init=True, + fsdp_config={ + 'expert_parallel': { + 'enabled': ENABLE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + lora_cfg = _build_lora_config(ENABLE_EP) + model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh new file mode 100644 index 00000000..37f0862a --- /dev/null +++ b/cookbook/transformers/ep_fsdp2_lora_deepseek_v4.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +# EP + FSDP2 + LoRA training for DeepSeek-V4. +# ENABLE_EP=1 trains expert LoRA with target_parameters. +# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters. + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export NPROC_PER_NODE="${NPROC_PER_NODE:-4}" +export ENABLE_EP="${ENABLE_EP:-1}" +export BATCH_SIZE="${BATCH_SIZE:-4}" +export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}" +export OUTPUT_DIR="${OUTPUT_DIR:-./output_dsv4}" + +torchrun --nproc-per-node="${NPROC_PER_NODE}" \ + cookbook/transformers/ep_fsdp2_lora_deepseek_v4.py diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py new file mode 100644 index 00000000..82a0e1a0 --- /dev/null +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py @@ -0,0 +1,148 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +"""EP + FSDP2 + LoRA SFT cookbook for Qwen3.5-MoE. + +Run on 4 GPUs: + torchrun --nproc-per-node=4 cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py +""" +import os +from pathlib import Path + +from peft import LoraConfig +from transformers import AutoConfig + +import twinkle +from twinkle import DeviceMesh, Platform, get_device_placement, get_logger +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.model import TransformersModel +from twinkle.preprocessor import SelfCognitionProcessor + +logger = get_logger() + +MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.6-35B-A3B') +DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') +TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) +GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) +LOG_INTERVAL = GRAD_ACCUM_STEPS +LR = float(os.environ.get('LR', '1e-4')) +MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) +LORA_R = int(os.environ.get('LORA_R', '8')) +LORA_ALPHA = int(os.environ.get('LORA_ALPHA', '32')) +ENABLE_EP = os.environ.get('ENABLE_EP', '1') == '1' +OUTPUT_DIR = os.environ.get('OUTPUT_DIR', './output') +RESUME_FROM_CHECKPOINT = os.environ.get('RESUME_FROM_CHECKPOINT') or None +RESUME_ONLY_MODEL = os.environ.get('RESUME_ONLY_MODEL', '0') == '1' +IGNORE_DATA_SKIP = os.environ.get('IGNORE_DATA_SKIP', '0') == '1' +ADAPTER_NAME = os.environ.get('ADAPTER_NAME', 'default') + +device_mesh = DeviceMesh.from_sizes( + fsdp_size=4, + dp_size=1, + ep_size=4, + device_type=Platform.get_platform().device_prefix(), +) +twinkle.initialize(mode='local', global_device_mesh=device_mesh) + + +def _build_lora_config(enable_ep: bool): + if enable_ep: + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + target_parameters=['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'], + ) + # Expert weights are bare nn.Parameters. PEFT trains them through + # target_parameters/ParamWrapper, which dynamically parametrizes weights + # during forward. That is not stable with plain FSDP2, so non-EP mode uses + # regular module LoRA and does not train expert parameters. + return LoraConfig( + r=LORA_R, + lora_alpha=LORA_ALPHA, + target_modules='all-linear', + ) + + +def save_checkpoint(model: TransformersModel, checkpoint_name: str, dataloader: DataLoader): + return model.save( + name=checkpoint_name, + output_dir=OUTPUT_DIR, + adapter_name=ADAPTER_NAME, + save_optimizer=True, + consumed_train_samples=dataloader.get_state()['consumed_train_samples'], + ) + + +def train(): + config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) + text_config = getattr(config, 'text_config', config) + if hasattr(text_config, 'use_cache'): + text_config.use_cache = False + + dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID)) + try: + dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) + except ValueError: + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) + dataset.map(SelfCognitionProcessor('twinkle', 'ModelScope')) + dataset.encode(batched=True) + dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh) + + model = TransformersModel( + model_id=MODEL_ID, + config=config, + device_mesh=device_mesh, + strategy='native_fsdp', + fsdp_config={ + 'expert_parallel': { + 'enabled': ENABLE_EP, + 'router_dtype': 'fp32', + 'keep_router_logits': False, + } + }, + ) + lora_cfg = _build_lora_config(ENABLE_EP) + model.add_adapter_to_model(ADAPTER_NAME, lora_cfg, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + model.set_optimizer('AdamW', lr=LR, foreach=False) + model.set_lr_scheduler( + scheduler_cls='CosineWarmupScheduler', + num_warmup_steps=5, + num_training_steps=len(dataloader), + ) + + if RESUME_FROM_CHECKPOINT: + checkpoint_path = Path(RESUME_FROM_CHECKPOINT).expanduser().resolve() + kwargs = {} + if ADAPTER_NAME: + kwargs['adapter_name'] = ADAPTER_NAME + progress = model.resume_from_checkpoint( + str(checkpoint_path), resume_only_model=RESUME_ONLY_MODEL, **kwargs) + if not IGNORE_DATA_SKIP: + dataloader.resume_from_checkpoint(progress['consumed_train_samples']) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs()) + logger.info( + f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' + f'enable_ep={ENABLE_EP}, output_dir={OUTPUT_DIR}') + + optimizer_group = model.optimizer_group[ADAPTER_NAME] + for batch in dataloader: + if callable(batch): + batch = batch() + model.forward_backward(inputs=batch) + model.clip_grad_and_step(max_grad_norm=MAX_GRAD_NORM, gradient_accumulation_steps=GRAD_ACCUM_STEPS) + cur_step = optimizer_group.cur_step + if cur_step > 0 and cur_step % LOG_INTERVAL == 0: + metric = model.calculate_metric(is_training=True) + if callable(metric): + metric = metric() + logger.info(f'Current is step {cur_step} of {len(dataloader)}, metric: {metric}') + + final_checkpoint = save_checkpoint(model, 'checkpoint-final', dataloader) + logger.info(f'Saved final adapter to {final_checkpoint}') + + +if __name__ == '__main__': + train() diff --git a/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh new file mode 100644 index 00000000..6a3b9574 --- /dev/null +++ b/cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -euo pipefail + +# EP + FSDP2 + LoRA training for Qwen3.5-MoE. +# ENABLE_EP=1 trains expert LoRA with target_parameters. +# ENABLE_EP=0 runs plain FSDP2 LoRA and does not train expert parameters. + +export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3}" +export NPROC_PER_NODE="${NPROC_PER_NODE:-4}" +export ENABLE_EP="${ENABLE_EP:-1}" +export BATCH_SIZE="${BATCH_SIZE:-4}" +export GRAD_ACCUM_STEPS="${GRAD_ACCUM_STEPS:-4}" +export OUTPUT_DIR="${OUTPUT_DIR:-./output_qwen3_5_moe}" + +torchrun --nproc-per-node="${NPROC_PER_NODE}" \ + cookbook/transformers/ep_fsdp2_lora_qwen3_5_moe.py diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.py b/cookbook/transformers/ep_fsdp_qwen3_moe.py deleted file mode 100644 index 11855fae..00000000 --- a/cookbook/transformers/ep_fsdp_qwen3_moe.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import os -from transformers import AutoConfig - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor - -logger = get_logger() - -MODEL_ID = os.environ.get('QWEN3_MODEL_ID', 'ms://Qwen/Qwen3.5-4B') -DATASET_ID = os.environ.get('DATASET_ID', 'ms://swift/self-cognition') -TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Qwen3_5Template') -_num_layers_env = os.environ.get('NUM_LAYERS') -NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None -BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4')) -GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4')) -LR = float(os.environ.get('LR', '1e-5')) -MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0')) -KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1' - -# 8 gpus, dp=1, fsdp=8 (data parallel), ep_size=8 (expert parallel) -device_mesh = DeviceMesh.from_sizes( - fsdp_size=8, - dp_size=1, - ep_size=8, - device_type=Platform.get_platform().device_prefix(), -) - -twinkle.initialize( - mode='local', - global_device_mesh=device_mesh, -) - - -def train(): - config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) - if NUM_LAYERS is not None and hasattr(config, 'num_hidden_layers'): - config.num_hidden_layers = NUM_LAYERS - if hasattr(config, 'use_cache'): - config.use_cache = False - - dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000))) - try: - dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID) - except ValueError: - dataset.set_template('Qwen3_5Template', model_id=MODEL_ID) - - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - dataset.encode(batched=True) - dataloader = DataLoader( - dataset=dataset, - batch_size=BATCH_SIZE, - device_mesh=device_mesh, - ) - - model = TransformersModel( - model_id=MODEL_ID, - config=config, - device_mesh=device_mesh, - fsdp_config={ - 'expert_parallel': { - 'enabled': True, - 'router_dtype': 'fp32', - 'keep_router_logits': KEEP_ROUTER_LOGITS, - } - }, - ) - # Disable foreach to avoid DTensor mixed-type errors in EP runs. - model.set_optimizer('AdamW', lr=LR, foreach=False) - model.set_lr_scheduler( - scheduler_cls='CosineWarmupScheduler', - num_warmup_steps=5, - num_training_steps=len(dataloader), - ) - - logger.info(get_device_placement()) - logger.info(model.get_train_configs()) - logger.info( - f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, ' - f'lr={LR:.2e}, max_grad_norm={MAX_GRAD_NORM}, ' - f'keep_router_logits={KEEP_ROUTER_LOGITS}') - - for step, batch in enumerate(dataloader): - if callable(batch): - batch = batch() - model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS) - model.clip_grad_and_step( - max_grad_norm=MAX_GRAD_NORM, - gradient_accumulation_steps=GRAD_ACCUM_STEPS, - ) - - is_sync_step = ((step + 1) % GRAD_ACCUM_STEPS == 0) - if is_sync_step: - optimizer_step = (step + 1) // GRAD_ACCUM_STEPS - metric = model.calculate_metric(is_training=True) - if callable(metric): - metric = metric() - logger.info(f'Current optimizer_step {optimizer_step}, metric: {metric}') - if optimizer_step > 0 and optimizer_step % 50 == 0: - model.save(name=f'checkpoint-step-{optimizer_step}', output_dir='./output') - - model.save(name='checkpoint-final', output_dir='./output') - logger.info('Saved final checkpoint to ./output/checkpoint-final') - - -if __name__ == '__main__': - train() diff --git a/cookbook/transformers/ep_fsdp_qwen3_moe.sh b/cookbook/transformers/ep_fsdp_qwen3_moe.sh deleted file mode 100644 index cfc8a7cf..00000000 --- a/cookbook/transformers/ep_fsdp_qwen3_moe.sh +++ /dev/null @@ -1,7 +0,0 @@ -# EP + FSDP2 (Transformers MoE) example. -# With expert_parallel enabled, expert parameters are sharded across the EP dimension. -# Non-expert parameters are sharded by FSDP (across world_size). -# Officially validated scope: qwen3_moe_like models (for example, Qwen3-30B-A3B). -# Other MoE models may work if their MoE blocks expose: `experts` + `gate/router` + `top_k` (or `num_experts_per_tok`). -# EP runtime constraints: `num_experts % ep_world_size == 0`. -CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 ep_fsdp_qwen3_moe.py diff --git a/cookbook/transformers/fsdp2_moe.py b/cookbook/transformers/fsdp2_moe.py deleted file mode 100644 index a2965ee5..00000000 --- a/cookbook/transformers/fsdp2_moe.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -from peft import LoraConfig -from tqdm import tqdm - -import twinkle -from twinkle import DeviceMesh, Platform, get_device_placement, get_logger -from twinkle.dataloader import DataLoader -from twinkle.dataset import Dataset, DatasetMeta -from twinkle.model import TransformersModel -from twinkle.preprocessor import SelfCognitionProcessor -from twinkle.utils.framework import Torch -from twinkle.kernel import apply_npu_patch - -# Construct a device_mesh, fsdp=4, dp=2 -device_mesh = DeviceMesh.from_sizes(fsdp_size=4, dp_size=2) -# use torchrun mode -twinkle.initialize(mode='local', global_device_mesh=device_mesh) - -logger = get_logger() - - -# npu patch -if Torch.is_npu_available(): - apply_npu_patch() - - -def eval(model): - # 100 Samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100))) - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507') - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - dataset.encode() - dataloader = DataLoader(dataset=dataset, batch_size=4) - for step, batch in tqdm(enumerate(dataloader)): - model.forward_only(inputs=batch) - model.calculate_loss() - metrics = model.calculate_metric(is_training=False) - return metrics - - -def train(): - # 1000 samples - dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000))) - # Set template to prepare encoding - dataset.set_template('Template', model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507') - # Preprocess the dataset to standard format - dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区')) - # Encode dataset - dataset.encode() - # Global batch size = 4, for GPUs, so 1 sample per GPU - dataloader = DataLoader(dataset=dataset, batch_size=8) - # Use a TransformersModel, transformer_cls_names_to_wrap=Qwen3MoeSparseMoeBlock to avoid hang of fsdp2 - model = TransformersModel(model_id='ms://Qwen/Qwen3-30B-A3B-Instruct-2507', fsdp_config={'transformer_cls_names_to_wrap':['Qwen3MoeSparseMoeBlock']}) - # Patch MoE model to fix the hang bug, support transformers==4.* - model.apply_patch('ms://twinkle-kit/qwen3_moe_transformers4_patch') - lora_config = LoraConfig( - r=8, - lora_alpha=32, - target_modules='all-linear' - ) - - # Add a lora to model, with name `default` - # Comment this to use full-parameter training - model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=2) - # Add Optimizer for lora `default` - model.set_optimizer(optimizer_cls='AdamW', lr=1e-4) - # Add LRScheduler for lora `default` - model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader)) - logger.info(get_device_placement()) - # Print the training config - logger.info(model.get_train_configs()) - logger.info(f'Total steps: {len(dataloader)}') - loss_metric = 99.0 - # lora: 34G * 8 - for step, batch in enumerate(dataloader): - # Do forward and backward - model.forward_backward(inputs=batch) - # Step - model.clip_grad_and_step() - if step % 20 == 0: - # Print metric - metric = model.calculate_metric(is_training=True) - logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}') - if step > 0 and step % 40 == 0: - metrics = eval(model) - logger.info(f'Eval metric: {metrics}') - metrics['step'] = step - if loss_metric > float(metrics['loss']): - model.save(f'checkpoint-{step}') - loss_metric = float(metrics['loss']) - model.save(f'last-checkpoint') - - -if __name__ == '__main__': - train() diff --git a/cookbook/transformers/fsdp2_moe.sh b/cookbook/transformers/fsdp2_moe.sh deleted file mode 100644 index c496cd1d..00000000 --- a/cookbook/transformers/fsdp2_moe.sh +++ /dev/null @@ -1 +0,0 @@ -CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py diff --git a/cookbook/transformers/fsdp2_moe_npu.sh b/cookbook/transformers/fsdp2_moe_npu.sh deleted file mode 100644 index 349f9d0d..00000000 --- a/cookbook/transformers/fsdp2_moe_npu.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env bash - -# CANN loading -source /usr/local/Ascend/ascend-toolkit/set_env.sh - -ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 fsdp2_moe.py diff --git "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" index 6e77b415..364275b6 100644 --- "a/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" +++ "b/docs/source_zh/\347\273\204\344\273\266/\346\250\241\346\235\277/Template.md" @@ -52,4 +52,5 @@ class Template: 目前模板关系较为简单: - Template类:纯文本模型通用 +- DeepseekV4Template类:DeepSeek V4 使用,重写了 chat template 编码逻辑,`encode_messages` 已内置在 twinkle 中 - Qwen3_5Template类:Qwen3.5多模态模型使用 diff --git a/src/twinkle/model/base.py b/src/twinkle/model/base.py index 4df53e99..303eaf74 100644 --- a/src/twinkle/model/base.py +++ b/src/twinkle/model/base.py @@ -1,6 +1,7 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import os from abc import ABC, abstractmethod +from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union from twinkle import Platform, torch_util @@ -164,6 +165,7 @@ def _try_init_process_group(self): 'init_method': 'env://', 'rank': Platform.get_rank(), 'world_size': Platform.get_world_size(), + 'timeout': timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))), } if self._should_bind_device_id_for_process_group(backend): init_kwargs['device_id'] = torch.device(Platform.get_local_device()) diff --git a/src/twinkle/model/transformers/moe/expert_parallel.py b/src/twinkle/model/transformers/moe/expert_parallel.py index 282f49c0..2b6e45ea 100644 --- a/src/twinkle/model/transformers/moe/expert_parallel.py +++ b/src/twinkle/model/transformers/moe/expert_parallel.py @@ -63,7 +63,7 @@ def apply_expert_parallel( ep_rank = ep_mesh.get_local_rank() specs = [] - for block in find_moe_blocks(model): + for _, block in find_moe_blocks_with_names(model): spec = shard_experts(block, ep_world_size, ep_rank, cfg) patch_forward(block, ep_group, ep_world_size, cfg) specs.append(spec) @@ -83,8 +83,12 @@ def _merge_config(config: dict[str, Any] | None) -> ExpertParallelConfig: def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: + return [block for _, block in find_moe_blocks_with_names(model)] + + +def find_moe_blocks_with_names(model: nn.Module) -> Iterable[tuple[str, nn.Module]]: blocks = [] - for module in model.modules(): + for name, module in model.named_modules(): experts = getattr(module, 'experts', None) if experts is None: continue @@ -92,7 +96,7 @@ def find_moe_blocks(model: nn.Module) -> Iterable[nn.Module]: continue if not _get_gate(module): continue - blocks.append(module) + blocks.append((name, module)) return blocks diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 8d31291a..6fb84530 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -1,10 +1,154 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from typing import Any, Dict, Literal, Optional +import os +from datetime import timedelta +from typing import Any, Dict, Literal, Mapping, Optional from twinkle import DeviceMesh from .load_context import fsdp_pretrained_load_context +def _patch_accelerate_fsdp2_load_full_state_dict(): + """Allow Accelerate FSDP2 state-dict loading to handle unsharded buffers. + + Some Transformers models keep persistent buffers in `state_dict`. FSDP2 + shards parameters as DTensors, but those buffers can remain ordinary + tensors; older Accelerate versions assume every state-dict entry has + `device_mesh` and fail on such buffers. + """ + import accelerate.utils.fsdp_utils as fsdp_utils + import torch + import torch.distributed as dist + from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor + + if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False): + return + + original = fsdp_utils.fsdp2_load_full_state_dict + + def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False): + meta_sharded_sd = model.state_dict() + sharded_sd = {} + + def _infer_parameter_dtype(model, param_name, empty_param): + try: + old_param = model.get_parameter_or_buffer(param_name) + except AttributeError: + # Need this for LoRA, as some params are not registered as + # parameters/buffers but still appear in the state dict. + base_param_name, local_param_name = param_name.rsplit('.', 1) + submodule = model.get_submodule(base_param_name) + old_param = getattr(submodule, local_param_name) + + is_torch_e4m3fn_available = hasattr(torch, 'float8_e4m3fn') + is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn + casting_dtype = None + if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: + casting_dtype = old_param.dtype + return old_param is not None and old_param.is_contiguous(), casting_dtype + + def _cast_and_contiguous(tensor, to_contiguous, dtype): + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if to_contiguous: + tensor = tensor.contiguous() + return tensor + + def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): + if device_mesh.device_type == 'cuda': + return distribute_tensor(full_tensor, device_mesh, placements) + + local_tensor = full_tensor + for mesh_dim, placement in enumerate(placements): + if isinstance(placement, Shard): + # All ranks already received the full tensor via broadcast. + # Split locally to avoid distribute_tensor's scatter path, + # which is fragile on some torch_npu/HCCL versions. + local_tensor = placement._shard_tensor( + local_tensor, + device_mesh, + mesh_dim, + src_data_rank=None, + ) + elif isinstance(placement, Replicate): + continue + elif isinstance(placement, Partial): + raise NotImplementedError('FSDP2 full-state loading does not support Partial placements.') + else: + raise NotImplementedError(f'Unsupported DTensor placement: {placement}') + return DTensor.from_local( + local_tensor, + device_mesh=device_mesh, + placements=placements, + run_check=False, + shape=full_tensor.shape, + stride=full_tensor.stride(), + ) + + def _load_full_value(param_name, sharded_param): + if param_name not in full_sd: + raise KeyError( + f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. " + f'Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys.') + full_value = full_sd[param_name].detach() + if isinstance(full_value, DTensor): + full_value = full_value.to_local() + device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device + return full_value.to(device).contiguous() + + def _tensor_debug(tensor): + if isinstance(tensor, DTensor): + return (f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} ' + f'placements={tensor.placements} mesh={tensor.device_mesh}') + if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'): + return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}' + return f'type={type(tensor).__name__}' + + for param_name, sharded_param in meta_sharded_sd.items(): + if isinstance(sharded_param, DTensor): + device_mesh = sharded_param.device_mesh + placements = sharded_param.placements + if accelerator.is_main_process: + full_param = _load_full_value(param_name, sharded_param) + else: + full_param = torch.empty( + sharded_param.size(), + device=device_mesh.device_type, + dtype=sharded_param.dtype, + ) + + dist.broadcast(full_param, src=0, group=dist.group.WORLD) + sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements) + to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param) + sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) + if cpu_offload: + sharded_tensor = sharded_tensor.to('cpu') + sharded_sd[param_name] = sharded_tensor + continue + + if accelerator.is_main_process: + full_value = _load_full_value(param_name, sharded_param) + else: + full_value = torch.empty( + sharded_param.size(), + device=accelerator.device, + dtype=sharded_param.dtype, + ) + + dist.broadcast(full_value, src=0, group=dist.group.WORLD) + to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value) + full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype) + if cpu_offload: + full_value = full_value.to('cpu') + sharded_sd[param_name] = full_value + + model.load_state_dict(sharded_sd, assign=True) + return model + + patched_fsdp2_load_full_state_dict._twinkle_patched = True + patched_fsdp2_load_full_state_dict._twinkle_original = original + fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict + + class AccelerateStrategy: """A training strategy that uses `accelerate` to wrap models. @@ -24,6 +168,9 @@ def __init__( memory_efficient_init: bool = False, ): from accelerate import Accelerator + from accelerate.utils import InitProcessGroupKwargs + + _patch_accelerate_fsdp2_load_full_state_dict() self.device_mesh = device_mesh self.mixed_precision = mixed_precision @@ -32,6 +179,9 @@ def __init__( fsdp_plugin = self._fsdp_config_from_device_mesh(device_mesh, fsdp_config, memory_efficient_init) kwargs_handlers = [] + kwargs_handlers.append( + InitProcessGroupKwargs( + timeout=timedelta(seconds=int(os.environ.get('TWINKLE_DIST_TIMEOUT_SECONDS', '7200'))))) if ddp_config is not None: from accelerate import DistributedDataParallelKwargs ddp_config = DistributedDataParallelKwargs(**ddp_config) @@ -47,6 +197,12 @@ def __init__( def pretrained_load_context(self): return fsdp_pretrained_load_context(self._memory_efficient_init and self.device_mesh is not None) + def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None: + return + + def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool): + return config_or_dir + @staticmethod def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh): # TODO should test with transformers v5.0 @@ -119,11 +275,17 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di return fsdp_plugin def wrap_model(self, model, *args): - return self.accelerator.prepare(model, *args) + result = self.accelerator.prepare(model, *args) + return result def unwrap_model(self, model): return self.accelerator.unwrap_model(model, keep_torch_compile=False) + def load_peft_weights(self, model, adapter_weights: Mapping[str, Any], adapter_name: str) -> None: + from peft.utils import set_peft_model_state_dict + + set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name) + def _get_fsdp_plugin(self): state = self.accelerator.state return state.fsdp_plugin if hasattr(state, 'fsdp_plugin') else None diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index 9e3bbad7..ef5666ce 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -4,14 +4,21 @@ from torch import nn from torch.distributed.device_mesh import DeviceMesh as TorchDeviceMesh from torch.distributed.fsdp import fully_shard -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, Set -from twinkle.utils import DeviceMesh, Platform, torch_util +from twinkle.utils import DeviceMesh, Platform, get_logger, torch_util +from twinkle.utils.torch_utils import clone_state_dict_to_cpu from .load_context import fsdp_pretrained_load_context if TYPE_CHECKING: from torch.distributed.fsdp import MixedPrecisionPolicy +logger = get_logger() + +LORA_STATE_KEY_MARKERS = ('lora_A', 'lora_B', 'lora_embedding') +PEFT_BASE_PREFIX = 'base_model.model.' +PEFT_BASE_LAYER_SEGMENT = 'base_layer' + class NativeFSDPStrategy: @@ -29,6 +36,8 @@ def __init__(self, self.enable_ep = enable_ep self.ep_fsdp_device_mesh = self._build_ep_fsdp_device_mesh(ep_size) if enable_ep else None self._rank0_pre_ep_full_state_dict = None + self._adapter_full_state_dict = None + self._pre_ep_state_captured = False def pretrained_load_context(self): # Native FSDP loads pretrained weights via rank0 broadcast during wrap_model(). @@ -39,9 +48,55 @@ def pretrained_load_context(self): def use_rank0_pretrained_broadcast(self) -> bool: return self._memory_efficient_init and self.device_mesh is not None + def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None: + if self._pre_ep_state_captured: + return + if not (enable_ep and self.use_rank0_pretrained_broadcast()): + return + is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0 + self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_rank0 else {}) + self._pre_ep_state_captured = True + + def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool): + if not enable_ep: + return config_or_dir + + from peft import LoraConfig + + if not isinstance(config_or_dir, LoraConfig): + return config_or_dir + + target_params = getattr(config_or_dir, 'target_parameters', None) or [] + if target_params: + if getattr(config_or_dir, 'use_dora', False): + raise ValueError('PEFT ParamWrapper does not support use_dora=True with target_parameters; ' + 'disable DoRA when training expert parameters.') + if getattr(config_or_dir, 'lora_bias', False): + raise ValueError('PEFT ParamWrapper does not support lora_bias=True with target_parameters.') + if float(getattr(config_or_dir, 'lora_dropout', 0.0)) > 0.0: + raise ValueError('PEFT ParamWrapper does not support lora_dropout>0 with target_parameters.') + return config_or_dir + + config_or_dir.target_parameters = ['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'] + logger.info('EP+LoRA auto-filled target_parameters with ' + "['mlp.experts.gate_up_proj', 'mlp.experts.down_proj'].") + return config_or_dir + def set_rank0_pre_ep_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: self._rank0_pre_ep_full_state_dict = state_dict + def set_adapter_full_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + self._adapter_full_state_dict = state_dict + + def load_peft_weights(self, model, adapter_weights: Mapping[str, torch.Tensor], adapter_name: str) -> None: + from peft.utils import set_peft_model_state_dict + + fsdp_world_size = self.device_mesh.fsdp_world_size if self.device_mesh is not None else 1 + if fsdp_world_size > 1: + _load_peft_weights_for_native_fsdp2(model, adapter_weights, adapter_name, self) + else: + set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name) + def _build_ep_fsdp_device_mesh(self, ep_size: Optional[int] = None) -> Optional[TorchDeviceMesh]: if self.device_mesh is None: return None @@ -73,12 +128,16 @@ def wrap_model(self, model, optimizer=None): original_sd = None saved_buffers = None + adapter_source_sd = {} + adapter_full_sd = {} if use_meta: is_rank0 = (dist.get_rank() == 0) if ep_enabled and self._rank0_pre_ep_full_state_dict is not None: original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {} else: original_sd = model.state_dict() if is_rank0 else {} + adapter_source_sd = _collect_adapter_source_state(model.state_dict()) + adapter_full_sd = self._adapter_full_state_dict if is_rank0 and self._adapter_full_state_dict else {} saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} if is_rank0: model = model.to(torch.device('meta')) @@ -144,22 +203,24 @@ def wrap_model(self, model, optimizer=None): if use_meta: device_type = self.device_mesh.device_type or 'cuda' - if ep_enabled: - _broadcast_sharded_state_dict( - model, - original_sd or {}, - device_type=device_type, - expert_shard_specs=_collect_ep_expert_shard_specs(model), - rank_to_ep_rank=_build_rank_to_ep_rank(self.ep_fsdp_device_mesh), - ) - else: - _load_rank0_full_state_dict(model, original_sd or {}) + expert_shard_specs = _collect_ep_expert_shard_specs(model) if ep_enabled else {} + rank_to_ep_rank = _build_rank_to_ep_rank(self.ep_fsdp_device_mesh) if ep_enabled else {} + _broadcast_sharded_state_dict( + model, + original_sd or {}, + device_type=device_type, + expert_shard_specs=expert_shard_specs, + rank_to_ep_rank=rank_to_ep_rank, + adapter_source_sd=adapter_source_sd, + adapter_full_sd=adapter_full_sd, + ) + self._adapter_full_state_dict = None target_device = torch.device(device_type) _broadcast_non_persistent_buffers(model, saved_buffers or {}, device=target_device) if hasattr(model, 'tie_weights'): model.tie_weights() - if ep_enabled and layer_pairs: + if ep_enabled and layer_pairs and self.fsdp_config.get('manual_prefetch', False): _setup_manual_prefetch([lp[0] for lp in layer_pairs]) if ep_enabled: @@ -272,7 +333,7 @@ def get_full_state_dict(self, model) -> dict: local_full = local_full.contiguous().to(Platform.get_local_device()) gathered = [torch.empty_like(local_full) for _ in range(ep_world_size)] dist.all_gather(gathered, local_full, group=ep_group) - local_full = torch.cat(gathered, dim=0) + local_full = torch.cat(gathered, dim=_ep_expert_state_dict_gather_dim(name)) state_dict[name] = local_full.cpu() del gathered, local_full else: @@ -297,6 +358,12 @@ def _detect_ep_expert_names(model: nn.Module) -> Set[str]: return candidate_names & actual_param_names +def _ep_expert_state_dict_gather_dim(name: str) -> int: + if 'lora_B' in name: + return 1 + return 0 + + def _build_mp_policy(mixed_precision: str) -> 'MixedPrecisionPolicy': from torch.distributed.fsdp import MixedPrecisionPolicy if mixed_precision == 'bf16': @@ -563,18 +630,152 @@ def _rebind_optimizer(optimizer: torch.optim.Optimizer, model: nn.Module) -> tor return optimizer -def _load_rank0_full_state_dict(model: nn.Module, full_sd: dict) -> None: - """Load rank0 full weights into a sharded FSDP2 model via DCP broadcast.""" - from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict +def _is_lora_state_key(name: str) -> bool: + return any(marker in name for marker in LORA_STATE_KEY_MARKERS) - set_model_state_dict( - model=model, - model_state_dict=full_sd, - options=StateDictOptions( - full_state_dict=True, - broadcast_from_rank0=True, - ), - ) + +def _strip_peft_base_prefix(name: str) -> str: + while name.startswith(PEFT_BASE_PREFIX): + name = name[len(PEFT_BASE_PREFIX):] + return name + + +def _strip_base_layer_segments(name: str) -> str: + return '.'.join(segment for segment in name.split('.') if segment != PEFT_BASE_LAYER_SEGMENT) + + +def _source_key_candidates(param_name: str) -> List[str]: + stripped_prefix = _strip_peft_base_prefix(param_name) + candidates = [ + param_name, + stripped_prefix, + _strip_base_layer_segments(param_name), + _strip_base_layer_segments(stripped_prefix), + ] + deduped = [] + for candidate in candidates: + if candidate not in deduped: + deduped.append(candidate) + return deduped + + +def _resolve_full_state_source_key(param_name: str, source_state: Mapping[str, Any]) -> str: + if _is_lora_state_key(param_name): + raise KeyError(f"LoRA parameter '{param_name}' must be loaded from adapter source state.") + + candidates = _source_key_candidates(param_name) + for candidate in candidates: + if candidate in source_state: + return candidate + raise KeyError(f"Missing source metadata for parameter '{param_name}'. " + f'Tried source keys: {", ".join(candidates)}.') + + +def _collect_adapter_source_state(state_dict: Mapping[str, Any]) -> Dict[str, Any]: + adapter_state = {} + for name, tensor in state_dict.items(): + if not _is_lora_state_key(name) or not hasattr(tensor, 'detach'): + continue + if getattr(tensor, 'is_meta', False): + continue + adapter_state[name] = tensor.detach().cpu().clone() + return adapter_state + + +def _collect_state_metadata(state_dict: Mapping[str, Any]) -> Dict[str, tuple[tuple[int, ...], Any]]: + return { + name: (tuple(tensor.shape), tensor.dtype) + for name, tensor in state_dict.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype') + } + + +def _get_named_child(module, name: str): + if hasattr(module, name): + return getattr(module, name) + if name.isdigit() and hasattr(module, '__getitem__'): + try: + return module[int(name)] + except (IndexError, TypeError, KeyError): + return None + return None + + +def _split_for_ep_pre_distribute(model, model_key: str, value: torch.Tensor, ep_world_size: int, + ep_rank: int) -> torch.Tensor: + """Slice saved LoRA expert weights by EP rank before DTensor/FSDP placement.""" + if ep_world_size <= 1: + return value + + parent = model + matched = False + parts = model_key.split('.') + for i, part in enumerate(parts[:-1]): + parent = _get_named_child(parent, part) + if parent is None: + return value + next_seg = parts[i + 1] if i + 1 < len(parts) else None + if getattr(parent, '_ep_patched', False) and next_seg == 'experts': + matched = True + + if not matched: + return value + if 'lora_A' in model_key: + chunk = value.size(0) // ep_world_size + return value.narrow(0, ep_rank * chunk, chunk).contiguous() + if 'lora_B' in model_key: + chunk = value.size(1) // ep_world_size + return value.narrow(1, ep_rank * chunk, chunk).contiguous() + return value + + +def _has_param_wrapper_without_base_weight(model) -> bool: + for module in model.modules(): + if not hasattr(module, 'parameter_name'): + continue + get_base_layer = getattr(module, 'get_base_layer', None) + if get_base_layer is None: + continue + base_layer = get_base_layer() + if not hasattr(base_layer, 'weight'): + return True + return False + + +def _load_peft_weights_for_native_fsdp2(model, adapter_weights: Mapping[str, torch.Tensor], adapter_name: str, + strategy: NativeFSDPStrategy) -> None: + """Load PEFT adapter weights into a native FSDP2 model, including EP expert adapters.""" + from peft.utils import set_peft_model_state_dict + from torch.distributed.tensor import DTensor, distribute_tensor + + ep_fsdp_mesh = getattr(strategy, 'ep_fsdp_device_mesh', None) + ep_world_size = ep_fsdp_mesh['ep'].size() if ep_fsdp_mesh is not None else 1 + ep_rank = ep_fsdp_mesh['ep'].get_local_rank() if ep_world_size > 1 else 0 + + model_sd = model.state_dict() + converted_weights = {} + direct_weights = {} + full_adapter_source = {} + for key, value in adapter_weights.items(): + model_key = key + if f'.{adapter_name}.weight' not in model_key: + model_key = model_key.replace('.weight', f'.{adapter_name}.weight') + if model_key in model_sd: + param = model_sd[model_key] + full_adapter_source[model_key] = value.detach().cpu().clone() + value = _split_for_ep_pre_distribute(model, model_key, value, ep_world_size, ep_rank) + if isinstance(param, DTensor) and not isinstance(value, DTensor): + value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements) + direct_weights[model_key] = value + converted_weights[key] = value + + set_adapter_full_state = getattr(strategy, 'set_adapter_full_state_dict', None) + if set_adapter_full_state is not None and full_adapter_source: + set_adapter_full_state(full_adapter_source) + + if _has_param_wrapper_without_base_weight(model): + model.load_state_dict(direct_weights, strict=False) + else: + set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name) def _broadcast_sharded_state_dict( @@ -583,6 +784,8 @@ def _broadcast_sharded_state_dict( device_type: str = 'cuda', expert_shard_specs: Optional[Dict[str, Dict[str, int]]] = None, rank_to_ep_rank: Optional[Dict[int, int]] = None, + adapter_source_sd: Optional[Dict[str, torch.Tensor]] = None, + adapter_full_sd: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Broadcast rank0 full state dict and materialize local FSDP2/EP shards.""" from torch.distributed.tensor import DTensor, Partial, Replicate, Shard @@ -592,16 +795,28 @@ def _broadcast_sharded_state_dict( is_rank0 = (dist.get_rank() == 0) expert_shard_specs = expert_shard_specs or {} rank_to_ep_rank = rank_to_ep_rank or {} - + adapter_source_sd = adapter_source_sd or {} + adapter_full_sd = adapter_full_sd or {} source_metadata = None + source_keys = None + adapter_metadata = None if is_rank0: - source_metadata = { - name: (tuple(tensor.shape), tensor.dtype) - for name, tensor in full_sd.items() if hasattr(tensor, 'shape') and hasattr(tensor, 'dtype') - } - metadata_holder = [source_metadata] + source_metadata = {} + source_keys = {} + for param_name in meta_sharded_sd: + if _is_lora_state_key(param_name): + continue + source_key = _resolve_full_state_source_key(param_name, full_sd) + source_tensor = full_sd[source_key] + if hasattr(source_tensor, 'shape') and hasattr(source_tensor, 'dtype'): + source_metadata[param_name] = (tuple(source_tensor.shape), source_tensor.dtype) + source_keys[param_name] = source_key + adapter_metadata = _collect_state_metadata({**adapter_source_sd, **adapter_full_sd}) + metadata_holder = [source_metadata, source_keys, adapter_metadata] dist.broadcast_object_list(metadata_holder, src=0) source_metadata = metadata_holder[0] or {} + source_keys = metadata_holder[1] or {} + adapter_metadata = metadata_holder[2] or {} def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): local_tensor = full_tensor @@ -613,6 +828,9 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): mesh_dim, src_data_rank=None, ) + # _shard_tensor may return a view into the replicated full + # tensor. Clone it so the final DTensor shard does not keep + # the full parameter storage alive after loading. local_tensor = local_tensor.contiguous().clone() elif isinstance(placement, Replicate): continue @@ -629,7 +847,52 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): stride=full_tensor.stride(), ) - def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): + def _broadcast_adapter_source_tensor(full_tensor, sharded_param): + if not isinstance(sharded_param, DTensor): + dist.broadcast(full_tensor, src=0) + return full_tensor + mesh = sharded_param.device_mesh.mesh + source_rank = int(mesh.flatten()[0].item()) + dist.broadcast(full_tensor, src=source_rank, group=sharded_param.device_mesh.get_group()) + return full_tensor + + def _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param): + local_shape = tuple(sharded_param.size()) + _, source_dtype = adapter_metadata[param_name] + local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) + + if is_rank0: + shard_dim = _ep_expert_state_dict_gather_dim(param_name) + local_dim = local_shape[shard_dim] + world_size = dist.get_world_size() + for rank in range(world_size): + if rank not in rank_to_ep_rank: + raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') + ep_rank = rank_to_ep_rank[rank] + start = ep_rank * local_dim + chunk = full_tensor.narrow(shard_dim, start, local_dim).contiguous().to(device_type) + if rank == 0: + local_tensor.copy_(chunk) + else: + dist.send(chunk, dst=rank) + else: + dist.recv(local_tensor, src=0) + + return local_tensor + + def _get_adapter_source(param_name): + if param_name in adapter_full_sd: + adapter_tensor = adapter_full_sd[param_name] + return adapter_tensor, tuple(adapter_tensor.shape), adapter_tensor.dtype + if param_name in adapter_source_sd: + adapter_tensor = adapter_source_sd[param_name] + return adapter_tensor, tuple(adapter_tensor.shape), adapter_tensor.dtype + if param_name in adapter_metadata: + source_shape, source_dtype = adapter_metadata[param_name] + return None, source_shape, source_dtype + raise KeyError(f"Missing adapter source state for parameter '{param_name}'.") + + def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): spec = expert_shard_specs[param_name] experts_per_rank = spec['experts_per_rank'] num_experts = spec['num_experts'] @@ -663,19 +926,39 @@ def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): return local_tensor for param_name, sharded_param in meta_sharded_sd.items(): - is_ep_expert_param = param_name in expert_shard_specs - if param_name not in source_metadata: - raise KeyError(f"Missing source metadata for parameter '{param_name}'.") - source_shape, source_dtype = source_metadata[param_name] - - if is_rank0: - if param_name not in full_sd: + shape = sharded_param.size() + is_adapter_param = _is_lora_state_key(param_name) + is_ep_expert_param = param_name in expert_shard_specs and not is_adapter_param + if is_adapter_param: + adapter_tensor, source_shape, source_dtype = _get_adapter_source(param_name) + else: + if param_name not in source_metadata: + raise KeyError(f"Missing source metadata for parameter '{param_name}'.") + source_shape, source_dtype = source_metadata[param_name] + is_ep_adapter_param = ( + is_adapter_param and param_name in expert_shard_specs and tuple(source_shape) != tuple(shape)) + + if is_adapter_param: + if adapter_tensor is not None: + full_tensor = adapter_tensor.detach() + if isinstance(full_tensor, DTensor): + full_tensor = full_tensor.to_local() + if not is_ep_adapter_param: + full_tensor = full_tensor.to(device_type) + else: + full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype) + if not is_ep_adapter_param: + full_tensor = _broadcast_adapter_source_tensor(full_tensor, sharded_param) + elif is_rank0: + source_key = source_keys[param_name] + if source_key not in full_sd: raise KeyError( f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict.") - full_param = full_sd[param_name] + full_param = full_sd[source_key] full_tensor = full_param.detach() if isinstance(full_tensor, DTensor): full_tensor = full_tensor.to_local() + # EP expert params: keep on CPU to avoid OOM; move chunks lazily in _scatter_ep_expert_tensor. if not is_ep_expert_param: full_tensor = full_tensor.to(device_type) if tuple(full_tensor.shape) != tuple(source_shape) or full_tensor.dtype != source_dtype: @@ -686,14 +969,16 @@ def _scatter_ep_expert_tensor(param_name: str, full_tensor, sharded_param): full_tensor = None if is_ep_expert_param else torch.empty( source_shape, device=device_type, dtype=source_dtype) - if is_ep_expert_param: + if is_ep_adapter_param: + full_tensor = _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param) + elif is_ep_expert_param: full_tensor = _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param) else: - if tuple(sharded_param.size()) != tuple(source_shape): + if tuple(shape) != tuple(source_shape): raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " - f'sharded logical shape={tuple(sharded_param.size())}, ' - f'source shape={source_shape}.') - dist.broadcast(full_tensor, src=0) + f'sharded logical shape={tuple(shape)}, source shape={source_shape}.') + if not is_adapter_param: + dist.broadcast(full_tensor, src=0) torch_util.synchronize() if isinstance(sharded_param, DTensor): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index c2bf8c7e..a9a80b8f 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -13,7 +13,7 @@ from copy import copy from dataclasses import dataclass, field from peft import PeftConfig, PeftModel, get_peft_model -from peft.utils import load_peft_weights, set_peft_model_state_dict +from peft.utils import load_peft_weights from safetensors.torch import save_file from torch import GradScaler from torch.optim import Adam, AdamW, Optimizer @@ -42,7 +42,6 @@ from twinkle.utils import construct_class, get_logger, selective_log_softmax, torch_util from twinkle.utils.framework import Torch from twinkle.utils.grad_clip import normalize_and_clip_grad_norm -from twinkle.utils.torch_utils import clone_state_dict_to_cpu from twinkle.utils.transformers_utils import filter_from_config_kwargs logger = get_logger() @@ -287,11 +286,7 @@ def _not_encoded(inputs): def _lazy_wrap_model(self): if not self._model_wrapped: optimizer_groups = [og for og in self.optimizer_group.values() if og.optimizer is not None] - use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False) - set_pre_ep_state = getattr(self.strategy, 'set_rank0_pre_ep_full_state_dict', None) - if self._enable_expert_parallel and use_rank0_broadcast() and set_pre_ep_state is not None: - is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0 - set_pre_ep_state(clone_state_dict_to_cpu(self.model.state_dict()) if is_rank0 else {}) + self.strategy.capture_pre_ep_state_if_needed(self.model, enable_ep=self._enable_expert_parallel) self._maybe_apply_expert_parallel() self._ensure_sp_strategy() if self.sp_strategy is not None: @@ -912,11 +907,17 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int # Full model save processed_state_dict = self.strategy.get_full_state_dict(self.model) else: - # LoRA adapter save - state_dict = self.get_state_dict(adapter_name=adapter_name, **kwargs) - for key, value in state_dict.items(): - key = key.replace(f'.{adapter_name}.', '.') - processed_state_dict[key] = torch_util.to_local_tensor(value).cpu() + # LoRA adapter save (EP-aware via strategy.get_full_state_dict) + full_state = self.strategy.get_full_state_dict(self.model) + adapter_marker = '.lora_' + adapter_suffix = f'.{adapter_name}.' + for key, value in full_state.items(): + if adapter_marker not in key: + continue + if adapter_suffix not in key: + continue + normalized = key.replace(adapter_suffix, '.') + processed_state_dict[normalized] = value if isinstance(model, PeftModel): if Platform.is_master(): @@ -1015,28 +1016,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs): model = self.strategy.unwrap_model(self.model) if isinstance(model, PeftModel): adapter_weights = load_peft_weights(checkpoint_dir, device='cpu') - - def load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name='default'): - from torch.distributed.tensor import DTensor, distribute_tensor - - model_sd = model.state_dict() - converted_weights = {} - for key, value in adapter_weights.items(): - model_key = key - if f'.{adapter_name}.weight' not in model_key: - model_key = model_key.replace('.weight', f'.{adapter_name}.weight') - if model_key in model_sd: - param = model_sd[model_key] - if isinstance(param, DTensor) and not isinstance(value, DTensor): - value = distribute_tensor(value.to(param.device), param.device_mesh, param.placements) - converted_weights[key] = value - - set_peft_model_state_dict(model, converted_weights, adapter_name=adapter_name) - - if self.device_mesh.fsdp_world_size > 1: - load_peft_weights_for_fsdp2(model, adapter_weights, adapter_name=adapter_name) - else: - set_peft_model_state_dict(model, adapter_weights, adapter_name=adapter_name) + self.strategy.load_peft_weights(model, adapter_weights, adapter_name) else: raise NotImplementedError @@ -1201,6 +1181,15 @@ def calculate_metric(self, is_training, **kwargs): def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs): assert adapter_name, 'Use a different adapter_name, current is empty.' unwrapped_model = self.strategy.unwrap_model(self.model) + if not isinstance(config_or_dir, str): + config_or_dir = self.strategy.prepare_adapter_config( + config_or_dir, enable_ep=getattr(self, '_enable_expert_parallel', False)) + + if getattr(self, '_enable_expert_parallel', False): + self.strategy.capture_pre_ep_state_if_needed(self.model, enable_ep=self._enable_expert_parallel) + self._maybe_apply_expert_parallel() + unwrapped_model = self.strategy.unwrap_model(self.model) + if isinstance(config_or_dir, str): config_or_dir = HubOperation.download_model(config_or_dir) _adapted_model = PeftModel.from_pretrained(