@@ -123,9 +123,13 @@ struct SDParams {
123123 int upscale_repeats = 1 ;
124124
125125 std::vector<int > skip_layers = {7 , 8 , 9 };
126- float slg_scale = 0 .;
127- float skip_layer_start = 0.01 ;
128- float skip_layer_end = 0.2 ;
126+ float slg_scale = 0 .0f ;
127+ float skip_layer_start = 0 .01f ;
128+ float skip_layer_end = 0 .2f ;
129+
130+ float apg_eta = 1 .0f ;
131+ float apg_momentum = 0 .0f ;
132+ float apg_norm_treshold = 0 .0f ;
129133};
130134
131135void print_params (SDParams params) {
@@ -207,6 +211,9 @@ void print_usage(int argc, const char* argv[]) {
207211 printf (" -p, --prompt [PROMPT] the prompt to render\n " );
208212 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
209213 printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
214+ printf (" --apg-eta VALUE parallel projected guidance scale for APG (default: 1.0, recommended: between 0 and 1)\n " );
215+ printf (" --apg-momentum VALUE CFG update direction momentum for APG (default: 0, recommended: around -0.5)\n " );
216+ printf (" --apg-nt, --apg-rescale VALUE CFG update direction norm threshold for APG (default: 0 = disabled, recommended: 4-15)\n " );
210217 printf (" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n " );
211218 printf (" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n " );
212219 printf (" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n " );
@@ -616,6 +623,24 @@ void parse_args(int argc, const char** argv, SDParams& params) {
616623 break ;
617624 }
618625 params.skip_layer_end = std::stof (argv[i]);
626+ } else if (arg == " --apg-eta" ) {
627+ if (++i >= argc) {
628+ invalid_arg = true ;
629+ break ;
630+ }
631+ params.apg_eta = std::stof (argv[i]);
632+ } else if (arg == " --apg-momentum" ) {
633+ if (++i >= argc) {
634+ invalid_arg = true ;
635+ break ;
636+ }
637+ params.apg_momentum = std::stof (argv[i]);
638+ } else if (arg == " --apg-nt" || arg == " --apg-rescale" ) {
639+ if (++i >= argc) {
640+ invalid_arg = true ;
641+ break ;
642+ }
643+ params.apg_norm_treshold = std::stof (argv[i]);
619644 } else {
620645 fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
621646 print_usage (argc, argv);
@@ -953,7 +978,9 @@ int main(int argc, const char* argv[]) {
953978 params.slg_scale ,
954979 params.skip_layer_start ,
955980 params.skip_layer_end },
956- sd_apg_params_t {1 , 0 , 0 });
981+ sd_apg_params_t {params.apg_eta ,
982+ params.apg_momentum ,
983+ params.apg_norm_treshold });
957984 } else {
958985 sd_image_t input_image = {(uint32_t )params.width ,
959986 (uint32_t )params.height ,
@@ -1022,7 +1049,9 @@ int main(int argc, const char* argv[]) {
10221049 params.slg_scale ,
10231050 params.skip_layer_start ,
10241051 params.skip_layer_end },
1025- sd_apg_params_t {1 , 0 , 0 });
1052+ sd_apg_params_t {params.apg_eta ,
1053+ params.apg_momentum ,
1054+ params.apg_norm_treshold });
10261055 }
10271056 }
10281057
0 commit comments