@@ -798,7 +798,7 @@ class StableDiffusionGGML {
798798 SDCondition id_cond,
799799 sd_slg_params_t slg_params = {NULL , 0 , 0 , 0 , 0 },
800800 sd_apg_params_t apg_params = {1 , 0 , 0 },
801- ggml_tensor* noise_mask = nullptr ) {
801+ ggml_tensor* noise_mask = nullptr ) {
802802 std::vector<int > skip_layers (slg_params.skip_layers , slg_params.skip_layers + slg_params.skip_layers_count );
803803
804804 LOG_DEBUG (" Sample" );
@@ -959,39 +959,41 @@ class StableDiffusionGGML {
959959 float diff_norm = 0 ;
960960 float cond_norm_sq = 0 ;
961961 float dot = 0 ;
962- for (int i = 0 ; i < ne_elements; i++) {
963- float delta = positive_data[i] - negative_data[i];
964- if (apg_params.momentum != 0 ) {
965- delta += apg_params.momentum * apg_momentum_buffer[i];
966- apg_momentum_buffer[i] = delta;
962+ if (has_unconditioned) {
963+ for (int i = 0 ; i < ne_elements; i++) {
964+ float delta = positive_data[i] - negative_data[i];
965+ if (apg_params.momentum != 0 ) {
966+ delta += apg_params.momentum * apg_momentum_buffer[i];
967+ apg_momentum_buffer[i] = delta;
968+ }
969+ if (apg_params.norm_treshold > 0 ) {
970+ diff_norm += delta * delta;
971+ }
972+ if (apg_params.eta != 1 .0f ) {
973+ cond_norm_sq += positive_data[i] * positive_data[i];
974+ dot += positive_data[i] * delta;
975+ }
976+ deltas[i] = delta;
967977 }
968978 if (apg_params.norm_treshold > 0 ) {
969- diff_norm += delta * delta;
979+ diff_norm = std::sqrtf (diff_norm);
980+ apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
970981 }
971982 if (apg_params.eta != 1 .0f ) {
972- cond_norm_sq += positive_data[i] * positive_data[i];
973- dot += positive_data[i] * delta;
983+ dot *= apg_scale_factor;
984+ // pre-normalize (avoids one square root and ne_elements extra divs)
985+ dot /= cond_norm_sq;
974986 }
975- deltas[i] = delta;
976- }
977- if (apg_params.norm_treshold > 0 ) {
978- diff_norm = std::sqrtf (diff_norm);
979- apg_scale_factor = std::min (1 .0f , apg_params.norm_treshold / diff_norm);
980- }
981- if (apg_params.eta != 1 .0f ) {
982- dot *= apg_scale_factor;
983- // pre-normalize (avoids one square root and ne_elements extra divs)
984- dot /= cond_norm_sq;
985- }
986987
987- for (int i = 0 ; i < ne_elements; i++) {
988- deltas[i] *= apg_scale_factor;
989- if (apg_params.eta != 1 .0f ) {
990- float apg_parallel = dot * positive_data[i];
991- float apg_orthogonal = deltas[i] - apg_parallel;
988+ for (int i = 0 ; i < ne_elements; i++) {
989+ deltas[i] *= apg_scale_factor;
990+ if (apg_params.eta != 1 .0f ) {
991+ float apg_parallel = dot * positive_data[i];
992+ float apg_orthogonal = deltas[i] - apg_parallel;
992993
993- // tweak deltas
994- deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
994+ // tweak deltas
995+ deltas[i] = apg_orthogonal + apg_params.eta * apg_parallel;
996+ }
995997 }
996998 }
997999
0 commit comments