Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .ai/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,5 @@ When adding a new pipeline (or reviewing one), skim `pipeline_flux.py`, `pipelin
4. **Subclassing an existing pipeline for a variant.** Don't use an existing pipeline class (e.g. `FluxPipeline`) to override another (e.g. `FluxImg2ImgPipeline`) inside the core `src/` codebase. Each pipeline lives in its own file with its own class, even if it shares 90% of `__call__` with a sibling. Convention across diffusers — flux, sdxl, wan, qwenimage — is duplicated `__call__` between img2img / text2img / inpaint variants, not subclassing. Reuse private utilities (shared schedulers, prep functions) but not the pipeline class itself.

5. **Copying a method from another pipeline without `# Copied from`.** When you reuse a method like `encode_prompt`, `prepare_latents`, `check_inputs`, or `_prepare_latent_image_ids` from another pipeline, add a `# Copied from` annotation so `make fix-copies` keeps the two in sync. Forgetting it means future refactors to the source drift away from your copy silently — and reviewers waste time spotting near-identical code that should have been linked. The annotation grammar (decorator placement, rename syntax with `with old->new`, etc.) is implemented in [`utils/check_copies.py`](../utils/check_copies.py) — read it for the exact rules.

6. **Partial batch expansion with `num_*_per_prompt`.** When a pipeline accepts `num_images_per_prompt`, `num_videos_per_prompt`, or precomputed conditioning tensors, every per-prompt input that reaches the denoising loop must be expanded to the same effective batch size in the same prompt order. Check prompt embeds, pooled embeds, image/video embeds, masks, control latents, image latents, guidance tensors, and any added conditioning. Prefer `torch.repeat_interleave(tensor, repeats=num_per_prompt, dim=0, output_size=tensor.shape[0] * num_per_prompt)` for contiguous per-prompt grouping, and validate precomputed tensors that are already batched instead of letting a later concat fail. Avoid `tensor.repeat((num_per_prompt, ...))` for batch expansion because it interleaves prompts (`[A, B] -> [A, B, A, B]`) rather than grouping them (`[A, A, B, B]`).
22 changes: 8 additions & 14 deletions src/diffusers/modular_pipelines/wan/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def inputs(self) -> list[InputParam]:
]

@property
def intermediate_outputs(self) -> list[str]:
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"batch_size",
Expand Down Expand Up @@ -252,22 +252,16 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
self.check_inputs(components, block_state)

block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
block_state.dtype = components.transformer.dtype

_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
block_state.prompt_embeds = block_state.prompt_embeds.repeat_interleave(
block_state.num_videos_per_prompt, dim=0
).to(block_state.dtype)

if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_videos_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat_interleave(
block_state.num_videos_per_prompt, dim=0
).to(block_state.dtype)

self.set_block_state(state, block_state)

Expand Down
129 changes: 96 additions & 33 deletions src/diffusers/modular_pipelines/wan/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline


Expand All @@ -54,17 +54,21 @@ def inputs(self) -> list[InputParam]:
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"latent_model_input",
type_hint=torch.Tensor,
description="Latents prepared as model hidden states for the denoiser.",
)
]

@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
block_state.latent_model_input = block_state.latents
return components, block_state


Expand Down Expand Up @@ -94,19 +98,21 @@ def inputs(self) -> list[InputParam]:
type_hint=torch.Tensor,
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"latent_model_input",
type_hint=torch.Tensor,
description="Latents and image conditioning prepared as model hidden states for the denoiser.",
)
]

@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat(
[block_state.latents, block_state.image_condition_latents], dim=1
).to(block_state.dtype)
block_state.latent_model_input = torch.cat([block_state.latents, block_state.image_condition_latents], dim=1)
return components, block_state


Expand Down Expand Up @@ -165,6 +171,18 @@ def inputs(self) -> list[tuple[str, Any]]:
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"latent_model_input",
required=True,
type_hint=torch.Tensor,
description="Latents prepared as model hidden states for the denoiser.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
guider_input_names = []
for value in self._guider_input_fields.values():
Expand All @@ -177,11 +195,22 @@ def inputs(self) -> list[tuple[str, Any]]:
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
return inputs

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"noise_pred",
type_hint=torch.Tensor,
description="The guided noise prediction for the current denoising step.",
)
]

@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
dtype = block_state.dtype

# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
Expand All @@ -198,16 +227,16 @@ def __call__(
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
k: v.to(dtype) if isinstance(v, torch.Tensor) else v
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}

# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input.to(block_state.dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
hidden_states=block_state.latent_model_input.to(dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
Expand Down Expand Up @@ -292,6 +321,12 @@ def inputs(self) -> list[tuple[str, Any]]:
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"latent_model_input",
required=True,
type_hint=torch.Tensor,
description="Latents prepared as model hidden states for the denoiser.",
),
]
guider_input_names = []
for value in self._guider_input_fields.values():
Expand All @@ -304,19 +339,30 @@ def inputs(self) -> list[tuple[str, Any]]:
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
return inputs

@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"noise_pred",
type_hint=torch.Tensor,
description="The guided noise prediction for the current denoising step.",
)
]

@torch.no_grad()
def __call__(
self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
boundary_timestep = components.config.boundary_ratio * components.num_train_timesteps
if t >= boundary_timestep:
block_state.current_model = components.transformer
block_state.guider = components.guider
current_model = components.transformer
guider = components.guider
else:
block_state.current_model = components.transformer_2
block_state.guider = components.guider_2
current_model = components.transformer_2
guider = components.guider_2
dtype = current_model.dtype

block_state.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)

# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
Expand All @@ -326,31 +372,31 @@ def __call__(
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = block_state.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
guider_state = guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)

# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
block_state.guider.prepare_models(block_state.current_model)
guider.prepare_models(current_model)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
k: v.to(dtype) if isinstance(v, torch.Tensor) else v
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}

# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
guider_state_batch.noise_pred = block_state.current_model(
hidden_states=block_state.latent_model_input.to(block_state.dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
guider_state_batch.noise_pred = current_model(
hidden_states=block_state.latent_model_input.to(dtype),
timestep=t.expand(block_state.latent_model_input.shape[0]),
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
block_state.guider.cleanup_models(block_state.current_model)
guider.cleanup_models(current_model)

# Perform guidance
block_state.noise_pred = block_state.guider(guider_state)[0]
block_state.noise_pred = guider(guider_state)[0]

return components, block_state

Expand All @@ -372,6 +418,23 @@ def description(self) -> str:
"object (e.g. `WanDenoiseLoopWrapper`)"
)

@property
def inputs(self) -> list[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to update with the scheduler step.",
),
InputParam(
"noise_pred",
required=True,
type_hint=torch.Tensor,
description="The guided noise prediction for the current denoising step.",
),
]

@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
Expand Down
8 changes: 4 additions & 4 deletions src/diffusers/modular_pipelines/wan/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ def patch_size_spatial(self):
@property
def vae_scale_factor_spatial(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
if getattr(self, "vae", None) is not None:
vae_scale_factor = self.vae.config.scale_factor_spatial
return vae_scale_factor

@property
def vae_scale_factor_temporal(self):
vae_scale_factor = 4
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
if getattr(self, "vae", None) is not None:
vae_scale_factor = self.vae.config.scale_factor_temporal
return vae_scale_factor

@property
Expand Down
21 changes: 12 additions & 9 deletions src/diffusers/pipelines/wan/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import PIL.Image
import torch

from ...configuration_utils import register_to_config
from ...image_processor import VaeImageProcessor
from ...utils import PIL_INTERPOLATION

Expand Down Expand Up @@ -53,7 +52,6 @@ class WanAnimateImageProcessor(VaeImageProcessor):
if `None`, will default to filling with data from `image`.
"""

@register_to_config
def __init__(
self,
do_resize: bool = True,
Expand All @@ -68,13 +66,18 @@ def __init__(
do_convert_grayscale: bool = False,
fill_color: str | float | tuple[float, ...] | None = 0,
):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
super().__init__(
do_resize=do_resize,
vae_scale_factor=vae_scale_factor,
vae_latent_channels=vae_latent_channels,
resample=resample,
reducing_gap=reducing_gap,
do_normalize=do_normalize,
do_binarize=do_binarize,
do_convert_rgb=do_convert_rgb,
do_convert_grayscale=do_convert_grayscale,
)
self.register_to_config(spatial_patch_size=spatial_patch_size, fill_color=fill_color)

def _resize_and_fill(
self,
Expand Down
Loading
Loading