Skip to content

Use correct values to update encoder KV Cache for streaming models#15323

Closed
MahmoudAshraf97 wants to merge 3 commits intoNVIDIA-NeMo:mainfrom
MahmoudAshraf97:fix_cache
Closed

Use correct values to update encoder KV Cache for streaming models#15323
MahmoudAshraf97 wants to merge 3 commits intoNVIDIA-NeMo:mainfrom
MahmoudAshraf97:fix_cache

Conversation

@MahmoudAshraf97
Copy link
Copy Markdown
Contributor

Important

The Update branch button must only be pressed in very rare occassions.
An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.

What does this PR do ?

The cache update includes invalid parts of the input that are discarded in streaming_post_process as invalid, but are included in the cache updates, this PR uses the correct values to update the cache, this PR resulted in much lower WER on our internal dataset

Collection: ASR

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

@nithinraok

@github-actions github-actions Bot added the ASR label Jan 27, 2026
Signed-off-by: MahmoudAshraf97 <hassouna97.ma@gmail.com>
@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

linting failure is unrelated to this PR

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Jan 30, 2026
Signed-off-by: KunalDhawan <KunalDhawan@users.noreply.github.com>
@KunalDhawan
Copy link
Copy Markdown
Collaborator

Thanks for opening this PR, @MahmoudAshraf97, great catch! The changes look good to me. I’ve scheduled CI tests to make sure the updates don’t break any existing pipelines, and I’m also running internal WER evaluations to assess the impact on accuracy and performance.

It would be great if you could also share any benchmarks you have on WER and latency before vs. after these changes to help validate the improvements.

@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Feb 7, 2026
@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

Hi @KunalDhawan , the impact of this PR is felt the most with ctc models or rnnt models with long files, where the wrong cache effect starts to accumulate, the symptoms are increase in deletion errors and missing chunks from the transcript

I also suggest adding tests to verify that the encoder output with cache is identical to the encoder output with the actual audio passed as a context

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Feb 9, 2026
@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

@nithinraok @KunalDhawan this is a gentle reminder as this PR is blocking follow up work related to caching

@chtruong814 chtruong814 added needs-follow-up Issue needs follow-up and removed needs-follow-up Issue needs follow-up labels Feb 20, 2026
@KunalDhawan
Copy link
Copy Markdown
Collaborator

Thanks for the PR @MahmoudAshraf97! I ran an evaluation using the nemotron-speech-streaming-en-0.6b model on hour-long audio files from the Earnings22 dataset using speech_to_text_cache_aware_streaming_infer.py with att_context_size=[70,13] and both batch_size=1 and batch_size=4 to verify behavior under batched inference. In all cases, the predictions from main and this branch were byte-for-byte identical.

My understanding is that this is expected for this model, since it uses att_context_style: chunked_limited, which sets cache_drop_size = 0. In that case, the old and new cache update logic should be functionally equivalent:

  1. MHA: query.shape[1] - 0 == valid_query_length - 0, since the entire chunk is valid in chunked_limited (i.e., no frames are discarded by streaming_post_process, so valid_query_length == query.shape[1]).
  2. CausalConv1D: similarly, when cache_drop_size = 0 and _right_padding = 0 (causal convolution), the cache slicing behavior remains unchanged.

The fix would likely have an observable impact with att_context_style="regular", where cache_drop_size = lookahead_steps > 0 and query.shape[1] > valid_out_len due to the extra lookahead frames that are discarded in streaming_post_process. However, cache-aware streaming models like Nemotron Speech are trained with att_context_style: chunked_limited, which is the most relevant deployment configuration for these models.

Could you share which model and att_context_style configuration you used for the internal evaluation that showed improved WER, along with the magnitude of the improvement and any observed latency impact? Please let me know if I’m overlooking anything in my setup or analysis.

@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Mar 3, 2026
@MahmoudAshraf97
Copy link
Copy Markdown
Contributor Author

Thank you for your reply, I verified the scripts we use for inference against the simulator script you shared, and I found that we are passing 160ms as pre-encoder context instead of 90ms, that alone caused the WER to rise from 10% on an internal dataset to 90%.
I tried running the nemotron model under the same conditions and managed to reproduce the error, albeit to a much lesser extent because it has a larger left context (70) while our model uses 20 so the effect is more noticeable

for future reference this is how to stream an actual audio file without using the CacheAwareStreamingAudioBuffer as it assumes that you have the whole audio file, and it's not usable if you receive audio in realtime

import nemo.collections.asr as nemo_asr
import torch

model = nemo_asr.models.ASRModel.from_pretrained(
    "nvidia/nemotron-speech-streaming-en-0.6b"
)
model = model.eval()

with open("test_clip.wav", "rb") as f:
    f.seek(40)
    audio_signal = (
        torch.frombuffer(f.read(), dtype=torch.int16).float().unsqueeze(0) / 32768
    )

batch_size = 1

attention_cache, conv_cache, attention_cache_len = (
    model.encoder.get_initial_cache_state(batch_size=batch_size)
)
outputs = []
chunk_size = int(
    (model.encoder.att_context_size[1] + 1)
    * model.cfg.preprocessor.window_stride
    * model.cfg.encoder.subsampling_factor
    * model.cfg.preprocessor.sample_rate
)
chunk_overlap = int(
    1
    * model.cfg.preprocessor.window_stride
    * (model.cfg.encoder.subsampling_factor + 1)
    * model.cfg.preprocessor.sample_rate
)
padded_audio = torch.nn.functional.pad(
    audio_signal, (chunk_overlap, chunk_size - audio_signal.shape[1] % chunk_size)
)
with torch.inference_mode():
    for i in range(chunk_overlap, padded_audio.shape[1], chunk_size):
        chunk = padded_audio[:, i - chunk_overlap : i + chunk_size].to(model.device)
        features, features_len = model.preprocessor(
            input_signal=chunk, length=torch.tensor([chunk.shape[1]]).to(model.device)
        )
        features = features[:, :, :-1]

        (
            encoded,
            encoded_len,
            attention_cache,
            conv_cache,
            attention_cache_len,
        ) = model.encoder.cache_aware_stream_step(
            processed_signal=features,
            processed_signal_length=features_len,
            cache_last_channel=attention_cache,
            cache_last_time=conv_cache,
            cache_last_channel_len=attention_cache_len,
            keep_all_outputs=False,
            drop_extra_pre_encoded=2,
            bypass_pre_encode=False,
        )

        outputs.append(encoded)
    projected_encoder_output = torch.cat(outputs, dim=-1)

model.decoding.rnnt_decoder_predictions_tensor(
    encoder_output=projected_encoder_output,
    encoded_lengths=torch.tensor([projected_encoder_output.shape[2]]).to(model.device),
    return_hypotheses=True,
    partial_hypotheses=None,
)[0].text

@chtruong814 chtruong814 added the needs-follow-up Issue needs follow-up label Mar 7, 2026
@chtruong814 chtruong814 removed the needs-follow-up Issue needs follow-up label Apr 6, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants