Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class GenerationConfig:
}

logits_processors: Custom logit processors.
repetition_ngram_size: The size of n-grams to consider for repetition early stop.
repetition_ngram_threshold: The number of times an n-gram must be repeated to trigger early stop.
"""

n: int = 1
Expand Down Expand Up @@ -129,6 +131,10 @@ class GenerationConfig:
# router replay
return_routed_experts: bool = False

# ngram, generation would stop if latest [size] tokens are repeated for [threshold] times
repetition_ngram_size: int = 0
Comment thread
RunningLeon marked this conversation as resolved.
repetition_ngram_threshold: int = 0
Comment thread
grimoire marked this conversation as resolved.

Comment thread
grimoire marked this conversation as resolved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest keeping repetition_ngram_window_size as an internal constant rather than exposing it in the public API. A fixed value of 1024 should be sufficient for most use cases

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on this PR, my understanding is that for any sequence configured with n-gram constraints, the engine checks for duplicates within the sliding window of the last max_ngram_window_size tokens (i.e., token_ids[-max_ngram_window_size:]). Once the repetition frequency exceeds the threshold, the engine prevents further token generation. Please correct me if I'm mistaken.
Since this behavior heavily depends on the max_ngram_window_size setting, I'm considering an alternative approach: perform the n-gram check only every max_ngram_window_size steps. Specifically, we could trigger the check when the length of generated_token_ids is divisible by max_ngram_window_size (i.e., len % size == 0). What do you think?"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perform the n-gram check only every max_ngram_window_size steps. Specifically, we could trigger the check when the length of generated_token_ids is divisible by max_ngram_window_size (i.e., len % size == 0).

Hard to do this in batch.

def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""Convert stop_words/bad_sords to ids and append the ids to
stop_token_ids/bad_token_ids."""
Expand Down
188 changes: 171 additions & 17 deletions lmdeploy/pytorch/engine/logits_process.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional, Tuple
from functools import lru_cache
from typing import Any

import numpy as np
import torch

from lmdeploy.messages import LogitsProcessor
from lmdeploy.pytorch import envs

