Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ NVIDIA Model Optimizer Changelog
**Bug Fixes**

- ONNX Runtime dependency upgraded to 1.24 to solve missing graph outputs when using the TensorRT Execution Provider.
- Fix broken ``AttentionModuleMixin`` import in the diffusers quantization example (path changed from
``modelopt.torch.quantization.plugins.diffusers`` to ``modelopt.torch.quantization.plugins.diffusion.diffusers``
in 0.42.0).

**New Features**

- Add Z-Image (``Tongyi-MAI/Z-Image``) and Z-Image-Turbo quantization support in the diffusers quantization
example. Supports FP8, INT8, and INT4 PTQ on the NextDiT/Lumina2 transformer backbone via the new
``--model zimage`` and ``--model zimage-turbo`` CLI options.
- Add ``ZImageTransformer2DModel`` dummy input generation in ``generate_diffusion_dummy_inputs()``, resolving
a QKV fusion warning during HF checkpoint export for Z-Image.
- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow.
- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory.
- Add ``--moe_calib_experts_ratio`` flag in ``hf_ptq.py`` to specify the ratio of experts to calibrate during forward pass to improve expert coverage during calibration. Default to all the experts.
Expand Down
30 changes: 30 additions & 0 deletions examples/diffusers/quantization/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
LTXConditionPipeline,
StableDiffusion3Pipeline,
WanPipeline,
ZImagePipeline,
)
from utils import (
filter_func_default,
filter_func_flux_dev,
filter_func_ltx_video,
filter_func_wan_video,
filter_func_zimage,
)


Expand All @@ -46,6 +48,8 @@ class ModelType(str, Enum):
LTX2 = "ltx-2"
WAN22_T2V_14b = "wan2.2-t2v-14b"
WAN22_T2V_5b = "wan2.2-t2v-5b"
ZIMAGE = "zimage"
ZIMAGE_TURBO = "zimage-turbo"


def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
Expand All @@ -69,6 +73,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: filter_func_ltx_video,
ModelType.WAN22_T2V_14b: filter_func_wan_video,
ModelType.WAN22_T2V_5b: filter_func_wan_video,
ModelType.ZIMAGE: filter_func_zimage,
ModelType.ZIMAGE_TURBO: filter_func_zimage,
}

return filter_func_map.get(model_type, filter_func_default)
Expand All @@ -86,6 +92,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: "Lightricks/LTX-2",
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
ModelType.ZIMAGE: "Tongyi-MAI/Z-Image",
ModelType.ZIMAGE_TURBO: "Tongyi-MAI/Z-Image-Turbo",
}

MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
Expand All @@ -99,6 +107,8 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.LTX2: None,
ModelType.WAN22_T2V_14b: WanPipeline,
ModelType.WAN22_T2V_5b: WanPipeline,
ModelType.ZIMAGE: ZImagePipeline,
ModelType.ZIMAGE_TURBO: ZImagePipeline,
}

# Shared dataset configurations
Expand Down Expand Up @@ -141,6 +151,17 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
"dataset": _OPENVID_DATASET,
}

_ZIMAGE_BASE_CONFIG: dict[str, Any] = {
"backbone": "transformer",
"dataset": _SD_PROMPTS_DATASET,
"inference_extra_args": {
"height": 512,
"width": 512,
"guidance_scale": 4.0,
"cfg_normalization": False,
},
}

# Model-specific default arguments for calibration
MODEL_DEFAULTS: dict[ModelType, dict[str, Any]] = {
ModelType.SDXL_BASE: _SDXL_BASE_CONFIG,
Expand Down Expand Up @@ -207,6 +228,15 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
),
},
},
ModelType.ZIMAGE: _ZIMAGE_BASE_CONFIG,
ModelType.ZIMAGE_TURBO: {
**_ZIMAGE_BASE_CONFIG,
"inference_extra_args": {
**_ZIMAGE_BASE_CONFIG["inference_extra_args"],
"guidance_scale": 1.0,
# num_inference_steps intentionally omitted — pass --n-steps 8 on the CLI
},
},
}


Expand Down
20 changes: 20 additions & 0 deletions examples/diffusers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ def filter_func_wan_video(name: str) -> bool:
return pattern.match(name) is not None


def filter_func_zimage(name: str) -> bool:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested the accuracy of this recipe? You could also try disabling quantization for the first 2 and last 2 layers, keeping them in their original precision.

"""Filter function for Z-Image (NextDiT / S3-DiT backbone).

Returns True for layers that should NOT be quantized.
Skips: patch embedder, timestep/caption embedding, final projection, norms.
Quantizes: all JointAttention (qkv, out) and FFN (w1, w2, w3) linears.
"""
pattern = re.compile(
r".*("
r"x_embedder" # patch embedding: image patches → dim
r"|final_layer" # output projection: dim -> patch_size^2 x channels
r"|time_caption_embed" # timestep + caption conditioning MLP
r"|cap_embedder" # caption projection (if present)
r"|norm" # all RMSNorm / LayerNorm weight tensors
r"|pos_embed" # any fixed positional embeddings
r").*"
)
return pattern.match(name) is not None


def load_calib_prompts(
batch_size,
calib_data_path: str | Path = "Gustavosta/Stable-Diffusion-Prompts",
Expand Down
30 changes: 30 additions & 0 deletions modelopt/torch/export/diffusers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ def _is_model_type(module_path: str, class_name: str, fallback: bool) -> bool:
"UNet2DConditionModel",
"unet" in model_class_name.lower(),
)
is_zimage = _is_model_type(
"diffusers.models.transformers",
"ZImageTransformer2DModel",
model_class_name == "ZImageTransformer2DModel",
)

cfg = getattr(model, "config", None)

