From 5c321ab138344545e2bc69cbd684d1b8715e5310 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Thu, 12 Mar 2026 19:58:37 +0000 Subject: [PATCH] [not for land yet]: improve cuda graph support for Qwen-Image Summary: Very brief writeup as I'm about to head out for the day: 1. we want to enable cuda graphs for qwen-image + nvfp4 at small batch sizes, because without cuda graphs are we bottlenecked on cpu ops 2. to make cuda graphs work, we need to change the modeling code a bit to match the cuda graph requirements There is a cleaner way to do this change repo-wide without having to change each model's modeling code, for now this is just a quick hack to demonstrate performarnce + accuracy Test Plan: use a modified version of @sayakpaul's script: https://gist.github.com/vkuzo/acac22c62404c89db2dcf195a64543db then, run it and see nvfp4 + bsz 1 time on qwen image improve by ~1.6x from 9.5s to 5.9s ``` // baseline (pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --torch_compile_mode reduce-overhead ... ====================================================================== SUMMARY ====================================================================== Quantization: None Compile: True Batch size: 1 Latency: 7.461s Peak Memory: 62.21 GB // nvfp4 dynamic, torch.compile default (pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --quant dynamic --use_filter_fn True ... ====================================================================== SUMMARY ====================================================================== Quantization: dynamic Compile: True Batch size: 1 Latency: 9.536s Peak Memory: 52.45 GB ====================================================================== // nvfp4 dynamic, torch.compile reduce-overhead (for cuda graphs) (pt_nightly) dev@gpu-dev-6c281422:~/tmp$ python 20260212_diffuser_nvfp4.py --compile True --quant dynamic --use_filter_fn True --torch_compile_mode reduce-overhead ... ====================================================================== SUMMARY ====================================================================== Quantization: dynamic Compile: True Batch size: 1 Latency: 5.936s Peak Memory: 52.45 GB ====================================================================== ``` --- src/diffusers/models/transformers/transformer_qwenimage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index a54cb3b8e092..0c887596109d 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -951,8 +951,8 @@ def forward( else: encoder_hidden_states, hidden_states = block( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, + hidden_states=hidden_states.clone(), + encoder_hidden_states=encoder_hidden_states.clone(), encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead) temb=temb, image_rotary_emb=image_rotary_emb,