@@ -400,8 +400,8 @@ class StableDiffusionGGML {
400400 offload_params_to_cpu,
401401 tensor_storage_map);
402402 diffusion_model = std::make_shared<MMDiTModel>(backend,
403- offload_params_to_cpu,
404- tensor_storage_map);
403+ offload_params_to_cpu,
404+ tensor_storage_map);
405405 } else if (sd_version_is_flux (version)) {
406406 bool is_chroma = false ;
407407 for (auto pair : tensor_storage_map) {
@@ -461,10 +461,10 @@ class StableDiffusionGGML {
461461 1 ,
462462 true );
463463 diffusion_model = std::make_shared<WanModel>(backend,
464- offload_params_to_cpu,
465- tensor_storage_map,
466- " model.diffusion_model" ,
467- version);
464+ offload_params_to_cpu,
465+ tensor_storage_map,
466+ " model.diffusion_model" ,
467+ version);
468468 if (strlen (SAFE_STR (sd_ctx_params->high_noise_diffusion_model_path )) > 0 ) {
469469 high_noise_diffusion_model = std::make_shared<WanModel>(backend,
470470 offload_params_to_cpu,
@@ -564,14 +564,27 @@ class StableDiffusionGGML {
564564 }
565565
566566 if (sd_version_is_wan (version) || sd_version_is_qwen_image (version)) {
567- first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
568- offload_params_to_cpu,
569- tensor_storage_map,
570- " first_stage_model" ,
571- vae_decode_only,
572- version);
573- first_stage_model->alloc_params_buffer ();
574- first_stage_model->get_param_tensors (tensors, " first_stage_model" );
567+ if (!use_tiny_autoencoder) {
568+ first_stage_model = std::make_shared<WAN::WanVAERunner>(vae_backend,
569+ offload_params_to_cpu,
570+ tensor_storage_map,
571+ " first_stage_model" ,
572+ vae_decode_only,
573+ version);
574+ first_stage_model->alloc_params_buffer ();
575+ first_stage_model->get_param_tensors (tensors, " first_stage_model" );
576+ } else {
577+ tae_first_stage = std::make_shared<TinyVideoAutoEncoder>(vae_backend,
578+ offload_params_to_cpu,
579+ tensor_storage_map,
580+ " decoder" ,
581+ vae_decode_only,
582+ version);
583+ if (sd_ctx_params->vae_conv_direct ) {
584+ LOG_INFO (" Using Conv2d direct in the tae model" );
585+ tae_first_stage->set_conv2d_direct_enabled (true );
586+ }
587+ }
575588 } else if (version == VERSION_CHROMA_RADIANCE) {
576589 first_stage_model = std::make_shared<FakeVAE>(vae_backend,
577590 offload_params_to_cpu);
@@ -598,14 +611,13 @@ class StableDiffusionGGML {
598611 }
599612 first_stage_model->alloc_params_buffer ();
600613 first_stage_model->get_param_tensors (tensors, " first_stage_model" );
601- }
602- if (use_tiny_autoencoder) {
603- tae_first_stage = std::make_shared<TinyAutoEncoder>(vae_backend,
604- offload_params_to_cpu,
605- tensor_storage_map,
606- " decoder.layers" ,
607- vae_decode_only,
608- version);
614+ } else if (use_tiny_autoencoder) {
615+ tae_first_stage = std::make_shared<TinyImageAutoEncoder>(vae_backend,
616+ offload_params_to_cpu,
617+ tensor_storage_map,
618+ " decoder.layers" ,
619+ vae_decode_only,
620+ version);
609621 if (sd_ctx_params->vae_conv_direct ) {
610622 LOG_INFO (" Using Conv2d direct in the tae model" );
611623 tae_first_stage->set_conv2d_direct_enabled (true );
@@ -726,13 +738,16 @@ class StableDiffusionGGML {
726738 unet_params_mem_size += high_noise_diffusion_model->get_params_buffer_size ();
727739 }
728740 size_t vae_params_mem_size = 0 ;
741+ LOG_DEBUG (" Here" );
729742 if (!use_tiny_autoencoder || sd_ctx_params->tae_preview_only ) {
730743 vae_params_mem_size = first_stage_model->get_params_buffer_size ();
731744 }
732745 if (use_tiny_autoencoder) {
746+ LOG_DEBUG (" Here" );
733747 if (!tae_first_stage->load_from_file (taesd_path, n_threads)) {
734748 return false ;
735749 }
750+ LOG_DEBUG (" Here" );
736751 vae_params_mem_size = tae_first_stage->get_params_buffer_size ();
737752 }
738753 size_t control_net_params_mem_size = 0 ;
0 commit comments