Skip to content

support repetition ngram logits processor#4288

Merged
lvhan028 merged 9 commits intoInternLM:mainfrom
grimoire:builtin-ngram
Mar 11, 2026
Merged

support repetition ngram logits processor#4288
lvhan028 merged 9 commits intoInternLM:mainfrom
grimoire:builtin-ngram

Conversation

@grimoire
Copy link
Collaborator

@grimoire grimoire commented Jan 23, 2026

repetition stopping is implemented as logits processor.
If we turn it to an engine-level feature, the implementation would be much easy and the performance would be better.

Copilot AI review requested due to automatic review settings January 23, 2026 10:20
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds support for an n-gram-based logits processor (intended to force generation of a stop token once repeated n-grams exceed a threshold), wiring new ngram_size / ngram_threshold parameters through sampling inputs and adding a unit test.

Changes:

  • Add n-gram matching + _filter_ngram_ into FusedLogitsProcessor.
  • Plumb ngram_size / ngram_threshold through GenerationConfigSamplingParamSamplingInputs, including new generated-token history gathering.
  • Add a unit test for _filter_ngram_.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/pytorch/engine/test_logits_process.py Adds coverage for the new n-gram filtering behavior.
lmdeploy/pytorch/engine/logits_process.py Implements n-gram matching/filtering and integrates it into fused logits processing.
lmdeploy/pytorch/strategies/ar/sampling.py Gathers per-request n-gram params and generated-token history for GPU-side processing.
lmdeploy/pytorch/strategies/dllm/sampling.py Repeats new sampling attributes across DLLM blocks and expands generated-id history.
lmdeploy/pytorch/messages.py Adds ngram_size / ngram_threshold to SamplingParam to carry runtime settings.
lmdeploy/messages.py Adds ngram_size / ngram_threshold to user-facing GenerationConfig.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +129 to +131
def test_filter_ngram():
from lmdeploy.pytorch.engine.logits_process import _filter_ngram_

Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description appears to still be the default template (Motivation/Modification/etc. are not filled in). Please add a brief motivation and summarize the intended behavior of the new n-gram logits processor (including how ngram_size/ngram_threshold interact with stop_words) to make review and downstream usage clearer.

Copilot uses AI. Check for mistakes.
Comment on lines +167 to +179
):
"""Filter ngram."""
if stop_words is None or stop_words.numel() == 0:
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, same_n)
stop_words = stop_words[:, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
stop_scores = scores.gather(1, stop_words[:, None])
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_words[:, None], stop_scores)
Copy link

Copilot AI Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_filter_ngram_ always uses stop_words[:, 0] without considering stop_mask. When some sequences have no stop words, __get_bad_words pads their row with 0s and a false mask; this function will still treat token 0 as a stop word and can force generation of token 0 when found is true. Pass/use stop_mask (select the first valid stop word per batch, and skip batches with none) or precompute a per-sequence stop token id when enabling n-gram stopping.

Suggested change
):
"""Filter ngram."""
if stop_words is None or stop_words.numel() == 0:
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, same_n)
stop_words = stop_words[:, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
stop_scores = scores.gather(1, stop_words[:, None])
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_words[:, None], stop_scores)
stop_mask: torch.Tensor | None = None,
):
"""Filter ngram."""
if stop_words is None or stop_words.numel() == 0:
return scores
# determine which sequences have ngram matches
_, found = ngram(generated_ids, n, threshold, max_n, same_n)
batch_size = scores.size(0)
if stop_mask is not None:
# has_stop indicates which batch elements have at least one valid stop word
if stop_mask.numel() == 0:
return scores
has_stop = stop_mask.any(dim=1)
if not has_stop.any():
# no sequences have valid stop words, nothing to do
return scores
# compute index of first valid stop word per batch element
num_stops = stop_mask.size(1)
idxs = torch.arange(
num_stops, device=stop_mask.device, dtype=torch.long
).unsqueeze(0).expand_as(stop_mask)
idxs = idxs.masked_fill(~stop_mask, num_stops)
first_idxs = idxs.argmin(dim=1)
batch_indices = torch.arange(batch_size, device=stop_words.device, dtype=torch.long)
stop_tokens = stop_words[batch_indices, first_idxs]
# only apply forcing where both an ngram is found and a valid stop word exists
valid_found = found & has_stop
if not valid_found.any():
return scores
scores.masked_fill_(valid_found[:, None], -float('inf'))
stop_tokens_exp = stop_tokens[:, None]
stop_scores = scores.gather(1, stop_tokens_exp)
stop_scores.masked_fill_(valid_found[:, None], 0)
scores.scatter_(1, stop_tokens_exp, stop_scores)
else:
# fallback: use the first stop word in each row, as originally implemented
stop_tokens = stop_words[:, 0]
# fill all scores -inf where an ngram is found
scores.masked_fill_(found[:, None], -float('inf'))
# set stop word scores to 0
stop_tokens_exp = stop_tokens[:, None]
stop_scores = scores.gather(1, stop_tokens_exp)
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_tokens_exp, stop_scores)

Copilot uses AI. Check for mistakes.
@grimoire grimoire changed the title support ngram logits processor support repetition ngram logits processor Jan 26, 2026
@lvhan028 lvhan028 requested a review from RunningLeon February 11, 2026 11:08
@lvhan028 lvhan028 added the enhancement New feature or request label Feb 11, 2026
@lvhan028 lvhan028 mentioned this pull request Feb 28, 2026
8 tasks
repetition_ngram_size: int = 0
repetition_ngram_threshold: int = 0
repetition_ngram_window_size: int = 1024

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest keeping repetition_ngram_window_size as an internal constant rather than exposing it in the public API. A fixed value of 1024 should be sufficient for most use cases

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this PR, my understanding is that for any sequence configured with n-gram constraints, the engine checks for duplicates within the sliding window of the last max_ngram_window_size tokens (i.e., token_ids[-max_ngram_window_size:]). Once the repetition frequency exceeds the threshold, the engine prevents further token generation. Please correct me if I'm mistaken.
Since this behavior heavily depends on the max_ngram_window_size setting, I'm considering an alternative approach: perform the n-gram check only every max_ngram_window_size steps. Specifically, we could trigger the check when the length of generated_token_ids is divisible by max_ngram_window_size (i.e., len % size == 0). What do you think?"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perform the n-gram check only every max_ngram_window_size steps. Specifically, we could trigger the check when the length of generated_token_ids is divisible by max_ngram_window_size (i.e., len % size == 0).

Hard to do this in batch.

normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0)

# check cos distance & check vector length
match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

brilliant!

Copy link
Collaborator

@RunningLeon RunningLeon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lvhan028 lvhan028 merged commit 17ed9e5 into InternLM:main Mar 11, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants