Skip to content

Commit e747395

Browse files
committed
Merge leejet#1222 (improve handling of VAE decode failures)
2 parents 4e1a030 + 6aa3b9d commit e747395

4 files changed

Lines changed: 48 additions & 24 deletions

File tree

examples/cli/main.cpp

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ bool save_results(const SDCliParams& cli_params,
409409
auto write_image = [&](const fs::path& path, int idx) {
410410
const sd_image_t& img = results[idx];
411411
if (!img.data)
412-
return;
412+
return false;
413413

414414
std::string params = get_image_params(cli_params, ctx_params, gen_params, gen_params.seed + idx);
415415
int ok = 0;
@@ -419,8 +419,11 @@ bool save_results(const SDCliParams& cli_params,
419419
ok = stbi_write_png(path.string().c_str(), img.width, img.height, img.channel, img.data, 0, params.c_str());
420420
}
421421
LOG_INFO("save result image %d to '%s' (%s)", idx, path.string().c_str(), ok ? "success" : "failure");
422+
return ok != 0;
422423
};
423424

425+
int sucessful_reults = 0;
426+
424427
if (std::regex_search(cli_params.output_path, format_specifier_regex)) {
425428
if (!is_jpg && ext_lower != ".png")
426429
ext = ".png";
@@ -429,19 +432,26 @@ bool save_results(const SDCliParams& cli_params,
429432

430433
for (int i = 0; i < num_results; ++i) {
431434
fs::path img_path = format_frame_idx(pattern.string(), output_begin_idx + i);
432-
write_image(img_path, i);
435+
if (write_image(img_path, i)) {
436+
sucessful_reults++;
437+
}
433438
}
434-
return true;
439+
LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
440+
return sucessful_reults != 0;
435441
}
436442

437443
if (cli_params.mode == VID_GEN && num_results > 1) {
438444
if (ext_lower != ".avi")
439445
ext = ".avi";
440446
fs::path video_path = base_path;
441447
video_path += ext;
442-
create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps);
443-
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
444-
return true;
448+
if (create_mjpg_avi_from_sd_images(video_path.string().c_str(), results, num_results, gen_params.fps) == 0) {
449+
LOG_INFO("save result MJPG AVI video to '%s'", video_path.string().c_str());
450+
return true;
451+
} else {
452+
LOG_ERROR("Failed to save result MPG AVI video to '%s'", video_path.string().c_str());
453+
return false;
454+
}
445455
}
446456

447457
if (!is_jpg && ext_lower != ".png")
@@ -453,10 +463,12 @@ bool save_results(const SDCliParams& cli_params,
453463
img_path += "_" + std::to_string(output_begin_idx + i);
454464
}
455465
img_path += ext;
456-
write_image(img_path, i);
466+
if (write_image(img_path, i)) {
467+
sucessful_reults++;
468+
}
457469
}
458-
459-
return true;
470+
LOG_INFO("%d/%d images saved", sucessful_reults, num_results);
471+
return sucessful_reults != 0;
460472
}
461473

462474
int main(int argc, const char* argv[]) {

ggml_extend.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor*
767767
return x;
768768
}
769769