from ..messages import SchedulerSequence
from .guided_process import GuidedDecodingManager
Expand Down Expand Up @@ -53,7 +56,7 @@ def _process_bad_words_(scores: torch.Tensor,
return scores


def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.LongTensor, penalty: torch.Tensor):
def _process_repetition_penalty_(scores: torch.Tensor, input_ids: torch.Tensor, penalty: torch.Tensor):
"""Process repetition penalty."""
score = torch.gather(scores, 1, input_ids)
penalty = penalty.to(score.dtype)
Expand Down Expand Up @@ -92,6 +95,131 @@ def _filter_minp_sorted_(scores: torch.Tensor, minp: torch.Tensor, filter_value:
return scores


@lru_cache
def _ngram_one(dtype: torch.dtype, device: torch.device, fill: int = 1):
return torch.ones(fill, dtype=dtype, device=device)


def ngram(
token_ids: torch.Tensor,
n: torch.Tensor | None,
threshold: torch.Tensor,
max_n: int,
max_window_size: int,
):
"""Compute n-gram matches between sliding windows and a target sequence.

For each batch, performs cosine similarity checking between:
- All sliding windows of length `max_n` from the full sequence
- The last `max_n` tokens of the sequence (target window)

A match is counted when both:
1. Cosine similarity ≈ 1 (normalized vectors match)
2. Vector lengths match (preventing zero/normalization artifacts)

Parameters
----------
token_ids : torch.Tensor
Input token IDs of shape (batch_size, seq_len).
Values are typically ≥0 (0 may represent padding/special tokens).
n : torch.Tensor
Effective n-gram length for each batch element, shape (batch_size,).
When `same_n=False`, positions beyond `n` in the last `max_n` tokens are masked.
threshold : torch.Tensor
Minimum number of matching windows required for validity, shape (batch_size,).
max_n : int
Maximum n-gram length (window size for matching).
max_window_size: int
Maximum window size for matching.

Returns
-------
matched_mask : torch.Tensor
Boolean mask of shape (batch_size, seq_len - max_n + 1) indicating
which sliding windows match the target n-gram.
found : torch.Tensor
Boolean tensor of shape (batch_size,) indicating whether each batch
element has at least `threshold` matches.
"""

batch_size, seq_len = token_ids.size()
if seq_len < max_n:
# Not enough tokens to form a single n-gram
matched_mask = torch.zeros((batch_size, 0), dtype=torch.bool, device=token_ids.device)
found = torch.zeros((batch_size, ), dtype=torch.bool, device=token_ids.device)
return matched_mask, found
# token_ids could be 0, so we add 2 to avoid div 0
token_ids = (token_ids + 2).to(torch.float32).log2()

# Trim to max_window_size
if seq_len >= max_window_size:
token_ids = token_ids[:, -max_window_size:]
max_window_size = token_ids.size(1)

# normalize ids
# we would set n=None if n shared same value. Read lmdeploy/pytorch/strategies/ar/sampling.py for more details
same_n = n is None
Comment thread
lvhan028 marked this conversation as resolved.
norm = token_ids[:, -max_n:]
if not same_n:
# fill 0 for n < max_n
mask = torch.arange(max_n, device=token_ids.device).unsqueeze(0) >= (max_n - n.unsqueeze(1))
norm = norm * mask.to(torch.float32)
norm = norm.norm(2, dim=-1, keepdim=True)
normed_ids = token_ids / norm

# concate p1 and p2 so we can check distance and vector in one conv1d
normed_n_ids = normed_ids[:, -max_n:]
normed_ids_p2 = normed_ids * normed_ids
ones_ids = torch.ones_like(normed_n_ids)
if not same_n:
# fill 0 for n < max_n
normed_n_ids = normed_n_ids * mask.to(torch.float32)
ones_ids = ones_ids * mask.to(torch.float32)
normed_ids = torch.cat([normed_ids, normed_ids_p2], dim=0)
normed_n_ids = torch.cat([normed_n_ids, ones_ids], dim=0)

# check cos distance & check vector length
match_norm = torch.conv1d(normed_ids.unsqueeze(0), normed_n_ids.unsqueeze(1), groups=batch_size * 2)[0]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

brilliant!

match_norm, match_ones = match_norm.chunk(2, dim=0)

# both match result should be close to 1
one_tensor = _ngram_one(dtype=match_norm.dtype, device=match_norm.device, fill=1)
matched_mask = match_norm.isclose(one_tensor) & match_ones.isclose(one_tensor)

# threshold
count = matched_mask.sum(-1)
found = (count >= threshold) & (threshold > 0)

return matched_mask, found


def _filter_repetition_ngram_(
scores: torch.Tensor,
stop_words: torch.Tensor,
generated_ids: torch.Tensor,
n: torch.Tensor | None,
threshold: torch.Tensor,
max_n: int,
max_ngram_window_size: int,
):
"""Filter ngram.

if generated ngram found, set all scores -inf, and set stop words to 0. We assume that stop words always exist.
"""
if stop_words is None or stop_words.numel() == 0:
return scores
# use first stop words
_, found = ngram(generated_ids, n, threshold, max_n, max_ngram_window_size)
stop_words = stop_words[:, 0]
# fill all scores -inf
scores.masked_fill_(found[:, None], -float('inf'))
# set stop words to 0
stop_scores = scores.gather(1, stop_words[:, None])
stop_scores.masked_fill_(found[:, None], 0)
scores.scatter_(1, stop_words[:, None], stop_scores)
return scores


def _multinomial_sampling(scores: torch.Tensor,
seeds: torch.LongTensor,
offsets: torch.LongTensor,
Expand All @@ -101,14 +229,14 @@ def _multinomial_sampling(scores: torch.Tensor,
return multinomial_sampling(scores, seeds, offsets, indices)


SeqList = List[SchedulerSequence]
SeqList = list[SchedulerSequence]


@dataclass
class SamplingInputsDelta:
num_ignore_eos: torch.Tensor = None
random_offsets: torch.Tensor = None
all_ids: Optional[torch.Tensor] = None
all_ids: None | torch.Tensor = None


@dataclass
Expand All @@ -126,18 +254,28 @@ class SamplingInputs:
random_offsets: torch.Tensor = None
max_top_k: int = 1
min_top_p: float = 1.0
response_formats: Tuple[str] = ()
logits_processors: List[List[LogitsProcessor]] = None
max_num_logprobs: Optional[int] = None
all_ids: Optional[torch.Tensor] = None
response_formats: list[str, ...] = ()
logits_processors: list[list[LogitsProcessor]] = None
max_num_logprobs: None | int = None
all_ids: None | torch.Tensor = None
num_ignore_eos: torch.Tensor = None
batch_size: int = 0
session_ctx: Optional[List[Dict[str, Any]]] = None
session_to_cleanup: Optional[List[int]] = None
session_ctx: None | list[dict[str, Any]] = None
session_to_cleanup: None | list[int] = None
# for repetition_penalty and ngram
generated_ids: torch.Tensor | None = None
generated_ids_cpu: np.ndarray | None = None

# n gram
repetition_ngram_size: torch.Tensor | None = None
repetition_ngram_threshold: torch.Tensor | None = None
max_repetition_ngram_size: int = 0

def to_device(self, device: str, non_blocking: bool = False):
"""To device."""
out_dict = dict()
if self.generated_ids is None and self.generated_ids_cpu is not None:
self.generated_ids = torch.from_numpy(self.generated_ids_cpu.copy())
for f in fields(self):
k = f.name
v = getattr(self, k)
Expand Down Expand Up @@ -192,8 +330,8 @@ class FusedLogitsProcessor:
def __init__(
self,
sampling_inputs: SamplingInputs,
logprobs_mode: Optional[str] = None,
guided_decoding_manager: Optional[GuidedDecodingManager] = None,
logprobs_mode: None | str = None,
guided_decoding_manager: None | GuidedDecodingManager = None,
):
self.sampling_inputs: SamplingInputs = sampling_inputs
self.logprobs_mode = logprobs_mode
Expand All @@ -213,18 +351,18 @@ async def _wait_stream_once(self):
if not stream.query():
await asyncio.sleep(0)

async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:
async def __call__(self, scores: torch.Tensor) -> torch.Tensor:
r"""
Args:
scores (torch.FloatTensor):
scores (torch.Tensor):
Prediction scores of a language modeling head.
These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token
when using beam search


Return:
torch.FloatTensor: The processed prediction scores.
torch.Tensor: The processed prediction scores.

"""

Expand Down Expand Up @@ -262,7 +400,23 @@ async def __call__(self, scores: torch.FloatTensor) -> torch.FloatTensor:

repetition_penalty = sampling_inputs.repetition_penalty
if repetition_penalty is not None:
scores = _process_repetition_penalty_(scores, all_ids, repetition_penalty)
generated_ids = sampling_inputs.generated_ids
scores = _process_repetition_penalty_(scores, generated_ids, repetition_penalty)
Comment thread
grimoire marked this conversation as resolved.

if sampling_inputs.max_repetition_ngram_size > 0:
generated_ids = sampling_inputs.generated_ids
assert generated_ids is not None
assert sampling_inputs.repetition_ngram_threshold is not None
max_repetition_ngram_window_size = envs.repetition_window_size
scores = _filter_repetition_ngram_(
scores,
sampling_inputs.stop_words,
generated_ids,
sampling_inputs.repetition_ngram_size,
sampling_inputs.repetition_ngram_threshold,
sampling_inputs.max_repetition_ngram_size,
max_repetition_ngram_window_size,
)

temperature = sampling_inputs.temperature
if temperature is not None:
Expand Down Expand Up @@ -346,7 +500,7 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens

return logprobs, indices.to(torch.int32)

def cleanup_sessions(self, session_ids: List[int]):
def cleanup_sessions(self, session_ids: list[int]):
if self.guided_decoding_manager:
for session_id in session_ids:
self.guided_decoding_manager.remove_processor(session_id)
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ def _patched_get_env(
# model format
scale_fmt = os.getenv('LMDEPLOY_SCALE_FMT', None)

# repetition check
repetition_window_size = env_to_int('LMDEPLOY_REPETITION_WINDOW_SIZE', 1024)


def get_all_envs():
"""Get all environment variables."""
Expand Down
22 changes: 14 additions & 8 deletions lmdeploy/pytorch/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List

import numpy as np
import torch
Expand Down Expand Up @@ -56,13 +56,17 @@ class SamplingParam:
bad_words: List[int] = field(default_factory=list)
max_new_tokens: int = 512
min_new_tokens: int = 0
response_format: Optional[str] = None
logits_processors: Optional[List[LogitsProcessor]] = None
response_format: None | str = None
logits_processors: None | List[LogitsProcessor] = None
out_logits: bool = False
out_last_hidden_states: bool = False
num_logprobs: int = -1
return_routed_experts: bool = False

# ngram
repetition_ngram_size: int = 0
repetition_ngram_threshold: int = 0

@classmethod
def from_gen_config(cls, gen_config: GenerationConfig):
"""From gen config."""
Expand Down Expand Up @@ -144,6 +148,8 @@ def from_gen_config(cls, gen_config: GenerationConfig):
out_logits=(output_logits is not None),
num_logprobs=logprobs,
return_routed_experts=gen_config.return_routed_experts,
repetition_ngram_size=gen_config.repetition_ngram_size,
repetition_ngram_threshold=gen_config.repetition_ngram_threshold,
)


Expand Down Expand Up @@ -262,7 +268,7 @@ def add_sequence(self,
adapter_name: str = None,
multimodals: MultiModalInputs = None,
input_embeddings: List[InputEmbeddings] = None,
migration_request: Optional[MigrationRequest] = None,
migration_request: None | MigrationRequest = None,
resp_cache: bool = False,
preserve_cache: bool = False) -> 'SchedulerSequence':
"""Add a new message."""
Expand Down Expand Up @@ -604,7 +610,7 @@ class SchedulerSequence:
model_meta: Dict[str, Any] = None

# For Disaggregation
migration_request: Optional[MigrationRequest] = None
migration_request: None | MigrationRequest = None
resp_cache: bool = False
preserve_cache: bool = False

Expand Down Expand Up @@ -698,7 +704,7 @@ def routed_experts(self) -> np.ndarray:
else:
return None

def append_routed_experts(self, routed_experts: Union[Tensor, np.ndarray]):
def append_routed_experts(self, routed_experts: Tensor | np.ndarray):
"""Append routed experts."""
if not self.return_routed_experts:
return
Expand Down Expand Up @@ -756,7 +762,7 @@ def logits(self):
"""Get logits."""
return self.all_logits.get_logits()

def append_logits(self, logits: Union[Tensor, np.ndarray]):
def append_logits(self, logits: Tensor | np.ndarray):
"""Append logits."""
if not self.return_logits:
return
Expand All @@ -776,7 +782,7 @@ def get_input_multimodals(self):
def record_event(
self,
event_type: EventType,
timestamp: Optional[float] = None,
timestamp: None | float = None,
) -> None:
self.engine_events.append(EngineEvent.new_event(event_type, timestamp))

Expand Down
Loading