Use correct values to update encoder KV Cache for streaming models#15323
Use correct values to update encoder KV Cache for streaming models#15323MahmoudAshraf97 wants to merge 3 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
f777ce4 to
07f9734
Compare
|
linting failure is unrelated to this PR |
Signed-off-by: KunalDhawan <KunalDhawan@users.noreply.github.com>
|
Thanks for opening this PR, @MahmoudAshraf97, great catch! The changes look good to me. I’ve scheduled CI tests to make sure the updates don’t break any existing pipelines, and I’m also running internal WER evaluations to assess the impact on accuracy and performance. It would be great if you could also share any benchmarks you have on WER and latency before vs. after these changes to help validate the improvements. |
|
Hi @KunalDhawan , the impact of this PR is felt the most with ctc models or rnnt models with long files, where the wrong cache effect starts to accumulate, the symptoms are increase in deletion errors and missing chunks from the transcript I also suggest adding tests to verify that the encoder output with cache is identical to the encoder output with the actual audio passed as a context |
|
@nithinraok @KunalDhawan this is a gentle reminder as this PR is blocking follow up work related to caching |
|
Thanks for the PR @MahmoudAshraf97! I ran an evaluation using the nemotron-speech-streaming-en-0.6b model on hour-long audio files from the Earnings22 dataset using speech_to_text_cache_aware_streaming_infer.py with My understanding is that this is expected for this model, since it uses
The fix would likely have an observable impact with Could you share which model and |
|
Thank you for your reply, I verified the scripts we use for inference against the simulator script you shared, and I found that we are passing 160ms as pre-encoder context instead of 90ms, that alone caused the WER to rise from 10% on an internal dataset to 90%. for future reference this is how to stream an actual audio file without using the import nemo.collections.asr as nemo_asr
import torch
model = nemo_asr.models.ASRModel.from_pretrained(
"nvidia/nemotron-speech-streaming-en-0.6b"
)
model = model.eval()
with open("test_clip.wav", "rb") as f:
f.seek(40)
audio_signal = (
torch.frombuffer(f.read(), dtype=torch.int16).float().unsqueeze(0) / 32768
)
batch_size = 1
attention_cache, conv_cache, attention_cache_len = (
model.encoder.get_initial_cache_state(batch_size=batch_size)
)
outputs = []
chunk_size = int(
(model.encoder.att_context_size[1] + 1)
* model.cfg.preprocessor.window_stride
* model.cfg.encoder.subsampling_factor
* model.cfg.preprocessor.sample_rate
)
chunk_overlap = int(
1
* model.cfg.preprocessor.window_stride
* (model.cfg.encoder.subsampling_factor + 1)
* model.cfg.preprocessor.sample_rate
)
padded_audio = torch.nn.functional.pad(
audio_signal, (chunk_overlap, chunk_size - audio_signal.shape[1] % chunk_size)
)
with torch.inference_mode():
for i in range(chunk_overlap, padded_audio.shape[1], chunk_size):
chunk = padded_audio[:, i - chunk_overlap : i + chunk_size].to(model.device)
features, features_len = model.preprocessor(
input_signal=chunk, length=torch.tensor([chunk.shape[1]]).to(model.device)
)
features = features[:, :, :-1]
(
encoded,
encoded_len,
attention_cache,
conv_cache,
attention_cache_len,
) = model.encoder.cache_aware_stream_step(
processed_signal=features,
processed_signal_length=features_len,
cache_last_channel=attention_cache,
cache_last_time=conv_cache,
cache_last_channel_len=attention_cache_len,
keep_all_outputs=False,
drop_extra_pre_encoded=2,
bypass_pre_encode=False,
)
outputs.append(encoded)
projected_encoder_output = torch.cat(outputs, dim=-1)
model.decoding.rnnt_decoder_predictions_tensor(
encoder_output=projected_encoder_output,
encoded_lengths=torch.tensor([projected_encoder_output.shape[2]]).to(model.device),
return_hypotheses=True,
partial_hypotheses=None,
)[0].text |
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
The cache update includes invalid parts of the input that are discarded in
streaming_post_processas invalid, but are included in the cache updates, this PR uses the correct values to update the cache, this PR resulted in much lower WER on our internal datasetCollection: ASR
PR Type:
Who can review?
@nithinraok