From 5d8a8bf99d60c1eac9c9d823ac7d4ef6976c45ac Mon Sep 17 00:00:00 2001 From: Genno Date: Mon, 11 May 2026 07:26:09 -0700 Subject: [PATCH 1/2] fix(wan): WanModel.patchify flatten + return grid size tuple Fixes #1063. WanModel.forward (and the parallel usp_dit_forward in diffsynth/utils/xfuser/xdit_context_parallel.py) both unpack patchify's return value as `x, (f, h, w) = self.patchify(x)`, but patchify only returns `x` -- and that `x` is a 5D (B, dim, f, h, w) tensor straight out of Conv3d, not the (B, f*h*w, dim) sequence the transformer blocks expect. Net result: every Wan training run crashes at the first forward call with `ValueError: not enough values to unpack (expected 2, got 1)`. The block loop never executes. The fix: extract (f, h, w) from the Conv3d output shape, flatten the 3D spatial-temporal grid into a token sequence (B, f*h*w, dim) ready for the transformer blocks, return (x, (f, h, w)) to match what forward and unpatchify expect. Verified end-to-end with a synthetic 8-layer WanModel on the same forward path (input shape (1, 16, 4, 8, 8) -> output shape (1, 16, 4, 8, 8) round-trip), forward + backward both complete, loss gradient propagates through all blocks. --- diffsynth/models/wan_video_dit.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 52f607e10..2a5ebba3a 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -498,7 +498,13 @@ def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch y_camera = self.control_adapter(control_camera_latents_input) x = [u + v for u, v in zip(x, y_camera)] x = x[0].unsqueeze(0) - return x + # After Conv3d the tensor is (B, dim, f, h, w). The forward expects + # (x, (f, h, w)) where x has been flattened to (B, f*h*w, dim) ready + # for the transformer blocks; the grid size tuple drives the RoPE + # frequency cat and the unpatchify rearrange at the end. + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + return x, (f, h, w) def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): return rearrange( From ecb6f5341c490c33771e6458f40feb3dc55f0b22 Mon Sep 17 00:00:00 2001 From: Genno Date: Mon, 11 May 2026 08:34:24 -0700 Subject: [PATCH 2/2] fix(wan): apply patchify fix in consumers, keep signature MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per @xiuyuan18 review: the bug should be fixed in forward (and the other tuple-unpacking consumers), not by changing patchify's return signature. This matches the official Wan 2.1 implementation, which also flattens x in forward, not in patchify. Reverts patchify back to returning just `x` (post-Conv3d tensor shape (B, dim, f, h, w)). Adds the flatten + (f, h, w) extraction at each tuple-unpacking consumer: - wan_video_dit.py: WanModel.forward - xdit_context_parallel.py: usp_dit_forward - wan_video.py: model_fn_wans2v (two sites: main x and ref_latents) WanToDance callers in wan_video.py (lines 1412/1414) continue to work unchanged because they always wanted the raw post-Conv3d tensor and do their own flattening downstream. Pre-existing S2V double-embedding pattern at model_fn_wans2v (caller pre-runs patch_embedding, then patchify re-runs patch_embedding internally) is preserved as-is — out of scope for this PR. Refs #1063 --- diffsynth/models/wan_video_dit.py | 14 +++++--------- diffsynth/pipelines/wan_video.py | 8 ++++++-- diffsynth/utils/xfuser/xdit_context_parallel.py | 6 ++++-- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 2a5ebba3a..43d8cd1de 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -498,13 +498,7 @@ def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch y_camera = self.control_adapter(control_camera_latents_input) x = [u + v for u, v in zip(x, y_camera)] x = x[0].unsqueeze(0) - # After Conv3d the tensor is (B, dim, f, h, w). The forward expects - # (x, (f, h, w)) where x has been flattened to (B, f*h*w, dim) ready - # for the transformer blocks; the grid size tuple drives the RoPE - # frequency cat and the unpatchify rearrange at the end. - f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] - x = x.flatten(2).transpose(1, 2) - return x, (f, h, w) + return x def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): return rearrange( @@ -533,8 +527,10 @@ def forward(self, clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) - + x = self.patchify(x) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index c1e4dfb3a..b96d20ee0 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -1658,11 +1658,15 @@ def model_fn_wans2v( # x and s2v_pose_latents s2v_pose_latents = torch.zeros_like(x) if s2v_pose_latents is None else s2v_pose_latents - x, (f, h, w) = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + x = dit.patchify(dit.patch_embedding(x) + dit.cond_encoder(s2v_pose_latents)) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) seq_len_x = seq_len_x_global = x.shape[1] # global used for unified sequence parallel # reference image - ref_latents, (rf, rh, rw) = dit.patchify(dit.patch_embedding(origin_ref_latents)) + ref_latents = dit.patchify(dit.patch_embedding(origin_ref_latents)) + rf, rh, rw = ref_latents.shape[-3], ref_latents.shape[-2], ref_latents.shape[-1] + ref_latents = ref_latents.flatten(2).transpose(1, 2) grid_sizes = dit.get_grid_sizes((f, h, w), (rf, rh, rw)) x = torch.cat([x, ref_latents], dim=1) # mask diff --git a/diffsynth/utils/xfuser/xdit_context_parallel.py b/diffsynth/utils/xfuser/xdit_context_parallel.py index abf0f3fef..40b18821a 100644 --- a/diffsynth/utils/xfuser/xdit_context_parallel.py +++ b/diffsynth/utils/xfuser/xdit_context_parallel.py @@ -81,8 +81,10 @@ def usp_dit_forward(self, clip_embdding = self.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - x, (f, h, w) = self.patchify(x) - + x = self.patchify(x) + f, h, w = x.shape[-3], x.shape[-2], x.shape[-1] + x = x.flatten(2).transpose(1, 2) + freqs = torch.cat([ self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),