Skip to content

Commit fe1f6e0

Browse files
committed
Add support for Wan2.1 TAEHV decoding
1 parent 8f6c5c2 commit fe1f6e0

File tree

2 files changed

+360
-43
lines changed

2 files changed

+360
-43
lines changed

stable-diffusion.cpp

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,8 @@ class StableDiffusionGGML {
346346
offload_params_to_cpu,
347347
tensor_storage_map);
348348
diffusion_model = std::make_shared<MMDiTModel>(backend,
349-
offload_params_to_cpu,
350-
tensor_storage_map);
349+
offload_params_to_cpu,
350+
tensor_storage_map);
351351
} else if (sd_version_is_flux(version)) {
352352
bool is_chroma = false;
353353
for (auto pair : tensor_storage_map) {
@@ -389,10 +389,10 @@ class StableDiffusionGGML {
389389
1,
390390
true);
391391
diffusion_model = std::make_shared<WanModel>(backend,
392-
offload_params_to_cpu,
393-
tensor_storage_map,
394-
"model.diffusion_model",
395-
version);
392+
offload_params_to_cpu,
393+
tensor_storage_map,
394+
"model.diffusion_model",
395+
version);
396396
if (strlen(SAFE_STR(sd_ctx_params->high_noise_diffusion_model_path)) > 0) {
397397
high_noise_diffusion_model = std::make_shared<WanModel>(backend,
398398
offload_params_to_cpu,
@@ -418,10 +418,10 @@ class StableDiffusionGGML {
418418
"",
419419
enable_vision);
420420
diffusion_model = std::make_shared<QwenImageModel>(backend,
421-
offload_params_to_cpu,
422-
tensor_storage_map,
423-
"model.diffusion_model",
424-
version);
421+
offload_params_to_cpu,
422+
tensor_storage_map,
423+
"model.diffusion_model",
424+
version);
425425
} else { // SD1.x SD2.x SDXL
426426
if (strstr(SAFE_STR(sd_ctx_params->photo_maker_path), "v2")) {
427427
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend,
@@ -475,14 +475,23 @@ class StableDiffusionGGML {
475475
}
476476

477477
if (sd_version_is_wan(version) || sd_version_is_qwen_image(version)) {
478-
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
479-
offload_params_to_cpu,
480-
tensor_storage_map,
481-
"first_stage_model",
482-
vae_decode_only,
483-
version);
484-
first_stage_model->alloc_params_buffer();
485-
first_stage_model->get_param_tensors(tensors, "first_stage_model");
478+
if (!use_tiny_autoencoder) {
479+
first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
480+
offload_params_to_cpu,
481+
tensor_storage_map,
482+
"first_stage_model",
483+
vae_decode_only,
484+
version);
485+
first_stage_model->alloc_params_buffer();
486+
first_stage_model->get_param_tensors(tensors, "first_stage_model");
487+
} else {
488+
tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
489+
offload_params_to_cpu,
490+
tensor_storage_map,
491+
"decoder",
492+
vae_decode_only,
493+
version);
494+
}
486495
} else if (version == VERSION_CHROMA_RADIANCE) {
487496
first_stage_model = std::make_shared<FakeVAE>(vae_backend,
488497
offload_params_to_cpu);
@@ -510,12 +519,12 @@ class StableDiffusionGGML {
510519
first_stage_model->alloc_params_buffer();
511520
first_stage_model->get_param_tensors(tensors, "first_stage_model");
512521
} else {
513-
tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
514-
offload_params_to_cpu,
515-
tensor_storage_map,
516-
"decoder.layers",
517-
vae_decode_only,
518-
version);
522+
tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
523+
offload_params_to_cpu,
524+
tensor_storage_map,
525+
"decoder.layers",
526+
vae_decode_only,
527+
version);
519528
if (sd_ctx_params->vae_conv_direct) {
520529
LOG_INFO("Using Conv2d direct in the tae model");
521530
tae_first_stage->set_conv2d_direct_enabled(true);
@@ -625,12 +634,15 @@ class StableDiffusionGGML {
625634
unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size();
626635
}
627636
size_t vae_params_mem_size = 0;
637+
LOG_DEBUG("Here");
628638
if (!use_tiny_autoencoder) {
629639
vae_params_mem_size = first_stage_model->get_params_buffer_size();
630640
} else {
641+
LOG_DEBUG("Here");
631642
if (!tae_first_stage->load_from_file(taesd_path, n_threads)) {
632643
return false;
633644
}
645+
LOG_DEBUG("Here");
634646
vae_params_mem_size = tae_first_stage->get_params_buffer_size();
635647
}
636648
size_t control_net_params_mem_size = 0;
@@ -1428,12 +1440,12 @@ class StableDiffusionGGML {
14281440
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
14291441
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
14301442
latents_std_vec = {
1431-
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1432-
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1433-
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1434-
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1435-
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1436-
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
1443+
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1444+
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1445+
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1446+
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1447+
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1448+
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
14371449
}
14381450
for (int i = 0; i < latent->ne[3]; i++) {
14391451
float mean = latents_mean_vec[i];
@@ -1474,12 +1486,12 @@ class StableDiffusionGGML {
14741486
-0.0313f, -0.1649f, 0.0117f, 0.0723f, -0.2839f, -0.2083f, -0.0520f, 0.3748f,
14751487
0.0152f, 0.1957f, 0.1433f, -0.2944f, 0.3573f, -0.0548f, -0.1681f, -0.0667f};
14761488
latents_std_vec = {
1477-
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1478-
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1479-
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1480-
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1481-
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1482-
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
1489+
0.4765f, 1.0364f, 0.4514f, 1.1677f, 0.5313f, 0.4990f, 0.4818f, 0.5013f,
1490+
0.8158f, 1.0344f, 0.5894f, 1.0901f, 0.6885f, 0.6165f, 0.8454f, 0.4978f,
1491+
0.5759f, 0.3523f, 0.7135f, 0.6804f, 0.5833f, 1.4146f, 0.8986f, 0.5659f,
1492+
0.7069f, 0.5338f, 0.4889f, 0.4917f, 0.4069f, 0.4999f, 0.6866f, 0.4093f,
1493+
0.5709f, 0.6065f, 0.6415f, 0.4944f, 0.5726f, 1.2042f, 0.5458f, 1.6887f,
1494+
0.3971f, 1.0600f, 0.3943f, 0.5537f, 0.5444f, 0.4089f, 0.7468f, 0.7744f};
14831495
}
14841496
for (int i = 0; i < latent->ne[3]; i++) {
14851497
float mean = latents_mean_vec[i];

0 commit comments

Comments
 (0)