From cd639f9eb25e58d48a3d6c3a2628d0af01785016 Mon Sep 17 00:00:00 2001 From: junsong Date: Mon, 2 Mar 2026 02:00:33 -0800 Subject: [PATCH 1/2] add ltx2 vae in sana-video; --- .../sana_video/pipeline_sana_video.py | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py index 8b44dfc1143c..c8ffa176430d 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -223,8 +223,25 @@ def __init__( tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler ) - self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8 + if getattr(self, "vae", None): + if hasattr(self.vae.config, "scale_factor_temporal"): + self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal + elif hasattr(self.vae.config, "temporal_compression_ratio"): + # LTX2 VAE uses temporal_compression_ratio + self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio + else: + self.vae_scale_factor_temporal = getattr(self.vae, "temporal_compression_ratio", 4) + + if hasattr(self.vae.config, "scale_factor_spatial"): + self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial + elif hasattr(self.vae.config, "spatial_compression_ratio"): + # LTX2 VAE uses spatial_compression_ratio + self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio + else: + self.vae_scale_factor_spatial = getattr(self.vae, "spatial_compression_ratio", 8) + else: + self.vae_scale_factor_temporal = 4 + self.vae_scale_factor_spatial = 8 self.vae_scale_factor = self.vae_scale_factor_spatial @@ -985,14 +1002,28 @@ def __call__( if is_torch_version(">=", "2.5.0") else torch_accelerator_module.OutOfMemoryError ) - latents_mean = ( - torch.tensor(self.vae.config.latents_mean) - .view(1, self.vae.config.z_dim, 1, 1, 1) - .to(latents.device, latents.dtype) - ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + if hasattr(self.vae.config, "latents_mean"): + latents_mean = torch.tensor(self.vae.config.latents_mean) + elif getattr(self.vae, "latents_mean", None) is not None: + latents_mean = self.vae.latents_mean + else: + latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype) + + if hasattr(self.vae.config, "latents_std"): + latents_std = torch.tensor(self.vae.config.latents_std) + elif getattr(self.vae, "latents_std", None) is not None: + latents_std = self.vae.latents_std + else: + latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype) + + z_dim = getattr(self.vae.config, "z_dim", None) + if z_dim is None: + z_dim = getattr(self.vae.config, "latent_channels", None) + if z_dim is None: + z_dim = latents.shape[1] + + latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) + latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean try: video = self.vae.decode(latents, return_dict=False)[0] From c03b7398c682fe3e96f2e68ac0b9aa11509bd40b Mon Sep 17 00:00:00 2001 From: junsong Date: Sun, 8 Mar 2026 21:58:43 -0700 Subject: [PATCH 2/2] add ltx vae in conversion script; --- scripts/convert_sana_video_to_diffusers.py | 49 ++++++++++++++++++---- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py index a939a06cbd46..c6be52d455b8 100644 --- a/scripts/convert_sana_video_to_diffusers.py +++ b/scripts/convert_sana_video_to_diffusers.py @@ -12,6 +12,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from diffusers import ( + AutoencoderKLLTX2Video, AutoencoderKLWan, DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, @@ -24,7 +25,10 @@ CTX = init_empty_weights if is_accelerate_available else nullcontext -ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"] +ckpt_ids = [ + "Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth", + "Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth", +] # https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py @@ -92,12 +96,22 @@ def main(args): if args.video_size == 480: sample_size = 30 # Wan-VAE: 8xp2 downsample factor patch_size = (1, 2, 2) + in_channels = 16 + out_channels = 16 elif args.video_size == 720: - sample_size = 22 # Wan-VAE: 32xp1 downsample factor + sample_size = 22 # DC-AE-V: 32xp1 downsample factor patch_size = (1, 1, 1) + in_channels = 32 + out_channels = 32 else: raise ValueError(f"Video size {args.video_size} is not supported.") + if args.vae_type == "ltx2": + sample_size = 22 + patch_size = (1, 1, 1) + in_channels = 128 + out_channels = 128 + for depth in range(layer_num): # Transformer blocks. converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( @@ -182,8 +196,8 @@ def main(args): # Transformer with CTX(): transformer_kwargs = { - "in_channels": 16, - "out_channels": 16, + "in_channels": in_channels, + "out_channels": out_channels, "num_attention_heads": 20, "attention_head_dim": 112, "num_layers": 20, @@ -235,9 +249,12 @@ def main(args): else: print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - vae = AutoencoderKLWan.from_pretrained( - "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32 - ) + if args.vae_type == "ltx2": + vae_path = args.vae_path or "Lightricks/LTX-2" + vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) + else: + vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" + vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32) # Text Encoder text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it" @@ -314,7 +331,23 @@ def main(args): choices=["flow-dpm_solver", "flow-euler", "uni-pc"], help="Scheduler type to use.", ) - parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.") + parser.add_argument( + "--vae_type", + default="wan", + type=str, + choices=["wan", "ltx2"], + help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).", + ) + parser.add_argument( + "--vae_path", + default=None, + type=str, + required=False, + help="Optional VAE path or repo id. If not set, a default is used per VAE type.", + ) + parser.add_argument( + "--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v." + ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.") parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")