support repetition ngram logits processor#4288
Conversation
There was a problem hiding this comment.
Pull request overview
Adds support for an n-gram-based logits processor (intended to force generation of a stop token once repeated n-grams exceed a threshold), wiring new ngram_size / ngram_threshold parameters through sampling inputs and adding a unit test.
Changes:
- Add n-gram matching +
_filter_ngram_intoFusedLogitsProcessor. - Plumb
ngram_size/ngram_thresholdthroughGenerationConfig→SamplingParam→SamplingInputs, including new generated-token history gathering. - Add a unit test for
_filter_ngram_.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/pytorch/engine/test_logits_process.py | Adds coverage for the new n-gram filtering behavior. |
| lmdeploy/pytorch/engine/logits_process.py | Implements n-gram matching/filtering and integrates it into fused logits processing. |
| lmdeploy/pytorch/strategies/ar/sampling.py | Gathers per-request n-gram params and generated-token history for GPU-side processing. |
| lmdeploy/pytorch/strategies/dllm/sampling.py | Repeats new sampling attributes across DLLM blocks and expands generated-id history. |
| lmdeploy/pytorch/messages.py | Adds ngram_size / ngram_threshold to SamplingParam to carry runtime settings. |
| lmdeploy/messages.py | Adds ngram_size / ngram_threshold to user-facing GenerationConfig. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def test_filter_ngram(): | ||
| from lmdeploy.pytorch.engine.logits_process import _filter_ngram_ | ||
|
|
There was a problem hiding this comment.
PR description appears to still be the default template (Motivation/Modification/etc. are not filled in). Please add a brief motivation and summarize the intended behavior of the new n-gram logits processor (including how ngram_size/ngram_threshold interact with stop_words) to make review and downstream usage clearer.
| ): | ||
| """Filter ngram.""" | ||
| if stop_words is None or stop_words.numel() == 0: | ||
| return scores | ||
| # use first stop words | ||
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | ||
| stop_words = stop_words[:, 0] | ||
| # fill all scores -inf | ||
| scores.masked_fill_(found[:, None], -float('inf')) | ||
| # set stop words to 0 | ||
| stop_scores = scores.gather(1, stop_words[:, None]) | ||
| stop_scores.masked_fill_(found[:, None], 0) | ||
| scores.scatter_(1, stop_words[:, None], stop_scores) |
There was a problem hiding this comment.
_filter_ngram_ always uses stop_words[:, 0] without considering stop_mask. When some sequences have no stop words, __get_bad_words pads their row with 0s and a false mask; this function will still treat token 0 as a stop word and can force generation of token 0 when found is true. Pass/use stop_mask (select the first valid stop word per batch, and skip batches with none) or precompute a per-sequence stop token id when enabling n-gram stopping.
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # use first stop words | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| stop_words = stop_words[:, 0] | |
| # fill all scores -inf | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop words to 0 | |
| stop_scores = scores.gather(1, stop_words[:, None]) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_words[:, None], stop_scores) | |
| stop_mask: torch.Tensor | None = None, | |
| ): | |
| """Filter ngram.""" | |
| if stop_words is None or stop_words.numel() == 0: | |
| return scores | |
| # determine which sequences have ngram matches | |
| _, found = ngram(generated_ids, n, threshold, max_n, same_n) | |
| batch_size = scores.size(0) | |
| if stop_mask is not None: | |
| # has_stop indicates which batch elements have at least one valid stop word | |
| if stop_mask.numel() == 0: | |
| return scores | |
| has_stop = stop_mask.any(dim=1) | |
| if not has_stop.any(): | |
| # no sequences have valid stop words, nothing to do | |
| return scores | |
| # compute index of first valid stop word per batch element | |
| num_stops = stop_mask.size(1) | |
| idxs = torch.arange( | |
| num_stops, device=stop_mask.device, dtype=torch.long | |
| ).unsqueeze(0).expand_as(stop_mask) | |
| idxs = idxs.masked_fill(~stop_mask, num_stops) | |
| first_idxs = idxs.argmin(dim=1) | |
| batch_indices = torch.arange(batch_size, device=stop_words.device, dtype=torch.long) | |
| stop_tokens = stop_words[batch_indices, first_idxs] | |
| # only apply forcing where both an ngram is found and a valid stop word exists | |
| valid_found = found & has_stop | |
| if not valid_found.any(): | |
| return scores | |
| scores.masked_fill_(valid_found[:, None], -float('inf')) | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(valid_found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) | |
| else: | |
| # fallback: use the first stop word in each row, as originally implemented | |
| stop_tokens = stop_words[:, 0] | |
| # fill all scores -inf where an ngram is found | |
| scores.masked_fill_(found[:, None], -float('inf')) | |
| # set stop word scores to 0 | |
| stop_tokens_exp = stop_tokens[:, None] | |
| stop_scores = scores.gather(1, stop_tokens_exp) | |
| stop_scores.masked_fill_(found[:, None], 0) | |
| scores.scatter_(1, stop_tokens_exp, stop_scores) |
| repetition_ngram_size: int = 0 | ||
| repetition_ngram_threshold: int = 0 | ||
| repetition_ngram_window_size: int = 1024 | ||
|
|
There was a problem hiding this comment.
I suggest keeping repetition_ngram_window_size as an internal constant rather than exposing it in the public API. A fixed value of 1024 should be sufficient for most use cases
There was a problem hiding this comment.
Based on this PR, my understanding is that for any sequence configured with n-gram constraints, the engine checks for duplicates within the sliding window of the last max_ngram_window_size tokens (i.e., token_ids[-max_ngram_window_size:]). Once the repetition frequency exceeds the threshold, the engine prevents further token generation. Please correct me if I'm mistaken.
Since this behavior heavily depends on the max_ngram_window_size setting, I'm considering an alternative approach: perform the n-gram check only every max_ngram_window_size steps. Specifically, we could trigger the check when the length of generated_token_ids is divisible by max_ngram_window_size (i.e., len % size == 0). What do you think?"
There was a problem hiding this comment.
perform the n-gram check only every max_ngram_window_size steps. Specifically, we could trigger the check when the length of generated_token_ids is divisible by max_ngram_window_size (i.e., len % size == 0).
Hard to do this in batch.
| normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0) | ||
|
|
||
| # check cos distance & check vector length | ||
| match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0] |
repetition stopping is implemented as logits processor.
If we turn it to an engine-level feature, the implementation would be much easy and the performance would be better.