From aefecde1fbac444bf6098bf3e90459f9fd44246c Mon Sep 17 00:00:00 2001 From: Arya Date: Sun, 15 Mar 2026 13:17:32 +0800 Subject: [PATCH] Add DiffSynth blockwise ControlNet support to QwenImageControlNetModel Extends the existing QwenImageControlNetModel with a `controlnet_block_type` config parameter to support the DiffSynth blockwise ControlNet architecture alongside the existing InstantX linear projection approach. The blockwise variant uses BlockWiseControlBlock modules (RMSNorm + MLP) that fuse base hidden states with control features at each transformer block, instead of the simple zero-initialized linear projections used by InstantX. Closes #12221 --- ...synth_blockwise_controlnet_to_diffusers.py | 122 ++++++++++++++++++ .../controlnets/controlnet_qwenimage.py | 51 +++++++- .../qwenimage/test_qwenimage_controlnet.py | 51 ++++++++ 3 files changed, 217 insertions(+), 7 deletions(-) create mode 100644 scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py diff --git a/scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py b/scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py new file mode 100644 index 000000000000..b76f5145ccd1 --- /dev/null +++ b/scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py @@ -0,0 +1,122 @@ +""" +A script to convert DiffSynth-Studio Blockwise ControlNet checkpoints to the Diffusers format. + +The DiffSynth checkpoints only contain the ControlNet-specific weights (controlnet_blocks + img_in). +The transformer backbone weights are loaded from the base Qwen-Image model. + +Example: + Convert using HuggingFace repo: + ```bash + python scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py \ + --original_state_dict_repo_id "DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny" \ + --filename "model.safetensors" \ + --transformer_repo_id "Qwen/Qwen-Image" \ + --output_path "output/qwenimage-blockwise-controlnet-canny" \ + --dtype "bf16" + ``` + + Or convert from a local file: + ```bash + python scripts/convert_diffsynth_blockwise_controlnet_to_diffusers.py \ + --checkpoint_path "path/to/model.safetensors" \ + --transformer_repo_id "Qwen/Qwen-Image" \ + --output_path "output/qwenimage-blockwise-controlnet-canny" \ + --dtype "bf16" + ``` + +Note: + Available DiffSynth blockwise ControlNet checkpoints: + - DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny + - DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth + - DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint +""" + +import argparse + +import safetensors.torch +import torch +from huggingface_hub import hf_hub_download + +from diffusers import QwenImageControlNetModel, QwenImageTransformer2DModel + + +parser = argparse.ArgumentParser() +parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file") +parser.add_argument( + "--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID for the blockwise checkpoint" +) +parser.add_argument("--filename", type=str, default="model.safetensors", help="Filename in the HF repo") +parser.add_argument( + "--transformer_repo_id", + type=str, + default="Qwen/Qwen-Image", + help="HuggingFace repo ID for the base transformer model", +) +parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model") +parser.add_argument( + "--dtype", type=str, default="bf16", help="Data type for the converted model (fp16, bf16, or fp32)" +) + +args = parser.parse_args() + + +def load_original_checkpoint(args): + if args.original_state_dict_repo_id is not None: + print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}") + ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename) + elif args.checkpoint_path is not None: + print(f"Loading checkpoint from local path: {args.checkpoint_path}") + ckpt_path = args.checkpoint_path + else: + raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`") + + original_state_dict = safetensors.torch.load_file(ckpt_path) + return original_state_dict + + +def main(args): + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + elif args.dtype == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32") + + # Load base transformer + print(f"Loading base transformer from {args.transformer_repo_id}...") + transformer = QwenImageTransformer2DModel.from_pretrained( + args.transformer_repo_id, subfolder="transformer", torch_dtype=dtype + ) + + # Create controlnet from transformer (copies backbone weights) + print("Creating blockwise ControlNet from transformer...") + controlnet = QwenImageControlNetModel.from_transformer( + transformer, + num_layers=transformer.config.num_layers, + attention_head_dim=transformer.config.attention_head_dim, + num_attention_heads=transformer.config.num_attention_heads, + controlnet_block_type="blockwise", + ) + + # Load DiffSynth blockwise weights (controlnet_blocks + img_in only) + original_ckpt = load_original_checkpoint(args) + missing, unexpected = controlnet.load_state_dict(original_ckpt, strict=False) + + # Verify: only transformer backbone keys should be missing, no unexpected keys + print(f"Missing keys (expected - backbone from transformer): {len(missing)}") + print(f"Unexpected keys (should be 0): {len(unexpected)}") + if unexpected: + print(f"WARNING: Unexpected keys found: {unexpected}") + + # Free transformer memory + del transformer + + print(f"Saving blockwise ControlNet in Diffusers format to {args.output_path}") + controlnet.to(dtype).save_pretrained(args.output_path) + print("Done!") + + +if __name__ == "__main__": + main(args) diff --git a/src/diffusers/models/controlnets/controlnet_qwenimage.py b/src/diffusers/models/controlnets/controlnet_qwenimage.py index cfe7c159ad89..8337ad0771a7 100644 --- a/src/diffusers/models/controlnets/controlnet_qwenimage.py +++ b/src/diffusers/models/controlnets/controlnet_qwenimage.py @@ -1,4 +1,5 @@ -# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. +# Copyright 2025 Black Forest Labs, The HuggingFace Team, The InstantX Team and The DiffSynth Team. +# All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,6 +44,29 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class BlockWiseControlBlock(nn.Module): + """A control block that fuses base hidden states with control features via RMSNorm + MLP. + + Used by the DiffSynth blockwise ControlNet variant. Unlike the linear projection used by InstantX, + this block normalizes both inputs separately before fusing them through a gated linear projection. + """ + + def __init__(self, dim: int = 3072): + super().__init__() + self.x_rms = RMSNorm(dim, eps=1e-6) + self.y_rms = RMSNorm(dim, eps=1e-6) + self.input_proj = nn.Linear(dim, dim) + self.act = nn.GELU() + self.output_proj = zero_module(nn.Linear(dim, dim)) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x, y = self.x_rms(x), self.y_rms(y) + x = self.input_proj(x + y) + x = self.act(x) + x = self.output_proj(x) + return x + + @dataclass class QwenImageControlNetOutput(BaseOutput): controlnet_block_samples: tuple[torch.Tensor] @@ -65,6 +89,7 @@ def __init__( joint_attention_dim: int = 3584, axes_dims_rope: tuple[int, int, int] = (16, 56, 56), extra_condition_channels: int = 0, # for controlnet-inpainting + controlnet_block_type: str = "linear", ): super().__init__() self.out_channels = out_channels or in_channels @@ -92,8 +117,12 @@ def __init__( # controlnet_blocks self.controlnet_blocks = nn.ModuleList([]) - for _ in range(len(self.transformer_blocks)): - self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) + if controlnet_block_type == "blockwise": + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(BlockWiseControlBlock(self.inner_dim)) + else: + for _ in range(len(self.transformer_blocks)): + self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) self.controlnet_x_embedder = zero_module( torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim) ) @@ -109,12 +138,14 @@ def from_transformer( num_attention_heads: int = 24, load_weights_from_transformer=True, extra_condition_channels: int = 0, + controlnet_block_type: str = "linear", ): config = dict(transformer.config) config["num_layers"] = num_layers config["attention_head_dim"] = attention_head_dim config["num_attention_heads"] = num_attention_heads config["extra_condition_channels"] = extra_condition_channels + config["controlnet_block_type"] = controlnet_block_type controlnet = cls.from_config(config) @@ -190,7 +221,8 @@ def forward( hidden_states = self.img_in(hidden_states) # add - hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) + controlnet_cond_embed = self.controlnet_x_embedder(controlnet_cond) + hidden_states = hidden_states + controlnet_cond_embed temb = self.time_text_embed(timestep, hidden_states) @@ -240,9 +272,14 @@ def forward( # controlnet block controlnet_block_samples = () - for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): - block_sample = controlnet_block(block_sample) - controlnet_block_samples = controlnet_block_samples + (block_sample,) + if self.config.controlnet_block_type == "blockwise": + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample, controlnet_cond_embed) + controlnet_block_samples = controlnet_block_samples + (block_sample,) + else: + for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): + block_sample = controlnet_block(block_sample) + controlnet_block_samples = controlnet_block_samples + (block_sample,) # scaling controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] diff --git a/tests/pipelines/qwenimage/test_qwenimage_controlnet.py b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py index 59a2dd497184..26e126e8c3bb 100644 --- a/tests/pipelines/qwenimage/test_qwenimage_controlnet.py +++ b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py @@ -295,6 +295,57 @@ def test_attention_slicing_forward_pass( def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1) + def test_qwen_blockwise_controlnet(self): + device = "cpu" + components = self.get_dummy_components() + + # Replace controlnet with blockwise variant + torch.manual_seed(0) + blockwise_controlnet = QwenImageControlNetModel( + patch_size=2, + in_channels=16, + out_channels=4, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + joint_attention_dim=16, + axes_dims_rope=(8, 4, 4), + controlnet_block_type="blockwise", + ) + components["controlnet"] = blockwise_controlnet + + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + + def test_qwen_blockwise_controlnet_from_transformer(self): + device = "cpu" + components = self.get_dummy_components() + transformer = components["transformer"] + + blockwise_controlnet = QwenImageControlNetModel.from_transformer( + transformer, + num_layers=2, + attention_head_dim=16, + num_attention_heads=3, + controlnet_block_type="blockwise", + ) + components["controlnet"] = blockwise_controlnet + + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + generated_image = image[0] + self.assertEqual(generated_image.shape, (3, 32, 32)) + def test_vae_tiling(self, expected_diff_max: float = 0.2): generator_device = "cpu" components = self.get_dummy_components()