diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index a3971ef16..aa03f0f7e 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -3,6 +3,7 @@ from diffsynth.core.data.operators import LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig from diffsynth.diffusion import * +from wan_dual_gpu_diffsynth import enable_wan_dual_gpu, is_dual_gpu_enabled os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -171,10 +172,53 @@ def wan_parser(): fp8_models=args.fp8_models, offload_models=args.offload_models, task=args.task, - device="cpu" if args.initialize_model_on_cpu else accelerator.device, + # Dual-GPU: force CPU load so the 14B bf16 transformer doesn't + # pre-allocate on cuda:0 before we get a chance to split it. + device="cpu" if (args.initialize_model_on_cpu or is_dual_gpu_enabled()) else accelerator.device, max_timestep_boundary=args.max_timestep_boundary, min_timestep_boundary=args.min_timestep_boundary, ) + + # Dual-GPU split: gated on DIFFSYNTH_DUAL_GPU env var. Distributes pipe.dit + # across cuda:0 and cuda:1 at the blocks midpoint. Called AFTER LoRA + # injection (which happened inside WanTrainingModule.__init__) so PEFT + # LoRA params follow the base layer's device automatically when + # block.to() runs. + if is_dual_gpu_enabled(): + for _i in range(torch.cuda.device_count()): + _free, _total = torch.cuda.mem_get_info(_i) + print(f"[wan-dual-gpu] pre-distribute cuda:{_i} free={_free/1024**3:.1f}G / {_total/1024**3:.1f}G") + # fp8 weight-only quant on CPU BEFORE distribute. Wan 2.2 14B is + # ~28 GB bf16 -- fits one 32 GB card, but with video activations + # at 480x832x49 frames plus gradient checkpointing the dual split + # gives breathing room. Filter excludes LoRA's lora_A/lora_B + # submodules -- quantizing them strips requires_grad and breaks + # backward (see also the FLUX.2 helper for the same pattern). + import gc as _gc + from torchao.quantization import quantize_, Float8WeightOnlyConfig + def _quant_filter(module, name): + if not isinstance(module, torch.nn.Linear): + return False + return "lora_A" not in name and "lora_B" not in name + quantize_(model.pipe.dit, Float8WeightOnlyConfig(), filter_fn=_quant_filter) + _gc.collect() + print("[wan-dual-gpu] fp8 weight-only quant done (LoRA params unmodified)") + enable_wan_dual_gpu(model.pipe.dit) + # Pin pipe.device to cuda:0 for input transfer (forward() calls + # transfer_data_to_device(inputs, pipe.device, ...)). Scaffold + # (patch_embedding, text_embedding) lives on cuda:0. + model.pipe.device = torch.device("cuda:0") + # Note: this Wan dual-GPU path depends on PR #1434 (flux2 dual-GPU) + # for the runner.py env-gated skip of model.to(accelerator.device). + # Without that change, runner.py will move the split DiT back to a + # single device and undo the distribute. The launcher's + # DIFFSYNTH_DUAL_GPU=true env var is read by both this helper's + # is_dual_gpu_enabled() and runner.py's gate, so no re-broadcast + # is needed here once both PRs are applied. + for _i in range(torch.cuda.device_count()): + _free, _total = torch.cuda.mem_get_info(_i) + print(f"[wan-dual-gpu] post-distribute cuda:{_i} free={_free/1024**3:.1f}G / {_total/1024**3:.1f}G") + model_logger = ModelLogger( args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, diff --git a/examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py b/examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py new file mode 100644 index 000000000..af98853fb --- /dev/null +++ b/examples/wanvideo/model_training/wan_dual_gpu_diffsynth.py @@ -0,0 +1,172 @@ +"""Dual-GPU model-parallel for Wan video DiT in DiffSynth-Studio. + +Drop-in helper for DiffSynth-Studio +(https://github.com/modelscope/DiffSynth-Studio) ``examples/wanvideo/model_training/`` +that splits the ``WanModel`` transformer across two CUDA devices at the +``blocks`` midpoint. Enables Wan 2.1 / 2.2 LoRA training on pairs of +24+ GB consumer GPUs (2× RTX 3090, 2× RTX 4090, 2× RTX 5090) — useful +for the 14B variants where the fp8-quantized weights fit on a single +32 GB card but video activations push training OOM at 480×832×49 +frames + gradient checkpointing. + +Companion to the FLUX.2 dual-GPU helper at +``examples/flux2/model_training/flux2_dual_gpu_diffsynth.py``. The +shape is similar but simpler — Wan has one block type +(``DiTBlock``) instead of FLUX.2's double + single stream split, so +the helper only registers per-block hooks across one boundary. + +Env vars: + DIFFSYNTH_DUAL_GPU=true enable dual-GPU path + DIFFSYNTH_DUAL_GPU_SPLIT_AT=15 override split index + (default: num_blocks // 2) + +Usage in DiffSynth-Studio's training script +(``examples/wanvideo/model_training/train.py``):: + + from wan_dual_gpu_diffsynth import enable_wan_dual_gpu + + # ...build training module normally... + training_module = WanTrainingModule(...) + + # Activate the split (env-gated; no-op when DIFFSYNTH_DUAL_GPU is unset). + # Call AFTER LoRA injection (switch_pipe_to_training_mode) so PEFT + # has wrapped target modules. The split places the wrapped modules + # and PEFT LoRA params follow the base layer's device automatically. + enable_wan_dual_gpu(training_module.pipe.dit) + +Launch with ``--num_processes=1`` — this is *model* parallelism (both +GPUs cooperate on a single training step), not data parallelism (which +would spawn one process per GPU). +""" +from __future__ import annotations + +import os +from typing import Any + +import torch +import torch.nn as nn + + +# ─── Env-gated public surface ─────────────────────────────────────────────── + +def is_dual_gpu_enabled() -> bool: + """True iff ``DIFFSYNTH_DUAL_GPU=true`` in the environment.""" + return os.getenv("DIFFSYNTH_DUAL_GPU", "false").lower() == "true" + + +def get_split_at(num_blocks: int) -> int: + """Block split index. Override via ``DIFFSYNTH_DUAL_GPU_SPLIT_AT``.""" + override = os.getenv("DIFFSYNTH_DUAL_GPU_SPLIT_AT") + if override is not None: + return int(override) + return num_blocks // 2 + + +def enable_wan_dual_gpu(dit: nn.Module) -> nn.Module: + """Distribute the WanModel DiT across cuda:0 and cuda:1. + + When ``DIFFSYNTH_DUAL_GPU`` is unset this is a no-op pass-through. + + Call after the WanModel is loaded and after PEFT LoRA injection has + run. PEFT places LoRA params on the base layer's device automatically, + so calling this function after LoRA injection puts everything in the + right place. + + Returns the (in-place modified) dit. + """ + if not is_dual_gpu_enabled(): + return dit + + if torch.cuda.device_count() < 2: + raise RuntimeError( + f"DIFFSYNTH_DUAL_GPU=true requires ≥2 CUDA devices, found " + f"{torch.cuda.device_count()}." + ) + + num_blocks = len(dit.blocks) + split_at = get_split_at(num_blocks) + if not 0 < split_at < num_blocks: + raise RuntimeError( + f"DIFFSYNTH_DUAL_GPU_SPLIT_AT={split_at} out of range " + f"(dit has {num_blocks} blocks)." + ) + + cuda0 = torch.device("cuda:0") + cuda1 = torch.device("cuda:1") + + # Place pre-block scaffolding + first half of blocks + output head on + # cuda:0. WanModel uses self.freqs (RoPE precomputed table) as a + # plain tensor attribute, not a registered buffer/parameter -- move + # it explicitly so the .to(device) calls below don't miss it. Same + # for any optional .img_emb / .ref_conv / .control_adapter modules + # that some Wan variants add. + dit.patch_embedding.to(cuda0) + dit.text_embedding.to(cuda0) + dit.time_embedding.to(cuda0) + dit.time_projection.to(cuda0) + for block in dit.blocks[:split_at]: + block.to(cuda0) + for block in dit.blocks[split_at:]: + block.to(cuda1) + dit.head.to(cuda0) + if hasattr(dit, "img_emb") and dit.img_emb is not None: + dit.img_emb.to(cuda0) + if hasattr(dit, "ref_conv") and dit.ref_conv is not None: + dit.ref_conv.to(cuda0) + if hasattr(dit, "control_adapter") and dit.control_adapter is not None: + dit.control_adapter.to(cuda0) + + # WanModel.freqs is a plain tuple of CPU tensors (not a registered + # buffer); .to() on the module doesn't move it. Push it to cuda:0 + # since the patchify/freq-concat step happens there. + if hasattr(dit, "freqs"): + if isinstance(dit.freqs, (tuple, list)): + dit.freqs = tuple(f.to(cuda0) if torch.is_tensor(f) else f for f in dit.freqs) + elif torch.is_tensor(dit.freqs): + dit.freqs = dit.freqs.to(cuda0) + + # Per-block hook on every cuda:1 block. WanModel.forward passes + # loop-level constants (context, t_mod, freqs) positionally to each + # block; a boundary-only hook bridges only the first block's inputs, + # subsequent blocks receive the cuda:0 originals and crash with a + # device-mismatch error. + for block in dit.blocks[split_at:]: + block.register_forward_pre_hook( + _make_device_bridge_hook(cuda1), with_kwargs=True + ) + # Bridge activations back to cuda:0 for the head + unpatchify. + dit.head.register_forward_pre_hook( + _make_device_bridge_hook(cuda0), with_kwargs=True + ) + + dit._wan_dual_gpu_split_at = split_at + return dit + + +# ─── Internals ────────────────────────────────────────────────────────────── + +def _move_to_device(obj: Any, device: torch.device) -> Any: + """Recursively move tensors in nested tuple/list/dict to ``device``. + + No-op when a tensor is already on ``device`` (``Tensor.to`` is an + identity operation in that case). + """ + if torch.is_tensor(obj): + return obj.to(device) if obj.device != device else obj + if isinstance(obj, tuple): + return tuple(_move_to_device(x, device) for x in obj) + if isinstance(obj, list): + return [_move_to_device(x, device) for x in obj] + if isinstance(obj, dict): + return {k: _move_to_device(v, device) for k, v in obj.items()} + return obj + + +def _make_device_bridge_hook(target_device: torch.device): + """Forward pre-hook that moves all tensor inputs to ``target_device``.""" + def hook(module, args, kwargs): + return ( + _move_to_device(args, target_device), + _move_to_device(kwargs, target_device), + ) + return hook