99#include " Lexicon.hpp"
1010#include < ax_sys_api.h>
1111#include " AudioFile.h"
12- #include " SolaProcessor.h"
1312#include " Lexicon.hpp"
1413
1514#include < signal.h>
@@ -253,14 +252,16 @@ class llm_task {
253252 }
254253 return false ;
255254 }
255+
256+ // Convert text to phonemes and tones
256257 std::vector<int > phones_bef, tones_bef;
257258 lexicon_->convert (msg_str, phones_bef, tones_bef);
258- // Add blank between words
259- auto phones = intersperse (phones_bef , 0 );
260- auto tones = intersperse (tones_bef, 0 );
261- int phone_len = phones. size ( );
262- int MELOTTS_LANG_IDS = MELOTTS_LANG_IDS_MAP[mode_config_. mode ];
263- std::vector< int > langids (phone_len, MELOTTS_LANG_IDS);
259+ auto phones = intersperse (phones_bef, 0 );
260+ auto tones = intersperse (tones_bef , 0 );
261+ int phone_len = phones. size ( );
262+ std::vector< int > langids ( phone_len, 3 );
263+
264+ // Run the encoder to generate hidden representations
264265 auto encoder_output =
265266 encoder_->Run (phones, tones, langids, g_matrix, mode_config_.noise_scale , mode_config_.noise_scale_w ,
266267 mode_config_.get_length_scale (), mode_config_.sdp_ratio );
@@ -269,66 +270,256 @@ class llm_task {
269270 auto zp_info = encoder_output.at (0 ).GetTensorTypeAndShapeInfo ();
270271 auto zp_shape = zp_info.GetShape ();
271272
272- // Decoder parameters setup
273- int zp_size = decoder_->GetInputSize (0 ) / sizeof (float );
274- int dec_len = zp_size / zp_shape[1 ];
275- int audio_slice_len = decoder_->GetOutputSize (0 ) / sizeof (float );
276- const int pad_frames = 16 ;
273+ // Calculate decoder parameters
274+ int zp_size = decoder_->GetInputSize (0 ) / sizeof (float );
275+ int dec_len = zp_size / zp_shape[1 ];
276+ int audio_slice_len = decoder_->GetOutputSize (0 ) / sizeof (float );
277+
278+ const int pad_frames = 24 ;
277279 const int samples_per_frame = 512 ;
278- const int effective_frames = dec_len - 2 * pad_frames;
280+
281+ const int effective_frames = dec_len - 2 * pad_frames;
282+
279283 int dec_slice_num =
280284 static_cast <int >(std::ceil (static_cast <double >(zp_shape[2 ]) / static_cast <double >(effective_frames)));
281- SolaProcessor sola (pad_frames, samples_per_frame);
285+
286+ // SOLA parameters setup
287+ const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length
288+ const int sola_search_frame = pad_frames * samples_per_frame; // Search window length
289+ const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length
290+
291+ // Create fade-in/fade-out windows for smooth transitions
292+ std::vector<float > fade_in_window (sola_buffer_frame);
293+ std::vector<float > fade_out_window (sola_buffer_frame);
294+
295+ for (int i = 0 ; i < sola_buffer_frame; i++) {
296+ fade_in_window[i] = static_cast <float >(i) / sola_buffer_frame;
297+ fade_out_window[i] = 1 .0f - fade_in_window[i];
298+ }
299+
300+ // Initialize SOLA buffer
301+ std::vector<float > sola_buffer (sola_buffer_frame, 0 .0f );
302+ bool first_frame = true ;
303+
282304 std::vector<float > pcmlist;
283305
306+ // Main decoding loop - process each slice
284307 for (int i = 0 ; i < dec_slice_num; i++) {
308+ // Calculate start position for current batch input
285309 int input_start = i * effective_frames;
310+ // Consider forward padding, but ensure non-negative
286311 if (i > 0 ) {
287312 input_start -= pad_frames;
288313 }
289- input_start = std::max (0 , input_start);
314+ input_start = std::max (0 , input_start);
315+
316+ // Actual input length
290317 int actual_len = std::min (dec_len, static_cast <int >(zp_shape[2 ] - input_start));
318+
319+ // Calculate effective output range (frame level)
320+ int output_start_frame, output_end_frame;
321+
322+ if (i == 0 ) {
323+ // First frame: skip padding at beginning
324+ output_start_frame = 0 ;
325+ output_end_frame = effective_frames - 1 ;
326+ } else if (i == dec_slice_num - 1 ) {
327+ // Last frame: calculate from current segment start
328+ output_start_frame = i * effective_frames;
329+ // Last frame extends to encoder's maximum output length
330+ output_end_frame = static_cast <int >(zp_shape[2 ]) - 1 ;
331+ } else {
332+ // Middle frames: standard calculation
333+ output_start_frame = i * effective_frames;
334+ output_end_frame = (i + 1 ) * effective_frames - 1 ;
335+ }
336+ // Prepare decoder input, initialize all to zero
291337 std::vector<float > zp (zp_size, 0 );
292338
339+ // Copy data to decoder input
293340 for (int n = 0 ; n < zp_shape[1 ]; n++) {
294341 int copy_size = std::min (actual_len, static_cast <int >(zp_shape[2 ] - input_start));
295342 if (copy_size > 0 ) {
296343 memcpy (zp.data () + n * dec_len, zp_data + n * zp_shape[2 ] + input_start,
297344 sizeof (float ) * copy_size);
298345 }
299346 }
347+
300348 // Run decoder
301349 std::vector<float > decoder_output (audio_slice_len);
302350 decoder_->SetInput (zp.data (), 0 );
303351 decoder_->SetInput (g_matrix.data (), 1 );
352+
304353 if (0 != decoder_->Run ()) {
354+ SLOGI (" Inference #%d: decoding failed" , i + 1 );
305355 throw std::string (" decoder_ RunSync error" );
306356 }
357+
307358 decoder_->GetOutput (decoder_output.data (), 0 );
308- std::vector<float > processed_output = sola.ProcessFrame (decoder_output, i, dec_slice_num, actual_len);
309359
310- pcmlist.insert (pcmlist.end (), processed_output.begin (), processed_output.end ());
360+ // === SOLA Processing Logic ===
361+ if (first_frame) {
362+ // Special handling for first frame - should not skip initial content
363+ // First frame starts directly from decoder output without skipping
364+ int audio_start = 0 ; // Start from beginning, don't skip pad_frames
365+
366+ // Calculate data length for first frame
367+ // First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end
368+ // for next frame alignment
369+ int audio_len = decoder_output.size () - sola_buffer_frame;
370+
371+ // Boundary check
372+ audio_len = std::max (0 , audio_len); // Ensure non-negative
373+
374+ // Add first frame data
375+ if (audio_len > 0 ) {
376+ pcmlist.insert (pcmlist.end (), decoder_output.begin () + audio_start,
377+ decoder_output.begin () + audio_start + audio_len);
378+ }
379+
380+ // Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment
381+ int buffer_start = audio_len;
382+
383+ // Ensure sufficient data is available for copying
384+ if (buffer_start + sola_buffer_frame <= decoder_output.size ()) {
385+ std::copy (decoder_output.begin () + buffer_start,
386+ decoder_output.begin () + buffer_start + sola_buffer_frame, sola_buffer.begin ());
387+ } else {
388+ // Possible case: first frame data is shorter than sola_buffer_frame
389+ int available = static_cast <int >(decoder_output.size () - buffer_start);
390+ if (available > 0 ) {
391+ std::copy (decoder_output.begin () + buffer_start, decoder_output.end (), sola_buffer.begin ());
392+ // Fill with zeros
393+ std::fill (sola_buffer.begin () + available, sola_buffer.end (), 0 .0f );
394+ } else {
395+ // Completely insufficient data, fill all with zeros
396+ std::fill (sola_buffer.begin (), sola_buffer.end (), 0 .0f );
397+ }
398+ }
399+
400+ first_frame = false ;
401+
402+ } else {
403+ // Non-first frame: SOLA alignment required
404+ int audio_start = pad_frames * samples_per_frame;
405+
406+ // 1. Prepare search window - beginning portion of current frame
407+ std::vector<float > search_window (sola_buffer_frame + sola_search_frame);
408+ std::copy (decoder_output.begin () + audio_start,
409+ decoder_output.begin () + audio_start + search_window.size (), search_window.begin ());
410+
411+ // 2. Find best alignment point (calculate cross-correlation)
412+ int best_offset = 0 ;
413+ float best_correlation = -1.0 ;
414+
415+ for (int offset = 0 ; offset <= sola_search_frame; offset++) {
416+ float correlation = 0.0 ;
417+ float energy = 0.0 ;
418+
419+ for (int j = 0 ; j < sola_buffer_frame; j++) {
420+ correlation += sola_buffer[j] * search_window[j + offset];
421+ energy += search_window[j + offset] * search_window[j + offset];
422+ }
423+
424+ // Normalize correlation (avoid division by zero)
425+ float normalized_correlation = (energy > 1e-8 ) ? correlation / std::sqrt (energy) : 0 .0f ;
426+
427+ if (normalized_correlation > best_correlation) {
428+ best_correlation = normalized_correlation;
429+ best_offset = offset;
430+ }
431+ }
432+
433+ // 3. Apply alignment offset
434+ int aligned_start = audio_start + best_offset;
435+
436+ // 4. Smooth transition processing (crossfade in alignment region)
437+ std::vector<float > crossfade_region (sola_buffer_frame);
438+
439+ for (int j = 0 ; j < sola_buffer_frame; j++) {
440+ // Apply fade-in/fade-out window functions
441+ crossfade_region[j] =
442+ decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j];
443+ }
444+
445+ // 5. Add crossfade region to output
446+ pcmlist.insert (pcmlist.end (), crossfade_region.begin (), crossfade_region.end ());
447+
448+ int remaining_start = aligned_start + sola_buffer_frame;
449+
450+ if (i == dec_slice_num - 1 ) {
451+ int total_expected_samples = audio_len * samples_per_frame / 512 ;
452+
453+ int processed_samples = static_cast <int >(pcmlist.size ());
454+
455+ int remaining_needed = total_expected_samples - processed_samples;
456+ remaining_needed = std::max (0 , remaining_needed);
457+
458+ int remaining_len =
459+ std::min (remaining_needed, static_cast <int >(decoder_output.size () - remaining_start));
460+
461+ if (remaining_len > 0 ) {
462+ pcmlist.insert (pcmlist.end (), decoder_output.begin () + remaining_start,
463+ decoder_output.begin () + remaining_start + remaining_len);
464+ }
465+
466+ } else {
467+ int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame;
468+
469+ remaining_len =
470+ std::min (remaining_len, static_cast <int >(decoder_output.size () - remaining_start));
471+
472+ if (remaining_len > 0 ) {
473+ pcmlist.insert (pcmlist.end (), decoder_output.begin () + remaining_start,
474+ decoder_output.begin () + remaining_start + remaining_len);
475+ }
476+
477+ int buffer_start = remaining_start + remaining_len;
478+
479+ if (buffer_start + sola_buffer_frame <= decoder_output.size ()) {
480+ std::copy (decoder_output.begin () + buffer_start,
481+ decoder_output.begin () + buffer_start + sola_buffer_frame, sola_buffer.begin ());
482+ } else {
483+ int avail = static_cast <int >(decoder_output.size () - buffer_start);
484+ if (avail > 0 ) {
485+ std::copy (decoder_output.begin () + buffer_start, decoder_output.end (),
486+ sola_buffer.begin ());
487+ }
488+ std::fill (sola_buffer.begin () + avail, sola_buffer.end (), 0 .0f );
489+ }
490+ }
491+ }
492+ }
493+
494+ if (pcmlist.size () > audio_len) {
495+ pcmlist.resize (audio_len);
311496 }
312497
313- double src_ratio = (mode_config_.audio_rate * 1 .0f ) / (mode_config_.mode_rate * 1 .0f );
498+ // Post-processing: resample and convert to int16
499+ double src_ratio =
500+ static_cast <double >(mode_config_.audio_rate ) / static_cast <double >(mode_config_.mode_rate );
314501 std::vector<float > tmp_pcm ((pcmlist.size () * src_ratio + 1 ));
315502 int len;
503+
316504 resample_audio (pcmlist.data (), pcmlist.size (), tmp_pcm.data (), &len, src_ratio);
317505
318506 // Convert to 16-bit PCM
319507 wav_pcm_data.reserve (len);
320508 std::transform (tmp_pcm.begin (), tmp_pcm.begin () + len, std::back_inserter (wav_pcm_data),
321- [](const auto val) { return ( int16_t ) (val * INT16_MAX); });
509+ [](const auto val) { return static_cast < int16_t > (val * INT16_MAX); });
322510
323- // Call callback function with output
324- if (out_callback_)
325- out_callback_ (std::string ((char *)wav_pcm_data.data (), wav_pcm_data.size () * sizeof (int16_t )), finish);
511+ // Call the output callback function with the result
512+ if (out_callback_) {
513+ out_callback_ (
514+ std::string (reinterpret_cast <char *>(wav_pcm_data.data ()), wav_pcm_data.size () * sizeof (int16_t )),
515+ finish);
516+ }
326517
327518 } catch (const std::exception &e) {
328519 SLOGI (" TTS processing exception: %s" , e.what ());
329520 return true ;
330521 } catch (...) {
331- SLOGI (" TTS processing encountered unknown exception" );
522+ SLOGI (" TTS processing encountered an unknown exception" );
332523 return true ;
333524 }
334525 return false ;
0 commit comments