-
Notifications
You must be signed in to change notification settings - Fork 690
support repetition ngram logits processor #4288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d942dfd
0f0b8df
d981a3c
9ea903e
a84e80a
7b51a05
7c5ac3d
3f1225f
d422d8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -95,6 +95,8 @@ class GenerationConfig: | |
| } | ||
|
|
||
| logits_processors: Custom logit processors. | ||
| repetition_ngram_size: The size of n-grams to consider for repetition early stop. | ||
| repetition_ngram_threshold: The number of times an n-gram must be repeated to trigger early stop. | ||
| """ | ||
|
|
||
| n: int = 1 | ||
|
|
@@ -129,6 +131,10 @@ class GenerationConfig: | |
| # router replay | ||
| return_routed_experts: bool = False | ||
|
|
||
| # ngram, generation would stop if latest [size] tokens are repeated for [threshold] times | ||
| repetition_ngram_size: int = 0 | ||
| repetition_ngram_threshold: int = 0 | ||
|
grimoire marked this conversation as resolved.
|
||
|
|
||
|
grimoire marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest keeping
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on this PR, my understanding is that for any sequence configured with n-gram constraints, the engine checks for duplicates within the sliding window of the last
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hard to do this in batch. |
||
| def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer): | ||
| """Convert stop_words/bad_sords to ids and append the ids to | ||
| stop_token_ids/bad_token_ids.""" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,14 @@ | ||
| # Copyright (c) OpenMMLab. All rights reserved. | ||
| import asyncio | ||
| from dataclasses import dataclass, fields | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
| from functools import lru_cache | ||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| from lmdeploy.messages import LogitsProcessor | ||
| from lmdeploy.pytorch import envs | ||
|
|
||
| from ..messages import SchedulerSequence | ||
| from .guided_process import GuidedDecodingManager | ||
|
|
@@ -53,7 +56,7 @@ def _process_bad_words_(scores: torch.Tensor, | |
| return scores | ||
|
|
||
|
|
||
| def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor): | ||
| def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor): | ||
| """Process repetition penalty.""" | ||
| score = torch.gather(scores, 1, input_ids) | ||
| penalty = penalty.to(score.dtype) | ||
|
|
@@ -92,6 +95,131 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value: | |
| return scores | ||
|
|
||
|
|
||
| @lru_cache | ||
| def _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1): | ||
| return torch.ones(fill, dtype=dtype, device=device) | ||
|
|
||
|
|
||
| def ngram( | ||
| token_ids: torch.Tensor, | ||
| n: torch.Tensor | None, | ||
| threshold: torch.Tensor, | ||
| max_n: int, | ||
| max_window_size: int, | ||
| ): | ||
| """Compute n-gram matches between sliding windows and a target sequence. | ||
|
|
||
| For each batch, performs cosine similarity checking between: | ||
| - All sliding windows of length `max_n` from the full sequence | ||
| - The last `max_n` tokens of the sequence (target window) | ||
|
|
||
| A match is counted when both: | ||
| 1. Cosine similarity ≈ 1 (normalized vectors match) | ||
| 2. Vector lengths match (preventing zero/normalization artifacts) | ||
|
|
||
| Parameters | ||
| ---------- | ||
| token_ids : torch.Tensor | ||
| Input token IDs of shape (batch_size, seq_len). | ||
| Values are typically ≥0 (0 may represent padding/special tokens). | ||
| n : torch.Tensor | ||
| Effective n-gram length for each batch element, shape (batch_size,). | ||
| When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked. | ||
| threshold : torch.Tensor | ||
| Minimum number of matching windows required for validity, shape (batch_size,). | ||
| max_n : int | ||
| Maximum n-gram length (window size for matching). | ||
| max_window_size: int | ||
| Maximum window size for matching. | ||
|
|
||
| Returns | ||
| ------- | ||
| matched_mask : torch.Tensor | ||
| Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating | ||
| which sliding windows match the target n-gram. | ||
| found : torch.Tensor | ||
| Boolean tensor of shape (batch_size,) indicating whether each batch | ||
| element has at least `threshold` matches. | ||
| """ | ||
|
|
||
| batch_size, seq_len = token_ids.size() | ||
| if seq_len < max_n: | ||
| # Not enough tokens to form a single n-gram | ||
| matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device) | ||
| found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device) | ||
| return matched_mask, found | ||
| # token_ids could be 0, so we add 2 to avoid div 0 | ||
| token_ids = (token_ids + 2).to(torch.float32).log2() | ||
|
|
||
| # Trim to max_window_size | ||
| if seq_len >= max_window_size: | ||
| token_ids = token_ids[:, -max_window_size:] | ||
| max_window_size = token_ids.size(1) | ||
|
|
||
| # normalize ids | ||
| # we would set n=None if n shared same value. Read lmdeploy/pytorch/strategies/ar/sampling.py for more details | ||
| same_n = n is None | ||
|
lvhan028 marked this conversation as resolved.
|
||
| norm = token_ids[:, -max_n:] | ||
| if not same_n: | ||
| # fill 0 for n < max_n | ||
| mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1)) | ||
| norm = norm * mask.to(torch.float32) | ||
| norm = norm.norm(2, dim=-1, keepdim=True) | ||
| normed_ids = token_ids / norm | ||
|
|
||
| # concate p1 and p2 so we can check distance and vector in one conv1d | ||
| normed_n_ids = normed_ids[:, -max_n:] | ||
| normed_ids_p2 = normed_ids * normed_ids | ||
| ones_ids = torch.ones_like(normed_n_ids) | ||
| if not same_n: | ||
| # fill 0 for n < max_n | ||
| normed_n_ids = normed_n_ids * mask.to(torch.float32) | ||
| ones_ids = ones_ids * mask.to(torch.float32) | ||
| normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0) | ||
| normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0) | ||
|
|
||
| # check cos distance & check vector length | ||
| match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. brilliant! |
||
| match_norm, match_ones = match_norm.chunk(2, dim=0) | ||
|
|
||
| # both match result should be close to 1 | ||
| one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1) | ||
| matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor) | ||
|
|
||
| # threshold | ||
| count = matched_mask.sum(-1) | ||
| found = (count >= threshold) & (threshold > 0) | ||
|
|
||
| return matched_mask, found | ||
|
|
||
|
|
||
| def _filter_repetition_ngram_( | ||
| scores: torch.Tensor, | ||
| stop_words: torch.Tensor, | ||
| generated_ids: torch.Tensor, | ||
| n: torch.Tensor | None, | ||
| threshold: torch.Tensor, | ||
| max_n: int, | ||
| max_ngram_window_size: int, | ||
| ): | ||
| """Filter ngram. | ||
|
|
||
| if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist. | ||
| """ | ||
| if stop_words is None or stop_words.numel() == 0: | ||
| return scores | ||
| # use first stop words | ||
| _, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size) | ||
| stop_words = stop_words[:, 0] | ||
| # fill all scores -inf | ||
| scores.masked_fill_(found[:, None], -float('inf')) | ||
| # set stop words to 0 | ||
| stop_scores = scores.gather(1, stop_words[:, None]) | ||
| stop_scores.masked_fill_(found[:, None], 0) | ||
| scores.scatter_(1, stop_words[:, None], stop_scores) | ||
| return scores | ||
|
|
||
|
|
||
| def _multinomial_sampling(scores: torch.Tensor, | ||
| seeds: torch.LongTensor, | ||
| offsets: torch.LongTensor, | ||
|
|
@@ -101,14 +229,14 @@ def _multinomial_sampling(scores: torch.Tensor, | |
| return multinomial_sampling(scores, seeds, offsets, indices) | ||
|
|
||
|
|
||
| SeqList = List[SchedulerSequence] | ||
| SeqList = list[SchedulerSequence] | ||
|
|
||
|
|
||
| @dataclass | ||
| class SamplingInputsDelta: | ||
| num_ignore_eos: torch.Tensor = None | ||
| random_offsets: torch.Tensor = None | ||
| all_ids: Optional[torch.Tensor] = None | ||
| all_ids: None | torch.Tensor = None | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -126,18 +254,28 @@ class SamplingInputs: | |
| random_offsets: torch.Tensor = None | ||
| max_top_k: int = 1 | ||
| min_top_p: float = 1.0 | ||
| response_formats: Tuple[str] = () | ||
| logits_processors: List[List[LogitsProcessor]] = None | ||
| max_num_logprobs: Optional[int] = None | ||
| all_ids: Optional[torch.Tensor] = None | ||
| response_formats: list[str, ...] = () | ||
| logits_processors: list[list[LogitsProcessor]] = None | ||
| max_num_logprobs: None | int = None | ||
| all_ids: None | torch.Tensor = None | ||
| num_ignore_eos: torch.Tensor = None | ||
| batch_size: int = 0 | ||
| session_ctx: Optional[List[Dict[str, Any]]] = None | ||
| session_to_cleanup: Optional[List[int]] = None | ||
| session_ctx: None | list[dict[str, Any]] = None | ||
| session_to_cleanup: None | list[int] = None | ||
| # for repetition_penalty and ngram | ||
| generated_ids: torch.Tensor | None = None | ||
| generated_ids_cpu: np.ndarray | None = None | ||
|
|
||
| # n gram | ||
| repetition_ngram_size: torch.Tensor | None = None | ||
| repetition_ngram_threshold: torch.Tensor | None = None | ||
| max_repetition_ngram_size: int = 0 | ||
|
|
||
| def to_device(self, device: str, non_blocking: bool = False): | ||
| """To device.""" | ||
| out_dict = dict() | ||
| if self.generated_ids is None and self.generated_ids_cpu is not None: | ||
| self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy()) | ||
| for f in fields(self): | ||
| k = f.name | ||
| v = getattr(self, k) | ||
|
|
@@ -192,8 +330,8 @@ class FusedLogitsProcessor: | |
| def __init__( | ||
| self, | ||
| sampling_inputs: SamplingInputs, | ||
| logprobs_mode: Optional[str] = None, | ||
| guided_decoding_manager: Optional[GuidedDecodingManager] = None, | ||
| logprobs_mode: None | str = None, | ||
| guided_decoding_manager: None | GuidedDecodingManager = None, | ||
| ): | ||
| self.sampling_inputs: SamplingInputs = sampling_inputs | ||
| self.logprobs_mode = logprobs_mode | ||
|
|
@@ -213,18 +351,18 @@ async def _wait_stream_once(self): | |
| if not stream.query(): | ||
| await asyncio.sleep(0) | ||
|
|
||
| async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: | ||
| async def __call__(self, scores: torch.Tensor) -> torch.Tensor: | ||
| r""" | ||
| Args: | ||
| scores (torch.FloatTensor): | ||
| scores (torch.Tensor): | ||
| Prediction scores of a language modeling head. | ||
| These can be logits for each vocabulary when not using | ||
| beam search or log softmax for each vocabulary token | ||
| when using beam search | ||
|
|
||
|
|
||
| Return: | ||
| torch.FloatTensor: The processed prediction scores. | ||
| torch.Tensor: The processed prediction scores. | ||
|
|
||
| """ | ||
|
|
||
|
|
@@ -262,7 +400,23 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor: | |
|
|
||
| repetition_penalty = sampling_inputs.repetition_penalty | ||
| if repetition_penalty is not None: | ||
| scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty) | ||
| generated_ids = sampling_inputs.generated_ids | ||
| scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty) | ||
|
grimoire marked this conversation as resolved.
|
||
|
|
||
| if sampling_inputs.max_repetition_ngram_size > 0: | ||
| generated_ids = sampling_inputs.generated_ids | ||
| assert generated_ids is not None | ||
| assert sampling_inputs.repetition_ngram_threshold is not None | ||
| max_repetition_ngram_window_size = envs.repetition_window_size | ||
| scores = _filter_repetition_ngram_( | ||
| scores, | ||
| sampling_inputs.stop_words, | ||
| generated_ids, | ||
| sampling_inputs.repetition_ngram_size, | ||
| sampling_inputs.repetition_ngram_threshold, | ||
| sampling_inputs.max_repetition_ngram_size, | ||
| max_repetition_ngram_window_size, | ||
| ) | ||
|
|
||
| temperature = sampling_inputs.temperature | ||
| if temperature is not None: | ||
|
|
@@ -346,7 +500,7 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens | |
|
|
||
| return logprobs, indices.to(torch.int32) | ||
|
|
||
| def cleanup_sessions(self, session_ids: List[int]): | ||
| def cleanup_sessions(self, session_ids: list[int]): | ||
| if self.guided_decoding_manager: | ||
| for session_id in session_ids: | ||
| self.guided_decoding_manager.remove_processor(session_id) | ||
Uh oh!
There was an error while loading. Please reload this page.