From 7f2922d373e9f5abc0404fe4ea2146c588557bd7 Mon Sep 17 00:00:00 2001 From: hlky Date: Thu, 7 May 2026 14:01:10 +0000 Subject: [PATCH] Fix Wan review issues --- .ai/pipelines.md | 2 + .../modular_pipelines/wan/before_denoise.py | 22 +-- .../modular_pipelines/wan/denoise.py | 129 ++++++++++---- .../modular_pipelines/wan/modular_pipeline.py | 8 +- .../pipelines/wan/image_processor.py | 21 ++- .../pipelines/wan/pipeline_wan_animate.py | 57 ++++-- .../pipelines/wan/pipeline_wan_i2v.py | 166 +++++++++++++----- .../pipelines/wan/pipeline_wan_vace.py | 16 +- .../pipelines/wan/pipeline_wan_video2video.py | 30 ++-- .../wan/test_modular_pipeline_wan.py | 137 +++++++++++++++ tests/pipelines/wan/test_wan_animate.py | 61 +++++++ .../pipelines/wan/test_wan_image_to_video.py | 119 +++++++++++++ tests/pipelines/wan/test_wan_vace.py | 41 +++++ .../pipelines/wan/test_wan_video_to_video.py | 19 ++ 14 files changed, 688 insertions(+), 140 deletions(-) diff --git a/.ai/pipelines.md b/.ai/pipelines.md index e107639cb24b..06902c27de76 100644 --- a/.ai/pipelines.md +++ b/.ai/pipelines.md @@ -60,3 +60,5 @@ When adding a new pipeline (or reviewing one), skim `pipeline_flux.py`, `pipelin 4. **Subclassing an existing pipeline for a variant.** Don't use an existing pipeline class (e.g. `FluxPipeline`) to override another (e.g. `FluxImg2ImgPipeline`) inside the core `src/` codebase. Each pipeline lives in its own file with its own class, even if it shares 90% of `__call__` with a sibling. Convention across diffusers — flux, sdxl, wan, qwenimage — is duplicated `__call__` between img2img / text2img / inpaint variants, not subclassing. Reuse private utilities (shared schedulers, prep functions) but not the pipeline class itself. 5. **Copying a method from another pipeline without `# Copied from`.** When you reuse a method like `encode_prompt`, `prepare_latents`, `check_inputs`, or `_prepare_latent_image_ids` from another pipeline, add a `# Copied from` annotation so `make fix-copies` keeps the two in sync. Forgetting it means future refactors to the source drift away from your copy silently — and reviewers waste time spotting near-identical code that should have been linked. The annotation grammar (decorator placement, rename syntax with `with old->new`, etc.) is implemented in [`utils/check_copies.py`](../utils/check_copies.py) — read it for the exact rules. + +6. **Partial batch expansion with `num_*_per_prompt`.** When a pipeline accepts `num_images_per_prompt`, `num_videos_per_prompt`, or precomputed conditioning tensors, every per-prompt input that reaches the denoising loop must be expanded to the same effective batch size in the same prompt order. Check prompt embeds, pooled embeds, image/video embeds, masks, control latents, image latents, guidance tensors, and any added conditioning. Prefer `torch.repeat_interleave(tensor, repeats=num_per_prompt, dim=0, output_size=tensor.shape[0] * num_per_prompt)` for contiguous per-prompt grouping, and validate precomputed tensors that are already batched instead of letting a later concat fail. Avoid `tensor.repeat((num_per_prompt, ...))` for batch expansion because it interleaves prompts (`[A, B] -> [A, B, A, B]`) rather than grouping them (`[A, A, B, B]`). diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py index 398b9665522c..942cff7a8974 100644 --- a/src/diffusers/modular_pipelines/wan/before_denoise.py +++ b/src/diffusers/modular_pipelines/wan/before_denoise.py @@ -223,7 +223,7 @@ def inputs(self) -> list[InputParam]: ] @property - def intermediate_outputs(self) -> list[str]: + def intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam( "batch_size", @@ -252,22 +252,16 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe self.check_inputs(components, block_state) block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = block_state.prompt_embeds.dtype + block_state.dtype = components.transformer.dtype - _, seq_len, _ = block_state.prompt_embeds.shape - block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1) - block_state.prompt_embeds = block_state.prompt_embeds.view( - block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 - ) + block_state.prompt_embeds = block_state.prompt_embeds.repeat_interleave( + block_state.num_videos_per_prompt, dim=0 + ).to(block_state.dtype) if block_state.negative_prompt_embeds is not None: - _, seq_len, _ = block_state.negative_prompt_embeds.shape - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( - 1, block_state.num_videos_per_prompt, 1 - ) - block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( - block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1 - ) + block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat_interleave( + block_state.num_videos_per_prompt, dim=0 + ).to(block_state.dtype) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py index 2f51f353012e..60dc5ced6bf3 100644 --- a/src/diffusers/modular_pipelines/wan/denoise.py +++ b/src/diffusers/modular_pipelines/wan/denoise.py @@ -27,7 +27,7 @@ ModularPipelineBlocks, PipelineState, ) -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from .modular_pipeline import WanModularPipeline @@ -54,17 +54,21 @@ def inputs(self) -> list[InputParam]: type_hint=torch.Tensor, description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs. Can be generated in input step.", - ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latent_model_input", + type_hint=torch.Tensor, + description="Latents prepared as model hidden states for the denoiser.", + ) ] @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = block_state.latents.to(block_state.dtype) + block_state.latent_model_input = block_state.latents return components, block_state @@ -94,19 +98,21 @@ def inputs(self) -> list[InputParam]: type_hint=torch.Tensor, description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.", ), - InputParam( - "dtype", - required=True, - type_hint=torch.dtype, - description="The dtype of the model inputs. Can be generated in input step.", - ), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "latent_model_input", + type_hint=torch.Tensor, + description="Latents and image conditioning prepared as model hidden states for the denoiser.", + ) ] @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): - block_state.latent_model_input = torch.cat( - [block_state.latents, block_state.image_condition_latents], dim=1 - ).to(block_state.dtype) + block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_condition_latents], dim=1) return components, block_state @@ -165,6 +171,18 @@ def inputs(self) -> list[tuple[str, Any]]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), + InputParam( + "latent_model_input", + required=True, + type_hint=torch.Tensor, + description="Latents prepared as model hidden states for the denoiser.", + ), + InputParam( + "dtype", + required=True, + type_hint=torch.dtype, + description="The dtype of the model inputs. Can be generated in input step.", + ), ] guider_input_names = [] for value in self._guider_input_fields.values(): @@ -177,11 +195,22 @@ def inputs(self) -> list[tuple[str, Any]]: inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) return inputs + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "noise_pred", + type_hint=torch.Tensor, + description="The guided noise prediction for the current denoising step.", + ) + ] + @torch.no_grad() def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + dtype = block_state.dtype # The guider splits model inputs into separate batches for conditional/unconditional predictions. # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: @@ -198,7 +227,7 @@ def __call__( components.guider.prepare_models(components.transformer) cond_kwargs = guider_state_batch.as_dict() cond_kwargs = { - k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys() } @@ -206,8 +235,8 @@ def __call__( # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches guider_state_batch.noise_pred = components.transformer( - hidden_states=block_state.latent_model_input.to(block_state.dtype), - timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + hidden_states=block_state.latent_model_input.to(dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]), attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, @@ -292,6 +321,12 @@ def inputs(self) -> list[tuple[str, Any]]: type_hint=int, description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", ), + InputParam( + "latent_model_input", + required=True, + type_hint=torch.Tensor, + description="Latents prepared as model hidden states for the denoiser.", + ), ] guider_input_names = [] for value in self._guider_input_fields.values(): @@ -304,19 +339,30 @@ def inputs(self) -> list[tuple[str, Any]]: inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) return inputs + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [ + OutputParam( + "noise_pred", + type_hint=torch.Tensor, + description="The guided noise prediction for the current denoising step.", + ) + ] + @torch.no_grad() def __call__( self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps if t >= boundary_timestep: - block_state.current_model = components.transformer - block_state.guider = components.guider + current_model = components.transformer + guider = components.guider else: - block_state.current_model = components.transformer_2 - block_state.guider = components.guider_2 + current_model = components.transformer_2 + guider = components.guider_2 + dtype = current_model.dtype - block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) # The guider splits model inputs into separate batches for conditional/unconditional predictions. # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}: @@ -326,31 +372,31 @@ def __call__( # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch # ] # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG). - guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) + guider_state = guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) # run the denoiser for each guidance batch for guider_state_batch in guider_state: - block_state.guider.prepare_models(block_state.current_model) + guider.prepare_models(current_model) cond_kwargs = guider_state_batch.as_dict() cond_kwargs = { - k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v + k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in cond_kwargs.items() if k in self._guider_input_fields.keys() } # Predict the noise residual # store the noise_pred in guider_state_batch so that we can apply guidance across all batches - guider_state_batch.noise_pred = block_state.current_model( - hidden_states=block_state.latent_model_input.to(block_state.dtype), - timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), + guider_state_batch.noise_pred = current_model( + hidden_states=block_state.latent_model_input.to(dtype), + timestep=t.expand(block_state.latent_model_input.shape[0]), attention_kwargs=block_state.attention_kwargs, return_dict=False, **cond_kwargs, )[0] - block_state.guider.cleanup_models(block_state.current_model) + guider.cleanup_models(current_model) # Perform guidance - block_state.noise_pred = block_state.guider(guider_state)[0] + block_state.noise_pred = guider(guider_state)[0] return components, block_state @@ -372,6 +418,23 @@ def description(self) -> str: "object (e.g. `WanDenoiseLoopWrapper`)" ) + @property + def inputs(self) -> list[InputParam]: + return [ + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The latents to update with the scheduler step.", + ), + InputParam( + "noise_pred", + required=True, + type_hint=torch.Tensor, + description="The guided noise prediction for the current denoising step.", + ), + ] + @torch.no_grad() def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): # Perform scheduler step using the predicted output diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py index 0e52026a51bf..d29cfb0f44c8 100644 --- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py @@ -69,15 +69,15 @@ def patch_size_spatial(self): @property def vae_scale_factor_spatial(self): vae_scale_factor = 8 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** len(self.vae.temperal_downsample) + if getattr(self, "vae", None) is not None: + vae_scale_factor = self.vae.config.scale_factor_spatial return vae_scale_factor @property def vae_scale_factor_temporal(self): vae_scale_factor = 4 - if hasattr(self, "vae") and self.vae is not None: - vae_scale_factor = 2 ** sum(self.vae.temperal_downsample) + if getattr(self, "vae", None) is not None: + vae_scale_factor = self.vae.config.scale_factor_temporal return vae_scale_factor @property diff --git a/src/diffusers/pipelines/wan/image_processor.py b/src/diffusers/pipelines/wan/image_processor.py index fa18150fcc6e..c3ea470ccc4d 100644 --- a/src/diffusers/pipelines/wan/image_processor.py +++ b/src/diffusers/pipelines/wan/image_processor.py @@ -17,7 +17,6 @@ import PIL.Image import torch -from ...configuration_utils import register_to_config from ...image_processor import VaeImageProcessor from ...utils import PIL_INTERPOLATION @@ -53,7 +52,6 @@ class WanAnimateImageProcessor(VaeImageProcessor): if `None`, will default to filling with data from `image`. """ - @register_to_config def __init__( self, do_resize: bool = True, @@ -68,13 +66,18 @@ def __init__( do_convert_grayscale: bool = False, fill_color: str | float | tuple[float, ...] | None = 0, ): - super().__init__() - if do_convert_rgb and do_convert_grayscale: - raise ValueError( - "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," - " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", - " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", - ) + super().__init__( + do_resize=do_resize, + vae_scale_factor=vae_scale_factor, + vae_latent_channels=vae_latent_channels, + resample=resample, + reducing_gap=reducing_gap, + do_normalize=do_normalize, + do_binarize=do_binarize, + do_convert_rgb=do_convert_rgb, + do_convert_grayscale=do_convert_grayscale, + ) + self.register_to_config(spatial_patch_size=spatial_patch_size, fill_color=fill_color) def _resize_and_fill( self, diff --git a/src/diffusers/pipelines/wan/pipeline_wan_animate.py b/src/diffusers/pipelines/wan/pipeline_wan_animate.py index 5806032c0142..e54680b29ff7 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_animate.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_animate.py @@ -384,17 +384,12 @@ def check_inputs( mode=None, prev_segment_conditioning_frames=None, ): - if image is not None and image_embeds is not None: - raise ValueError( - f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" - " only forward one of the two." - ) - if image is None and image_embeds is None: - raise ValueError( - "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." - ) + if image is None: + raise ValueError("`image` must be provided for Wan Animate generation.") if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") + if image_embeds is not None and (not isinstance(image_embeds, torch.Tensor) or image_embeds.ndim != 3): + raise ValueError("`image_embeds` has to be a 3D `torch.Tensor`.") if pose_video is None: raise ValueError("Provide `pose_video`. Cannot leave `pose_video` undefined.") if face_video is None: @@ -485,6 +480,30 @@ def get_i2v_mask( return mask_lat_size + def _expand_tensor_to_effective_batch( + self, + tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int, + tensor_name: str, + ) -> torch.Tensor: + target_batch_size = batch_size * num_videos_per_prompt + + if tensor.shape[0] == target_batch_size: + return tensor + + if tensor.shape[0] == 1: + repeat_by = target_batch_size + elif tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{tensor_name}` batch size must be 1, `batch_size` ({batch_size}), or " + f"`batch_size * num_videos_per_prompt` ({target_batch_size}), but got {tensor.shape[0]}." + ) + + return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by) + def prepare_reference_image_latents( self, image: torch.Tensor, @@ -776,13 +795,13 @@ def __call__( prev_segment_conditioning_frames: int = 1, motion_encode_batch_size: int | None = None, guidance_scale: float = 1.0, - num_videos_per_prompt: int | None = 1, + num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, - output_type: str | None = "np", + output_type: str = "np", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int, None], PipelineCallback | MultiPipelineCallbacks] | None = None, @@ -841,10 +860,12 @@ def __call__( the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan Animate inference. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + The number of videos to generate per prompt. When greater than 1, prompt, image, and conditioning + batches are expanded in per-prompt order. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. + generation deterministic. If a list is passed, it must match the effective batch size + (`batch_size * num_videos_per_prompt`). latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -856,10 +877,10 @@ def __call__( Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `negative_prompt` input argument. image_embeds (`torch.Tensor`, *optional*): - Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, - image embeddings are generated from the `image` input argument. + Pre-generated CLIP image embeddings. If not provided, image embeddings are generated from the `image` + input argument. `image` is still required because it is used for VAE reference-image conditioning. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + The output format of the generated video. Choose between `"np"`, `"pt"`, `"pil"`, or `"latent"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): @@ -975,7 +996,9 @@ def __call__( # Get CLIP features from the reference image if image_embeds is None: image_embeds = self.encode_image(image, device) - image_embeds = image_embeds.repeat(batch_size * num_videos_per_prompt, 1, 1) + image_embeds = self._expand_tensor_to_effective_batch( + image_embeds, batch_size, num_videos_per_prompt, "image_embeds" + ) image_embeds = image_embeds.to(transformer_dtype) # 5. Encode conditioning videos (pose, face) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index f669e9b1d0ec..6384dc6dd87b 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -15,13 +15,12 @@ import html from typing import Any, Callable -import PIL import regex as re import torch from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import PipelineImageInput +from ...image_processor import PipelineImageInput, is_valid_image_imagelist from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -329,6 +328,71 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds + def _expand_tensor_to_effective_batch( + self, + tensor: torch.Tensor, + batch_size: int, + num_videos_per_prompt: int, + tensor_name: str, + ) -> torch.Tensor: + target_batch_size = batch_size * num_videos_per_prompt + + if tensor.shape[0] == target_batch_size: + return tensor + + if tensor.shape[0] == 1: + repeat_by = target_batch_size + elif tensor.shape[0] == batch_size: + repeat_by = num_videos_per_prompt + else: + raise ValueError( + f"`{tensor_name}` batch size must be 1, `batch_size` ({batch_size}), or " + f"`batch_size * num_videos_per_prompt` ({target_batch_size}), but got {tensor.shape[0]}." + ) + + return torch.repeat_interleave(tensor, repeats=repeat_by, dim=0, output_size=tensor.shape[0] * repeat_by) + + def _prepare_image_embeds( + self, + image: PipelineImageInput, + device: torch.device, + batch_size: int, + num_videos_per_prompt: int, + transformer_dtype: torch.dtype, + image_embeds: torch.Tensor | None = None, + last_image: PipelineImageInput | None = None, + ) -> torch.Tensor: + if image_embeds is None: + if last_image is None: + image_embeds = self.encode_image(image, device) + else: + image_embeds = self.encode_image(image, device) + last_image_embeds = self.encode_image(last_image, device) + if image_embeds.shape[0] != last_image_embeds.shape[0]: + raise ValueError( + "`image` and `last_image` must have matching batch sizes, but got " + f"{image_embeds.shape[0]} and {last_image_embeds.shape[0]}." + ) + image_embeds = torch.stack([image_embeds, last_image_embeds], dim=1).flatten(0, 1) + + if last_image is not None and self.transformer.condition_embedder.image_embedder.pos_embed is not None: + if image_embeds.shape[0] % 2 != 0: + raise ValueError( + "`image_embeds` batch size must be even when passing `last_image`, but got " + f"{image_embeds.shape[0]}." + ) + image_embeds = image_embeds.reshape(-1, 2, *image_embeds.shape[1:]) + image_embeds = self._expand_tensor_to_effective_batch( + image_embeds, batch_size, num_videos_per_prompt, "image_embeds" + ) + image_embeds = image_embeds.reshape(-1, *image_embeds.shape[2:]) + else: + image_embeds = self._expand_tensor_to_effective_batch( + image_embeds, batch_size, num_videos_per_prompt, "image_embeds" + ) + + return image_embeds.to(transformer_dtype) + def check_inputs( self, prompt, @@ -342,17 +406,16 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, guidance_scale_2=None, ): - if image is not None and image_embeds is not None: - raise ValueError( - f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to" - " only forward one of the two." - ) - if image is None and image_embeds is None: + if image is None: + raise ValueError("`image` must be provided for image-to-video generation.") + image_to_validate = list(image) if isinstance(image, tuple) else image + if isinstance(image_to_validate, list) and len(image_to_validate) == 0: + raise ValueError("`image` cannot be an empty list or tuple.") + if not is_valid_image_imagelist(image_to_validate): raise ValueError( - "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined." + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image`, list or tuple but is" + f" {type(image)}" ) - if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image): - raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}") if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -388,7 +451,12 @@ def check_inputs( raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.") if self.config.boundary_ratio is not None and image_embeds is not None: - raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.") + raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is configured.") + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor) or image_embeds.ndim != 3: + raise ValueError("`image_embeds` has to be a 3D `torch.Tensor`.") + if self.transformer is None or self.transformer.config.image_dim is None: + raise ValueError("`image_embeds` is only supported when `transformer` is configured with `image_dim`.") def prepare_latents( self, @@ -403,16 +471,18 @@ def prepare_latents( generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, last_image: torch.Tensor | None = None, + num_videos_per_prompt: int = 1, ) -> tuple[torch.Tensor, torch.Tensor]: num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial + target_batch_size = batch_size * num_videos_per_prompt - shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) - if isinstance(generator, list) and len(generator) != batch_size: + shape = (target_batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + if isinstance(generator, list) and len(generator) != target_batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + f" size of {target_batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: @@ -446,14 +516,10 @@ def prepare_latents( latents.device, latents.dtype ) - if isinstance(generator, list): - latent_condition = [ - retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator - ] - latent_condition = torch.cat(latent_condition) - else: - latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") - latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") + latent_condition = self._expand_tensor_to_effective_batch( + latent_condition, batch_size, num_videos_per_prompt, "image" + ) latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std @@ -465,7 +531,7 @@ def prepare_latents( first_frame_mask[:, :, 0] = 0 return latents, latent_condition, first_frame_mask - mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) + mask_lat_size = torch.ones(target_batch_size, 1, num_frames, latent_height, latent_width) if last_image is None: mask_lat_size[:, :, list(range(1, num_frames))] = 0 @@ -474,7 +540,9 @@ def prepare_latents( first_frame_mask = mask_lat_size[:, :, 0:1] first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) - mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) + mask_lat_size = mask_lat_size.view( + target_batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width + ) mask_lat_size = mask_lat_size.transpose(1, 2) mask_lat_size = mask_lat_size.to(latent_condition.device) @@ -517,14 +585,14 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 5.0, guidance_scale_2: float | None = None, - num_videos_per_prompt: int | None = 1, + num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, image_embeds: torch.Tensor | None = None, - last_image: torch.Tensor | None = None, - output_type: str | None = "np", + last_image: PipelineImageInput | None = None, + output_type: str = "np", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, @@ -536,7 +604,8 @@ def __call__( Args: image (`PipelineImageInput`): - The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`. + The input image to condition the generation on. Must be an image, a list or tuple of images, or a + `torch.Tensor`. prompt (`str` or `list[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. @@ -564,10 +633,12 @@ def __call__( `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + The number of videos to generate per prompt. When greater than 1, prompt, image, and conditioning + batches are expanded in per-prompt order. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. + generation deterministic. If a list is passed, it must match the effective batch size + (`batch_size * num_videos_per_prompt`). latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -579,10 +650,13 @@ def __call__( Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `negative_prompt` input argument. image_embeds (`torch.Tensor`, *optional*): - Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided, - image embeddings are generated from the `image` input argument. + Pre-generated CLIP image embeddings. If not provided, image embeddings are generated from the `image` + input argument. `image` is still required because it is used for VAE conditioning. + last_image (`PipelineImageInput`, *optional*): + The input image to condition the final frame on. Must be an image, a list or tuple of images, or a + `torch.Tensor`. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + The output format of the generated video. Choose between `"np"`, `"pt"`, `"pil"`, or `"latent"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): @@ -614,6 +688,11 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + if isinstance(image, tuple): + image = list(image) + if isinstance(last_image, tuple): + last_image = list(last_image) + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -690,13 +769,15 @@ def __call__( # only wan 2.1 i2v transformer accepts image_embeds if self.transformer is not None and self.transformer.config.image_dim is not None: - if image_embeds is None: - if last_image is None: - image_embeds = self.encode_image(image, device) - else: - image_embeds = self.encode_image([image, last_image], device) - image_embeds = image_embeds.repeat(batch_size, 1, 1) - image_embeds = image_embeds.to(transformer_dtype) + image_embeds = self._prepare_image_embeds( + image, + device, + batch_size, + num_videos_per_prompt, + transformer_dtype, + image_embeds, + last_image, + ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -712,7 +793,7 @@ def __call__( latents_outputs = self.prepare_latents( image, - batch_size * num_videos_per_prompt, + batch_size, num_channels_latents, height, width, @@ -722,6 +803,7 @@ def __call__( generator, latents, last_image, + num_videos_per_prompt=num_videos_per_prompt, ) if self.config.expand_timesteps: # wan 2.2 5b i2v use firt_frame_mask to mask timesteps diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py index c016eec1b535..fae436efe13b 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py @@ -196,8 +196,8 @@ def __init__( scheduler=scheduler, ) self.register_to_config(boundary_ratio=boundary_ratio) - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds @@ -702,12 +702,12 @@ def __call__( num_inference_steps: int = 50, guidance_scale: float = 5.0, guidance_scale_2: float | None = None, - num_videos_per_prompt: int | None = 1, + num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, - output_type: str | None = "np", + output_type: str = "np", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, @@ -766,10 +766,10 @@ def __call__( `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2` and the pipeline's `boundary_ratio` are not None. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + The number of videos to generate per prompt. Currently only `1` is supported. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. + generation deterministic. If a list is passed, it must match the effective batch size. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -778,7 +778,7 @@ def __call__( Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + The output format of the generated video. Choose between `"np"`, `"pt"`, `"pil"`, or `"latent"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): @@ -1017,8 +1017,8 @@ def __call__( self._current_timestep = None + latents = latents[:, :, num_reference_images:] if not output_type == "latent": - latents = latents[:, :, num_reference_images:] latents = latents.to(vae_dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index 3d7c5297f4c4..20c721aacbba 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -18,10 +18,10 @@ import regex as re import torch -from PIL import Image from transformers import AutoTokenizer, UMT5EncoderModel from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput from ...loaders import WanLoraLoaderMixin from ...models import AutoencoderKLWan, WanTransformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -214,8 +214,8 @@ def __init__( scheduler=scheduler, ) - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds @@ -397,7 +397,7 @@ def prepare_latents( width: int = 832, dtype: torch.dtype | None = None, device: torch.device | None = None, - generator: torch.Generator | None = None, + generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, timestep: torch.Tensor | None = None, ): @@ -480,7 +480,7 @@ def attention_kwargs(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - video: list[Image.Image] = None, + video: PipelineImageInput | list[PipelineImageInput] | None = None, prompt: str | list[str] = None, negative_prompt: str | list[str] = None, height: int = 480, @@ -489,12 +489,12 @@ def __call__( timesteps: list[int] | None = None, guidance_scale: float = 5.0, strength: float = 0.8, - num_videos_per_prompt: int | None = 1, + num_videos_per_prompt: int = 1, generator: torch.Generator | list[torch.Generator] | None = None, latents: torch.Tensor | None = None, prompt_embeds: torch.Tensor | None = None, negative_prompt_embeds: torch.Tensor | None = None, - output_type: str | None = "np", + output_type: str = "np", return_dict: bool = True, attention_kwargs: dict[str, Any] | None = None, callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None, @@ -505,6 +505,9 @@ def __call__( The call function to the pipeline for generation. Args: + video (`PipelineImageInput` or `list[PipelineImageInput]`, *optional*): + The input video or videos to be used as a starting point for the generation. The video can be a list of + PIL images, a numpy array, or a torch tensor. prompt (`str` or `list[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` instead. @@ -512,8 +515,6 @@ def __call__( The height in pixels of the generated image. width (`int`, defaults to `832`): The width in pixels of the generated image. - num_frames (`int`, defaults to `81`): - The number of frames in the generated video. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -526,10 +527,10 @@ def __call__( strength (`float`, defaults to `0.8`): Higher strength leads to more differences between original image and generated video. num_videos_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. + The number of videos to generate per prompt. Currently only `1` is supported. generator (`torch.Generator` or `list[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. + generation deterministic. If a list is passed, it must match the effective batch size. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents @@ -538,7 +539,7 @@ def __call__( Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. output_type (`str`, *optional*, defaults to `"np"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. + The output format of the generated video. Choose between `"np"`, `"pt"`, `"pil"`, or `"latent"`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. attention_kwargs (`dict`, *optional*): @@ -572,7 +573,10 @@ def __call__( height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial - num_videos_per_prompt = 1 + if num_videos_per_prompt != 1: + raise ValueError( + "Generating multiple videos per prompt is not yet supported. This may be supported in the future." + ) # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/tests/modular_pipelines/wan/test_modular_pipeline_wan.py b/tests/modular_pipelines/wan/test_modular_pipeline_wan.py index c5ed9613e40f..3f9404aa5a39 100644 --- a/tests/modular_pipelines/wan/test_modular_pipeline_wan.py +++ b/tests/modular_pipelines/wan/test_modular_pipeline_wan.py @@ -13,13 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. +from types import SimpleNamespace + import pytest +import torch from diffusers.modular_pipelines import WanBlocks, WanModularPipeline +from diffusers.modular_pipelines.modular_pipeline import BlockState, PipelineState +from diffusers.modular_pipelines.wan.before_denoise import WanTextInputStep +from diffusers.modular_pipelines.wan.denoise import Wan22LoopDenoiser, WanLoopDenoiser from ..test_modular_pipelines_common import ModularPipelineTesterMixin +class _FakeGuider: + def set_state(self, step, num_inference_steps, timestep): + self.step = step + self.num_inference_steps = num_inference_steps + self.timestep = timestep + + def prepare_inputs_from_block_state(self, block_state, guider_input_fields): + batch = {} + for model_input_name, block_state_input_names in guider_input_fields.items(): + if isinstance(block_state_input_names, tuple): + block_state_input_name = block_state_input_names[0] + else: + block_state_input_name = block_state_input_names + batch[model_input_name] = getattr(block_state, block_state_input_name) + return [BlockState(**batch)] + + def prepare_models(self, model): + pass + + def cleanup_models(self, model): + pass + + def __call__(self, guider_state): + return [guider_state[0].noise_pred] + + +class _FakeTransformer: + def __init__(self, dtype): + self.dtype = dtype + self.calls = [] + + def __call__(self, hidden_states, timestep, attention_kwargs, return_dict, **kwargs): + self.calls.append( + { + "hidden_states": hidden_states, + "timestep": timestep, + "attention_kwargs": attention_kwargs, + "return_dict": return_dict, + **kwargs, + } + ) + return (torch.zeros_like(hidden_states),) + + class TestWanModularPipelineFast(ModularPipelineTesterMixin): pipeline_class = WanModularPipeline pipeline_blocks_class = WanBlocks @@ -47,3 +97,90 @@ def get_dummy_inputs(self, seed=0): @pytest.mark.skip(reason="num_videos_per_prompt") def test_num_images_per_prompt(self): pass + + def test_vae_scale_factors_use_config_values(self): + pipe = WanModularPipeline.__new__(WanModularPipeline) + + assert pipe.vae_scale_factor_spatial == 8 + assert pipe.vae_scale_factor_temporal == 4 + + pipe.vae = SimpleNamespace( + config=SimpleNamespace(scale_factor_spatial=16, scale_factor_temporal=2), + temperal_downsample=[True, True, False], + ) + + assert pipe.vae_scale_factor_spatial == 16 + assert pipe.vae_scale_factor_temporal == 2 + assert pipe.default_height == 960 + assert pipe.default_width == 1664 + assert pipe.default_num_frames == 41 + + def test_text_input_step_uses_transformer_dtype_and_repeat_interleave(self): + step = WanTextInputStep() + components = SimpleNamespace(transformer=SimpleNamespace(dtype=torch.bfloat16)) + prompt_embeds = torch.arange(2 * 3 * 4, dtype=torch.float32).reshape(2, 3, 4) + negative_prompt_embeds = -prompt_embeds + state = PipelineState() + state.set("num_videos_per_prompt", 2) + state.set("prompt_embeds", prompt_embeds) + state.set("negative_prompt_embeds", negative_prompt_embeds) + + _, state = step(components, state) + + assert state.batch_size == 2 + assert state.dtype == torch.bfloat16 + torch.testing.assert_close(state.prompt_embeds, prompt_embeds.repeat_interleave(2, dim=0).to(torch.bfloat16)) + torch.testing.assert_close( + state.negative_prompt_embeds, negative_prompt_embeds.repeat_interleave(2, dim=0).to(torch.bfloat16) + ) + + def test_loop_denoiser_preserves_timestep_dtype(self): + transformer = _FakeTransformer(dtype=torch.bfloat16) + components = SimpleNamespace(transformer=transformer, guider=_FakeGuider()) + block_state = BlockState( + attention_kwargs=None, + dtype=torch.bfloat16, + latent_model_input=torch.ones(2, 4, dtype=torch.float32), + num_inference_steps=1, + prompt_embeds=torch.ones(2, 3, dtype=torch.float32), + negative_prompt_embeds=torch.zeros(2, 3, dtype=torch.float32), + ) + timestep = torch.tensor(999.1234, dtype=torch.float32) + + WanLoopDenoiser()(components, block_state, i=0, t=timestep) + + call = transformer.calls[0] + assert call["hidden_states"].dtype == torch.bfloat16 + assert call["encoder_hidden_states"].dtype == torch.bfloat16 + assert call["timestep"].dtype == torch.float32 + torch.testing.assert_close(call["timestep"], timestep.expand(2)) + + def test_wan22_loop_denoiser_uses_selected_transformer_dtype_and_preserves_timestep_dtype(self): + high_noise_transformer = _FakeTransformer(dtype=torch.bfloat16) + low_noise_transformer = _FakeTransformer(dtype=torch.float16) + components = SimpleNamespace( + config=SimpleNamespace(boundary_ratio=0.875), + num_train_timesteps=1000, + transformer=high_noise_transformer, + transformer_2=low_noise_transformer, + guider=_FakeGuider(), + guider_2=_FakeGuider(), + ) + block_state = BlockState( + attention_kwargs=None, + dtype=torch.bfloat16, + latent_model_input=torch.ones(2, 4, dtype=torch.float32), + num_inference_steps=1, + prompt_embeds=torch.ones(2, 3, dtype=torch.float32), + negative_prompt_embeds=torch.zeros(2, 3, dtype=torch.float32), + ) + timestep = torch.tensor(10.25, dtype=torch.float32) + + Wan22LoopDenoiser()(components, block_state, i=0, t=timestep) + + assert len(high_noise_transformer.calls) == 0 + call = low_noise_transformer.calls[0] + assert call["hidden_states"].dtype == torch.float16 + assert call["encoder_hidden_states"].dtype == torch.float16 + assert call["timestep"].dtype == torch.float32 + torch.testing.assert_close(call["timestep"], timestep.expand(2)) diff --git a/tests/pipelines/wan/test_wan_animate.py b/tests/pipelines/wan/test_wan_animate.py index 5d634fb71849..e2496e7e353e 100644 --- a/tests/pipelines/wan/test_wan_animate.py +++ b/tests/pipelines/wan/test_wan_animate.py @@ -33,6 +33,7 @@ WanAnimatePipeline, WanAnimateTransformer3DModel, ) +from diffusers.pipelines.wan.image_processor import WanAnimateImageProcessor from ...testing_utils import ( backend_empty_cache, @@ -208,6 +209,66 @@ def test_inference_replacement(self): video = pipe(**inputs).frames[0] self.assertEqual(video.shape, (17, 3, 16, 16)) + def test_num_videos_per_prompt_with_image_embeds(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs.update( + { + "prompt": ["dance monkey", "dance robot"], + "negative_prompt": ["negative", "negative"], + "num_videos_per_prompt": 2, + "num_inference_steps": 1, + } + ) + with torch.no_grad(): + image_embeds = pipe.encode_image(inputs["image"], device) + inputs["image_embeds"] = torch.cat([image_embeds, image_embeds + 1.0], dim=0) + + video = pipe(**inputs).frames + + self.assertEqual(video.shape, (4, 17, 3, 16, 16)) + + def test_image_embeds_still_requires_image(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + inputs = self.get_dummy_inputs("cpu") + inputs["image_embeds"] = torch.zeros(1, 1, pipe.transformer.config.image_dim) + inputs["image"] = None + + with self.assertRaisesRegex(ValueError, "`image` must be provided"): + pipe(**inputs) + + def test_wan_image_processor_preserves_config(self): + processor = WanAnimateImageProcessor( + do_resize=False, + vae_scale_factor=16, + vae_latent_channels=32, + spatial_patch_size=(1, 4), + resample="nearest", + reducing_gap=2, + do_normalize=False, + do_binarize=True, + do_convert_rgb=True, + fill_color=(1, 2, 3), + ) + + self.assertFalse(processor.config.do_resize) + self.assertEqual(processor.config.vae_scale_factor, 16) + self.assertEqual(processor.config.vae_latent_channels, 32) + self.assertEqual(processor.config.spatial_patch_size, (1, 4)) + self.assertEqual(processor.config.resample, "nearest") + self.assertEqual(processor.config.reducing_gap, 2) + self.assertFalse(processor.config.do_normalize) + self.assertTrue(processor.config.do_binarize) + self.assertTrue(processor.config.do_convert_rgb) + self.assertEqual(processor.config.fill_color, (1, 2, 3)) + @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py index 7ed263abdcb5..11fbc62f6839 100644 --- a/tests/pipelines/wan/test_wan_image_to_video.py +++ b/tests/pipelines/wan/test_wan_image_to_video.py @@ -161,6 +161,125 @@ def test_inference(self): generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + def test_num_videos_per_prompt_with_image_embeds(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs.update( + { + "prompt": ["dance monkey", "dance robot"], + "guidance_scale": 1.0, + "num_videos_per_prompt": 3, + "num_inference_steps": 1, + "num_frames": 1, + "output_type": "latent", + } + ) + with torch.no_grad(): + image_embeds = pipe.encode_image(inputs["image"], device) + inputs["image_embeds"] = torch.cat([image_embeds, image_embeds + 1.0], dim=0) + + video = pipe(**inputs).frames + + self.assertEqual(video.shape, (6, 16, 1, 2, 2)) + + def test_image_embeds_invalid_batch_size_raises(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs.update( + { + "prompt": ["dance monkey", "dance robot"], + "guidance_scale": 1.0, + "num_videos_per_prompt": 2, + "num_inference_steps": 1, + "num_frames": 1, + "output_type": "latent", + } + ) + with torch.no_grad(): + image_embeds = pipe.encode_image(inputs["image"], device) + inputs["image_embeds"] = image_embeds.repeat(3, 1, 1) + + with self.assertRaisesRegex(ValueError, "`image_embeds` batch size must be 1"): + pipe(**inputs) + + def test_check_inputs_accepts_image_lists_and_tuples(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + + pil_image = Image.new("RGB", (16, 16)) + np_image = np.zeros((16, 16, 3), dtype=np.float32) + torch_image = torch.zeros(3, 16, 16) + + valid_image_batches = [ + [pil_image, pil_image], + (pil_image, pil_image), + [np_image, np_image], + (np_image, np_image), + [torch_image, torch_image], + (torch_image, torch_image), + ] + for image in valid_image_batches: + with self.subTest(image_type=type(image[0]).__name__, container_type=type(image).__name__): + pipe.check_inputs(["dance monkey", "dance robot"], None, image, 16, 16) + + def test_tuple_image_input(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + image = (Image.new("RGB", (16, 16)), Image.new("RGB", (16, 16))) + inputs = self.get_dummy_inputs(device) + inputs.update( + { + "image": image, + "prompt": ["dance monkey", "dance robot"], + "guidance_scale": 1.0, + "num_inference_steps": 1, + "num_frames": 1, + "output_type": "latent", + } + ) + + video = pipe(**inputs).frames + + self.assertEqual(video.shape, (2, 16, 1, 2, 2)) + + def test_last_image_embeds_expand_per_prompt_order(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.transformer.condition_embedder.image_embedder.pos_embed = torch.nn.Parameter(torch.zeros(1, 2, 4)) + + image_embeds = torch.tensor([0.0, 1.0, 10.0, 11.0]).reshape(4, 1, 1) + image_embeds = image_embeds.expand(4, 1, pipe.transformer.config.image_dim) + + image_embeds = pipe._prepare_image_embeds( + image=[Image.new("RGB", (16, 16)), Image.new("RGB", (16, 16))], + device=torch.device("cpu"), + batch_size=2, + num_videos_per_prompt=2, + transformer_dtype=torch.float32, + image_embeds=image_embeds, + last_image=[Image.new("RGB", (16, 16)), Image.new("RGB", (16, 16))], + ) + + expected = torch.tensor([0.0, 1.0, 0.0, 1.0, 10.0, 11.0, 10.0, 11.0]) + torch.testing.assert_close(image_embeds[:, 0, 0], expected) + @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py index 53becce1685d..6fcd6301fbbc 100644 --- a/tests/pipelines/wan/test_wan_vace.py +++ b/tests/pipelines/wan/test_wan_vace.py @@ -193,6 +193,47 @@ def test_inference_with_multiple_reference_image(self): video_slice = [round(x, 5) for x in video_slice.tolist()] self.assertTrue(np.allclose(video_slice, expected_slice, atol=1e-3)) + def test_uses_vae_config_scale_factors(self): + components = self.get_dummy_components() + components["vae"].register_to_config(scale_factor_temporal=2, scale_factor_spatial=16) + + pipe = self.pipeline_class(**components) + + self.assertEqual(pipe.vae_scale_factor_temporal, 2) + self.assertEqual(pipe.vae_scale_factor_spatial, 16) + self.assertEqual(pipe.video_processor.config.vae_scale_factor, 16) + + def test_latent_output_trims_reference_latents(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 1 + inputs["reference_images"] = Image.new("RGB", (16, 16)) + inputs["output_type"] = "latent" + + expected_latent_frames = (inputs["num_frames"] - 1) // pipe.vae_scale_factor_temporal + 1 + latent_height = inputs["height"] // pipe.vae_scale_factor_spatial + latent_width = inputs["width"] // pipe.vae_scale_factor_spatial + inputs["latents"] = torch.zeros( + 1, + pipe.transformer.config.in_channels, + expected_latent_frames + 1, + latent_height, + latent_width, + ) + + latents = pipe(**inputs).frames + + self.assertEqual( + latents.shape, + (1, pipe.transformer.config.in_channels, expected_latent_frames, latent_height, latent_width), + ) + @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py index 3804e972b97f..ae859264d7f2 100644 --- a/tests/pipelines/wan/test_wan_video_to_video.py +++ b/tests/pipelines/wan/test_wan_video_to_video.py @@ -133,6 +133,25 @@ def test_inference(self): generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + def test_uses_vae_config_scale_factors(self): + components = self.get_dummy_components() + components["vae"].register_to_config(scale_factor_temporal=2, scale_factor_spatial=16) + + pipe = self.pipeline_class(**components) + + self.assertEqual(pipe.vae_scale_factor_temporal, 2) + self.assertEqual(pipe.vae_scale_factor_spatial, 16) + self.assertEqual(pipe.video_processor.config.vae_scale_factor, 16) + + def test_num_videos_per_prompt_unsupported_raises(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + inputs = self.get_dummy_inputs("cpu") + inputs["num_videos_per_prompt"] = 2 + + with self.assertRaisesRegex(ValueError, "Generating multiple videos per prompt"): + pipe(**inputs) + @unittest.skip("Test not supported") def test_attention_slicing_forward_pass(self): pass