Skip to content

Commit 197b0c8

Browse files
committed
Add support for Wan2.1 TAEHV decoding
1 parent 8f05f5b commit 197b0c8

File tree

2 files changed

+349
-29
lines changed

2 files changed

+349
-29
lines changed

stable-diffusion.cpp

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -400,8 +400,8 @@ class StableDiffusionGGML {
400400
offload_params_to_cpu,
401401
tensor_storage_map);
402402
diffusion_model = std::make_shared<MMDiTModel>(backend,
403-
offload_params_to_cpu,
404-
tensor_storage_map);
403+
offload_params_to_cpu,
404+
tensor_storage_map);
405405
} else if (sd_version_is_flux(version)) {
406406
bool is_chroma = false;
407407
for (auto pair : tensor_storage_map) {
@@ -461,10 +461,10 @@ class StableDiffusionGGML {
461461
1,
462462
true);
463463
diffusion_model = std::make_shared<WanModel>(backend,
464-
offload_params_to_cpu,
465-
tensor_storage_map,
466-
"model.diffusion_model",
467-
version);
464+
offload_params_to_cpu,
465+
tensor_storage_map,
466+
"model.diffusion_model",
467+
version);
468468
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
469469
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
470470
offload_params_to_cpu,
@@ -564,14 +564,27 @@ class StableDiffusionGGML {
564564
}
565565

566566
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
567-
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
568-
offload_params_to_cpu,
569-
tensor_storage_map,
570-
"first_stage_model",
571-
vae_decode_only,
572-
version);
573-
first_stage_model->alloc_params_buffer();
574-
first_stage_model->get_param_tensors(tensors, "first_stage_model");
567+
if (!use_tiny_autoencoder) {
568+
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
569+
offload_params_to_cpu,
570+
tensor_storage_map,
571+
"first_stage_model",
572+
vae_decode_only,
573+
version);
574+
first_stage_model->alloc_params_buffer();
575+
first_stage_model->get_param_tensors(tensors, "first_stage_model");
576+
} else {
577+
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
578+
offload_params_to_cpu,
579+
tensor_storage_map,
580+
"decoder",
581+
vae_decode_only,
582+
version);
583+
if (sd_ctx_params->vae_conv_direct) {
584+
LOG_INFO("Using Conv2d direct in the tae model");
585+
tae_first_stage->set_conv2d_direct_enabled(true);
586+
}
587+
}
575588
} else if (version == VERSION_CHROMA_RADIANCE) {
576589
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
577590
offload_params_to_cpu);
@@ -598,14 +611,13 @@ class StableDiffusionGGML {
598611
}
599612
first_stage_model->alloc_params_buffer();
600613
first_stage_model->get_param_tensors(tensors, "first_stage_model");
601-
}
602-
if (use_tiny_autoencoder) {
603-
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
604-
offload_params_to_cpu,
605-
tensor_storage_map,
606-
"decoder.layers",
607-
vae_decode_only,
608-
version);
614+
} else if (use_tiny_autoencoder) {
615+
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
616+
offload_params_to_cpu,
617+
tensor_storage_map,
618+
"decoder.layers",
619+
vae_decode_only,
620+
version);
609621
if (sd_ctx_params->vae_conv_direct) {
610622
LOG_INFO("Using Conv2d direct in the tae model");
611623
tae_first_stage->set_conv2d_direct_enabled(true);
@@ -726,13 +738,16 @@ class StableDiffusionGGML {
726738
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
727739
}
728740
size_t vae_params_mem_size = 0;
741+
LOG_DEBUG("Here");
729742
if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only) {
730743
vae_params_mem_size = first_stage_model->get_params_buffer_size();
731744
}
732745
if (use_tiny_autoencoder) {
746+
LOG_DEBUG("Here");
733747
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
734748
return false;
735749
}
750+
LOG_DEBUG("Here");
736751
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
737752
}
738753
size_t control_net_params_mem_size = 0;

0 commit comments

Comments
 (0)