diff --git a/diffsynth_engine/configs/__init__.py b/diffsynth_engine/configs/__init__.py index 9474d9c..2761cef 100644 --- a/diffsynth_engine/configs/__init__.py +++ b/diffsynth_engine/configs/__init__.py @@ -1,7 +1,9 @@ from .base import PipelineConfig from .qwen_image import QwenImagePipelineConfig +from .wan import WanPipelineConfig __all__ = [ "PipelineConfig", "QwenImagePipelineConfig", + "WanPipelineConfig", ] diff --git a/diffsynth_engine/configs/wan.py b/diffsynth_engine/configs/wan.py new file mode 100644 index 0000000..6b437a4 --- /dev/null +++ b/diffsynth_engine/configs/wan.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +from diffsynth_engine.configs.base import PipelineConfig + + +@dataclass +class WanPipelineConfig(PipelineConfig): + pass diff --git a/diffsynth_engine/models/base.py b/diffsynth_engine/models/base.py index 43b9067..3275389 100644 --- a/diffsynth_engine/models/base.py +++ b/diffsynth_engine/models/base.py @@ -15,6 +15,12 @@ class DiffusionModel(nn.Module, ConfigMixin): config_name = CONFIG_NAME + # This is identical to diffusers' ModelMixin._keep_in_fp32_modules. + _keep_in_fp32_modules: list[str] | None = None + + # ModelMixin._keys_to_ignore_on_load_unexpected. + _keys_to_ignore_on_load_unexpected: list[str] | None = None + @classmethod def from_pretrained( cls, @@ -30,8 +36,30 @@ def from_pretrained( with init_empty_weights(): model = cls.from_config(config_dict) - # load model weights - state_dict = load_model_weights(model_path, subfolder, device, dtype) + # avoids precision loss + if dtype is not None and dtype != torch.float32 and cls._keep_in_fp32_modules: + state_dict = load_model_weights(model_path, subfolder, device, dtype=None) + for key in state_dict: + if any(m in key.split(".") for m in cls._keep_in_fp32_modules): + state_dict[key] = state_dict[key].to(device=device, dtype=torch.float32) + else: + state_dict[key] = state_dict[key].to(device=device, dtype=dtype) + else: + state_dict = load_model_weights(model_path, subfolder, device, dtype) + + # Filter out unexpected keys that the model explicitly ignores + if cls._keys_to_ignore_on_load_unexpected: + keys_to_remove = [ + key for key in state_dict if any(pattern in key for pattern in cls._keys_to_ignore_on_load_unexpected) + ] + for key in keys_to_remove: + del state_dict[key] + if keys_to_remove: + logger.info( + f"Dropped {len(keys_to_remove)} unexpected key(s) matching " + f"{cls._keys_to_ignore_on_load_unexpected} from state_dict." + ) + model.load_state_dict(state_dict, strict=True, assign=True) model.to(device=device) return model diff --git a/diffsynth_engine/models/wan/__init__.py b/diffsynth_engine/models/wan/__init__.py new file mode 100644 index 0000000..4436a2e --- /dev/null +++ b/diffsynth_engine/models/wan/__init__.py @@ -0,0 +1,6 @@ +from .autoencoder_kl_wan import AutoencoderKLWan +from .transformer_wan import WanTransformer3DModel +from .transformer_wan_animate import WanAnimateTransformer3DModel +from .transformer_wan_vace import WanVACETransformer3DModel + +__all__ = ["AutoencoderKLWan", "WanTransformer3DModel", "WanAnimateTransformer3DModel", "WanVACETransformer3DModel"] diff --git a/diffsynth_engine/models/wan/autoencoder_kl_wan.py b/diffsynth_engine/models/wan/autoencoder_kl_wan.py new file mode 100644 index 0000000..40f72bf --- /dev/null +++ b/diffsynth_engine/models/wan/autoencoder_kl_wan.py @@ -0,0 +1,1455 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_wan.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import register_to_config +from diffusers.models.activations import get_activation +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.modeling_outputs import AutoencoderKLOutput + +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +CACHE_T = 2 + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + batch_size, channels, time, height, width = x.shape + x = x.view( + batch_size, + channels, + time // self.factor_t, + self.factor_t, + height // self.factor_s, + self.factor_s, + width // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + batch_size, + channels * self.factor, + time // self.factor_t, + height // self.factor_s, + width // self.factor_s, + ) + x = x.view( + batch_size, + self.out_channels, + self.group_size, + time // self.factor_t, + height // self.factor_s, + width // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class WanCausalConv3d(nn.Conv3d): + """ + A custom 3D causal convolution layer with feature caching support. + + This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature + caching for efficient inference. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int, int]], + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + ) -> None: + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + # Set up causal padding + self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + return super().forward(x) + + +class WanRMS_norm(nn.Module): + """ + A custom RMS normalization layer. + """ + + def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None: + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class WanUpsample(nn.Upsample): + """Perform upsampling while ensuring the output tensor has the same data type as the input.""" + + def forward(self, x): + return super().forward(x.float()).type_as(x) + + +class WanResample(nn.Module): + """ + A custom resampling module for 2D and 3D data. + """ + + def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None: + super().__init__() + self.dim = dim + self.mode = mode + + if upsample_out_dim is None: + upsample_out_dim = dim // 2 + + if mode == "upsample2d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, upsample_out_dim, 3, padding=1), + ) + self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + batch_size, channels, time, height, width = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + cache_x = torch.cat( + [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2 + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(batch_size, 2, channels, time, height, width) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(batch_size, channels, time * 2, height, width) + time = x.shape[2] + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.resample(x) + x = x.view(batch_size, time, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + +class WanResidualBlock(nn.Module): + """ + A custom residual block module. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + dropout: float = 0.0, + non_linearity: str = "silu", + ) -> None: + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.nonlinearity = get_activation(non_linearity) + + self.norm1 = WanRMS_norm(in_dim, images=False) + self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1) + self.norm2 = WanRMS_norm(out_dim, images=False) + self.dropout = nn.Dropout(dropout) + self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1) + self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.conv_shortcut(x) + + x = self.norm1(x) + x = self.nonlinearity(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + x = self.norm2(x) + x = self.nonlinearity(x) + x = self.dropout(x) + + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv2(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv2(x) + + return x + h + + +class WanAttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + self.norm = WanRMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + identity = x + batch_size, channels, time, height, width = x.size() + + x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width) + x = self.norm(x) + + qkv = self.to_qkv(x) + qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1) + qkv = qkv.permute(0, 1, 3, 2).contiguous() + query, key, value = qkv.chunk(3, dim=-1) + + x = F.scaled_dot_product_attention(query, key, value) + + x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width) + + x = self.proj(x) + + x = x.view(batch_size, time, channels, height, width) + x = x.permute(0, 2, 1, 3, 4) + + return x + identity + + +class WanMidBlock(nn.Module): + """ + Middle block for WanVAE encoder and decoder. + """ + + def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1): + super().__init__() + self.dim = dim + + resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)] + attentions = [] + for _ in range(num_layers): + attentions.append(WanAttentionBlock(dim)) + resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity)) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + x = attn(x) + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + + return x + + +class WanResidualDownBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False): + super().__init__() + + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + resnets = [] + for _ in range(num_res_blocks): + resnets.append(WanResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + self.downsampler = WanResample(out_dim, mode=mode) + else: + self.downsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for resnet in self.resnets: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + if self.downsampler is not None: + x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class WanEncoder3d(nn.Module): + """ + A 3D encoder module for WanVAE. + """ + + def __init__( + self, + in_channels: int = 3, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + non_linearity: str = "silu", + is_residual: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.nonlinearity = get_activation(non_linearity) + + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1) + + self.down_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + if is_residual: + self.down_blocks.append( + WanResidualDownBlock( + in_dim, + out_dim, + dropout, + num_res_blocks, + temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False, + down_flag=i != len(dim_mult) - 1, + ) + ) + else: + for _ in range(num_res_blocks): + self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + self.down_blocks.append(WanAttentionBlock(out_dim)) + in_dim = out_dim + + if i != len(dim_mult) - 1: + mode = "downsample3d" if temperal_downsample[i] else "downsample2d" + self.down_blocks.append(WanResample(out_dim, mode=mode)) + scale /= 2.0 + + self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1) + + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + for layer in self.down_blocks: + if feat_cache is not None: + x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = layer(x) + + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + + return x + + +class WanResidualUpBlock(nn.Module): + """ + A block that handles upsampling with residual shortcut for the WanVAE decoder (Wan 2.2). + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + temperal_upsample: bool = False, + up_flag: bool = False, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2, + ) + else: + self.avg_shortcut = None + + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + if up_flag: + upsample_mode = "upsample3d" if temperal_upsample else "upsample2d" + self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim) + else: + self.upsampler = None + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_copy = x.clone() + + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = resnet(x) + + if self.upsampler is not None: + if feat_cache is not None: + x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = self.upsampler(x) + + if self.avg_shortcut is not None: + x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk) + + return x + + +class WanUpBlock(nn.Module): + """ + A block that handles upsampling for the WanVAE decoder (Wan 2.1). + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_res_blocks: int, + dropout: float = 0.0, + upsample_mode: Optional[str] = None, + non_linearity: str = "silu", + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + resnets = [] + current_dim = in_dim + for _ in range(num_res_blocks + 1): + resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity)) + current_dim = out_dim + self.resnets = nn.ModuleList(resnets) + + self.upsamplers = None + if upsample_mode is not None: + self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)]) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None): + for resnet in self.resnets: + if feat_cache is not None: + x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = resnet(x) + + if self.upsamplers is not None: + if feat_cache is not None: + x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx) + else: + x = self.upsamplers[0](x) + return x + + +class WanDecoder3d(nn.Module): + """ + A 3D decoder module for WanVAE. + """ + + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + non_linearity: str = "silu", + out_channels: int = 3, + is_residual: bool = False, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + self.nonlinearity = get_activation(non_linearity) + + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + + self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1) + + self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1) + + self.up_blocks = nn.ModuleList([]) + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + if i > 0 and not is_residual: + in_dim = in_dim // 2 + + up_flag = i != len(dim_mult) - 1 + upsample_mode = None + if up_flag and temperal_upsample[i]: + upsample_mode = "upsample3d" + elif up_flag: + upsample_mode = "upsample2d" + + if is_residual: + up_block = WanResidualUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + temperal_upsample=temperal_upsample[i] if up_flag else False, + up_flag=up_flag, + non_linearity=non_linearity, + ) + else: + up_block = WanUpBlock( + in_dim=in_dim, + out_dim=out_dim, + num_res_blocks=num_res_blocks, + dropout=dropout, + upsample_mode=upsample_mode, + non_linearity=non_linearity, + ) + self.up_blocks.append(up_block) + + self.norm_out = WanRMS_norm(out_dim, images=False) + self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_in(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_in(x) + + x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx) + + for up_block in self.up_blocks: + x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk) + + x = self.norm_out(x) + x = self.nonlinearity(x) + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + x = self.conv_out(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv_out(x) + return x + + +def patchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + + batch_size, channels, frames, height, width = x.shape + if height % patch_size != 0 or width % patch_size != 0: + raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})") + + x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size) + x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous() + x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size) + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() != 5: + raise ValueError(f"Invalid input shape: {x.shape}") + + batch_size, c_patches, frames, height, width = x.shape + channels = c_patches // (patch_size * patch_size) + + x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width) + x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous() + x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size) + return x + + +class AutoencoderKLWan(DiffusionModel): + """ + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. + Introduced in Wan 2.1. + + Adapted from diffusers AutoencoderKLWan with the following changes: + - Inherits DiffusionModel instead of ModelMixin + AutoencoderMixin + ConfigMixin + FromOriginalModelMixin + - Removes @apply_forward_hook decorators (accelerate integration hooks) + - Removes _supports_gradient_checkpointing, _group_offload_block_modules, _skip_keys class attributes + - Adds mask-based spatial tiled encode/decode from DiffSynth-Engine-main + - Keeps all init params for backward compatibility via ConfigMixin + + Deleted methods/attributes from original Mixins: + - FromOriginalModelMixin: provided from_single_file() for loading original checkpoints + - AutoencoderMixin: provided enable_tiling/slicing convenience methods (re-implemented directly) + - ModelMixin: provided from_pretrained/save_pretrained (DiffusionModel already has from_pretrained) + - @apply_forward_hook: accelerate forward hooks for automatic device movement + - _supports_gradient_checkpointing: diffusers gradient checkpointing flag + - _group_offload_block_modules: diffusers module group offload config + - _skip_keys: AlignDeviceHook skip keys for mutable state + """ + + @register_to_config + def __init__( + self, + base_dim: int = 96, + decoder_base_dim: Optional[int] = None, + z_dim: int = 16, + dim_mult: List[int] = [1, 2, 4, 4], + num_res_blocks: int = 2, + attn_scales: List[float] = [], + temperal_downsample: List[bool] = [False, True, True], + dropout: float = 0.0, + latents_mean: List[float] = [ + -0.7571, + -0.7089, + -0.9113, + 0.1075, + -0.1745, + 0.9653, + -0.1517, + 1.5508, + 0.4134, + -0.0715, + 0.5517, + -0.3632, + -0.1922, + -0.9497, + 0.2503, + -0.2921, + ], + latents_std: List[float] = [ + 2.8184, + 1.4541, + 2.3275, + 2.6558, + 1.2196, + 1.7708, + 2.6052, + 2.0743, + 3.2687, + 2.1526, + 2.8652, + 1.5579, + 1.6382, + 1.1253, + 2.8251, + 1.9160, + ], + is_residual: bool = False, + in_channels: int = 3, + out_channels: int = 3, + patch_size: Optional[int] = None, + scale_factor_temporal: Optional[int] = 4, + scale_factor_spatial: Optional[int] = 8, + ) -> None: + super().__init__() + + self.z_dim = z_dim + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + if decoder_base_dim is None: + decoder_base_dim = base_dim + + self.encoder = WanEncoder3d( + in_channels=in_channels, + dim=base_dim, + z_dim=z_dim * 2, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_downsample=temperal_downsample, + dropout=dropout, + is_residual=is_residual, + ) + self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1) + self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1) + + self.decoder = WanDecoder3d( + dim=decoder_base_dim, + z_dim=z_dim, + dim_mult=dim_mult, + num_res_blocks=num_res_blocks, + attn_scales=attn_scales, + temperal_upsample=self.temperal_upsample, + dropout=dropout, + out_channels=out_channels, + is_residual=is_residual, + ) + + self.spatial_compression_ratio = scale_factor_spatial + + self.use_slicing = False + self.use_tiling = False + + # The minimal tile height and width for spatial tiling to be used + self.tile_sample_min_height = 256 + self.tile_sample_min_width = 256 + + # The minimal distance between two spatial tiles + self.tile_sample_stride_height = 192 + self.tile_sample_stride_width = 192 + + # Precompute and cache conv counts for encoder and decoder for clear_cache speedup + self._cached_conv_counts = { + "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules()) + if self.decoder is not None + else 0, + "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules()) + if self.encoder is not None + else 0, + } + + def enable_tiling( + self, + tile_sample_min_height: Optional[int] = None, + tile_sample_min_width: Optional[int] = None, + tile_sample_stride_height: Optional[int] = None, + tile_sample_stride_width: Optional[int] = None, + ) -> None: + """ + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = True + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width + self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height + self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width + + def disable_tiling(self) -> None: + """Disable tiled VAE decoding.""" + self.use_tiling = False + + def enable_slicing(self) -> None: + """Enable sliced VAE decoding (process one batch element at a time).""" + self.use_slicing = True + + def disable_slicing(self) -> None: + """Disable sliced VAE decoding.""" + self.use_slicing = False + + def clear_cache(self): + """Clear the feature cache for causal convolutions.""" + self._conv_num = self._cached_conv_counts["decoder"] + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + self._enc_conv_num = self._cached_conv_counts["encoder"] + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + def _encode(self, x: torch.Tensor): + _, _, num_frame, height, width = x.shape + + self.clear_cache() + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + + if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + return self.tiled_encode(x) + + iter_ = 1 + (num_frame - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + + enc = self.quant_conv(out) + self.clear_cache() + return enc + + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True): + _, _, num_frame, height, width = z.shape + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + + if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height): + return self.tiled_decode(z, return_dict=return_dict) + + self.clear_cache() + x = self.post_quant_conv(z) + for i in range(num_frame): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True + ) + else: + out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + + if self.config.patch_size is not None: + out = unpatchify(out, patch_size=self.config.patch_size) + + out = torch.clamp(out, min=-1.0, max=1.0) + + self.clear_cache() + if not return_dict: + return (out,) + + return DecoderOutput(sample=out) + + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( + y / blend_extent + ) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( + x / blend_extent + ) + return b + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode a batch of images using a tiled encoder (diffusers blend-based approach).""" + _, _, num_frames, height, width = x.shape + encode_spatial_compression_ratio = self.spatial_compression_ratio + if self.config.patch_size is not None: + assert encode_spatial_compression_ratio % self.config.patch_size == 0 + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + + latent_height = height // encode_spatial_compression_ratio + latent_width = width // encode_spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // encode_spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // encode_spatial_compression_ratio + + blend_height = tile_latent_min_height - tile_latent_stride_height + blend_width = tile_latent_min_width - tile_latent_stride_width + + rows = [] + for i in range(0, height, self.tile_sample_stride_height): + row = [] + for j in range(0, width, self.tile_sample_stride_width): + self.clear_cache() + time = [] + frame_range = 1 + (num_frames - 1) // 4 + for k in range(frame_range): + self._enc_conv_idx = [0] + if k == 0: + tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width] + else: + tile = x[ + :, + :, + 1 + 4 * (k - 1) : 1 + 4 * k, + i : i + self.tile_sample_min_height, + j : j + self.tile_sample_min_width, + ] + tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) + tile = self.quant_conv(tile) + time.append(tile) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + + enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width] + return enc + + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """Decode a batch of images using a tiled decoder (diffusers blend-based approach).""" + _, _, num_frames, height, width = z.shape + sample_height = height * self.spatial_compression_ratio + sample_width = width * self.spatial_compression_ratio + + tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio + tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio + tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio + tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = self.tile_sample_min_height - tile_sample_stride_height + blend_width = self.tile_sample_min_width - tile_sample_stride_width + + rows = [] + for i in range(0, height, tile_latent_stride_height): + row = [] + for j in range(0, width, tile_latent_stride_width): + self.clear_cache() + time = [] + for k in range(num_frames): + self._conv_idx = [0] + tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] + tile = self.post_quant_conv(tile) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) + time.append(decoded) + row.append(torch.cat(time, dim=2)) + rows.append(row) + self.clear_cache() + + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_height) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_width) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) + result_rows.append(torch.cat(result_row, dim=-1)) + dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + + if not return_dict: + return (dec,) + return DecoderOutput(sample=dec) + + # ---- Mask-based spatial tiled encode/decode (from DiffSynth-Engine-main) ---- + + def _build_1d_mask(self, length: int, left_bound: bool, right_bound: bool, border_width: int) -> torch.Tensor: + """Build a 1D linear ramp mask for tile blending.""" + mask = torch.ones((length,)) + if not left_bound: + mask[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + mask[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return mask + + def _build_spatial_mask( + self, + data: torch.Tensor, + is_bound: Tuple[bool, bool, bool, bool], + border_width: Tuple[int, int], + ) -> torch.Tensor: + """Build a 2D spatial mask for tile blending using linear ramps at borders.""" + _, _, _, spatial_height, spatial_width = data.shape + height_mask = self._build_1d_mask(spatial_height, is_bound[0], is_bound[1], border_width[0]) + width_mask = self._build_1d_mask(spatial_width, is_bound[2], is_bound[3], border_width[1]) + + height_mask = height_mask.unsqueeze(1).expand(spatial_height, spatial_width) + width_mask = width_mask.unsqueeze(0).expand(spatial_height, spatial_width) + + mask = torch.stack([height_mask, width_mask]).min(dim=0).values + mask = mask.reshape(1, 1, 1, spatial_height, spatial_width) + return mask + + def tiled_encode_with_mask( + self, + x: torch.Tensor, + tile_size: Tuple[int, int] = (256, 256), + tile_stride: Tuple[int, int] = (192, 192), + ) -> torch.Tensor: + """ + Encode using mask-weighted spatial tiling (from DiffSynth-Engine-main). + + This approach uses smooth gradient masks at tile boundaries for blending, + which can produce better results than the simple blend approach. + + Args: + x: Input tensor [B, C, T, H, W]. + tile_size: (height, width) of each tile in pixel space. + tile_stride: (height, width) stride between tiles in pixel space. + + Returns: + Encoded latent tensor. + """ + _, _, num_frames, height, width = x.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + _, _, _, height, width = x.shape + + encode_spatial_compression_ratio = self.spatial_compression_ratio + if self.config.patch_size is not None: + encode_spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size + size_h = size_h // self.config.patch_size + size_w = size_w // self.config.patch_size + stride_h = stride_h // self.config.patch_size + stride_w = stride_w // self.config.patch_size + + # Build tile tasks + tasks = [] + for h in range(0, height, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= height: + continue + for w in range(0, width, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= width: + continue + tasks.append((h, h + size_h, w, w + size_w)) + + out_t = 1 + (num_frames - 1) // 4 + latent_h = height // encode_spatial_compression_ratio + latent_w = width // encode_spatial_compression_ratio + + weight = torch.zeros((1, 1, out_t, latent_h, latent_w), dtype=x.dtype, device="cpu") + values = torch.zeros((1, self.z_dim, out_t, latent_h, latent_w), dtype=x.dtype, device="cpu") + + for h, h_end, w, w_end in tasks: + tile = x[:, :, :, h:h_end, w:w_end] + + self.clear_cache() + iter_ = 1 + (num_frames - 1) // 4 + for k in range(iter_): + self._enc_conv_idx = [0] + if k == 0: + enc_out = self.encoder( + tile[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx + ) + else: + enc_chunk = self.encoder( + tile[:, :, 1 + 4 * (k - 1) : 1 + 4 * k, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + enc_out = torch.cat([enc_out, enc_chunk], 2) + enc_out = self.quant_conv(enc_out) + + # Build posterior and sample + enc_out_cpu = enc_out.to("cpu") + + mask = self._build_spatial_mask( + enc_out_cpu, + is_bound=(h == 0, h_end >= height, w == 0, w_end >= width), + border_width=( + (size_h - stride_h) // encode_spatial_compression_ratio, + (size_w - stride_w) // encode_spatial_compression_ratio, + ), + ).to(dtype=x.dtype, device="cpu") + + target_h = h // encode_spatial_compression_ratio + target_w = w // encode_spatial_compression_ratio + values[ + :, + :, + :, + target_h : target_h + enc_out_cpu.shape[3], + target_w : target_w + enc_out_cpu.shape[4], + ] += enc_out_cpu * mask + weight[ + :, + :, + :, + target_h : target_h + enc_out_cpu.shape[3], + target_w : target_w + enc_out_cpu.shape[4], + ] += mask + + self.clear_cache() + result = values / weight + return result.to(x.device) + + def tiled_decode_with_mask( + self, + z: torch.Tensor, + tile_size: Tuple[int, int] = (32, 32), + tile_stride: Tuple[int, int] = (24, 24), + ) -> torch.Tensor: + """ + Decode using mask-weighted spatial tiling (from DiffSynth-Engine-main). + + This approach uses smooth gradient masks at tile boundaries for blending, + which can produce better results than the simple blend approach. + + Args: + z: Input latent tensor [B, C, T, H, W]. + tile_size: (height, width) of each tile in latent space. + tile_stride: (height, width) stride between tiles in latent space. + + Returns: + Decoded video tensor, clamped to [-1, 1]. + """ + _, _, num_frames, height, width = z.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + upsampling_factor = self.spatial_compression_ratio + + # Build tile tasks + tasks = [] + for h in range(0, height, stride_h): + if h - stride_h >= 0 and h - stride_h + size_h >= height: + continue + for w in range(0, width, stride_w): + if w - stride_w >= 0 and w - stride_w + size_w >= width: + continue + tasks.append((h, h + size_h, w, w + size_w)) + + out_t = num_frames * 4 - 3 + out_channels = self.config.out_channels + out_h = height * upsampling_factor + out_w = width * upsampling_factor + if self.config.patch_size is not None: + out_h = out_h // self.config.patch_size + out_w = out_w // self.config.patch_size + + weight = torch.zeros((1, 1, out_t, out_h, out_w), dtype=z.dtype, device="cpu") + values = torch.zeros((1, out_channels, out_t, out_h, out_w), dtype=z.dtype, device="cpu") + + for h, h_end, w, w_end in tasks: + tile_z = z[:, :, :, h:h_end, w:w_end] + + self.clear_cache() + tile_x = self.post_quant_conv(tile_z) + for k in range(num_frames): + self._conv_idx = [0] + if k == 0: + dec_out = self.decoder( + tile_x[:, :, k : k + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True, + ) + else: + dec_chunk = self.decoder( + tile_x[:, :, k : k + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + dec_out = torch.cat([dec_out, dec_chunk], 2) + + if self.config.patch_size is not None: + dec_out = unpatchify(dec_out, patch_size=self.config.patch_size) + + dec_out_cpu = dec_out.to("cpu") + + effective_upsampling = upsampling_factor + if self.config.patch_size is not None: + effective_upsampling = upsampling_factor // self.config.patch_size + + mask = self._build_spatial_mask( + dec_out_cpu, + is_bound=(h == 0, h_end >= height, w == 0, w_end >= width), + border_width=( + (size_h - stride_h) * effective_upsampling, + (size_w - stride_w) * effective_upsampling, + ), + ).to(dtype=z.dtype, device="cpu") + + target_h = h * effective_upsampling + target_w = w * effective_upsampling + values[ + :, + :, + :, + target_h : target_h + dec_out_cpu.shape[3], + target_w : target_w + dec_out_cpu.shape[4], + ] += dec_out_cpu * mask + weight[ + :, + :, + :, + target_h : target_h + dec_out_cpu.shape[3], + target_w : target_w + dec_out_cpu.shape[4], + ] += mask + + self.clear_cache() + result = values / weight + result = result.float().clamp(-1, 1) + return result.to(z.device) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + """ + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior or use the mode. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Random number generator for sampling. + """ + posterior = self.encode(sample).latent_dist + + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, return_dict=return_dict) + return dec diff --git a/diffsynth_engine/models/wan/transformer_wan.py b/diffsynth_engine/models/wan/transformer_wan.py new file mode 100644 index 0000000..009f74e --- /dev/null +++ b/diffsynth_engine/models/wan/transformer_wan.py @@ -0,0 +1,546 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.forward_context import get_forward_context +from diffsynth_engine.layers.attention import USPAttention +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + + +def apply_wan_rotary_emb( + hidden_states: torch.Tensor, + freqs_cos: torch.Tensor, + freqs_sin: torch.Tensor, +) -> torch.Tensor: + """Apply rotary positional embeddings to hidden states in the Wan-specific format.""" + x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1) + cos = freqs_cos[..., 0::2] + sin = freqs_sin[..., 1::2] + out = torch.empty_like(hidden_states) + out[..., 0::2] = x1 * cos - x2 * sin + out[..., 1::2] = x1 * sin + x2 * cos + return out.type_as(hidden_states) + + +class WanAttention(nn.Module): + """ + Simplified attention module for Wan, using USPAttention instead of the processor pattern. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-5, + dropout: float = 0.0, + added_kv_proj_dim: Optional[int] = None, + cross_attention_dim_head: Optional[int] = None, + ): + super().__init__() + + self.inner_dim = dim_head * heads + self.heads = heads + self.added_kv_proj_dim = added_kv_proj_dim + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + self.to_q = nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True) + # Keep as ModuleList to match diffusers weight names (to_out.0.weight, to_out.0.bias) + self.to_out = nn.ModuleList( + [ + nn.Linear(self.inner_dim, dim, bias=True), + nn.Dropout(dropout), + ] + ) + self.norm_q = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True) + + self.add_k_proj = self.add_v_proj = None + if added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True) + self.norm_added_k = nn.RMSNorm(dim_head * heads, eps=eps) + + self.is_cross_attention = cross_attention_dim_head is not None + + # USPAttention for attention computation, attn_type from ForwardContext + forward_context = get_forward_context() + self.usp_attn = USPAttention( + num_heads=heads, + head_size=dim_head, + attn_type=forward_context.attn_type, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + # Handle I2V: split image and text from encoder_hidden_states + encoder_hidden_states_img = None + if self.add_k_proj is not None and encoder_hidden_states is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + + # QKV projections + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # QK normalization + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape: [B, S, H*D] -> [B, S, H, D] + query = query.unflatten(2, (self.heads, -1)) + key = key.unflatten(2, (self.heads, -1)) + value = value.unflatten(2, (self.heads, -1)) + + # Apply rotary embeddings (only for self-attention) + if rotary_emb is not None: + query = apply_wan_rotary_emb(query, *rotary_emb) + key = apply_wan_rotary_emb(key, *rotary_emb) + + # I2V: compute attention with image encoder hidden states + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = self.add_k_proj(encoder_hidden_states_img) + value_img = self.add_v_proj(encoder_hidden_states_img) + key_img = self.norm_added_k(key_img) + + key_img = key_img.unflatten(2, (self.heads, -1)) + value_img = value_img.unflatten(2, (self.heads, -1)) + + hidden_states_img = self.usp_attn(query, key_img, value_img) + hidden_states_img = hidden_states_img.flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + # Main attention + hidden_states = self.usp_attn(query, key, value, attn_mask=attention_mask) + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + # Combine I2V attention output + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + # Output projection + hidden_states = self.to_out[0](hidden_states) + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class WanImageEmbedding(nn.Module): + def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None): + super().__init__() + + self.norm1 = FP32LayerNorm(in_features) + self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu") + self.norm2 = FP32LayerNorm(out_features) + if pos_embed_seq_len is not None: + self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features)) + else: + self.pos_embed = None + + def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor: + if self.pos_embed is not None: + batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape + encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim) + encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed + + hidden_states = self.norm1(encoder_hidden_states_image) + hidden_states = self.ff(hidden_states) + hidden_states = self.norm2(hidden_states) + return hidden_states + + +class WanTimeTextImageEmbedding(nn.Module): + def __init__( + self, + dim: int, + time_freq_dim: int, + time_proj_dim: int, + text_embed_dim: int, + image_embed_dim: Optional[int] = None, + pos_embed_seq_len: Optional[int] = None, + ): + super().__init__() + + self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) + self.act_fn = nn.SiLU() + self.time_proj = nn.Linear(dim, time_proj_dim) + self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh") + + self.image_embedder = None + if image_embed_dim is not None: + self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len) + + def forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + timestep_seq_len: Optional[int] = None, + ): + timestep = self.timesteps_proj(timestep) + if timestep_seq_len is not None: + timestep = timestep.unflatten(0, (-1, timestep_seq_len)) + + # Compute time embedding in fp32 to avoid bfloat16 precision loss + with torch.amp.autocast(device_type=timestep.device.type, dtype=torch.float32): + timestep = timestep.float() + temb = self.time_embedder(timestep) + timestep_proj = self.time_proj(self.act_fn(temb)) + timestep_proj = timestep_proj.type_as(encoder_hidden_states) + temb = temb.type_as(encoder_hidden_states) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image + + +class WanRotaryPosEmbed(nn.Module): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): + super().__init__() + + self.attention_head_dim = attention_head_dim + self.patch_size = patch_size + self.max_seq_len = max_seq_len + + h_dim = w_dim = 2 * (attention_head_dim // 6) + t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + + # Force CPU initialization to avoid issues with meta device + with torch.device("cpu"): + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + + freqs_cos = [] + freqs_sin = [] + + for dim in [t_dim, h_dim, w_dim]: + freq_cos, freq_sin = get_1d_rotary_pos_embed( + dim, + max_seq_len, + theta, + use_real=True, + repeat_interleave_real=True, + freqs_dtype=freqs_dtype, + ) + freqs_cos.append(freq_cos) + freqs_sin.append(freq_sin) + + self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False) + self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False) + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.patch_size + ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w + + split_sizes = [self.t_dim, self.h_dim, self.w_dim] + + freqs_cos = self.freqs_cos.split(split_sizes, dim=1) + freqs_sin = self.freqs_sin.split(split_sizes, dim=1) + + freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) + freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) + + freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1) + + return freqs_cos, freqs_sin + + +class WanTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + ): + super().__init__() + + # 1. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + ) + + # 2. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 3. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + + +class WanTransformer3DModel(DiffusionModel): + """ + A Transformer model for video-like data used in the Wan model. + """ + + # Keep precision-sensitive submodules in fp32. + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int, ...] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # Save original sequence length for unshard + original_seq_len = hidden_states.shape[1] + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # Sequence parallel shard + rotary_emb_cos, rotary_emb_sin = rotary_emb + hidden_states, rotary_emb_cos, rotary_emb_sin = sequence_parallel_shard( + [hidden_states, rotary_emb_cos, rotary_emb_sin], + seq_dims=[1, 1, 1], + ) + rotary_emb = (rotary_emb_cos, rotary_emb_sin) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Sequence parallel unshard + (hidden_states,) = sequence_parallel_unshard([hidden_states], seq_dims=[1], seq_lens=[original_seq_len]) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffsynth_engine/models/wan/transformer_wan_animate.py b/diffsynth_engine/models/wan/transformer_wan_animate.py new file mode 100644 index 0000000..38dd90b --- /dev/null +++ b/diffsynth_engine/models/wan/transformer_wan_animate.py @@ -0,0 +1,695 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan_animate.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import register_to_config +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.forward_context import get_forward_context +from diffsynth_engine.layers.attention import USPAttention +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.models.wan.transformer_wan import ( + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + +WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = { + "4": 512, + "8": 512, + "16": 512, + "32": 512, + "64": 256, + "128": 128, + "256": 64, + "512": 32, + "1024": 16, +} + + +class FusedLeakyReLU(nn.Module): + """Fused LeakyRelu with scale factor and channel-wise bias.""" + + def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_channels: Optional[int] = None): + super().__init__() + self.negative_slope = negative_slope + self.scale = scale + self.channels = bias_channels + + if self.channels is not None: + self.bias = nn.Parameter(torch.zeros(self.channels)) + else: + self.bias = None + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if self.bias is not None: + expanded_shape = [1] * hidden_states.ndim + expanded_shape[channel_dim] = self.bias.shape[0] + bias = self.bias.reshape(*expanded_shape) + hidden_states = hidden_states + bias + return F.leaky_relu(hidden_states, self.negative_slope) * self.scale + + +class MotionConv2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = True, + blur_kernel: Optional[Tuple[int, ...]] = None, + blur_upsample_factor: int = 1, + use_activation: bool = True, + ): + super().__init__() + self.use_activation = use_activation + self.in_channels = in_channels + + # Handle blurring (applying a FIR filter with the given kernel) if available + self.blur = False + if blur_kernel is not None: + padding_amount = (len(blur_kernel) - stride) + (kernel_size - 1) + self.blur_padding = ((padding_amount + 1) // 2, padding_amount // 2) + + kernel = torch.tensor(blur_kernel) + if kernel.ndim == 1: + kernel = kernel[None, :] * kernel[:, None] + kernel = kernel / kernel.sum() + if blur_upsample_factor > 1: + kernel = kernel * (blur_upsample_factor**2) + self.register_buffer("blur_kernel", kernel, persistent=False) + self.blur = True + + self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channels * kernel_size**2) + + self.stride = stride + self.padding = padding + + if bias and not self.use_activation: + self.bias = nn.Parameter(torch.zeros(out_channels)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(bias_channels=out_channels) + else: + self.act_fn = None + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if self.blur: + expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1) + hidden_states = hidden_states.to(expanded_kernel.dtype) + hidden_states = F.conv2d(hidden_states, expanded_kernel, padding=self.blur_padding, groups=self.in_channels) + + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = F.conv2d( + hidden_states, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding + ) + + if self.use_activation: + hidden_states = self.act_fn(hidden_states, channel_dim=channel_dim) + return hidden_states + + +class MotionLinear(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + bias: bool = True, + use_activation: bool = False, + ): + super().__init__() + self.use_activation = use_activation + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim)) + self.scale = 1 / math.sqrt(in_dim) + + if bias and not self.use_activation: + self.bias = nn.Parameter(torch.zeros(out_dim)) + else: + self.bias = None + + if self.use_activation: + self.act_fn = FusedLeakyReLU(bias_channels=out_dim) + else: + self.act_fn = None + + def forward(self, input_tensor: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + output = F.linear(input_tensor, self.weight * self.scale, bias=self.bias) + if self.use_activation: + output = self.act_fn(output, channel_dim=channel_dim) + return output + + +class MotionEncoderResBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + kernel_size_skip: int = 1, + blur_kernel: Tuple[int, ...] = (1, 3, 3, 1), + downsample_factor: int = 2, + ): + super().__init__() + self.downsample_factor = downsample_factor + + self.conv1 = MotionConv2d( + in_channels, + in_channels, + kernel_size, + stride=1, + padding=kernel_size // 2, + use_activation=True, + ) + self.conv2 = MotionConv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=self.downsample_factor, + padding=0, + blur_kernel=blur_kernel, + use_activation=True, + ) + self.conv_skip = MotionConv2d( + in_channels, + out_channels, + kernel_size=kernel_size_skip, + stride=self.downsample_factor, + padding=0, + bias=False, + blur_kernel=blur_kernel, + use_activation=False, + ) + + def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + residual = self.conv1(hidden_states, channel_dim) + residual = self.conv2(residual, channel_dim) + skip = self.conv_skip(hidden_states, channel_dim) + return (residual + skip) / math.sqrt(2) + + +class WanAnimateMotionEncoder(nn.Module): + def __init__( + self, + size: int = 512, + style_dim: int = 512, + motion_dim: int = 20, + out_dim: int = 512, + motion_blocks: int = 5, + channels: Optional[Dict[str, int]] = None, + ): + super().__init__() + self.size = size + + if channels is None: + channels = WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES + + self.conv_in = MotionConv2d(3, channels[str(size)], 1, use_activation=True) + + self.res_blocks = nn.ModuleList() + in_channels = channels[str(size)] + log_size = int(math.log(size, 2)) + for i in range(log_size, 2, -1): + out_channels = channels[str(2 ** (i - 1))] + self.res_blocks.append(MotionEncoderResBlock(in_channels, out_channels)) + in_channels = out_channels + + self.conv_out = MotionConv2d(in_channels, style_dim, 4, padding=0, bias=False, use_activation=False) + + linears = [MotionLinear(style_dim, style_dim) for _ in range(motion_blocks - 1)] + linears.append(MotionLinear(style_dim, motion_dim)) + self.motion_network = nn.ModuleList(linears) + + self.motion_synthesis_weight = nn.Parameter(torch.randn(out_dim, motion_dim)) + + def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tensor: + if (face_image.shape[-2] != self.size) or (face_image.shape[-1] != self.size): + raise ValueError( + f"Face pixel values has resolution ({face_image.shape[-1]}, {face_image.shape[-2]}) but is expected" + f" to have resolution ({self.size}, {self.size})" + ) + + face_image = self.conv_in(face_image, channel_dim) + for block in self.res_blocks: + face_image = block(face_image, channel_dim) + face_image = self.conv_out(face_image, channel_dim) + motion_feat = face_image.squeeze(-1).squeeze(-1) + + for linear_layer in self.motion_network: + motion_feat = linear_layer(motion_feat, channel_dim=channel_dim) + + # Motion synthesis via Linear Motion Decomposition + weight = self.motion_synthesis_weight + 1e-8 + original_motion_dtype = motion_feat.dtype + motion_feat = motion_feat.to(torch.float32) + weight = weight.to(torch.float32) + + orthogonal_basis = torch.linalg.qr(weight)[0].to(device=motion_feat.device) + + motion_feat_diag = torch.diag_embed(motion_feat) + motion_decomposition = torch.matmul(motion_feat_diag, orthogonal_basis.T) + motion_vec = torch.sum(motion_decomposition, dim=1) + + motion_vec = motion_vec.to(dtype=original_motion_dtype) + return motion_vec + + +class WanAnimateFaceEncoder(nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + hidden_dim: int = 1024, + num_heads: int = 4, + kernel_size: int = 3, + eps: float = 1e-6, + pad_mode: str = "replicate", + ): + super().__init__() + self.num_heads = num_heads + self.time_causal_padding = (kernel_size - 1, 0) + self.pad_mode = pad_mode + + self.act = nn.SiLU() + + self.conv1_local = nn.Conv1d(in_dim, hidden_dim * num_heads, kernel_size=kernel_size, stride=1) + self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2) + self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, kernel_size, stride=2) + + self.norm1 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + self.norm2 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + self.norm3 = nn.LayerNorm(hidden_dim, eps, elementwise_affine=False) + + self.out_proj = nn.Linear(hidden_dim, out_dim) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, out_dim)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + # Reshape to channels-first to apply causal Conv1d over frame dim + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + hidden_states = self.conv1_local(hidden_states) + hidden_states = hidden_states.unflatten(1, (self.num_heads, -1)).flatten(0, 1) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.norm1(hidden_states) + hidden_states = self.act(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + hidden_states = self.conv2(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.norm2(hidden_states) + hidden_states = self.act(hidden_states) + + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode) + hidden_states = self.conv3(hidden_states) + hidden_states = hidden_states.permute(0, 2, 1) + hidden_states = self.norm3(hidden_states) + hidden_states = self.act(hidden_states) + + hidden_states = self.out_proj(hidden_states) + # [B * N, T, C_out] --> [B, T, N, C_out] + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + padding = self.padding_tokens.repeat(batch_size, hidden_states.shape[1], 1, 1).to(device=hidden_states.device) + hidden_states = torch.cat([hidden_states, padding], dim=-2) + + return hidden_states + + +class WanAnimateFaceBlockCrossAttention(nn.Module): + """ + Temporally-aligned cross attention with the face motion signal in the Wan Animate Face Blocks. + + This is a simplified version that directly implements the attention logic using USPAttention, + instead of the processor pattern used in diffusers. + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + eps: float = 1e-6, + cross_attention_dim_head: Optional[int] = None, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.heads = heads + self.cross_attention_dim_head = cross_attention_dim_head + self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads + + # Pre-Attention Norms for hidden_states (video latents) and encoder_hidden_states (motion vector) + self.pre_norm_q = nn.LayerNorm(dim, eps, elementwise_affine=False) + self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False) + + # QKV and Output Projections + self.to_q = nn.Linear(dim, self.inner_dim, bias=True) + self.to_k = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_v = nn.Linear(dim, self.kv_inner_dim, bias=True) + self.to_out = nn.Linear(self.inner_dim, dim, bias=True) + + # QK Norm (applied after reshape, so over dim_head rather than dim_head * heads) + self.norm_q = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True) + self.norm_k = nn.RMSNorm(dim_head, eps=eps, elementwise_affine=True) + + # USPAttention for attention computation + forward_context = get_forward_context() + self.usp_attn = USPAttention( + num_heads=heads, + head_size=dim_head, + attn_type=forward_context.attn_type, + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # encoder_hidden_states corresponds to the motion vec + # attention_mask corresponds to the motion mask (if any) + hidden_states = self.pre_norm_q(hidden_states) + encoder_hidden_states = self.pre_norm_kv(encoder_hidden_states) + + # B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> dim + batch_size, num_time_steps, num_tokens, channels = encoder_hidden_states.shape + + query = self.to_q(hidden_states) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + # [B, S, H * D] --> [B, S, H, D] + query = query.unflatten(2, (self.heads, -1)) + # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv] + key = key.view(batch_size, num_time_steps, num_tokens, self.heads, -1) + value = value.view(batch_size, num_time_steps, num_tokens, self.heads, -1) + + query = self.norm_q(query) + key = self.norm_k(key) + + # Reshape for temporally-aligned attention: + # query: [B, S, H, D] --> [B * T, S / T, H, D] + query = query.unflatten(1, (num_time_steps, -1)).flatten(0, 1) + # key/value: [B, T, N, H, D_kv] --> [B * T, N, H, D_kv] + key = key.flatten(0, 1) + value = value.flatten(0, 1) + + hidden_states = self.usp_attn(query, key, value) + + hidden_states = hidden_states.flatten(2, 3) + hidden_states = hidden_states.type_as(query) + hidden_states = hidden_states.unflatten(0, (batch_size, num_time_steps)).flatten(1, 2) + + hidden_states = self.to_out(hidden_states) + + if attention_mask is not None: + attention_mask = attention_mask.flatten(start_dim=1) + hidden_states = hidden_states * attention_mask + + return hidden_states + + +class WanAnimateTransformer3DModel(DiffusionModel): + """ + A Transformer model for video-like data used in the WanAnimate model, + supporting character animation and replacement. + """ + + _keep_in_fp32_modules = [ + "time_embedder", + "scale_shift_table", + "norm1", + "norm2", + "norm3", + "motion_synthesis_weight", + ] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: Optional[int] = 36, + latent_channels: Optional[int] = 16, + out_channels: Optional[int] = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = 1280, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + motion_encoder_channel_sizes: Optional[Dict[str, int]] = None, + motion_encoder_size: int = 512, + motion_style_dim: int = 512, + motion_dim: int = 20, + motion_encoder_dim: int = 512, + face_encoder_hidden_dim: int = 1024, + face_encoder_num_heads: int = 4, + inject_face_latents_blocks: int = 5, + motion_encoder_batch_size: int = 8, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + # Allow either only in_channels or only latent_channels to be set for convenience + if in_channels is None and latent_channels is not None: + in_channels = 2 * latent_channels + 4 + elif in_channels is not None and latent_channels is None: + latent_channels = (in_channels - 4) // 2 + elif in_channels is not None and latent_channels is not None: + assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4" + else: + raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.") + out_channels = out_channels or latent_channels + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.pose_patch_embedding = nn.Conv3d(latent_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Motion encoder + self.motion_encoder = WanAnimateMotionEncoder( + size=motion_encoder_size, + style_dim=motion_style_dim, + motion_dim=motion_dim, + out_dim=motion_encoder_dim, + channels=motion_encoder_channel_sizes, + ) + + # 4. Face encoder + self.face_encoder = WanAnimateFaceEncoder( + in_dim=motion_encoder_dim, + out_dim=inner_dim, + hidden_dim=face_encoder_hidden_dim, + num_heads=face_encoder_num_heads, + ) + + # 5. Transformer blocks + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + ) + for _ in range(num_layers) + ] + ) + + # 6. Face adapter (applied after every inject_face_latents_blocks-th block) + self.face_adapter = nn.ModuleList( + [ + WanAnimateFaceBlockCrossAttention( + dim=inner_dim, + heads=num_attention_heads, + dim_head=inner_dim // num_attention_heads, + eps=eps, + cross_attention_dim_head=inner_dim // num_attention_heads, + ) + for _ in range(num_layers // inject_face_latents_blocks) + ] + ) + + # 7. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + pose_hidden_states: Optional[torch.Tensor] = None, + face_pixel_values: Optional[torch.Tensor] = None, + motion_encode_batch_size: Optional[int] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass of Wan2.2-Animate transformer model. + + Args: + hidden_states: Input noisy video latents of shape (B, 2C + 4, T + 1, H, W). + timestep: The current timestep in the denoising loop. + encoder_hidden_states: Text embeddings from the text encoder. + encoder_hidden_states_image: CLIP visual features of the reference (character) image. + pose_hidden_states: Pose video latents of shape (B, C, T, H, W). + face_pixel_values: Face video in pixel space of shape (B, C', S, H', W'). + motion_encode_batch_size: Batch size for batched encoding of the face video via the motion encoder. + return_dict: Whether to return the output as a dict or tuple. + """ + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + pose_hidden_states = self.pose_patch_embedding(pose_hidden_states) + # Add pose embeddings to hidden states (skip the first conditioning frame) + hidden_states[:, :, 1:] = hidden_states[:, :, 1:] + pose_hidden_states + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # Save original sequence length for unshard + original_seq_len = hidden_states.shape[1] + + # 3. Condition embeddings (time, text, image) + # Wan Animate is based on Wan 2.1 and thus uses Wan 2.1's timestep logic + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=None + ) + + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Get motion features from the face video + batch_size, face_channels, num_face_frames, face_height, face_width = face_pixel_values.shape + # Rearrange from (B, C, T, H, W) to (B*T, C, H, W) + face_pixel_values = face_pixel_values.permute(0, 2, 1, 3, 4).reshape(-1, face_channels, face_height, face_width) + + # Batched motion encoder inference to trade off speed for memory + motion_encode_batch_size = motion_encode_batch_size or self.config.motion_encoder_batch_size + face_batches = torch.split(face_pixel_values, motion_encode_batch_size) + motion_vec_batches = [] + for face_batch in face_batches: + motion_vec_batch = self.motion_encoder(face_batch) + motion_vec_batches.append(motion_vec_batch) + motion_vec = torch.cat(motion_vec_batches) + motion_vec = motion_vec.view(batch_size, num_face_frames, -1) + + # Get face features from the motion vector + motion_vec = self.face_encoder(motion_vec) + + # Add padding at the beginning (prepend zeros for the conditioning frame) + pad_face = torch.zeros_like(motion_vec[:, :1]) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + + # 5. Sequence parallel shard + rotary_emb_cos, rotary_emb_sin = rotary_emb + hidden_states, rotary_emb_cos, rotary_emb_sin = sequence_parallel_shard( + [hidden_states, rotary_emb_cos, rotary_emb_sin], + seq_dims=[1, 1, 1], + ) + rotary_emb = (rotary_emb_cos, rotary_emb_sin) + + # 6. Transformer blocks with face adapter integration + for block_idx, block in enumerate(self.blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # Face adapter integration: apply after every inject_face_latents_blocks-th block + if block_idx % self.config.inject_face_latents_blocks == 0: + face_adapter_block_idx = block_idx // self.config.inject_face_latents_blocks + face_adapter_output = self.face_adapter[face_adapter_block_idx](hidden_states, motion_vec) + face_adapter_output = face_adapter_output.to(device=hidden_states.device) + hidden_states = face_adapter_output + hidden_states + + # 7. Sequence parallel unshard + (hidden_states,) = sequence_parallel_unshard([hidden_states], seq_dims=[1], seq_lens=[original_seq_len]) + + # 8. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffsynth_engine/models/wan/transformer_wan_vace.py b/diffsynth_engine/models/wan/transformer_wan_vace.py new file mode 100644 index 0000000..29c3e84 --- /dev/null +++ b/diffsynth_engine/models/wan/transformer_wan_vace.py @@ -0,0 +1,371 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_wan_vace.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from diffusers.configuration_utils import register_to_config +from diffusers.models.attention import FeedForward +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.models.normalization import FP32LayerNorm + +from diffsynth_engine.distributed.utils import sequence_parallel_shard, sequence_parallel_unshard +from diffsynth_engine.models.base import DiffusionModel +from diffsynth_engine.models.wan.transformer_wan import ( + WanAttention, + WanRotaryPosEmbed, + WanTimeTextImageEmbedding, + WanTransformerBlock, +) +from diffsynth_engine.utils import logging + +logger = logging.get_logger(__name__) + + +class WanVACETransformerBlock(nn.Module): + """ + VACE control branch Transformer block for the Wan model. + + This block mirrors the structure of WanTransformerBlock but adds input/output projection + layers for the VACE control signal injection. + """ + + def __init__( + self, + dim: int, + ffn_dim: int, + num_heads: int, + qk_norm: str = "rms_norm_across_heads", + cross_attn_norm: bool = False, + eps: float = 1e-6, + added_kv_proj_dim: Optional[int] = None, + apply_input_projection: bool = False, + apply_output_projection: bool = False, + ): + super().__init__() + + # 1. Input projection (only for the first VACE block, layer 0) + self.proj_in = None + if apply_input_projection: + self.proj_in = nn.Linear(dim, dim) + + # 2. Self-attention + self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) + self.attn1 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + cross_attention_dim_head=None, + ) + + # 3. Cross-attention + self.attn2 = WanAttention( + dim=dim, + heads=num_heads, + dim_head=dim // num_heads, + eps=eps, + added_kv_proj_dim=added_kv_proj_dim, + cross_attention_dim_head=dim // num_heads, + ) + self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + + # 4. Feed-forward + self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") + self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False) + + # 5. Output projection + self.proj_out = None + if apply_output_projection: + self.proj_out = nn.Linear(dim, dim) + + self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + control_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + ) -> torch.Tensor: + if self.proj_in is not None: + control_hidden_states = self.proj_in(control_hidden_states) + control_hidden_states = control_hidden_states + hidden_states + + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.to(temb.device) + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as( + control_hidden_states + ) + attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb) + control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states) + + # 2. Cross-attention + norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None) + control_hidden_states = control_hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + control_hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as( + control_hidden_states + ) + + conditioning_states = None + if self.proj_out is not None: + conditioning_states = self.proj_out(control_hidden_states) + + return conditioning_states, control_hidden_states + + +class WanVACETransformer3DModel(DiffusionModel): + """ + A Transformer model for video-like data used in the Wan VACE model. + + This model extends the base Wan Transformer with a VACE control branch that injects + conditioning signals at specified layers for controllable video generation. + + Args: + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch). + num_attention_heads (`int`, defaults to `40`): + The number of attention heads. + attention_head_dim (`int`, defaults to `128`): + The number of channels in each head. + in_channels (`int`, defaults to `16`): + The number of channels in the input. + out_channels (`int`, defaults to `16`): + The number of channels in the output. + text_dim (`int`, defaults to `4096`): + Input dimension for text embeddings. + freq_dim (`int`, defaults to `256`): + Dimension for sinusoidal time embeddings. + ffn_dim (`int`, defaults to `13824`): + Intermediate dimension in feed-forward network. + num_layers (`int`, defaults to `40`): + The number of layers of transformer blocks to use. + cross_attn_norm (`bool`, defaults to `True`): + Enable cross-attention normalization. + qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`): + Query/key normalization type. + eps (`float`, defaults to `1e-6`): + Epsilon value for normalization layers. + image_dim (`int`, *optional*, defaults to `None`): + Dimension for image embeddings (used in I2V models). + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels for added key and value projections. + rope_max_seq_len (`int`, defaults to `1024`): + Maximum sequence length for rotary position embeddings. + pos_embed_seq_len (`int`, *optional*, defaults to `None`): + Positional embedding sequence length. + vace_layers (`List[int]`, defaults to `[0, 5, 10, 15, 20, 25, 30, 35]`): + Layer indices where VACE control signals are injected. + vace_in_channels (`int`, defaults to `96`): + Number of input channels for the VACE patch embedding. + """ + + _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"] + _keys_to_ignore_on_load_unexpected = ["norm_added_q"] + + @register_to_config + def __init__( + self, + patch_size: Tuple[int, ...] = (1, 2, 2), + num_attention_heads: int = 40, + attention_head_dim: int = 128, + in_channels: int = 16, + out_channels: int = 16, + text_dim: int = 4096, + freq_dim: int = 256, + ffn_dim: int = 13824, + num_layers: int = 40, + cross_attn_norm: bool = True, + qk_norm: Optional[str] = "rms_norm_across_heads", + eps: float = 1e-6, + image_dim: Optional[int] = None, + added_kv_proj_dim: Optional[int] = None, + rope_max_seq_len: int = 1024, + pos_embed_seq_len: Optional[int] = None, + vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35], + vace_in_channels: int = 96, + ) -> None: + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + out_channels = out_channels or in_channels + + if max(vace_layers) >= num_layers: + raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.") + if 0 not in vace_layers: + raise ValueError("VACE layers must include layer 0.") + + # 1. Patch & position embedding + self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) + self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) + + # 2. Condition embeddings + self.condition_embedder = WanTimeTextImageEmbedding( + dim=inner_dim, + time_freq_dim=freq_dim, + time_proj_dim=inner_dim * 6, + text_embed_dim=text_dim, + image_embed_dim=image_dim, + pos_embed_seq_len=pos_embed_seq_len, + ) + + # 3. Transformer blocks (main backbone) + self.blocks = nn.ModuleList( + [ + WanTransformerBlock( + inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim + ) + for _ in range(num_layers) + ] + ) + + # 4. VACE control blocks + self.vace_blocks = nn.ModuleList( + [ + WanVACETransformerBlock( + inner_dim, + ffn_dim, + num_attention_heads, + qk_norm, + cross_attn_norm, + eps, + added_kv_proj_dim, + apply_input_projection=i == 0, + apply_output_projection=True, + ) + for i in range(len(vace_layers)) + ] + ) + + # 5. Output norm & projection + self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) + self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) + self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + control_hidden_states: torch.Tensor = None, + control_hidden_states_scale: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + if control_hidden_states_scale is None: + control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers)) + control_hidden_states_scale = torch.unbind(control_hidden_states_scale) + if len(control_hidden_states_scale) != len(self.config.vace_layers): + raise ValueError( + f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be " + f"equal to {len(self.config.vace_layers)}." + ) + + # 1. Rotary position embedding + rotary_emb = self.rope(hidden_states) + + # 2. Patch embedding + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + control_hidden_states = self.vace_patch_embedding(control_hidden_states) + control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2) + control_hidden_states_padding = control_hidden_states.new_zeros( + batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2) + ) + control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1) + + # Save original sequence length for unshard + original_seq_len = hidden_states.shape[1] + + # 3. Time embedding + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + # 4. Image embedding + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 5. Sequence parallel shard + rotary_emb_cos, rotary_emb_sin = rotary_emb + hidden_states, control_hidden_states, rotary_emb_cos, rotary_emb_sin = sequence_parallel_shard( + [hidden_states, control_hidden_states, rotary_emb_cos, rotary_emb_sin], + seq_dims=[1, 1, 1, 1], + ) + rotary_emb = (rotary_emb_cos, rotary_emb_sin) + + # 6. VACE control blocks (prepare control hints) + control_hidden_states_list = [] + for i, block in enumerate(self.vace_blocks): + conditioning_states, control_hidden_states = block( + hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb + ) + control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i])) + control_hidden_states_list = control_hidden_states_list[::-1] + + # 7. Main transformer blocks with VACE injection + for i, block in enumerate(self.blocks): + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + if i in self.config.vace_layers: + control_hint, scale = control_hidden_states_list.pop() + hidden_states = hidden_states + control_hint * scale + + # 8. Sequence parallel unshard + (hidden_states,) = sequence_parallel_unshard([hidden_states], seq_dims=[1], seq_lens=[original_seq_len]) + + # 9. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/diffsynth_engine/pipelines/wan/__init__.py b/diffsynth_engine/pipelines/wan/__init__.py new file mode 100644 index 0000000..8186b95 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/__init__.py @@ -0,0 +1,11 @@ +from .pipeline_wan_animate import WanAnimatePipeline +from .pipeline_wan_i2v import WanImageToVideoPipeline +from .pipeline_wan_t2v import WanTextToVideoPipeline +from .pipeline_wan_vace import WanVACEPipeline + +__all__ = [ + "WanTextToVideoPipeline", + "WanImageToVideoPipeline", + "WanAnimatePipeline", + "WanVACEPipeline", +] diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_animate.py b/diffsynth_engine/pipelines/wan/pipeline_wan_animate.py new file mode 100644 index 0000000..d4fbf0a --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_animate.py @@ -0,0 +1,1310 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_animate.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import json +import os +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import regex as re +import torch +from accelerate import init_empty_weights +from diffusers.pipelines.wan.image_processor import WanAnimateImageProcessor +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.layers.attention import get_attn_backend +from diffsynth_engine.models.wan import AutoencoderKLWan, WanAnimateTransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.load_utils import load_model_weights + +logger = logging.get_logger(__name__) + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanAnimatePipeline(Pipeline): + r""" + Pipeline for unified character animation and replacement using Wan-Animate, + adapted for DiffSynth-Engine. + + Supports two modes: + 1. **Animation mode**: Generates a video of the character image that mimics the human motion + in the input pose and face videos. + 2. **Replacement mode**: Replaces a character in a background video with the provided character + image, using the pose and face videos for motion control. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + image_encoder (`CLIPVisionModel`): + CLIP vision model for encoding input images. + image_processor (`CLIPImageProcessor`): + CLIP image processor for preprocessing input images. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`UniPCMultistepScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanAnimateTransformer3DModel`): + Conditional Transformer to denoise the input latents. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: UniPCMultistepScheduler, + image_processor: CLIPImageProcessor, + image_encoder: CLIPVisionModel, + transformer: WanAnimateTransformer3DModel, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.image_encoder = image_encoder + self.transformer = transformer + self.scheduler = scheduler + self.image_processor = image_processor + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + self.video_processor_for_mask = VideoProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, do_normalize=False, do_convert_grayscale=True + ) + + spatial_patch_size = self.transformer.config.patch_size[-2:] if self.transformer is not None else (2, 2) + self.vae_image_processor = WanAnimateImageProcessor( + vae_scale_factor=self.vae_scale_factor_spatial, + spatial_patch_size=spatial_patch_size, + resample="bilinear", + fill_color=0, + ) + + head_dim = self.transformer.config.attention_head_dim + self.attn_backend = get_attn_backend( + head_size=head_dim, + attn_type=pipeline_config.attn_type, + ) + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanAnimatePipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanAnimatePipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + # Load transformer + transformer = cls.init_transformer(pipeline_config) + + # Load scheduler - auto-detect scheduler class from config + scheduler_config_path = os.path.join(pipeline_config.model_path, "scheduler", SCHEDULER_CONFIG_NAME) + scheduler_cls = UniPCMultistepScheduler + if os.path.exists(scheduler_config_path): + with open(scheduler_config_path, "r") as f: + scheduler_config_dict = json.load(f) + class_name = scheduler_config_dict.get("_class_name", None) + if class_name is not None: + try: + from diffusers import schedulers as schedulers_module + + scheduler_cls = getattr(schedulers_module, class_name) + logger.info(f"Using scheduler class from config: {class_name}") + except AttributeError: + logger.warning( + f"Scheduler class '{class_name}' not found in diffusers.schedulers, " + f"falling back to UniPCMultistepScheduler" + ) + scheduler = scheduler_cls.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(pipeline_config) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + # Load image encoder + image_encoder = cls.init_image_encoder(pipeline_config) + + # Load image processor + image_processor = None + image_processor_path = os.path.join(pipeline_config.model_path, "image_processor") + if os.path.isdir(image_processor_path): + image_processor = CLIPImageProcessor.from_pretrained( + pipeline_config.model_path, + subfolder="image_processor", + ) + logger.info("Loaded image_processor from `image_processor` subfolder.") + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + image_encoder=image_encoder, + image_processor=image_processor, + transformer=transformer, + scheduler=scheduler, + ) + + @staticmethod + def init_transformer( + pipeline_config: WanPipelineConfig, empty_weights: bool = False, subfolder: str = "transformer" + ): + logger.info(f"Initializing transformer from subfolder={subfolder}...") + with set_forward_context(attn_type=pipeline_config.attn_type): + if empty_weights: + with init_empty_weights(): + config_dict = WanAnimateTransformer3DModel.load_config( + pipeline_config.model_path, + subfolder=subfolder, + local_files_only=True, + ) + model = WanAnimateTransformer3DModel.from_config(config_dict) + else: + model = WanAnimateTransformer3DModel.from_pretrained( + pipeline_config.model_path, + subfolder=subfolder, + device=pipeline_config.device, + dtype=pipeline_config.model_dtype, + ) + return model + + @staticmethod + def init_text_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing text encoder...") + if empty_weights: + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="text_encoder", + device=pipeline_config.device, + dtype=pipeline_config.text_encoder_dtype, + ) + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + + if "shared.weight" in state_dict and "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + model.load_state_dict(state_dict, strict=False, assign=True) + model.to(device=pipeline_config.device) + return model + + @staticmethod + def init_vae(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing VAE...") + if empty_weights: + with init_empty_weights(): + config_dict = AutoencoderKLWan.load_config( + pipeline_config.model_path, + subfolder="vae", + local_files_only=True, + ) + model = AutoencoderKLWan.from_config(config_dict) + return model + + model = AutoencoderKLWan.from_pretrained( + pipeline_config.model_path, + subfolder="vae", + device=pipeline_config.device, + dtype=pipeline_config.vae_dtype, + ) + return model + + @staticmethod + def init_image_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing image encoder...") + image_encoder_path = os.path.join(pipeline_config.model_path, "image_encoder") + if not os.path.isdir(image_encoder_path): + logger.warning(f"image_encoder subfolder not found in {pipeline_config.model_path}. Skipping.") + return None + + if empty_weights: + with init_empty_weights(): + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + local_files_only=True, + ) + return model + + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + dtype=torch.float32, + ) + model.to(device=pipeline_config.device) + return model + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # Duplicate text embeddings for each generation per prompt + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image, + device: Optional[torch.device] = None, + ): + device = device or self.device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + mode=None, + prev_segment_conditioning_frames=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if pose_video is None: + raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") + if face_video is None: + raise ValueError("Provide `face_video`. Cannot leave `face_video` undefined.") + if not isinstance(pose_video, list) or not isinstance(face_video, list): + raise ValueError("`pose_video` and `face_video` must be lists of PIL images.") + if len(pose_video) == 0 or len(face_video) == 0: + raise ValueError("`pose_video` and `face_video` must contain at least one frame.") + if mode == "replace" and (background_video is None or mask_video is None): + raise ValueError( + "Provide `background_video` and `mask_video`. Cannot leave both `background_video` and `mask_video`" + " undefined when mode is `replace`." + ) + if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)): + raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.") + + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found" + f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if mode is not None and (not isinstance(mode, str) or mode not in ("animate", "replace")): + raise ValueError( + f"`mode` has to be of type `str` and in ('animate', 'replace') but its type is {type(mode)} and value is {mode}" + ) + + if prev_segment_conditioning_frames is not None and ( + not isinstance(prev_segment_conditioning_frames, int) or prev_segment_conditioning_frames not in (1, 5) + ): + raise ValueError( + f"`prev_segment_conditioning_frames` has to be of type `int` and 1 or 5 but its type is" + f" {type(prev_segment_conditioning_frames)} and value is {prev_segment_conditioning_frames}" + ) + + def get_i2v_mask( + self, + batch_size: int, + latent_t: int, + latent_h: int, + latent_w: int, + mask_len: int = 1, + mask_pixel_values: Optional[torch.Tensor] = None, + dtype: Optional[torch.dtype] = None, + device: Union[str, torch.device] = "cuda", + ) -> torch.Tensor: + # mask_pixel_values shape (if supplied): [B, C = 1, T, latent_h, latent_w] + if mask_pixel_values is None: + mask_lat_size = torch.zeros( + batch_size, 1, (latent_t - 1) * 4 + 1, latent_h, latent_w, dtype=dtype, device=device + ) + else: + mask_lat_size = mask_pixel_values.clone().to(device=device, dtype=dtype) + mask_lat_size[:, :, :mask_len] = 1 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:]], dim=2) + mask_lat_size = mask_lat_size.view( + batch_size, -1, self.vae_scale_factor_temporal, latent_h, latent_w + ).transpose(1, 2) + + return mask_lat_size + + def prepare_reference_image_latents( + self, + image: torch.Tensor, + batch_size: int = 1, + sample_mode: str = "argmax", + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + # image shape: (B, C, H, W) or (B, C, T, H, W) + dtype = dtype or next(self.vae.parameters()).dtype + if image.ndim == 4: + image = image.unsqueeze(2) + + _, _, _, height, width = image.shape + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + ref_image_latents = [ + retrieve_latents(self.vae.encode(image), generator=g, sample_mode=sample_mode) for g in generator + ] + ref_image_latents = torch.cat(ref_image_latents) + else: + ref_image_latents = retrieve_latents(self.vae.encode(image), generator, sample_mode) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(ref_image_latents.device, ref_image_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + ref_image_latents.device, ref_image_latents.dtype + ) + ref_image_latents = (ref_image_latents - latents_mean) * latents_recip_std + + if ref_image_latents.shape[0] == 1 and batch_size > 1: + ref_image_latents = ref_image_latents.expand(batch_size, -1, -1, -1, -1) + + reference_image_mask = self.get_i2v_mask(batch_size, 1, latent_height, latent_width, 1, None, dtype, device) + reference_image_latents = torch.cat([reference_image_mask, ref_image_latents], dim=1) + + return reference_image_latents + + def prepare_prev_segment_cond_latents( + self, + prev_segment_cond_video: Optional[torch.Tensor] = None, + background_video: Optional[torch.Tensor] = None, + mask_video: Optional[torch.Tensor] = None, + batch_size: int = 1, + segment_frame_length: int = 77, + start_frame: int = 0, + height: int = 720, + width: int = 1280, + prev_segment_cond_frames: int = 1, + task: str = "animate", + interpolation_mode: str = "bicubic", + sample_mode: str = "argmax", + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + dtype = dtype or next(self.vae.parameters()).dtype + if prev_segment_cond_video is None: + if task == "replace": + prev_segment_cond_video = background_video[:, :, :prev_segment_cond_frames].to(dtype) + else: + cond_frames_shape = (batch_size, 3, prev_segment_cond_frames, height, width) + prev_segment_cond_video = torch.zeros(cond_frames_shape, dtype=dtype, device=device) + + data_batch_size, channels, _, segment_height, segment_width = prev_segment_cond_video.shape + num_latent_frames = (segment_frame_length - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + if segment_height != height or segment_width != width: + logger.info( + f"Interpolating prev segment cond video from ({segment_width}, {segment_height}) to ({width}, {height})" + ) + prev_segment_cond_video = prev_segment_cond_video.transpose(1, 2).flatten(0, 1) + prev_segment_cond_video = torch.nn.functional.interpolate( + prev_segment_cond_video, size=(height, width), mode=interpolation_mode + ) + prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2) + + if task == "replace": + remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype) + else: + remaining_segment_frames = segment_frame_length - prev_segment_cond_frames + remaining_segment = torch.zeros( + batch_size, channels, remaining_segment_frames, height, width, dtype=dtype, device=device + ) + + prev_segment_cond_video = prev_segment_cond_video.to(dtype=dtype) + full_segment_cond_video = torch.cat([prev_segment_cond_video, remaining_segment], dim=2) + + if isinstance(generator, list): + if data_batch_size == len(generator): + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode) + for i, g in enumerate(generator) + ] + elif data_batch_size == 1: + prev_segment_cond_latents = [ + retrieve_latents(self.vae.encode(full_segment_cond_video), g, sample_mode) for g in generator + ] + else: + raise ValueError( + f"The batch size of the prev segment video should be either {len(generator)} or 1 but is" + f" {data_batch_size}" + ) + prev_segment_cond_latents = torch.cat(prev_segment_cond_latents) + else: + prev_segment_cond_latents = retrieve_latents( + self.vae.encode(full_segment_cond_video), generator, sample_mode + ) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(prev_segment_cond_latents.device, prev_segment_cond_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + prev_segment_cond_latents.device, prev_segment_cond_latents.dtype + ) + prev_segment_cond_latents = (prev_segment_cond_latents - latents_mean) * latents_recip_std + + if task == "replace": + mask_video = 1 - mask_video + mask_video = mask_video.permute(0, 2, 1, 3, 4) + mask_video = mask_video.flatten(0, 1) + mask_video = torch.nn.functional.interpolate(mask_video, size=(latent_height, latent_width), mode="nearest") + mask_pixel_values = mask_video.unflatten(0, (batch_size, -1)) + mask_pixel_values = mask_pixel_values.permute(0, 2, 1, 3, 4) + else: + mask_pixel_values = None + prev_segment_cond_mask = self.get_i2v_mask( + batch_size, + num_latent_frames, + latent_height, + latent_width, + mask_len=prev_segment_cond_frames if start_frame > 0 else 0, + mask_pixel_values=mask_pixel_values, + dtype=dtype, + device=device, + ) + + prev_segment_cond_latents = torch.cat([prev_segment_cond_mask, prev_segment_cond_latents], dim=1) + return prev_segment_cond_latents + + def prepare_pose_latents( + self, + pose_video: torch.Tensor, + batch_size: int = 1, + sample_mode: str = "argmax", + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + pose_video = pose_video.to( + device=device, dtype=dtype if dtype is not None else next(self.vae.parameters()).dtype + ) + if isinstance(generator, list): + pose_latents = [ + retrieve_latents(self.vae.encode(pose_video), generator=g, sample_mode=sample_mode) for g in generator + ] + pose_latents = torch.cat(pose_latents) + else: + pose_latents = retrieve_latents(self.vae.encode(pose_video), generator, sample_mode) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(pose_latents.device, pose_latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + pose_latents.device, pose_latents.dtype + ) + pose_latents = (pose_latents - latents_mean) * latents_recip_std + if pose_latents.shape[0] == 1 and batch_size > 1: + pose_latents = pose_latents.expand(batch_size, -1, -1, -1, -1) + return pose_latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 720, + width: int = 1280, + num_frames: int = 77, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + # +1 for the conditioning frame + shape = (batch_size, num_channels_latents, num_latent_frames + 1, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents + + def pad_video_frames(self, frames: List[Any], num_target_frames: int) -> List[Any]: + """ + Pads an array-like video `frames` to `num_target_frames` using a "reflect"-like strategy. + + Example: pad_video_frames([1, 2, 3, 4, 5], 10) -> [1, 2, 3, 4, 5, 4, 3, 2, 1, 2] + """ + idx = 0 + flip = False + target_frames = [] + while len(target_frames) < num_target_frames: + target_frames.append(deepcopy(frames[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == len(frames) - 1: + flip = not flip + + return target_frames + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + image_embeds: Optional[torch.Tensor], + pose_latents: torch.Tensor, + face_video_segment: torch.Tensor, + motion_encode_batch_size: Optional[int], + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + batch_size: int, + ): + """ + Predict noise with optional classifier-free guidance and CFG parallelism. + + For Wan Animate, the unconditional pass blanks out the face video (sets all pixels to -1) + to remove face conditioning. + + Args: + latent_model_input: The model input (latents concatenated with reference/conditioning latents). + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + image_embeds: Image embeddings tensor for cross-attention. + pose_latents: Pose video latents. + face_video_segment: Face video segment in pixel space. + motion_encode_batch_size: Batch size for batched motion encoding. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + batch_size: The actual batch size. + + Returns: + noise_pred: The predicted noise tensor. + """ + if not apply_cfg: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not model_parallel_is_initialized(): + raise RuntimeError("Model parallel groups must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + noise_pred_pos = torch.zeros_like(latent_model_input, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latent_model_input, dtype=torch.float32) + + # Positive prompt forward pass (conditional) + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + )[0].float() + + # Negative prompt forward pass (unconditional) - blank out face + face_pixel_values_uncond = face_video_segment * 0 - 1 + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + pose_hidden_states=pose_latents, + face_pixel_values=face_pixel_values_uncond, + motion_encode_batch_size=motion_encode_batch_size, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + image, + pose_video: List[PIL.Image.Image], + face_video: List[PIL.Image.Image], + background_video: Optional[List[PIL.Image.Image]] = None, + mask_video: Optional[List[PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + segment_frame_length: int = 77, + num_inference_steps: int = 20, + mode: str = "animate", + prev_segment_conditioning_frames: int = 1, + motion_encode_batch_size: Optional[int] = None, + guidance_scale: float = 1.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image: The input character image to condition the generation on. + pose_video (`List[PIL.Image.Image]`): The input pose video. + face_video (`List[PIL.Image.Image]`): The input face video. + background_video (`List[PIL.Image.Image]`, *optional*): Background video for replace mode. + mask_video (`List[PIL.Image.Image]`, *optional*): Mask video for replace mode. + prompt (`str` or `List[str]`, *optional*): The prompt(s) to guide generation. + negative_prompt (`str` or `List[str]`, *optional*): The negative prompt(s). + height (`int`, defaults to `720`): The height of the generated video. + width (`int`, defaults to `1280`): The width of the generated video. + segment_frame_length (`int`, defaults to `77`): Frames per generated segment. + num_inference_steps (`int`, defaults to `20`): Number of denoising steps. + mode (`str`, defaults to `"animate"`): `"animate"` or `"replace"`. + prev_segment_conditioning_frames (`int`, defaults to `1`): Frames from previous segment for conditioning. + motion_encode_batch_size (`int`, *optional*): Batch size for motion encoding. + guidance_scale (`float`, defaults to `1.0`): CFG scale. + num_videos_per_prompt (`int`, *optional*, defaults to 1): Videos per prompt. + generator: Random generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. + image_embeds (`torch.Tensor`, *optional*): Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"np"`): Output format. + return_dict (`bool`, *optional*, defaults to `True`): Whether to return a `WanPipelineOutput`. + attention_kwargs (`dict`, *optional*): Kwargs for attention. + callback_on_step_end (`Callable`, *optional*): Callback at end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): Tensor inputs for callback. + max_sequence_length (`int`, defaults to `512`): Max sequence length for text encoder. + + Returns: + `WanPipelineOutput` or `tuple`: Generated video frames. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + image, + pose_video, + face_video, + background_video, + mask_video, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + mode, + prev_segment_conditioning_frames, + ) + + if segment_frame_length % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the" + f" nearest number." + ) + segment_frame_length = ( + segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + ) + segment_frame_length = max(segment_frame_length, 1) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Compute segment layout + cond_video_frames = len(pose_video) + effective_segment_length = segment_frame_length - prev_segment_conditioning_frames + last_segment_frames = (cond_video_frames - prev_segment_conditioning_frames) % effective_segment_length + if last_segment_frames == 0: + num_padding_frames = 0 + else: + num_padding_frames = effective_segment_length - last_segment_frames + num_target_frames = cond_video_frames + num_padding_frames + num_segments = num_target_frames // effective_segment_length + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.pipeline_config.model_dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Preprocess and encode the reference (character) image + image_height, image_width = self.vae_image_processor.get_default_height_width(image) + if image_height != height or image_width != width: + logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})") + image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to( + device, dtype=torch.float32 + ) + + # Get CLIP features from the reference image + if image_embeds is None: + image_embeds = self.encode_image(image, device) + image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 5. Encode conditioning videos (pose, face) + pose_video = self.pad_video_frames(pose_video, num_target_frames) + face_video = self.pad_video_frames(face_video, num_target_frames) + + pose_video_width, pose_video_height = pose_video[0].size + if pose_video_height != height or pose_video_width != width: + logger.warning( + f"Reshaping pose video from ({pose_video_width}, {pose_video_height}) to ({width}, {height})" + ) + pose_video = self.video_processor.preprocess_video(pose_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + face_video_width, face_video_height = face_video[0].size + expected_face_size = self.transformer.config.motion_encoder_size + if face_video_width != expected_face_size or face_video_height != expected_face_size: + logger.warning( + f"Reshaping face video from ({face_video_width}, {face_video_height}) to ({expected_face_size}," + f" {expected_face_size})" + ) + face_video = self.video_processor.preprocess_video( + face_video, height=expected_face_size, width=expected_face_size + ).to(device, dtype=torch.float32) + + if mode == "replace": + background_video = self.pad_video_frames(background_video, num_target_frames) + mask_video = self.pad_video_frames(mask_video, num_target_frames) + + background_video = self.video_processor.preprocess_video(background_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + mask_video = self.video_processor_for_mask.preprocess_video(mask_video, height=height, width=width).to( + device, dtype=torch.float32 + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables which stay constant for all inference segments + num_channels_latents = self.vae.config.z_dim + + # Get VAE-encoded latents of the reference (character) image + reference_image_latents = self.prepare_reference_image_latents( + image_pixels, batch_size * num_videos_per_prompt, generator=generator, device=device + ) + + # 8. Loop over video inference segments + start = 0 + end = segment_frame_length + all_out_frames = [] + out_frames = None + actual_batch_size = batch_size * num_videos_per_prompt + + for _ in range(num_segments): + assert start + prev_segment_conditioning_frames < cond_video_frames + + # Sample noisy latents for the current inference segment + latents = self.prepare_latents( + actual_batch_size, + num_channels_latents=num_channels_latents, + height=height, + width=width, + num_frames=segment_frame_length, + dtype=torch.float32, + device=device, + generator=generator, + latents=latents if start == 0 else None, + ) + + pose_video_segment = pose_video[:, :, start:end] + face_video_segment = face_video[:, :, start:end] + + face_video_segment = face_video_segment.expand(actual_batch_size, -1, -1, -1, -1) + face_video_segment = face_video_segment.to(dtype=transformer_dtype) + + if start > 0: + prev_segment_cond_video = out_frames[:, :, -prev_segment_conditioning_frames:].clone().detach() + else: + prev_segment_cond_video = None + + if mode == "replace": + background_video_segment = background_video[:, :, start:end] + mask_video_segment = mask_video[:, :, start:end] + + background_video_segment = background_video_segment.expand(actual_batch_size, -1, -1, -1, -1) + mask_video_segment = mask_video_segment.expand(actual_batch_size, -1, -1, -1, -1) + else: + background_video_segment = None + mask_video_segment = None + + pose_latents = self.prepare_pose_latents( + pose_video_segment, actual_batch_size, generator=generator, device=device + ) + pose_latents = pose_latents.to(dtype=transformer_dtype) + + prev_segment_cond_latents = self.prepare_prev_segment_cond_latents( + prev_segment_cond_video, + background_video=background_video_segment, + mask_video=mask_video_segment, + batch_size=actual_batch_size, + segment_frame_length=segment_frame_length, + start_frame=start, + height=height, + width=width, + prev_segment_cond_frames=prev_segment_conditioning_frames, + task=mode, + generator=generator, + device=device, + ) + + # Concatenate the reference latents in the frame dimension + reference_latents = torch.cat([reference_image_latents, prev_segment_cond_latents], dim=2) + + # 8.1 Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Concatenate the reference image + prev segment conditioning in the channel dim + latent_model_input = torch.cat([latents, reference_latents], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latent_model_input=latent_model_input, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + pose_latents=pose_latents, + face_video_segment=face_video_segment, + motion_encode_batch_size=motion_encode_batch_size, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + batch_size=actual_batch_size, + ) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + latents = latents.to(self.pipeline_config.vae_dtype) + # Destandardize latents in preparation for Wan VAE decoding + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) + latents = latents / latents_recip_std + latents_mean + # Skip the first latent frame (used for conditioning) + out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0] + + if start > 0: + out_frames = out_frames[:, :, prev_segment_conditioning_frames:] + all_out_frames.append(out_frames) + + start += effective_segment_length + end += effective_segment_length + + # Reset scheduler timesteps / state for next denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + self._current_timestep = None + assert start + prev_segment_conditioning_frames >= cond_video_frames + + if not output_type == "latent": + video = torch.cat(all_out_frames, dim=2)[:, :, :cond_video_frames] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_i2v.py b/diffsynth_engine/pipelines/wan/pipeline_wan_i2v.py new file mode 100644 index 0000000..c8ac8e8 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_i2v.py @@ -0,0 +1,1055 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_i2v.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import json +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL +import regex as re +import torch +from accelerate import init_empty_weights +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.layers.attention import get_attn_backend +from diffsynth_engine.models.wan import AutoencoderKLWan, WanTransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.load_utils import load_model_weights + +logger = logging.get_logger(__name__) + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanImageToVideoPipeline(Pipeline): + r""" + Pipeline for image-to-video generation using Wan, adapted for DiffSynth-Engine. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + image_encoder (`CLIPVisionModel`, *optional*): + CLIP vision model for encoding input images. + image_processor (`CLIPImageProcessor`, *optional*): + CLIP image processor for preprocessing input images. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`FlowMatchEulerDiscreteScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents. + transformer_2 (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + expand_timesteps (`bool`, defaults to `False`): + Whether to expand timesteps for Wan2.2 ti2v models. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + image_processor: Optional[CLIPImageProcessor] = None, + image_encoder: Optional[CLIPVisionModel] = None, + transformer: Optional[WanTransformer3DModel] = None, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.image_encoder = image_encoder + self.image_processor = image_processor + self.transformer = transformer + self.transformer_2 = transformer_2 + self.scheduler = scheduler + self.boundary_ratio = boundary_ratio + self.expand_timesteps = expand_timesteps + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + active_transformer = transformer if transformer is not None else transformer_2 + head_dim = active_transformer.config.attention_head_dim + self.attn_backend = get_attn_backend( + head_size=head_dim, + attn_type=pipeline_config.attn_type, + ) + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanImageToVideoPipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanImageToVideoPipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + # Load model_index.json to read pipeline-level config and component declarations. + model_index_path = os.path.join(pipeline_config.model_path, "model_index.json") + model_index = {} + boundary_ratio = None + expand_timesteps = False + if os.path.exists(model_index_path): + with open(model_index_path, "r") as f: + model_index = json.load(f) + boundary_ratio = model_index.get("boundary_ratio", None) + expand_timesteps = model_index.get("expand_timesteps", False) + if boundary_ratio is not None: + logger.info(f"Loaded boundary_ratio={boundary_ratio} from model_index.json") + if expand_timesteps: + logger.info(f"Loaded expand_timesteps={expand_timesteps} from model_index.json") + + # Load transformer (subfolder defaults to "transformer") + transformer = cls.init_transformer(pipeline_config) + + # Load transformer_2 if declared in model_index.json. + transformer_2 = None + if "transformer_2" in model_index and model_index["transformer_2"] is not None: + transformer_2_subfolder = "transformer_2" + if os.path.isdir(os.path.join(pipeline_config.model_path, transformer_2_subfolder)): + transformer_2 = cls.init_transformer(pipeline_config, subfolder=transformer_2_subfolder) + logger.info( + f"Loaded transformer_2 from `{transformer_2_subfolder}` subfolder of {pipeline_config.model_path}." + ) + else: + logger.warning( + f"transformer_2 declared in model_index.json but subfolder " + f"'{transformer_2_subfolder}' not found in {pipeline_config.model_path}. Skipping." + ) + + # Load scheduler - auto-detect scheduler class from config + scheduler_config_path = os.path.join(pipeline_config.model_path, "scheduler", SCHEDULER_CONFIG_NAME) + scheduler_cls = FlowMatchEulerDiscreteScheduler + if os.path.exists(scheduler_config_path): + with open(scheduler_config_path, "r") as f: + scheduler_config_dict = json.load(f) + class_name = scheduler_config_dict.get("_class_name", None) + if class_name is not None: + try: + from diffusers import schedulers as schedulers_module + + scheduler_cls = getattr(schedulers_module, class_name) + logger.info(f"Using scheduler class from config: {class_name}") + except AttributeError: + logger.warning( + f"Scheduler class '{class_name}' not found in diffusers.schedulers, " + f"falling back to FlowMatchEulerDiscreteScheduler" + ) + scheduler = scheduler_cls.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(pipeline_config) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + # Load image encoder + image_encoder = cls.init_image_encoder(pipeline_config) + + # Load image processor + image_processor = None + image_processor_path = os.path.join(pipeline_config.model_path, "image_processor") + if os.path.isdir(image_processor_path): + image_processor = CLIPImageProcessor.from_pretrained( + pipeline_config.model_path, + subfolder="image_processor", + ) + logger.info("Loaded image_processor from `image_processor` subfolder.") + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + image_encoder=image_encoder, + image_processor=image_processor, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + ) + + @staticmethod + def init_transformer( + pipeline_config: WanPipelineConfig, empty_weights: bool = False, subfolder: str = "transformer" + ): + logger.info(f"Initializing transformer from subfolder={subfolder}...") + with set_forward_context(attn_type=pipeline_config.attn_type): + if empty_weights: + with init_empty_weights(): + config_dict = WanTransformer3DModel.load_config( + pipeline_config.model_path, + subfolder=subfolder, + local_files_only=True, + ) + model = WanTransformer3DModel.from_config(config_dict) + else: + model = WanTransformer3DModel.from_pretrained( + pipeline_config.model_path, + subfolder=subfolder, + device=pipeline_config.device, + dtype=pipeline_config.model_dtype, + ) + return model + + @staticmethod + def init_text_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing text encoder...") + if empty_weights: + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="text_encoder", + device=pipeline_config.device, + dtype=pipeline_config.text_encoder_dtype, + ) + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + + if "shared.weight" in state_dict and "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + model.load_state_dict(state_dict, strict=False, assign=True) + model.to(device=pipeline_config.device) + return model + + @staticmethod + def init_vae(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing VAE...") + if empty_weights: + with init_empty_weights(): + config_dict = AutoencoderKLWan.load_config( + pipeline_config.model_path, + subfolder="vae", + local_files_only=True, + ) + model = AutoencoderKLWan.from_config(config_dict) + return model + + model = AutoencoderKLWan.from_pretrained( + pipeline_config.model_path, + subfolder="vae", + device=pipeline_config.device, + dtype=pipeline_config.vae_dtype, + ) + return model + + @staticmethod + def init_image_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing image encoder...") + image_encoder_path = os.path.join(pipeline_config.model_path, "image_encoder") + if not os.path.isdir(image_encoder_path): + logger.warning(f"image_encoder subfolder not found in {pipeline_config.model_path}. Skipping.") + return None + + if empty_weights: + with init_empty_weights(): + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + local_files_only=True, + ) + return model + + model = CLIPVisionModel.from_pretrained( + pipeline_config.model_path, + subfolder="image_encoder", + dtype=torch.float32, + ) + model.to(device=pipeline_config.device) + return model + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_image( + self, + image, + device: Optional[torch.device] = None, + ): + device = device or self.device + image = self.image_processor(images=image, return_tensors="pt").to(device) + image_embeds = self.image_encoder(**image, output_hidden_states=True) + return image_embeds.hidden_states[-2] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + image_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if image is not None and image_embeds is not None: + raise ValueError( + f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" + " only forward one of the two." + ) + if image is None and image_embeds is None: + raise ValueError( + "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined." + ) + if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): + raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if self.boundary_ratio is not None and image_embeds is not None: + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + + def prepare_latents( + self, + image, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, ...]: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height = height // self.vae_scale_factor_spatial + latent_width = width // self.vae_scale_factor_spatial + + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + image = image.unsqueeze(2) # [batch_size, channels, 1, height, width] + + if self.expand_timesteps: + video_condition = image + elif last_image is None: + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 + ) + else: + last_image = last_image.unsqueeze(2) + video_condition = torch.cat( + [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], + dim=2, + ) + video_condition = video_condition.to(device=device, dtype=self.pipeline_config.vae_dtype) + + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if isinstance(generator, list): + latent_condition = [ + retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator + ] + latent_condition = torch.cat(latent_condition) + else: + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + + latent_condition = latent_condition.to(dtype) + latent_condition = (latent_condition - latents_mean) * latents_std + + if self.expand_timesteps: + first_frame_mask = torch.ones( + 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device + ) + first_frame_mask[:, :, 0] = 0 + return latents, latent_condition, first_frame_mask + + mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + + if last_image is None: + mask_lat_size[:, :, list(range(1, num_frames))] = 0 + else: + mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 + first_frame_mask = mask_lat_size[:, :, 0:1] + first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) + mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) + mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.transpose(1, 2) + mask_lat_size = mask_lat_size.to(latent_condition.device) + + return latents, torch.concat([mask_lat_size, latent_condition], dim=1) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + image_embeds: Optional[torch.Tensor], + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + batch_size: int, + model: Optional[WanTransformer3DModel] = None, + ): + """ + Predict noise with optional classifier-free guidance and CFG parallelism. + + Args: + latent_model_input: The model input (latents or latents + condition). + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + image_embeds: Image embeddings tensor for I2V cross-attention. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + batch_size: The actual batch size. + model: The transformer model to use. If None, defaults to self.transformer. + + Returns: + noise_pred: The predicted noise tensor. + """ + if model is None: + model = self.transformer + + if not apply_cfg: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not model_parallel_is_initialized(): + raise RuntimeError("Model parallel groups must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + noise_pred_pos = torch.zeros_like(latent_model_input, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latent_model_input, dtype=torch.float32) + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0].float() + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + image, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + image_embeds: Optional[torch.Tensor] = None, + last_image: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + image (`PipelineImageInput`): + The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale for classifier-free guidance. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage when `boundary_ratio` is set. If `None` and + `boundary_ratio` is not None, uses the same value as `guidance_scale`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + Random generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings. + last_image (`torch.Tensor`, *optional*): + Optional last frame image for video generation with start and end frames. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + Tensor inputs for the callback function. + max_sequence_length (`int`, defaults to `512`): + Maximum sequence length for the text encoder. + + Returns: + `WanPipelineOutput` or `tuple`: Generated video frames. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + image, + height, + width, + prompt_embeds, + negative_prompt_embeds, + image_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + "Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + patch_size = ( + self.transformer.config.patch_size if self.transformer is not None else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper " + f"patchification. Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + + if self.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + # Encode image embedding + transformer_dtype = self.pipeline_config.model_dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # only wan 2.1 i2v transformer accepts image_embeds + if self.transformer is not None and self.transformer.config.image_dim is not None: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image([image, last_image], device) + image_embeds = image_embeds.repeat(batch_size, 1, 1) + image_embeds = image_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.vae.config.z_dim + image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32) + if last_image is not None: + last_image = self.video_processor.preprocess(last_image, height=height, width=width).to( + device, dtype=torch.float32 + ) + + latents_outputs = self.prepare_latents( + image, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + last_image, + ) + if self.expand_timesteps: + latents, condition, first_frame_mask = latents_outputs + else: + latents, condition = latents_outputs + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + actual_batch_size = batch_size * num_videos_per_prompt + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Determine current model and guidance scale based on boundary_ratio + if boundary_timestep is None or t >= boundary_timestep: + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + if self.expand_timesteps: + latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents + latent_model_input = latent_model_input.to(transformer_dtype) + + temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latent_model_input=latent_model_input, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image_embeds=image_embeds, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=current_guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + batch_size=actual_batch_size, + model=current_model, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if self.expand_timesteps: + latents = (1 - first_frame_mask) * condition + first_frame_mask * latents + + if not output_type == "latent": + latents = latents.to(self.pipeline_config.vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_t2v.py b/diffsynth_engine/pipelines/wan/pipeline_wan_t2v.py new file mode 100644 index 0000000..d42d876 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_t2v.py @@ -0,0 +1,881 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from accelerate import init_empty_weights +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.layers.attention import get_attn_backend +from diffsynth_engine.models.wan import AutoencoderKLWan, WanTransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.load_utils import load_model_weights + +logger = logging.get_logger(__name__) + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +class WanTextToVideoPipeline(Pipeline): + r""" + Pipeline for text-to-video generation using Wan, adapted for DiffSynth-Engine. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5, specifically the google/umt5-xxl variant. + text_encoder (`UMT5EncoderModel`): + T5 text encoder, specifically the google/umt5-xxl variant. + vae (`AutoencoderKLWan`): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + scheduler (`FlowMatchEulerDiscreteScheduler`): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + transformer (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents. + transformer_2 (`WanTransformer3DModel`, *optional*): + Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables + two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise + stages. If not provided, only `transformer` is used. + boundary_ratio (`float`, *optional*, defaults to `None`): + Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising. + The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided, + `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps < + boundary_timestep. If `None`, only `transformer` is used for the entire denoising process. + expand_timesteps (`bool`, defaults to `False`): + Whether to expand timesteps for Wan2.2 ti2v models. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: Optional[WanTransformer3DModel] = None, + transformer_2: Optional[WanTransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + expand_timesteps: bool = False, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.transformer = transformer + self.transformer_2 = transformer_2 + self.scheduler = scheduler + self.boundary_ratio = boundary_ratio + self.expand_timesteps = expand_timesteps + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + active_transformer = transformer if transformer is not None else transformer_2 + head_dim = active_transformer.config.attention_head_dim + self.attn_backend = get_attn_backend( + head_size=head_dim, + attn_type=pipeline_config.attn_type, + ) + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + """ + Load a WanTextToVideoPipeline from a pretrained model path or config. + + Args: + model_path_or_config: Either a string path to the model directory or a WanPipelineConfig instance. + + Returns: + WanTextToVideoPipeline: The loaded pipeline. + """ + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + # Load model_index.json to read pipeline-level config and component declarations. + import json + + model_index_path = os.path.join(pipeline_config.model_path, "model_index.json") + model_index = {} + boundary_ratio = None + expand_timesteps = False + if os.path.exists(model_index_path): + with open(model_index_path, "r") as f: + model_index = json.load(f) + boundary_ratio = model_index.get("boundary_ratio", None) + expand_timesteps = model_index.get("expand_timesteps", False) + if boundary_ratio is not None: + logger.info(f"Loaded boundary_ratio={boundary_ratio} from model_index.json") + if expand_timesteps: + logger.info(f"Loaded expand_timesteps={expand_timesteps} from model_index.json") + + # Load transformer (subfolder defaults to "transformer") + transformer = cls.init_transformer(pipeline_config) + + # Load transformer_2 if declared in model_index.json. + transformer_2 = None + if "transformer_2" in model_index and model_index["transformer_2"] is not None: + transformer_2_subfolder = "transformer_2" + if os.path.isdir(os.path.join(pipeline_config.model_path, transformer_2_subfolder)): + transformer_2 = cls.init_transformer(pipeline_config, subfolder=transformer_2_subfolder) + logger.info( + f"Loaded transformer_2 from `{transformer_2_subfolder}` subfolder of {pipeline_config.model_path}." + ) + else: + logger.warning( + f"transformer_2 declared in model_index.json but subfolder " + f"'{transformer_2_subfolder}' not found in {pipeline_config.model_path}. Skipping." + ) + + # Load scheduler - auto-detect scheduler class from config, matching diffusers behavior + scheduler_config_path = os.path.join(pipeline_config.model_path, "scheduler", SCHEDULER_CONFIG_NAME) + scheduler_cls = FlowMatchEulerDiscreteScheduler # default fallback + if os.path.exists(scheduler_config_path): + with open(scheduler_config_path, "r") as f: + scheduler_config_dict = json.load(f) + class_name = scheduler_config_dict.get("_class_name", None) + if class_name is not None: + try: + from diffusers import schedulers as schedulers_module + + scheduler_cls = getattr(schedulers_module, class_name) + logger.info(f"Using scheduler class from config: {class_name}") + except AttributeError: + logger.warning( + f"Scheduler class '{class_name}' not found in diffusers.schedulers, " + f"falling back to FlowMatchEulerDiscreteScheduler" + ) + scheduler = scheduler_cls.from_pretrained( + pipeline_config.model_path, + subfolder="scheduler", + ) + + # Load VAE + vae = cls.init_vae(pipeline_config) + + # Load text encoder + text_encoder = cls.init_text_encoder(pipeline_config) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained( + pipeline_config.model_path, + subfolder="tokenizer", + ) + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + boundary_ratio=boundary_ratio, + expand_timesteps=expand_timesteps, + ) + + @staticmethod + def init_transformer( + pipeline_config: WanPipelineConfig, empty_weights: bool = False, subfolder: str = "transformer" + ): + logger.info(f"Initializing transformer from subfolder={subfolder}...") + with set_forward_context(attn_type=pipeline_config.attn_type): + if empty_weights: + with init_empty_weights(): + config_dict = WanTransformer3DModel.load_config( + pipeline_config.model_path, + subfolder=subfolder, + local_files_only=True, + ) + model = WanTransformer3DModel.from_config(config_dict) + else: + model = WanTransformer3DModel.from_pretrained( + pipeline_config.model_path, + subfolder=subfolder, + device=pipeline_config.device, + dtype=pipeline_config.model_dtype, + ) + return model + + @staticmethod + def init_text_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing text encoder...") + if empty_weights: + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="text_encoder", + device=pipeline_config.device, + dtype=pipeline_config.text_encoder_dtype, + ) + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + + if "shared.weight" in state_dict and "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + model.load_state_dict(state_dict, strict=False, assign=True) + model.to(device=pipeline_config.device) + return model + + @staticmethod + def init_vae(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing VAE...") + if empty_weights: + with init_empty_weights(): + config_dict = AutoencoderKLWan.load_config( + pipeline_config.model_path, + subfolder="vae", + local_files_only=True, + ) + model = AutoencoderKLWan.from_config(config_dict) + return model + + model = AutoencoderKLWan.from_pretrained( + pipeline_config.model_path, + subfolder="vae", + device=pipeline_config.device, + dtype=pipeline_config.vae_dtype, + ) + return model + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + max_sequence_length (`int`, *optional*, defaults to 226): + Maximum sequence length for the text encoder. + device (`torch.device`, *optional*): + torch device + dtype (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + guidance_scale_2=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + batch_size: int, + model: Optional[WanTransformer3DModel] = None, + ): + """ + Predict noise with optional classifier-free guidance and CFG parallelism. + + Args: + latents: Current noisy latents, shape (batch, channels, frames, height, width). + timestep: Current timestep tensor, shape (batch,). + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + batch_size: The actual batch size. + model: The transformer model to use. If None, defaults to self.transformer. + + Returns: + noise_pred: The predicted noise tensor. + """ + if model is None: + model = self.transformer + + transformer_dtype = self.pipeline_config.model_dtype + + if not apply_cfg: + latent_model_input = latents.to(transformer_dtype) + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not model_parallel_is_initialized(): + raise RuntimeError("Model parallel groups must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + latent_model_input = latents.to(transformer_dtype) + + noise_pred_pos = torch.zeros_like(latents, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latents, dtype=torch.float32) + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0].float() + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale for classifier-free guidance. + guidance_scale_2 (`float`, *optional*, defaults to `None`): + Guidance scale for the low-noise stage when `boundary_ratio` is set. If `None` and + `boundary_ratio` is not None, uses the same value as `guidance_scale`. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + Random generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + Tensor inputs for the callback function. + max_sequence_length (`int`, defaults to `512`): + Maximum sequence length for the text encoder. + + Returns: + `WanPipelineOutput` or `tuple`: Generated video frames. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + "Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + patch_size = ( + self.transformer.config.patch_size if self.transformer is not None else self.transformer_2.config.patch_size + ) + h_multiple_of = self.vae_scale_factor_spatial * patch_size[1] + w_multiple_of = self.vae_scale_factor_spatial * patch_size[2] + calc_height = height // h_multiple_of * h_multiple_of + calc_width = width // w_multiple_of * w_multiple_of + if height != calc_height or width != calc_width: + logger.warning( + f"`height` and `width` must be multiples of ({h_multiple_of}, {w_multiple_of}) for proper " + f"patchification. Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." + ) + height, width = calc_height, calc_width + + if self.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.pipeline_config.model_dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + mask = torch.ones(latents.shape, dtype=torch.float32, device=device) + + actual_batch_size = batch_size * num_videos_per_prompt + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + # Determine current model and guidance scale based on boundary_ratio + if boundary_timestep is None or t >= boundary_timestep: + # wan2.1 or high-noise stage in wan2.2 + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + # low-noise stage in wan2.2 + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + if self.expand_timesteps: + # Wan2.2 timestep expansion: seq_len = num_latent_frames * latent_height//2 * latent_width//2 + temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten() + timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1) + else: + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latents=latents, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=current_guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + batch_size=actual_batch_size, + model=current_model, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.pipeline_config.vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/pipelines/wan/pipeline_wan_vace.py b/diffsynth_engine/pipelines/wan/pipeline_wan_vace.py new file mode 100644 index 0000000..967a3b4 --- /dev/null +++ b/diffsynth_engine/pipelines/wan/pipeline_wan_vace.py @@ -0,0 +1,1124 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan_vace.py + +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import json +import os +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL.Image +import regex as re +import torch +from accelerate import init_empty_weights +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from transformers import AutoTokenizer, UMT5EncoderModel + +from diffsynth_engine.configs.wan import WanPipelineConfig +from diffsynth_engine.distributed.parallel_state import get_cfg_group, model_parallel_is_initialized +from diffsynth_engine.forward_context import set_forward_context +from diffsynth_engine.layers.attention import get_attn_backend +from diffsynth_engine.models.wan import AutoencoderKLWan, WanVACETransformer3DModel +from diffsynth_engine.pipelines.base import Pipeline +from diffsynth_engine.utils import logging +from diffsynth_engine.utils.load_utils import load_model_weights + +logger = logging.get_logger(__name__) + + +def basic_clean(text): + try: + import ftfy + + text = ftfy.fix_text(text) + except ImportError: + pass + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + + +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class WanVACEPipeline(Pipeline): + r""" + Pipeline for controllable video generation using Wan VACE, adapted for DiffSynth-Engine. + + Args: + pipeline_config (`WanPipelineConfig`): + Configuration for the pipeline. + tokenizer (`AutoTokenizer`): + Tokenizer from T5. + text_encoder (`UMT5EncoderModel`): + T5 text encoder. + vae (`AutoencoderKLWan`): + VAE Model to encode and decode videos. + scheduler (`FlowMatchEulerDiscreteScheduler`): + Scheduler for denoising. + transformer (`WanVACETransformer3DModel`, *optional*): + Transformer for high-noise stage denoising. + transformer_2 (`WanVACETransformer3DModel`, *optional*): + Transformer for low-noise stage denoising. + boundary_ratio (`float`, *optional*): + Ratio for switching between transformers in two-stage denoising. + """ + + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + pipeline_config: WanPipelineConfig, + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + vae: AutoencoderKLWan, + scheduler: FlowMatchEulerDiscreteScheduler, + transformer: Optional[WanVACETransformer3DModel] = None, + transformer_2: Optional[WanVACETransformer3DModel] = None, + boundary_ratio: Optional[float] = None, + ): + super().__init__(pipeline_config) + + self.tokenizer = tokenizer + self.text_encoder = text_encoder + self.vae = vae + self.transformer = transformer + self.transformer_2 = transformer_2 + self.scheduler = scheduler + self.boundary_ratio = boundary_ratio + + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if self.vae is not None else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if self.vae is not None else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + active_transformer = transformer if transformer is not None else transformer_2 + head_dim = active_transformer.config.attention_head_dim + self.attn_backend = get_attn_backend( + head_size=head_dim, + attn_type=pipeline_config.attn_type, + ) + + @classmethod + def from_pretrained(cls, model_path_or_config: str | WanPipelineConfig): + if isinstance(model_path_or_config, str): + pipeline_config = WanPipelineConfig(model_path=model_path_or_config) + else: + pipeline_config = model_path_or_config + + if not os.path.exists(pipeline_config.model_path): + raise FileNotFoundError(f"Model path not found: {pipeline_config.model_path}") + + model_index_path = os.path.join(pipeline_config.model_path, "model_index.json") + model_index = {} + boundary_ratio = None + if os.path.exists(model_index_path): + with open(model_index_path, "r") as f: + model_index = json.load(f) + boundary_ratio = model_index.get("boundary_ratio", None) + if boundary_ratio is not None: + logger.info(f"Loaded boundary_ratio={boundary_ratio} from model_index.json") + + transformer = cls.init_transformer(pipeline_config) + + transformer_2 = None + if "transformer_2" in model_index and model_index["transformer_2"] is not None: + transformer_2_subfolder = "transformer_2" + if os.path.isdir(os.path.join(pipeline_config.model_path, transformer_2_subfolder)): + transformer_2 = cls.init_transformer(pipeline_config, subfolder=transformer_2_subfolder) + logger.info( + f"Loaded transformer_2 from `{transformer_2_subfolder}` subfolder of {pipeline_config.model_path}." + ) + else: + logger.warning( + f"transformer_2 declared in model_index.json but subfolder " + f"'{transformer_2_subfolder}' not found in {pipeline_config.model_path}. Skipping." + ) + + scheduler_config_path = os.path.join(pipeline_config.model_path, "scheduler", SCHEDULER_CONFIG_NAME) + scheduler_cls = FlowMatchEulerDiscreteScheduler + if os.path.exists(scheduler_config_path): + with open(scheduler_config_path, "r") as f: + scheduler_config_dict = json.load(f) + class_name = scheduler_config_dict.get("_class_name", None) + if class_name is not None: + try: + from diffusers import schedulers as schedulers_module + + scheduler_cls = getattr(schedulers_module, class_name) + logger.info(f"Using scheduler class from config: {class_name}") + except AttributeError: + logger.warning( + f"Scheduler class '{class_name}' not found in diffusers.schedulers, " + f"falling back to FlowMatchEulerDiscreteScheduler" + ) + scheduler = scheduler_cls.from_pretrained(pipeline_config.model_path, subfolder="scheduler") + + vae = cls.init_vae(pipeline_config) + text_encoder = cls.init_text_encoder(pipeline_config) + tokenizer = AutoTokenizer.from_pretrained(pipeline_config.model_path, subfolder="tokenizer") + + return cls( + pipeline_config=pipeline_config, + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + transformer_2=transformer_2, + scheduler=scheduler, + boundary_ratio=boundary_ratio, + ) + + @staticmethod + def init_transformer( + pipeline_config: WanPipelineConfig, empty_weights: bool = False, subfolder: str = "transformer" + ): + logger.info(f"Initializing VACE transformer from subfolder={subfolder}...") + with set_forward_context(attn_type=pipeline_config.attn_type): + if empty_weights: + with init_empty_weights(): + config_dict = WanVACETransformer3DModel.load_config( + pipeline_config.model_path, + subfolder=subfolder, + local_files_only=True, + ) + model = WanVACETransformer3DModel.from_config(config_dict) + else: + model = WanVACETransformer3DModel.from_pretrained( + pipeline_config.model_path, + subfolder=subfolder, + device=pipeline_config.device, + dtype=pipeline_config.model_dtype, + ) + return model + + @staticmethod + def init_text_encoder(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing text encoder...") + if empty_weights: + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + return model + + state_dict = load_model_weights( + pipeline_config.model_path, + subfolder="text_encoder", + device=pipeline_config.device, + dtype=pipeline_config.text_encoder_dtype, + ) + with init_empty_weights(): + model = UMT5EncoderModel.from_pretrained( + pipeline_config.model_path, + subfolder="text_encoder", + local_files_only=True, + ) + + if "shared.weight" in state_dict and "encoder.embed_tokens.weight" not in state_dict: + state_dict["encoder.embed_tokens.weight"] = state_dict["shared.weight"] + + model.load_state_dict(state_dict, strict=False, assign=True) + model.to(device=pipeline_config.device) + return model + + @staticmethod + def init_vae(pipeline_config: WanPipelineConfig, empty_weights: bool = False): + logger.info("Initializing VAE...") + if empty_weights: + with init_empty_weights(): + config_dict = AutoencoderKLWan.load_config( + pipeline_config.model_path, + subfolder="vae", + local_files_only=True, + ) + model = AutoencoderKLWan.from_config(config_dict) + return model + + model = AutoencoderKLWan.from_pretrained( + pipeline_config.model_path, + subfolder="vae", + device=pipeline_config.device, + dtype=pipeline_config.vae_dtype, + ) + return model + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self.device + dtype = dtype or self.pipeline_config.text_encoder_dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(u) for u in prompt] + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask + seq_lens = mask.gt(0).sum(dim=1).long() + + prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self.device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + video=None, + mask=None, + reference_images=None, + guidance_scale_2=None, + ): + if self.transformer is not None: + base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1] + elif self.transformer_2 is not None: + base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1] + else: + raise ValueError( + "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline" + ) + + if height % base != 0 or width % base != 0: + raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found " + f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if self.boundary_ratio is None and guidance_scale_2 is not None: + raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: " + f"{negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + if video is not None: + if mask is not None: + if len(video) != len(mask): + raise ValueError( + f"Length of `video` {len(video)} and `mask` {len(mask)} do not match. Please make sure that" + " they have the same length." + ) + if reference_images is not None: + is_pil_image = isinstance(reference_images, PIL.Image.Image) + is_list_of_pil_images = isinstance(reference_images, list) and all( + isinstance(ref_img, PIL.Image.Image) for ref_img in reference_images + ) + is_list_of_list_of_pil_images = isinstance(reference_images, list) and all( + isinstance(ref_img, list) and all(isinstance(r, PIL.Image.Image) for r in ref_img) + for ref_img in reference_images + ) + if not (is_pil_image or is_list_of_pil_images or is_list_of_list_of_pil_images): + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + f"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + if is_list_of_list_of_pil_images and len(reference_images) != 1: + raise ValueError( + "The pipeline only supports generating one video at a time. When passing a list " + "of list of reference images, please make sure to only pass one inner list." + ) + elif mask is not None: + raise ValueError("`mask` can only be passed if `video` is passed as well.") + + def preprocess_conditions( + self, + video: Optional[List] = None, + mask: Optional[List] = None, + reference_images: Optional[Union[PIL.Image.Image, List[PIL.Image.Image], List[List[PIL.Image.Image]]]] = None, + batch_size: int = 1, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + if video is not None: + base = self.vae_scale_factor_spatial * ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) + video_height, video_width = self.video_processor.get_default_height_width(video[0]) + + if video_height * video_width > height * width: + scale = min(width / video_width, height / video_height) + video_height, video_width = int(video_height * scale), int(video_width * scale) + + if video_height % base != 0 or video_width % base != 0: + logger.warning( + f"Video height and width should be divisible by {base}, but got {video_height} and {video_width}." + ) + video_height = (video_height // base) * base + video_width = (video_width // base) * base + + assert video_height * video_width <= height * width + + video = self.video_processor.preprocess_video(video, video_height, video_width) + image_size = (video_height, video_width) + else: + video = torch.zeros(batch_size, 3, num_frames, height, width, dtype=dtype, device=device) + image_size = (height, width) + + if mask is not None: + mask = self.video_processor.preprocess_video(mask, image_size[0], image_size[1]) + mask = torch.clamp((mask + 1) / 2, min=0, max=1) + else: + mask = torch.ones_like(video) + + video = video.to(dtype=dtype, device=device) + mask = mask.to(dtype=dtype, device=device) + + # Normalize reference_images to list of list format + if reference_images is None or isinstance(reference_images, PIL.Image.Image): + reference_images = [[reference_images] for _ in range(video.shape[0])] + elif isinstance(reference_images, (list, tuple)) and isinstance(next(iter(reference_images)), PIL.Image.Image): + reference_images = [reference_images] + elif ( + isinstance(reference_images, (list, tuple)) + and isinstance(next(iter(reference_images)), list) + and isinstance(next(iter(reference_images[0])), PIL.Image.Image) + ): + reference_images = reference_images + else: + raise ValueError( + "`reference_images` has to be of type `PIL.Image.Image` or `list` of `PIL.Image.Image`, or " + f"`list` of `list` of `PIL.Image.Image`, but is {type(reference_images)}" + ) + + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` " + f"{len(reference_images)} does not match." + ) + + ref_images_lengths = [len(batch) for batch in reference_images] + if any(length != ref_images_lengths[0] for length in ref_images_lengths): + raise ValueError( + f"All batches of `reference_images` should have the same length, but got {ref_images_lengths}." + ) + + reference_images_preprocessed = [] + for reference_images_batch in reference_images: + preprocessed_images = [] + for image in reference_images_batch: + if image is None: + continue + image = self.video_processor.preprocess(image, None, None) + img_height, img_width = image.shape[-2:] + scale = min(image_size[0] / img_height, image_size[1] / img_width) + new_height, new_width = int(img_height * scale), int(img_width * scale) + resized_image = torch.nn.functional.interpolate( + image, size=(new_height, new_width), mode="bilinear", align_corners=False + ).squeeze(0) + top = (image_size[0] - new_height) // 2 + left = (image_size[1] - new_width) // 2 + canvas = torch.ones(3, *image_size, device=device, dtype=dtype) + canvas[:, top : top + new_height, left : left + new_width] = resized_image + preprocessed_images.append(canvas) + reference_images_preprocessed.append(preprocessed_images) + + return video, mask, reference_images_preprocessed + + def prepare_video_latents( + self, + video: torch.Tensor, + mask: torch.Tensor, + reference_images: Optional[List[List[torch.Tensor]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + device = device or self.device + + if isinstance(generator, list): + raise ValueError("Passing a list of generators is not yet supported.") + + if reference_images is None: + reference_images = [[None] for _ in range(video.shape[0])] + else: + if video.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `video` {video.shape[0]} and length of `reference_images` " + f"{len(reference_images)} does not match." + ) + + if video.shape[0] != 1: + raise ValueError("Generating with more than one video is not yet supported.") + + vae_dtype = self.pipeline_config.vae_dtype + video = video.to(dtype=vae_dtype) + + latents_mean = torch.tensor(self.vae.config.latents_mean, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=device, dtype=torch.float32).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ) + + if mask is None: + latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0) + latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype) + else: + mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype) + inactive = video * (1 - mask) + reactive = video * mask + inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax") + reactive = retrieve_latents(self.vae.encode(reactive), generator, sample_mode="argmax") + inactive = ((inactive.float() - latents_mean) * latents_std).to(vae_dtype) + reactive = ((reactive.float() - latents_mean) * latents_std).to(vae_dtype) + latents = torch.cat([inactive, reactive], dim=1) + + latent_list = [] + for latent, reference_images_batch in zip(latents, reference_images): + for reference_image in reference_images_batch: + assert reference_image.ndim == 3 + reference_image = reference_image.to(dtype=vae_dtype) + reference_image = reference_image[None, :, None, :, :] + reference_latent = retrieve_latents(self.vae.encode(reference_image), generator, sample_mode="argmax") + reference_latent = ((reference_latent.float() - latents_mean) * latents_std).to(vae_dtype) + reference_latent = reference_latent.squeeze(0) + reference_latent = torch.cat([reference_latent, torch.zeros_like(reference_latent)], dim=0) + latent = torch.cat([reference_latent.squeeze(0), latent], dim=1) + latent_list.append(latent) + return torch.stack(latent_list) + + def prepare_masks( + self, + mask: torch.Tensor, + reference_images: Optional[List[List[torch.Tensor]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ) -> torch.Tensor: + if isinstance(generator, list): + raise ValueError("Passing a list of generators is not yet supported.") + + if reference_images is None: + reference_images = [[None] for _ in range(mask.shape[0])] + else: + if mask.shape[0] != len(reference_images): + raise ValueError( + f"Batch size of `mask` {mask.shape[0]} and length of `reference_images` " + f"{len(reference_images)} does not match." + ) + + if mask.shape[0] != 1: + raise ValueError("Generating with more than one video is not yet supported.") + + transformer_patch_size = ( + self.transformer.config.patch_size[1] + if self.transformer is not None + else self.transformer_2.config.patch_size[1] + ) + + mask_list = [] + for mask_, reference_images_batch in zip(mask, reference_images): + num_channels, num_frames, height, width = mask_.shape + new_num_frames = (num_frames + self.vae_scale_factor_temporal - 1) // self.vae_scale_factor_temporal + new_height = height // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size + new_width = width // (self.vae_scale_factor_spatial * transformer_patch_size) * transformer_patch_size + mask_ = mask_[0, :, :, :] + mask_ = mask_.view( + num_frames, new_height, self.vae_scale_factor_spatial, new_width, self.vae_scale_factor_spatial + ) + mask_ = mask_.permute(2, 4, 0, 1, 3).flatten(0, 1) + mask_ = torch.nn.functional.interpolate( + mask_.unsqueeze(0), size=(new_num_frames, new_height, new_width), mode="nearest-exact" + ).squeeze(0) + num_ref_images = len(reference_images_batch) + if num_ref_images > 0: + mask_padding = torch.zeros_like(mask_[:, :num_ref_images, :, :]) + mask_ = torch.cat([mask_padding, mask_], dim=1) + mask_list.append(mask_) + return torch.stack(mask_list) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_channels_latents, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @property + def attention_kwargs(self): + return self._attention_kwargs + + def _build_attn_metadata(self, attn_params): + if attn_params is None: + return None + + builder_cls = self.attn_backend.get_builder_cls() + builder = builder_cls() + attn_params_dict = attn_params.to_dict() + attn_metadata = builder.build(**attn_params_dict) + return attn_metadata + + def _predict_noise_with_cfg( + self, + latent_model_input: torch.Tensor, + timestep: torch.Tensor, + prompt_embeds: torch.Tensor, + negative_prompt_embeds: torch.Tensor, + control_hidden_states: torch.Tensor, + control_hidden_states_scale: torch.Tensor, + attn_metadata, + apply_cfg: bool, + guidance_scale: float, + use_cfg_parallel: bool, + batch_size: int, + model: Optional[WanVACETransformer3DModel] = None, + ): + """ + Predict noise with optional classifier-free guidance and CFG parallelism. + + Args: + latent_model_input: The model input latents. + timestep: Current timestep tensor. + prompt_embeds: Positive prompt embeddings tensor. + negative_prompt_embeds: Negative prompt embeddings tensor. + control_hidden_states: VACE conditioning latents. + control_hidden_states_scale: Per-layer scale for VACE conditioning. + attn_metadata: Attention metadata for set_forward_context. + apply_cfg: Whether to apply classifier-free guidance this step. + guidance_scale: The CFG scale factor. + use_cfg_parallel: Whether to use CFG parallelism across devices. + batch_size: The actual batch size. + model: The transformer model to use. If None, defaults to self.transformer. + + Returns: + noise_pred: The predicted noise tensor. + """ + if model is None: + model = self.transformer + + if not apply_cfg: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + control_hidden_states=control_hidden_states, + control_hidden_states_scale=control_hidden_states_scale, + return_dict=False, + )[0] + return noise_pred.float() + + # CFG mode + cfg_group, cfg_rank = None, None + if use_cfg_parallel: + if not model_parallel_is_initialized(): + raise RuntimeError("Model parallel groups must be initialized when use_cfg_parallel=True") + cfg_group = get_cfg_group() + cfg_rank = cfg_group.rank_in_group + + noise_pred_pos = torch.zeros_like(latent_model_input, dtype=torch.float32) + noise_pred_neg = torch.zeros_like(latent_model_input, dtype=torch.float32) + + # Positive prompt forward pass + if not (use_cfg_parallel and cfg_rank != 0): + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_pos = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + control_hidden_states=control_hidden_states, + control_hidden_states_scale=control_hidden_states_scale, + return_dict=False, + )[0].float() + + # Negative prompt forward pass + if not use_cfg_parallel or cfg_rank != 0: + with set_forward_context(attn_metadata=attn_metadata): + noise_pred_neg = model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + control_hidden_states=control_hidden_states, + control_hidden_states_scale=control_hidden_states_scale, + return_dict=False, + )[0].float() + + # All-reduce for CFG parallel + if use_cfg_parallel: + noise_pred_pos = cfg_group.all_reduce(noise_pred_pos) + noise_pred_neg = cfg_group.all_reduce(noise_pred_neg) + + # Apply CFG + noise_pred = noise_pred_neg + guidance_scale * (noise_pred_pos - noise_pred_neg) + return noise_pred + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + video: Optional[List] = None, + mask: Optional[List] = None, + reference_images: Optional[List] = None, + conditioning_scale: Union[float, List[float], torch.Tensor] = 1.0, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + guidance_scale_2: Optional[float] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None]]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. + video (`List[PIL.Image.Image]`, *optional*): + The input video frames for conditioning. + mask (`List[PIL.Image.Image]`, *optional*): + The input mask defining conditioning vs generation regions. + reference_images (`List[PIL.Image.Image]`, *optional*): + Reference images for extra conditioning. + conditioning_scale (`float`, `List[float]`, `torch.Tensor`, defaults to `1.0`): + The conditioning scale for VACE control layers. + height (`int`, defaults to `480`): + The height in pixels of the generated video. + width (`int`, defaults to `832`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale for classifier-free guidance. + guidance_scale_2 (`float`, *optional*): + Guidance scale for the low-noise stage transformer. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + Random generator(s) for deterministic generation. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"np"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a `WanPipelineOutput` instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + Kwargs passed to the attention processor. + callback_on_step_end (`Callable`, *optional*): + A function called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + Tensor inputs for the callback function. + max_sequence_length (`int`, defaults to `512`): + Maximum sequence length for the text encoder. + + Returns: + `WanPipelineOutput` or `tuple`: Generated video frames. + """ + # 1. Check inputs + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + video, + mask, + reference_images, + guidance_scale_2, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. " + "Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + if self.boundary_ratio is not None and guidance_scale_2 is None: + guidance_scale_2 = guidance_scale + + self._guidance_scale = guidance_scale + self._guidance_scale_2 = guidance_scale_2 + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self.device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + transformer_dtype = self.pipeline_config.model_dtype + + vace_layers = ( + self.transformer.config.vace_layers + if self.transformer is not None + else self.transformer_2.config.vace_layers + ) + if isinstance(conditioning_scale, (int, float)): + conditioning_scale = [conditioning_scale] * len(vace_layers) + if isinstance(conditioning_scale, list): + if len(conditioning_scale) != len(vace_layers): + raise ValueError( + f"Length of `conditioning_scale` {len(conditioning_scale)} does not match " + f"number of layers {len(vace_layers)}." + ) + conditioning_scale = torch.tensor(conditioning_scale) + if isinstance(conditioning_scale, torch.Tensor): + if conditioning_scale.size(0) != len(vace_layers): + raise ValueError( + f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match " + f"number of layers {len(vace_layers)}." + ) + conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype) + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + video, mask, reference_images = self.preprocess_conditions( + video, + mask, + reference_images, + batch_size, + height, + width, + num_frames, + torch.float32, + device, + ) + num_reference_images = len(reference_images[0]) + + conditioning_latents = self.prepare_video_latents(video, mask, reference_images, generator, device) + mask = self.prepare_masks(mask, reference_images, generator) + conditioning_latents = torch.cat([conditioning_latents, mask], dim=1) + conditioning_latents = conditioning_latents.to(transformer_dtype) + + num_channels_latents = ( + self.transformer.config.in_channels + if self.transformer is not None + else self.transformer_2.config.in_channels + ) + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames + num_reference_images * self.vae_scale_factor_temporal, + torch.float32, + device, + generator, + latents, + ) + + if conditioning_latents.shape[2] != latents.shape[2]: + logger.warning( + "The number of frames in the conditioning latents does not match the number of frames " + "to be generated. Generation quality may be affected." + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + if self.boundary_ratio is not None: + boundary_timestep = self.boundary_ratio * self.scheduler.config.num_train_timesteps + else: + boundary_timestep = None + + actual_batch_size = batch_size * num_videos_per_prompt + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + + if boundary_timestep is None or t >= boundary_timestep: + current_model = self.transformer + current_guidance_scale = guidance_scale + else: + current_model = self.transformer_2 + current_guidance_scale = guidance_scale_2 + + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + attn_metadata = self._build_attn_metadata(self.pipeline_config.attn_params) + + noise_pred = self._predict_noise_with_cfg( + latent_model_input=latent_model_input, + timestep=timestep, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + control_hidden_states=conditioning_latents, + control_hidden_states_scale=conditioning_scale, + attn_metadata=attn_metadata, + apply_cfg=self.do_classifier_free_guidance, + guidance_scale=current_guidance_scale, + use_cfg_parallel=self.pipeline_config.use_cfg_parallel, + batch_size=actual_batch_size, + model=current_model, + ) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + self._current_timestep = None + + if not output_type == "latent": + latents = latents[:, :, num_reference_images:] + latents = latents.to(self.pipeline_config.vae_dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) diff --git a/diffsynth_engine/utils/video.py b/diffsynth_engine/utils/video.py index ce00a36..5eceebf 100644 --- a/diffsynth_engine/utils/video.py +++ b/diffsynth_engine/utils/video.py @@ -35,11 +35,15 @@ def save_video(frames, save_path, fps=15): elif save_path.endswith(".mp4"): codec = "libx264" - frames = [np.array(img) for img in frames] + converted_frames = [] + for img in frames: + arr = np.array(img) + if arr.dtype != np.uint8: + arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8) + converted_frames.append(arr) - # 使用 imageio 写入 .webm 文件 with iio.imopen(save_path, "w", plugin="FFMPEG") as writer: - writer.write(frames, fps=fps, codec=codec) + writer.write(converted_frames, fps=fps, codec=codec) def read_n_frames( diff --git a/examples/input/wan_22_animate_face.mp4 b/examples/input/wan_22_animate_face.mp4 new file mode 100644 index 0000000..622d6df Binary files /dev/null and b/examples/input/wan_22_animate_face.mp4 differ diff --git a/examples/input/wan_22_animate_input.png b/examples/input/wan_22_animate_input.png new file mode 100644 index 0000000..a1b85f7 Binary files /dev/null and b/examples/input/wan_22_animate_input.png differ diff --git a/examples/input/wan_22_animate_pose.mp4 b/examples/input/wan_22_animate_pose.mp4 new file mode 100644 index 0000000..af36f5e Binary files /dev/null and b/examples/input/wan_22_animate_pose.mp4 differ diff --git a/examples/input/wan_22_i2v_input.png b/examples/input/wan_22_i2v_input.png new file mode 100644 index 0000000..6a558eb Binary files /dev/null and b/examples/input/wan_22_i2v_input.png differ diff --git a/examples/input/wan_vace_first_frame.png b/examples/input/wan_vace_first_frame.png new file mode 100644 index 0000000..032cd5c Binary files /dev/null and b/examples/input/wan_vace_first_frame.png differ diff --git a/examples/input/wan_vace_last_frame.png b/examples/input/wan_vace_last_frame.png new file mode 100644 index 0000000..83ac8c5 Binary files /dev/null and b/examples/input/wan_vace_last_frame.png differ diff --git a/examples/wan/wan_22_animate.py b/examples/wan/wan_22_animate.py new file mode 100644 index 0000000..c41fc5d --- /dev/null +++ b/examples/wan/wan_22_animate.py @@ -0,0 +1,60 @@ +import torch +from diffusers.utils import export_to_video, load_video +from PIL import Image + +from diffsynth_engine.pipelines.wan import WanAnimatePipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.2-Animate-14B-Diffusers") + pipe = WanAnimatePipeline.from_pretrained(model_path) + + # Load the reference character image + image = Image.open("examples/input/wan_22_animate_input.png") + + # Load pose and face conditioning videos (preprocessed from a reference video) + pose_video = load_video("examples/input/wan_22_animate_pose.mp4") + face_video = load_video("examples/input/wan_22_animate_face.mp4") + + prompt = "People in the video are doing actions." + + # ---- Animate mode ---- + video = pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt=prompt, + mode="animate", + segment_frame_length=77, + prev_segment_conditioning_frames=1, + guidance_scale=1.0, + num_inference_steps=20, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + export_to_video(video.frames[0], "animated_output.mp4", fps=30) + + # ---- Replace mode (optional) ---- + # In replace mode, an additional background_video and mask_video are required. + # background_video: the original video whose character will be replaced. + # mask_video: grayscale masks indicating the region to replace (white = replace). + # + # background_video = load_video("examples/input/wan_22_animate_background.mp4") + # mask_video = load_video("examples/input/wan_22_animate_mask.mp4") + # + # video_replace = pipe( + # image=image, + # pose_video=pose_video, + # face_video=face_video, + # background_video=background_video, + # mask_video=mask_video, + # prompt=prompt, + # mode="replace", + # segment_frame_length=77, + # prev_segment_conditioning_frames=1, + # guidance_scale=1.0, + # num_inference_steps=20, + # generator=torch.Generator(device="cpu").manual_seed(42), + # ) + # + # export_to_video(video_replace.frames[0], "animated_output_replace.mp4", fps=30) diff --git a/examples/wan/wan_22_image_to_video.py b/examples/wan/wan_22_image_to_video.py new file mode 100644 index 0000000..ac50229 --- /dev/null +++ b/examples/wan/wan_22_image_to_video.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +from diffusers.utils import export_to_video +from PIL import Image + +from diffsynth_engine.pipelines.wan import WanImageToVideoPipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.2-I2V-A14B-Diffusers") + pipe = WanImageToVideoPipeline.from_pretrained(model_path) + + image = Image.open("examples/input/wan_22_i2v_input.png") + max_area = 480 * 832 + aspect_ratio = image.height / image.width + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + + prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + guidance_scale=3.5, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + export_to_video(video.frames[0], "wan_22_i2v.mp4", fps=16) diff --git a/examples/wan/wan_22_text_to_video.py b/examples/wan/wan_22_text_to_video.py new file mode 100644 index 0000000..9e62928 --- /dev/null +++ b/examples/wan/wan_22_text_to_video.py @@ -0,0 +1,23 @@ +import torch +from diffusers.utils import export_to_video + +from diffsynth_engine.pipelines.wan import WanTextToVideoPipeline +from diffsynth_engine.utils.download import fetch_model + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + pipe = WanTextToVideoPipeline.from_pretrained(model_path) + + video = pipe( + prompt="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_frames=81, + width=1280, + height=720, + guidance_scale=4.0, + guidance_scale_2=3.0, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + export_to_video(video.frames[0], "wan_22_t2v.mp4", fps=16) diff --git a/examples/wan/wan_vace.py b/examples/wan/wan_vace.py new file mode 100644 index 0000000..eaf4648 --- /dev/null +++ b/examples/wan/wan_vace.py @@ -0,0 +1,73 @@ +import PIL.Image +import torch +from diffusers.schedulers import UniPCMultistepScheduler +from diffusers.utils import export_to_video, load_image + +from diffsynth_engine.pipelines.wan import WanVACEPipeline +from diffsynth_engine.utils.download import fetch_model + + +def prepare_video_and_mask( + first_img: PIL.Image.Image, + last_img: PIL.Image.Image, + height: int, + width: int, + num_frames: int, +): + first_img = first_img.resize((width, height)) + last_img = last_img.resize((width, height)) + frames = [first_img] + frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2)) + frames.append(last_img) + mask_black = PIL.Image.new("L", (width, height), 0) + mask_white = PIL.Image.new("L", (width, height), 255) + mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black] + return frames, mask + + +if __name__ == "__main__": + model_path = fetch_model("Wan-AI/Wan2.1-VACE-14B-diffusers") + pipe = WanVACEPipeline.from_pretrained(model_path) + + # Set flow_shift to 5.0 for 720P (use 3.0 for 480P) + flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) + + # Load the first and last frame images + first_frame = load_image("examples/input/wan_vace_first_frame.png") + last_frame = load_image("examples/input/wan_vace_last_frame.png") + + prompt = ( + "CG animation style, a small blue bird takes off from the ground, flapping its wings. " + "The bird's feathers are delicate, with a unique pattern on its chest. " + "The background shows a blue sky with white clouds under bright sunshine. " + "The camera follows the bird upward, capturing its flight and the vastness of the sky " + "from a close-up, low-angle perspective." + ) + negative_prompt = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " + "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " + "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " + "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " + "in the background, walking backwards" + ) + + height = 512 + width = 512 + num_frames = 81 + video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames) + + output = pipe( + video=video, + mask=mask, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=30, + guidance_scale=5.0, + generator=torch.Generator().manual_seed(42), + ) + + export_to_video(output.frames[0], "wan_vace_output.mp4", fps=16) diff --git a/tests/common/test_case.py b/tests/common/test_case.py index aa97998..13528ff 100644 --- a/tests/common/test_case.py +++ b/tests/common/test_case.py @@ -9,7 +9,7 @@ from diffsynth_engine.utils.load_utils import load_file from diffsynth_engine.utils.video import VideoReader, load_video, save_video -from tests.common.utils import compute_normalized_ssim +from tests.common.utils import compute_normalized_ssim, compute_video_ms_ssim TEST_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # test flags @@ -109,3 +109,22 @@ def assertVideoEqualAndSaveFailed( name = expect_video_path.split("/")[-1] self.save_video(input_video, name, fps=fps) raise e + + def assertVideoMsSsimEqual(self, input_video: List[Image.Image], expect_video: List[Image.Image], threshold=0.95): + ms_ssim_score = compute_video_ms_ssim(input_video, expect_video) + self.assertGreaterEqual(ms_ssim_score, threshold) + + def assertVideoMsSsimEqualAndSaveFailed( + self, input_video: List[Image.Image], expect_video_path: str, threshold=0.95, fps: int = 15 + ): + """ + 比较input_video和testdata/expect/{name}的MS-SSIM相似度,如果失败则保存input_video到当前工作目录 + """ + try: + expect_video = self.get_expect_video(expect_video_path) + expect_frames = [expect_video[i] for i in range(len(expect_video))] + self.assertVideoMsSsimEqual(input_video, expect_frames, threshold=threshold) + except Exception as e: + name = expect_video_path.split("/")[-1] + self.save_video(input_video, name, fps=fps) + raise e diff --git a/tests/common/utils.py b/tests/common/utils.py index f4d5ba3..00f2e01 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -1,4 +1,7 @@ +from typing import List + import numpy as np +import torch from PIL import Image from skimage.metrics import structural_similarity @@ -14,3 +17,41 @@ def compute_normalized_ssim(image1: Image.Image, image2: Image.Image): ssim_normalized = (ssim + 1) / 2 return ssim_normalized + + +def compute_video_ms_ssim( + input_frames: List[Image.Image], + expect_frames: List[Image.Image], +) -> float: + """Compute the mean MS-SSIM score between two frame sequences. + + Each frame is converted to a ``[1, C, H, W]`` float tensor in ``[0, 1]`` + and scored with ``MultiScaleStructuralSimilarityIndexMeasure``. The + returned value is the average MS-SSIM across all frame pairs. + """ + from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure + + ms_ssim_metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) + + scores: List[float] = [] + for pred_frame, target_frame in zip(input_frames, expect_frames): + pred_array = np.array(pred_frame).astype(np.float32) + target_array = np.array(target_frame).astype(np.float32) + + # Normalize to [0, 1]: only divide by 255 when the data is in uint8 range + if pred_array.max() > 1.0: + pred_array = pred_array / 255.0 + if target_array.max() > 1.0: + target_array = target_array / 255.0 + + pred_tensor = torch.from_numpy(pred_array) + target_tensor = torch.from_numpy(target_array) + + # [H, W, C] -> [1, C, H, W] + pred_tensor = pred_tensor.permute(2, 0, 1).unsqueeze(0) + target_tensor = target_tensor.permute(2, 0, 1).unsqueeze(0) + + score = ms_ssim_metric(pred_tensor, target_tensor) + scores.append(score.item()) + + return float(np.mean(scores)) diff --git a/tests/data/expect/wan/wan_22_animate.mp4 b/tests/data/expect/wan/wan_22_animate.mp4 new file mode 100644 index 0000000..6f5e809 Binary files /dev/null and b/tests/data/expect/wan/wan_22_animate.mp4 differ diff --git a/tests/data/expect/wan/wan_22_i2v.mp4 b/tests/data/expect/wan/wan_22_i2v.mp4 new file mode 100644 index 0000000..4c064a3 Binary files /dev/null and b/tests/data/expect/wan/wan_22_i2v.mp4 differ diff --git a/tests/data/expect/wan/wan_22_t2v.mp4 b/tests/data/expect/wan/wan_22_t2v.mp4 new file mode 100644 index 0000000..976ded2 Binary files /dev/null and b/tests/data/expect/wan/wan_22_t2v.mp4 differ diff --git a/tests/data/expect/wan/wan_vace.mp4 b/tests/data/expect/wan/wan_vace.mp4 new file mode 100644 index 0000000..94779cc Binary files /dev/null and b/tests/data/expect/wan/wan_vace.mp4 differ diff --git a/tests/data/input/wan_22_animate_face.mp4 b/tests/data/input/wan_22_animate_face.mp4 new file mode 100644 index 0000000..622d6df Binary files /dev/null and b/tests/data/input/wan_22_animate_face.mp4 differ diff --git a/tests/data/input/wan_22_animate_input.png b/tests/data/input/wan_22_animate_input.png new file mode 100644 index 0000000..a1b85f7 Binary files /dev/null and b/tests/data/input/wan_22_animate_input.png differ diff --git a/tests/data/input/wan_22_animate_pose.mp4 b/tests/data/input/wan_22_animate_pose.mp4 new file mode 100644 index 0000000..af36f5e Binary files /dev/null and b/tests/data/input/wan_22_animate_pose.mp4 differ diff --git a/tests/data/input/wan_22_i2v_input.png b/tests/data/input/wan_22_i2v_input.png new file mode 100644 index 0000000..6a558eb Binary files /dev/null and b/tests/data/input/wan_22_i2v_input.png differ diff --git a/tests/data/input/wan_vace_first_frame.png b/tests/data/input/wan_vace_first_frame.png new file mode 100644 index 0000000..032cd5c Binary files /dev/null and b/tests/data/input/wan_vace_first_frame.png differ diff --git a/tests/data/input/wan_vace_last_frame.png b/tests/data/input/wan_vace_last_frame.png new file mode 100644 index 0000000..83ac8c5 Binary files /dev/null and b/tests/data/input/wan_vace_last_frame.png differ diff --git a/tests/test_pipelines/test_wan_22_animate.py b/tests/test_pipelines/test_wan_22_animate.py new file mode 100644 index 0000000..b5f9f6e --- /dev/null +++ b/tests/test_pipelines/test_wan_22_animate.py @@ -0,0 +1,48 @@ +import unittest + +import torch + +from diffsynth_engine.pipelines.wan import WanAnimatePipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +class TestWan22AnimatePipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.2-Animate-14B-Diffusers") + cls.pipe = WanAnimatePipeline.from_pretrained(model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_animate(self): + image = self.get_input_image("wan_22_animate_input.png") + + pose_video_reader = self.get_input_video("wan_22_animate_pose.mp4") + face_video_reader = self.get_input_video("wan_22_animate_face.mp4") + pose_video = [pose_video_reader[i] for i in range(len(pose_video_reader))] + face_video = [face_video_reader[i] for i in range(len(face_video_reader))] + + prompt = "People in the video are doing actions." + + video = self.pipe( + image=image, + pose_video=pose_video, + face_video=face_video, + prompt=prompt, + mode="animate", + segment_frame_length=77, + prev_segment_conditioning_frames=1, + guidance_scale=1.0, + num_inference_steps=20, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = video.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_22_animate.mp4", threshold=0.98, fps=30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_wan_22_image_to_video.py b/tests/test_pipelines/test_wan_22_image_to_video.py new file mode 100644 index 0000000..86d0960 --- /dev/null +++ b/tests/test_pipelines/test_wan_22_image_to_video.py @@ -0,0 +1,50 @@ +import unittest + +import numpy as np +import torch + +from diffsynth_engine.pipelines.wan import WanImageToVideoPipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +class TestWan22ImageToVideoPipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.2-I2V-A14B-Diffusers") + cls.pipe = WanImageToVideoPipeline.from_pretrained(model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_image_to_video(self): + image = self.get_input_image("wan_22_i2v_input.png") + max_area = 480 * 832 + aspect_ratio = image.height / image.width + mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] + height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value + width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value + image = image.resize((width, height)) + + prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + video = self.pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=81, + guidance_scale=3.5, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = video.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_22_i2v.mp4", threshold=0.98, fps=16) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_wan_22_text_to_video.py b/tests/test_pipelines/test_wan_22_text_to_video.py new file mode 100644 index 0000000..1b5a962 --- /dev/null +++ b/tests/test_pipelines/test_wan_22_text_to_video.py @@ -0,0 +1,43 @@ +import unittest + +import torch + +from diffsynth_engine.pipelines.wan import WanTextToVideoPipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +class TestWan22TextToVideoPipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.2-T2V-A14B-Diffusers") + cls.pipe = WanTextToVideoPipeline.from_pretrained(model_path) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_text_to_video(self): + prompt = ( + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." + ) + negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + + video = self.pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=81, + width=1280, + height=720, + guidance_scale=4.0, + guidance_scale_2=3.0, + num_inference_steps=40, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = video.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_22_t2v.mp4", threshold=0.98, fps=16) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pipelines/test_wan_vace.py b/tests/test_pipelines/test_wan_vace.py new file mode 100644 index 0000000..c795942 --- /dev/null +++ b/tests/test_pipelines/test_wan_vace.py @@ -0,0 +1,84 @@ +import unittest + +import PIL.Image +import torch +from diffusers.schedulers import UniPCMultistepScheduler + +from diffsynth_engine.pipelines.wan import WanVACEPipeline +from diffsynth_engine.utils.download import fetch_model +from tests.common.test_case import VideoTestCase + + +def prepare_video_and_mask( + first_img: PIL.Image.Image, + last_img: PIL.Image.Image, + height: int, + width: int, + num_frames: int, +): + first_img = first_img.resize((width, height)) + last_img = last_img.resize((width, height)) + frames = [first_img] + frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2)) + frames.append(last_img) + mask_black = PIL.Image.new("L", (width, height), 0) + mask_white = PIL.Image.new("L", (width, height), 255) + mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black] + return frames, mask + + +class TestWanVACEPipeline(VideoTestCase): + @classmethod + def setUpClass(cls): + model_path = fetch_model("Wan-AI/Wan2.1-VACE-14B-diffusers") + cls.pipe = WanVACEPipeline.from_pretrained(model_path) + flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P + cls.pipe.scheduler = UniPCMultistepScheduler.from_config(cls.pipe.scheduler.config, flow_shift=flow_shift) + + @classmethod + def tearDownClass(cls): + del cls.pipe + + def test_vace(self): + first_frame = self.get_input_image("wan_vace_first_frame.png") + last_frame = self.get_input_image("wan_vace_last_frame.png") + + prompt = ( + "CG animation style, a small blue bird takes off from the ground, flapping its wings. " + "The bird's feathers are delicate, with a unique pattern on its chest. " + "The background shows a blue sky with white clouds under bright sunshine. " + "The camera follows the bird upward, capturing its flight and the vastness of the sky " + "from a close-up, low-angle perspective." + ) + negative_prompt = ( + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, " + "images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, " + "incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, " + "misshapen limbs, fused fingers, still picture, messy background, three legs, many people " + "in the background, walking backwards" + ) + + height = 512 + width = 512 + num_frames = 81 + video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames) + + result = self.pipe( + video=video, + mask=mask, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=30, + guidance_scale=5.0, + generator=torch.Generator(device="cpu").manual_seed(42), + ) + + output_frames = result.frames[0] + self.assertVideoMsSsimEqualAndSaveFailed(output_frames, "wan/wan_vace.mp4", threshold=0.93, fps=16) + + +if __name__ == "__main__": + unittest.main()