Expand Down Expand Up @@ -274,6 +279,30 @@ def _wan_inputs() -> dict[str, torch.Tensor]:
"return_dict": False,
}

def _zimage_inputs() -> dict[str, torch.Tensor]:
# ZImageTransformer2DModel (NextDiT): 3D hidden_states (batch, seq_len, in_channels)
# Requires: hidden_states, timestep, encoder_hidden_states
in_channels = getattr(cfg, "in_channels", 16)
# Qwen3-4B encoder hidden size; fall back to caption_channels if present
encoder_hidden_size = getattr(
cfg, "encoder_hidden_size", getattr(cfg, "caption_channels", 2560)
)

# Small seq_len: 4x4 patch grid
img_seq_len = 16
text_seq_len = 8

return {
"hidden_states": torch.randn(
batch_size, img_seq_len, in_channels, device=device, dtype=dtype
),
"timestep": torch.tensor([0.5], device=device, dtype=dtype).expand(batch_size),
"encoder_hidden_states": torch.randn(
batch_size, text_seq_len, encoder_hidden_size, device=device, dtype=dtype
),
"return_dict": False,
}

def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None:
# Try generic transformer handling for other model types
# Check if model has common transformer attributes
Expand Down Expand Up @@ -318,6 +347,7 @@ def _generic_transformer_inputs() -> dict[str, torch.Tensor] | None:
("dit", is_dit, _dit_inputs),
("wan", is_wan, _wan_inputs),
("unet", is_unet, _unet_inputs),
("zimage", is_zimage, _zimage_inputs),
]

for _, matches, build_inputs in model_input_builders:
Expand Down
6 changes: 6 additions & 0 deletions tests/_test_utils/examples/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,9 @@ def _select_path(remote_id: str, local_id: str) -> str:
remote_id="hf-internal-testing/tiny-sd3-pipe",
local_id="stabilityai/stable-diffusion-3-medium-diffusers",
)

# No tiny test model exists on hf-internal-testing; runs only when MODELOPT_LOCAL_MODEL_ROOT is set
ZIMAGE_PATH = _select_path(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This model might be too large for CI/CD. Is there something like a tiny version available? You could try reducing the number of layers.

remote_id="Tongyi-MAI/Z-Image",
local_id="Tongyi-MAI/Z-Image",
)
64 changes: 63 additions & 1 deletion tests/examples/diffusers/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import NamedTuple

import pytest
from _test_utils.examples.models import FLUX_SCHNELL_PATH, SD3_PATH, SDXL_1_0_PATH
from _test_utils.examples.models import FLUX_SCHNELL_PATH, SD3_PATH, SDXL_1_0_PATH, ZIMAGE_PATH
from _test_utils.examples.run_command import run_example_command
from _test_utils.torch.misc import minimum_sm

Expand Down Expand Up @@ -134,12 +134,27 @@ def inference(self, tmp_path: Path) -> None:
quant_algo="smoothquant",
collect_method="min-mean",
),
pytest.param(
DiffuserModel(
name="zimage",
path=ZIMAGE_PATH,
dtype="BFloat16",
format_type="int8",
quant_algo="smoothquant",
collect_method="min-mean",
),
marks=pytest.mark.skipif(
not __import__("os").getenv("MODELOPT_LOCAL_MODEL_ROOT"),
reason="Requires MODELOPT_LOCAL_MODEL_ROOT — no tiny Z-Image pipe on hf-internal-testing",
),
),
],
ids=[
"flux_schnell_bf16_int8_smoothquant_3.0_min_mean",
"sd3_medium_fp16_int8_smoothquant_3.0_min_mean",
"sdxl_1.0_fp16_fp8_max_3.0_default",
"sdxl_1.0_fp16_int8_smoothquant_3.0_min_mean",
"zimage_bf16_int8_smoothquant_min_mean",
],
)
def test_diffusers_quantization(
Expand Down Expand Up @@ -187,3 +202,50 @@ def test_diffusion_trt_torch(
if torch_compile:
cmd_args.append("--torch-compile")
run_example_command(cmd_args, "diffusers/quantization")


@pytest.mark.parametrize(
("name", "expected"),
[
# Layers that should be skipped (filter returns True)
("transformer.x_embedder.weight", True),
("transformer.final_layer.linear.weight", True),
("transformer.time_caption_embed.mlp.0.weight", True),
("transformer.cap_embedder.weight", True),
("transformer.layers.0.norm1.weight", True),
("transformer.pos_embed", True),
# Layers that should be quantized (filter returns False)
("transformer.layers.0.attention.wq.weight", False),
("transformer.layers.0.attention.wk.weight", False),
("transformer.layers.0.attention.wv.weight", False),
("transformer.layers.0.attention.wo.weight", False),
("transformer.layers.0.feed_forward.w1.weight", False),
("transformer.layers.0.feed_forward.w2.weight", False),
("transformer.layers.0.feed_forward.w3.weight", False),
("transformer.layers.47.feed_forward.w1.weight", False),
],
ids=[
"skip_x_embedder",
"skip_final_layer",
"skip_time_caption_embed",
"skip_cap_embedder",
"skip_norm",
"skip_pos_embed",
"quant_attn_wq",
"quant_attn_wk",
"quant_attn_wv",
"quant_attn_wo",
"quant_ffn_w1",
"quant_ffn_w2",
"quant_ffn_w3",
"quant_ffn_w1_last_layer",
],
)
def test_filter_func_zimage(name: str, expected: bool) -> None:
"""filter_func_zimage must skip conditioning/norm layers and quantize attention/FFN linears."""
import sys

sys.path.insert(0, str(Path(__file__).parents[3] / "examples/diffusers/quantization"))
from utils import filter_func_zimage

assert filter_func_zimage(name) == expected