770-
typedef std::function<void(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
770+
typedef std::function<bool(ggml_tensor*, ggml_tensor*, bool)> on_tile_process;
771771

772772
__STATIC_INLINE__ void sd_tiling_calc_tiles(int& num_tiles_dim,
773773
float& tile_overlap_factor_dim,
@@ -918,12 +918,15 @@ __STATIC_INLINE__ void sd_tiling_non_square(ggml_tensor* input,
918918

919919
int64_t t1 = ggml_time_ms();
920920
ggml_ext_tensor_split_2d(input, input_tile, x_in, y_in);
921-
on_processing(input_tile, output_tile, false);
922-
ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
921+
if (on_processing(input_tile, output_tile, false)) {
922+
ggml_ext_tensor_merge_2d(output_tile, output, x_out, y_out, overlap_x_out, overlap_y_out, dx, dy);
923923

924-
int64_t t2 = ggml_time_ms();
925-
last_time = (t2 - t1) / 1000.0f;
926-
pretty_progress(tile_count, num_tiles, last_time);
924+
int64_t t2 = ggml_time_ms();
925+
last_time = (t2 - t1) / 1000.0f;
926+
pretty_progress(tile_count, num_tiles, last_time);
927+
} else {
928+
LOG_ERROR("Failed to process patch %d at (%d, %d)", tile_count, x, y);
929+
}
927930
tile_count++;
928931
}
929932
last_x = false;

stable-diffusion.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,7 +1558,7 @@ class StableDiffusionGGML {
15581558
if (vae_tiling_params.enabled) {
15591559
// split latent in 32x32 tiles and compute in several steps
15601560
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1561-
first_stage_model->compute(n_threads, in, true, &out, nullptr);
1561+
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
15621562
};
15631563
silent_tiling(latents, result, get_vae_scale_factor(), 32, 0.5f, on_tiling);
15641564

@@ -1577,7 +1577,7 @@ class StableDiffusionGGML {
15771577
if (vae_tiling_params.enabled) {
15781578
// split latent in 64x64 tiles and compute in several steps
15791579
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
1580-
tae_first_stage->compute(n_threads, in, true, &out, nullptr);
1580+
return tae_first_stage->compute(n_threads, in, true, &out, nullptr);
15811581
};
15821582
silent_tiling(latents, result, get_vae_scale_factor(), 64, 0.5f, on_tiling);
15831583
} else {
@@ -2546,7 +2546,7 @@ class StableDiffusionGGML {
25462546
LOG_DEBUG("VAE Tile size: %dx%d", tile_size_x, tile_size_y);
25472547

25482548
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
2549-
first_stage_model->compute(n_threads, in, false, &out, work_ctx);
2549+
return first_stage_model->compute(n_threads, in, false, &out, work_ctx);
25502550
};
25512551
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
25522552
} else {
@@ -2557,7 +2557,7 @@ class StableDiffusionGGML {
25572557
if (vae_tiling_params.enabled && !encode_video) {
25582558
// split latent in 32x32 tiles and compute in several steps
25592559
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
2560-
tae_first_stage->compute(n_threads, in, false, &out, nullptr);
2560+
return tae_first_stage->compute(n_threads, in, false, &out, nullptr);
25612561
};
25622562
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
25632563
} else {
@@ -2675,23 +2675,31 @@ class StableDiffusionGGML {
26752675

26762676
// split latent in 32x32 tiles and compute in several steps
26772677
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
2678-
first_stage_model->compute(n_threads, in, true, &out, nullptr);
2678+
return first_stage_model->compute(n_threads, in, true, &out, nullptr);
26792679
};
26802680
sd_tiling_non_square(x, result, vae_scale_factor, tile_size_x, tile_size_y, tile_overlap, on_tiling);
26812681
} else {
2682-
first_stage_model->compute(n_threads, x, true, &result, work_ctx);
2682+
if(!first_stage_model->compute(n_threads, x, true, &result, work_ctx)){
2683+
LOG_ERROR("Failed to decode latetnts");
2684+
first_stage_model->free_compute_buffer();
2685+
return nullptr;
2686+
}
26832687
}
26842688
first_stage_model->free_compute_buffer();
26852689
process_vae_output_tensor(result);
26862690
} else {
26872691
if (vae_tiling_params.enabled) {
26882692
// split latent in 64x64 tiles and compute in several steps
26892693
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
2690-
tae_first_stage->compute(n_threads, in, true, &out);
2694+
return tae_first_stage->compute(n_threads, in, true, &out);
26912695
};
26922696
sd_tiling(x, result, vae_scale_factor, 64, 0.5f, on_tiling);
26932697
} else {
2694-
tae_first_stage->compute(n_threads, x, true, &result);
2698+
if(!tae_first_stage->compute(n_threads, x, true, &result)){
2699+
LOG_ERROR("Failed to decode latetnts");
2700+
tae_first_stage->free_compute_buffer();
2701+
return nullptr;
2702+
}
26952703
}
26962704
tae_first_stage->free_compute_buffer();
26972705
}
@@ -3461,6 +3469,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
34613469
ggml_free(work_ctx);
34623470
return nullptr;
34633471
}
3472+
memset(result_images, 0, batch_count * sizeof(sd_image_t));
34643473

34653474
for (size_t i = 0; i < decoded_images.size(); i++) {
34663475
result_images[i].width = width;

upscaler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ struct UpscalerGGML {
8989

9090
ggml_tensor* upscaled = ggml_new_tensor_4d(upscale_ctx, GGML_TYPE_F32, output_width, output_height, 3, 1);
9191
auto on_tiling = [&](ggml_tensor* in, ggml_tensor* out, bool init) {
92-
esrgan_upscaler->compute(n_threads, in, &out);
92+
return esrgan_upscaler->compute(n_threads, in, &out);
9393
};
9494
int64_t t0 = ggml_time_ms();
9595
sd_tiling(input_image_tensor, upscaled, esrgan_upscaler->scale, esrgan_upscaler->tile_size, 0.25f, on_tiling);

0 commit comments

Comments
 (0)