22import torch
33import torch .nn as nn
44import numpy as np
5- from typing import Any , Dict , Optional
5+ from typing import Any , Dict , List , Optional
66from einops import rearrange
77
88from diffsynth_engine .models .basic .transformer_helper import (
@@ -245,7 +245,7 @@ def __init__(
245245 self .ff_a = nn .Sequential (
246246 nn .Linear (dim , dim * 4 , device = device , dtype = dtype ),
247247 nn .GELU (approximate = "tanh" ),
248- nn .Linear (dim * 4 , dim , device = device , dtype = dtype )
248+ nn .Linear (dim * 4 , dim , device = device , dtype = dtype ),
249249 )
250250 # Text
251251 self .norm_msa_b = AdaLayerNormZero (dim , device = device , dtype = dtype )
@@ -395,21 +395,19 @@ def prepare_image_ids(latents: torch.Tensor):
395395
396396 def forward (
397397 self ,
398- hidden_states ,
399- timestep ,
400- prompt_emb ,
401- pooled_prompt_emb ,
402- image_emb ,
403- guidance ,
404- text_ids ,
405- image_ids = None ,
406- controlnet_double_block_output = None ,
407- controlnet_single_block_output = None ,
398+ hidden_states : torch . Tensor ,
399+ timestep : torch . Tensor ,
400+ prompt_emb : torch . Tensor ,
401+ pooled_prompt_emb : torch . Tensor ,
402+ image_ids : torch . Tensor ,
403+ text_ids : torch . Tensor ,
404+ guidance : torch . Tensor ,
405+ image_emb : torch . Tensor | None = None ,
406+ controlnet_double_block_output : List [ torch . Tensor ] | None = None ,
407+ controlnet_single_block_output : List [ torch . Tensor ] | None = None ,
408408 ** kwargs ,
409409 ):
410- h , w = hidden_states .shape [- 2 :]
411- if image_ids is None :
412- image_ids = self .prepare_image_ids (hidden_states )
410+ image_seq_len = hidden_states .shape [1 ]
413411 controlnet_double_block_output = (
414412 controlnet_double_block_output if controlnet_double_block_output is not None else ()
415413 )
@@ -428,10 +426,10 @@ def forward(
428426 timestep ,
429427 prompt_emb ,
430428 pooled_prompt_emb ,
431- image_emb ,
432- guidance ,
433- text_ids ,
434429 image_ids ,
430+ text_ids ,
431+ guidance ,
432+ image_emb ,
435433 * controlnet_double_block_output ,
436434 * controlnet_single_block_output ,
437435 ),
@@ -448,7 +446,6 @@ def forward(
448446 rope_emb = self .pos_embedder (torch .cat ((text_ids , image_ids ), dim = 1 ))
449447 text_rope_emb = rope_emb [:, :, : text_ids .size (1 )]
450448 image_rope_emb = rope_emb [:, :, text_ids .size (1 ) :]
451- hidden_states = self .patchify (hidden_states )
452449
453450 with sequence_parallel (
454451 (
@@ -489,9 +486,8 @@ def forward(
489486 hidden_states = hidden_states [:, prompt_emb .shape [1 ] :]
490487 hidden_states = self .final_norm_out (hidden_states , conditioning )
491488 hidden_states = self .final_proj_out (hidden_states )
492- (hidden_states ,) = sequence_parallel_unshard ((hidden_states ,), seq_dims = (1 ,), seq_lens = (h * w // 4 ,))
489+ (hidden_states ,) = sequence_parallel_unshard ((hidden_states ,), seq_dims = (1 ,), seq_lens = (image_seq_len ,))
493490
494- hidden_states = self .unpatchify (hidden_states , h , w )
495491 (hidden_states ,) = cfg_parallel_unshard ((hidden_states ,), use_cfg = use_cfg )
496492 return hidden_states
497493
0 commit comments