Skip to content

Commit 33d00d0

Browse files
authored
supports flux kontext with multiple input images (#173)
1 parent 9821529 commit 33d00d0

7 files changed

Lines changed: 441 additions & 146 deletions

File tree

diffsynth_engine/conf/models/flux/flux_dit.json

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,5 +101,24 @@
101101
"proj_mlp": "proj_in_besides_attn",
102102
"proj_out": "proj_out"
103103
}
104-
}
104+
},
105+
"preferred_kontext_resolutions": [
106+
[672, 1568],
107+
[688, 1504],
108+
[720, 1456],
109+
[752, 1392],
110+
[800, 1328],
111+
[832, 1248],
112+
[880, 1184],
113+
[944, 1104],
114+
[1024, 1024],
115+
[1104, 944],
116+
[1184, 880],
117+
[1248, 832],
118+
[1328, 800],
119+
[1392, 752],
120+
[1456, 720],
121+
[1504, 688],
122+
[1568, 672]
123+
]
105124
}

diffsynth_engine/conf/models/flux/flux_vae.json

Lines changed: 253 additions & 5 deletions
Large diffs are not rendered by default.

diffsynth_engine/models/flux/flux_controlnet.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,16 @@ def patchify(self, hidden_states):
119119

120120
def forward(
121121
self,
122-
hidden_states,
123-
control_condition,
124-
control_scale,
125-
timestep,
126-
prompt_emb,
127-
pooled_prompt_emb,
128-
guidance,
129-
image_ids,
130-
text_ids,
122+
hidden_states: torch.Tensor,
123+
control_condition: torch.Tensor,
124+
control_scale: float,
125+
timestep: torch.Tensor,
126+
prompt_emb: torch.Tensor,
127+
pooled_prompt_emb: torch.Tensor,
128+
image_ids: torch.Tensor,
129+
text_ids: torch.Tensor,
130+
guidance: torch.Tensor,
131131
):
132-
hidden_states = self.patchify(hidden_states)
133-
control_condition = self.patchify(control_condition)
134132
hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition)
135133
condition = (
136134
self.time_embedder(timestep, hidden_states.dtype)

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44
import numpy as np
5-
from typing import Any, Dict, Optional
5+
from typing import Any, Dict, List, Optional
66
from einops import rearrange
77

88
from diffsynth_engine.models.basic.transformer_helper import (
@@ -245,7 +245,7 @@ def __init__(
245245
self.ff_a = nn.Sequential(
246246
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
247247
nn.GELU(approximate="tanh"),
248-
nn.Linear(dim * 4, dim, device=device, dtype=dtype)
248+
nn.Linear(dim * 4, dim, device=device, dtype=dtype),
249249
)
250250
# Text
251251
self.norm_msa_b = AdaLayerNormZero(dim, device=device, dtype=dtype)
@@ -395,21 +395,19 @@ def prepare_image_ids(latents: torch.Tensor):
395395

396396
def forward(
397397
self,
398-
hidden_states,
399-
timestep,
400-
prompt_emb,
401-
pooled_prompt_emb,
402-
image_emb,
403-
guidance,
404-
text_ids,
405-
image_ids=None,
406-
controlnet_double_block_output=None,
407-
controlnet_single_block_output=None,
398+
hidden_states: torch.Tensor,
399+
timestep: torch.Tensor,
400+
prompt_emb: torch.Tensor,
401+
pooled_prompt_emb: torch.Tensor,
402+
image_ids: torch.Tensor,
403+
text_ids: torch.Tensor,
404+
guidance: torch.Tensor,
405+
image_emb: torch.Tensor | None = None,
406+
controlnet_double_block_output: List[torch.Tensor] | None = None,
407+
controlnet_single_block_output: List[torch.Tensor] | None = None,
408408
**kwargs,
409409
):
410-
h, w = hidden_states.shape[-2:]
411-
if image_ids is None:
412-
image_ids = self.prepare_image_ids(hidden_states)
410+
image_seq_len = hidden_states.shape[1]
413411
controlnet_double_block_output = (
414412
controlnet_double_block_output if controlnet_double_block_output is not None else ()
415413
)
@@ -428,10 +426,10 @@ def forward(
428426
timestep,
429427
prompt_emb,
430428
pooled_prompt_emb,
431-
image_emb,
432-
guidance,
433-
text_ids,
434429
image_ids,
430+
text_ids,
431+
guidance,
432+
image_emb,
435433
*controlnet_double_block_output,
436434
*controlnet_single_block_output,
437435
),
@@ -448,7 +446,6 @@ def forward(
448446
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
449447
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
450448
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
451-
hidden_states = self.patchify(hidden_states)
452449

453450
with sequence_parallel(
454451
(
@@ -489,9 +486,8 @@ def forward(
489486
hidden_states = hidden_states[:, prompt_emb.shape[1] :]
490487
hidden_states = self.final_norm_out(hidden_states, conditioning)
491488
hidden_states = self.final_proj_out(hidden_states)
492-
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
489+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,))
493490

494-
hidden_states = self.unpatchify(hidden_states, h, w)
495491
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
496492
return hidden_states
497493

diffsynth_engine/models/flux/flux_dit_fbcache.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import numpy as np
3-
from typing import Any, Dict, Optional
3+
from typing import Any, Dict, List, Optional
44

55
from diffsynth_engine.utils.gguf import gguf_inference
66
from diffsynth_engine.utils.fp8_linear import fp8_inference
@@ -48,21 +48,19 @@ def refresh_cache_status(self, num_inference_steps):
4848

4949
def forward(
5050
self,
51-
hidden_states,
52-
timestep,
53-
prompt_emb,
54-
pooled_prompt_emb,
55-
image_emb,
56-
guidance,
57-
text_ids,
58-
image_ids=None,
59-
controlnet_double_block_output=None,
60-
controlnet_single_block_output=None,
51+
hidden_states: torch.Tensor,
52+
timestep: torch.Tensor,
53+
prompt_emb: torch.Tensor,
54+
pooled_prompt_emb: torch.Tensor,
55+
image_ids: torch.Tensor,
56+
text_ids: torch.Tensor,
57+
guidance: torch.Tensor,
58+
image_emb: torch.Tensor | None = None,
59+
controlnet_double_block_output: List[torch.Tensor] | None = None,
60+
controlnet_single_block_output: List[torch.Tensor] | None = None,
6161
**kwargs,
6262
):
63-
h, w = hidden_states.shape[-2:]
64-
if image_ids is None:
65-
image_ids = self.prepare_image_ids(hidden_states)
63+
image_seq_len = hidden_states.shape[1]
6664
controlnet_double_block_output = (
6765
controlnet_double_block_output if controlnet_double_block_output is not None else ()
6866
)
@@ -81,10 +79,10 @@ def forward(
8179
timestep,
8280
prompt_emb,
8381
pooled_prompt_emb,
84-
image_emb,
85-
guidance,
86-
text_ids,
8782
image_ids,
83+
text_ids,
84+
guidance,
85+
image_emb,
8886
*controlnet_double_block_output,
8987
*controlnet_single_block_output,
9088
),
@@ -101,7 +99,6 @@ def forward(
10199
rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
102100
text_rope_emb = rope_emb[:, :, : text_ids.size(1)]
103101
image_rope_emb = rope_emb[:, :, text_ids.size(1) :]
104-
hidden_states = self.patchify(hidden_states)
105102

106103
with sequence_parallel(
107104
(
@@ -131,7 +128,7 @@ def forward(
131128
first_hidden_states_residual = hidden_states - original_hidden_states
132129

133130
(first_hidden_states_residual,) = sequence_parallel_unshard(
134-
(first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,)
131+
(first_hidden_states_residual,), seq_dims=(1,), seq_lens=(image_seq_len,)
135132
)
136133

137134
if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1):
@@ -172,9 +169,8 @@ def forward(
172169

173170
hidden_states = self.final_norm_out(hidden_states, conditioning)
174171
hidden_states = self.final_proj_out(hidden_states)
175-
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
172+
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,))
176173

177-
hidden_states = self.unpatchify(hidden_states, h, w)
178174
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
179175

180176
return hidden_states

diffsynth_engine/models/flux/flux_vae.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,29 @@ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.
2525
new_state_dict[name_] = param
2626
return new_state_dict
2727

28+
def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
29+
rename_dict = config["diffusers"]["rename_dict"]
30+
new_state_dict = {}
31+
for name, param in state_dict.items():
32+
if name not in rename_dict:
33+
continue
34+
name_ = rename_dict[name]
35+
if "transformer_blocks" in name_:
36+
param = param.squeeze()
37+
new_state_dict[name_] = param
38+
return new_state_dict
39+
2840
def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
2941
assert self.has_decoder or self.has_encoder, "Either decoder or encoder must be present"
30-
if "decoder.conv_in.weight" in state_dict or "encoder.conv_in.weight" in state_dict:
42+
if "decoder.up.0.block.0.conv1.weight" in state_dict or "encoder.down.0.block.0.conv1.weight" in state_dict:
3143
state_dict = self._from_civitai(state_dict)
3244
logger.info("use civitai format state dict")
45+
elif (
46+
"decoder.up_blocks.0.resnets.0.conv1.weight" in state_dict
47+
or "encoder.down_blocks.0.resnets.0.conv1.weight" in state_dict
48+
):
49+
state_dict = self._from_diffusers(state_dict)
50+
logger.info("use diffusers format state dict")
3351
else:
3452
logger.info("use diffsynth format state dict")
3553
return self._filter(state_dict)

0 commit comments

Comments
 (0)