Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions scripts/convert_sana_video_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from diffusers import (
AutoencoderKLLTX2Video,
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
Expand All @@ -24,7 +25,10 @@

CTX = init_empty_weights if is_accelerate_available else nullcontext

ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
ckpt_ids = [
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py


Expand Down Expand Up @@ -92,12 +96,22 @@ def main(args):
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
in_channels = 16
out_channels = 16
elif args.video_size == 720:
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
patch_size = (1, 1, 1)
in_channels = 32
out_channels = 32
else:
raise ValueError(f"Video size {args.video_size} is not supported.")

if args.vae_type == "ltx2":
sample_size = 22
patch_size = (1, 1, 1)
in_channels = 128
out_channels = 128

for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
Expand Down Expand Up @@ -182,8 +196,8 @@ def main(args):
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 16,
"out_channels": 16,
"in_channels": in_channels,
"out_channels": out_channels,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
Expand Down Expand Up @@ -235,9 +249,12 @@ def main(args):
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
if args.vae_type == "ltx2":
vae_path = args.vae_path or "Lightricks/LTX-2"
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
else:
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)

# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
Expand Down Expand Up @@ -314,7 +331,23 @@ def main(args):
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument(
"--vae_type",
default="wan",
type=str,
choices=["wan", "ltx2"],
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
)
parser.add_argument(
"--vae_path",
default=None,
type=str,
required=False,
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
)
parser.add_argument(
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
Expand Down
51 changes: 41 additions & 10 deletions src/diffusers/pipelines/sana_video/pipeline_sana_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,25 @@ def __init__(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)

self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if hasattr(self.vae.config, "scale_factor_temporal"):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
elif hasattr(self.vae.config, "temporal_compression_ratio"):
# LTX2 VAE uses temporal_compression_ratio
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
else:
self.vae_scale_factor_temporal = getattr(self.vae, "temporal_compression_ratio", 4)

if hasattr(self.vae.config, "scale_factor_spatial"):
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
elif hasattr(self.vae.config, "spatial_compression_ratio"):
# LTX2 VAE uses spatial_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
else:
self.vae_scale_factor_spatial = getattr(self.vae, "spatial_compression_ratio", 8)
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8

self.vae_scale_factor = self.vae_scale_factor_spatial

Expand Down Expand Up @@ -985,14 +1002,28 @@ def __call__(
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if hasattr(self.vae.config, "latents_mean"):
latents_mean = torch.tensor(self.vae.config.latents_mean)
elif getattr(self.vae, "latents_mean", None) is not None:
latents_mean = self.vae.latents_mean
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)

if hasattr(self.vae.config, "latents_std"):
latents_std = torch.tensor(self.vae.config.latents_std)
elif getattr(self.vae, "latents_std", None) is not None:
latents_std = self.vae.latents_std
else:
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)

z_dim = getattr(self.vae.config, "z_dim", None)
if z_dim is None:
z_dim = getattr(self.vae.config, "latent_channels", None)
if z_dim is None:
z_dim = latents.shape[1]

latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]
Expand Down