Skip to content

Commit 31dce8f

Browse files
authored
support flux.2-klein (#227)
* support flux.2-klein * rename Flux2 to Flux2Klein * rename flux2_image to flux2_klein_image * fix commited issues * fix commited issues
1 parent 8e14013 commit 31dce8f

22 files changed

Lines changed: 4049 additions & 1 deletion

diffsynth_engine/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
QwenImagePipelineConfig,
88
HunyuanPipelineConfig,
99
ZImagePipelineConfig,
10+
Flux2KleinPipelineConfig,
1011
SDStateDicts,
1112
SDXLStateDicts,
1213
FluxStateDicts,
1314
WanStateDicts,
1415
QwenImageStateDicts,
1516
ZImageStateDicts,
17+
Flux2StateDicts,
1618
AttnImpl,
1719
SpargeAttentionParams,
1820
VideoSparseAttentionParams,
@@ -26,6 +28,7 @@
2628
SDImagePipeline,
2729
SDXLImagePipeline,
2830
FluxImagePipeline,
31+
Flux2KleinPipeline,
2932
WanVideoPipeline,
3033
WanDMDPipeline,
3134
QwenImagePipeline,
@@ -59,12 +62,14 @@
5962
"QwenImagePipelineConfig",
6063
"HunyuanPipelineConfig",
6164
"ZImagePipelineConfig",
65+
"Flux2KleinPipelineConfig",
6266
"SDStateDicts",
6367
"SDXLStateDicts",
6468
"FluxStateDicts",
6569
"WanStateDicts",
6670
"QwenImageStateDicts",
6771
"ZImageStateDicts",
72+
"Flux2StateDicts",
6873
"AttnImpl",
6974
"SpargeAttentionParams",
7075
"VideoSparseAttentionParams",
@@ -78,6 +83,7 @@
7883
"SDXLImagePipeline",
7984
"SDXLControlNetUnion",
8085
"FluxImagePipeline",
86+
"Flux2KleinPipeline",
8187
"FluxControlNet",
8288
"FluxIPAdapter",
8389
"FluxRedux",
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
{
2+
"architectures": [
3+
"Qwen3ForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"bos_token_id": 151643,
8+
"dtype": "bfloat16",
9+
"eos_token_id": 151645,
10+
"head_dim": 128,
11+
"hidden_act": "silu",
12+
"hidden_size": 4096,
13+
"initializer_range": 0.02,
14+
"intermediate_size": 12288,
15+
"layer_types": [
16+
"full_attention",
17+
"full_attention",
18+
"full_attention",
19+
"full_attention",
20+
"full_attention",
21+
"full_attention",
22+
"full_attention",
23+
"full_attention",
24+
"full_attention",
25+
"full_attention",
26+
"full_attention",
27+
"full_attention",
28+
"full_attention",
29+
"full_attention",
30+
"full_attention",
31+
"full_attention",
32+
"full_attention",
33+
"full_attention",
34+
"full_attention",
35+
"full_attention",
36+
"full_attention",
37+
"full_attention",
38+
"full_attention",
39+
"full_attention",
40+
"full_attention",
41+
"full_attention",
42+
"full_attention",
43+
"full_attention",
44+
"full_attention",
45+
"full_attention",
46+
"full_attention",
47+
"full_attention",
48+
"full_attention",
49+
"full_attention",
50+
"full_attention",
51+
"full_attention"
52+
],
53+
"max_position_embeddings": 40960,
54+
"max_window_layers": 36,
55+
"model_type": "qwen3",
56+
"num_attention_heads": 32,
57+
"num_hidden_layers": 36,
58+
"num_key_value_heads": 8,
59+
"rms_norm_eps": 1e-06,
60+
"rope_scaling": null,
61+
"rope_theta": 1000000,
62+
"sliding_window": null,
63+
"tie_word_embeddings": false,
64+
"transformers_version": "4.56.1",
65+
"use_cache": true,
66+
"use_sliding_window": false,
67+
"vocab_size": 151936
68+
}

diffsynth_engine/configs/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
QwenImagePipelineConfig,
1212
HunyuanPipelineConfig,
1313
ZImagePipelineConfig,
14+
Flux2KleinPipelineConfig,
1415
BaseStateDicts,
1516
SDStateDicts,
1617
SDXLStateDicts,
@@ -19,6 +20,7 @@
1920
WanS2VStateDicts,
2021
QwenImageStateDicts,
2122
ZImageStateDicts,
23+
Flux2StateDicts,
2224
AttnImpl,
2325
SpargeAttentionParams,
2426
VideoSparseAttentionParams,
@@ -44,6 +46,7 @@
4446
"QwenImagePipelineConfig",
4547
"HunyuanPipelineConfig",
4648
"ZImagePipelineConfig",
49+
"Flux2KleinPipelineConfig",
4750
"BaseStateDicts",
4851
"SDStateDicts",
4952
"SDXLStateDicts",
@@ -52,6 +55,7 @@
5255
"WanS2VStateDicts",
5356
"QwenImageStateDicts",
5457
"ZImageStateDicts",
58+
"Flux2StateDicts",
5559
"AttnImpl",
5660
"SpargeAttentionParams",
5761
"VideoSparseAttentionParams",

diffsynth_engine/configs/pipeline.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from enum import Enum
44
from dataclasses import dataclass, field
55
from typing import List, Dict, Tuple, Optional
6+
from typing_extensions import Literal
67

78
from diffsynth_engine.configs.controlnet import ControlType
89

@@ -339,6 +340,47 @@ def __post_init__(self):
339340
init_parallel_config(self)
340341

341342

343+
@dataclass
344+
class Flux2KleinPipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfig, BaseConfig):
345+
model_path: str | os.PathLike | List[str | os.PathLike]
346+
model_dtype: torch.dtype = torch.bfloat16
347+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
348+
vae_dtype: torch.dtype = torch.bfloat16
349+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
350+
encoder_dtype: torch.dtype = torch.bfloat16
351+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None
352+
image_encoder_dtype: torch.dtype = torch.bfloat16
353+
model_size: Literal["4B", "9B"] = "4B"
354+
355+
@classmethod
356+
def basic_config(
357+
cls,
358+
model_path: str | os.PathLike | List[str | os.PathLike],
359+
encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
360+
vae_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
361+
image_encoder_path: Optional[str | os.PathLike | List[str | os.PathLike]] = None,
362+
device: str = "cuda",
363+
parallelism: int = 1,
364+
offload_mode: Optional[str] = None,
365+
offload_to_disk: bool = False,
366+
) -> "Flux2KleinPipelineConfig":
367+
return cls(
368+
model_path=model_path,
369+
device=device,
370+
encoder_path=encoder_path,
371+
vae_path=vae_path,
372+
image_encoder_path=image_encoder_path,
373+
parallelism=parallelism,
374+
use_cfg_parallel=True if parallelism > 1 else False,
375+
use_fsdp=True if parallelism > 1 else False,
376+
offload_mode=offload_mode,
377+
offload_to_disk=offload_to_disk,
378+
)
379+
380+
def __post_init__(self):
381+
init_parallel_config(self)
382+
383+
342384
@dataclass
343385
class BaseStateDicts:
344386
pass
@@ -398,7 +440,14 @@ class ZImageStateDicts:
398440
image_encoder: Optional[Dict[str, torch.Tensor]] = None
399441

400442

401-
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig):
443+
@dataclass
444+
class Flux2StateDicts:
445+
model: Dict[str, torch.Tensor]
446+
vae: Dict[str, torch.Tensor]
447+
encoder: Dict[str, torch.Tensor]
448+
449+
450+
def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig | WanPipelineConfig | ZImagePipelineConfig | Flux2KleinPipelineConfig):
402451
assert config.parallelism in (1, 2, 4, 8), "parallelism must be 1, 2, 4 or 8"
403452
config.batch_cfg = True if config.parallelism > 1 and config.use_cfg_parallel else config.batch_cfg
404453

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .flux2_dit import Flux2DiT
2+
from .flux2_vae import Flux2VAE
3+
4+
__all__ = [
5+
"Flux2DiT",
6+
"Flux2VAE",
7+
]

0 commit comments

Comments
 (0)