Skip to content

Commit ff38ea1

Browse files
committed
main: add apg support
1 parent 102a9ea commit ff38ea1

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

examples/cli/main.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,13 @@ struct SDParams {
123123
int upscale_repeats = 1;
124124

125125
std::vector<int> skip_layers = {7, 8, 9};
126-
float slg_scale = 0.;
127-
float skip_layer_start = 0.01;
128-
float skip_layer_end = 0.2;
126+
float slg_scale = 0.0f;
127+
float skip_layer_start = 0.01f;
128+
float skip_layer_end = 0.2f;
129+
130+
float apg_eta = 1.0f;
131+
float apg_momentum = 0.0f;
132+
float apg_norm_treshold = 0.0f;
129133
};
130134

131135
void print_params(SDParams params) {
@@ -207,6 +211,9 @@ void print_usage(int argc, const char* argv[]) {
207211
printf(" -p, --prompt [PROMPT] the prompt to render\n");
208212
printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n");
209213
printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n");
214+
printf(" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n");
215+
printf(" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n");
216+
printf(" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n");
210217
printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n");
211218
printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n");
212219
printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n");
@@ -616,6 +623,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
616623
break;
617624
}
618625
params.skip_layer_end = std::stof(argv[i]);
626+
} else if (arg == "--apg-eta") {
627+
if (++i >= argc) {
628+
invalid_arg = true;
629+
break;
630+
}
631+
params.apg_eta = std::stof(argv[i]);
632+
} else if (arg == "--apg-momentum") {
633+
if (++i >= argc) {
634+
invalid_arg = true;
635+
break;
636+
}
637+
params.apg_momentum = std::stof(argv[i]);
638+
} else if (arg == "--apg-nt" || arg == "--apg-rescale") {
639+
if (++i >= argc) {
640+
invalid_arg = true;
641+
break;
642+
}
643+
params.apg_norm_treshold = std::stof(argv[i]);
619644
} else {
620645
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
621646
print_usage(argc, argv);
@@ -953,7 +978,9 @@ int main(int argc, const char* argv[]) {
953978
params.slg_scale,
954979
params.skip_layer_start,
955980
params.skip_layer_end},
956-
sd_apg_params_t{1, 0, 0});
981+
sd_apg_params_t{params.apg_eta,
982+
params.apg_momentum,
983+
params.apg_norm_treshold});
957984
} else {
958985
sd_image_t input_image = {(uint32_t)params.width,
959986
(uint32_t)params.height,
@@ -1022,7 +1049,9 @@ int main(int argc, const char* argv[]) {
10221049
params.slg_scale,
10231050
params.skip_layer_start,
10241051
params.skip_layer_end},
1025-
sd_apg_params_t{1, 0, 0});
1052+
sd_apg_params_t{params.apg_eta,
1053+
params.apg_momentum,
1054+
params.apg_norm_treshold});
10261055
}
10271056
}
10281057

0 commit comments

Comments
 (0)