diff --git a/docs/features/thinking_budget.md b/docs/features/thinking_budget.md
index 2dbaf30e367..c236cb4fcb4 100644
--- a/docs/features/thinking_budget.md
+++ b/docs/features/thinking_budget.md
@@ -2,7 +2,9 @@
## Overview
-`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the ` ... ` segment. When the budget is reached, it forces a line break token and then the `` token to terminate the thinking section.
+`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the ` ... `
+segment. When the budget is reached, it terminates thinking by forcing ``. If
+`think_stop_sentence` is configured, it forces the custom sentence first and then ``.
## When to Use
@@ -11,19 +13,22 @@
## How It Works
-1. **CPU precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
+1. **Request-side precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section.
2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`.
-3. **Budget enforcement**: once the budget is reached, it forces a line break and then the thinking end token.
+3. **Budget enforcement**: once the budget is reached, it forces `` directly. If `think_stop_sentence`
+ is configured, it forces that sentence first and then ``.
## Requirements
-- The model must provide valid token ids for `think_start_id`, `think_end_id`, and `line_break_id` (via `ModelConfig`).
-- If any of these ids are invalid, the processor is disabled and `thinking_budget` will not take effect.
+- The model must provide valid token ids for `think_start_id` and `think_end_id` (via `ModelConfig`).
+- If either of these ids is invalid, the processor is disabled and `thinking_budget` will not take effect.
## Request Parameters
-- `thinking_budget` (int, required to enable): maximum number of tokens after `` before forced termination.
-- `think_stop_sentence` (string, optional): a stop sentence that will be tokenized on the CPU side and enforced near the budget boundary.
+- `thinking_budget` (int, required to enable): maximum number of decode-time tokens after `` before forced
+ termination.
+- `think_stop_sentence` (string, optional): a literal custom sentence that will be tokenized on the request side
+ and enforced near the budget boundary.
## Operator-Level vs LogitsProcessor
@@ -41,16 +46,25 @@ FastDeploy has two ways to limit thinking length:
In short:
- If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`.
-- If you need custom behavior (for example, injecting custom sentence tokens), use `ThinkingBudgetLogitsProcessor`.
+- If you need custom behavior (for example, inserting a custom sentence before ``), use
+ `ThinkingBudgetLogitsProcessor`.
## Practical guidance
`reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation.
If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase.
-- To use **operator-level-only** behavior: this is request-level config only. Set `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
-- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`, and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave `reasoning_max_tokens` unset.
-- Avoid enabling both for strict custom sentence insertion requirements, because operator-level termination may cut the custom sentence path earlier.
+- To use **operator-level-only** behavior: this is request-level config only. Set
+ `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`.
+- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires
+ service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`,
+ and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave
+ `reasoning_max_tokens` unset.
+- `thinking_budget` itself does not require `enable_thinking=true`.
+- If an ERNIE chat template already appends `` in the prompt, `thinking_budget` should still take effect; it
+ does not require the model to emit another `` during decoding.
+- Avoid enabling both for strict custom sentence insertion requirements, because operator-level
+ termination may cut the custom sentence path earlier.
## Online Usage
@@ -120,4 +134,6 @@ print(outputs[0].outputs.text)
## Performance Note
-This processor runs `update_state` and `apply` on every decode step. If you only need a hard thinking-length cap and care most about throughput, consider the operator-level reasoning-length controls instead of per-step logits processing.
+This processor runs `update_state` and `apply` on every decode step. If you only need a hard
+thinking-length cap and care most about throughput, consider the operator-level reasoning-length
+controls instead of per-step logits processing.
diff --git a/docs/zh/features/thinking_budget.md b/docs/zh/features/thinking_budget.md
index bf614a24c0b..70e58d0c789 100644
--- a/docs/zh/features/thinking_budget.md
+++ b/docs/zh/features/thinking_budget.md
@@ -2,7 +2,9 @@
## 概述
-`ThinkingBudgetLogitsProcessor` 用于限制 ` ... ` 区间的生成长度。当预算达到阈值时,会强制生成换行符 token,再强制生成 ``,从而结束思考段。
+`ThinkingBudgetLogitsProcessor` 用于限制 ` ... ` 区间的生成长度。当预算达到阈值时,
+会直接强制生成 `` 来结束思考段;如果配置了 `think_stop_sentence`,则会先强制输出该自定义
+文案,再输出 ``。
## 适用场景
@@ -11,19 +13,20 @@
## 工作原理
-1. **CPU 侧预计算(DataProcessor)**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
+1. **请求侧预计算(DataProcessor)**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。
2. **每步更新**:解码过程中跟踪 `last_token_id` 与 `tokens_after_start`。
-3. **预算约束**:达到预算后,依次强制换行符与思考结束 token。
+3. **预算约束**:达到预算后,默认直接强制 ``;如果配置了 `think_stop_sentence`,则先逐 token
+ 强制输出该文案,再输出 ``。
## 前置要求
-- 模型需提供有效的 `think_start_id`、`think_end_id`、`line_break_id`(来自 `ModelConfig`)。
-- 若任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
+- 模型需提供有效的 `think_start_id`、`think_end_id`(来自 `ModelConfig`)。
+- 若其中任意 id 无效,处理器会禁用,`thinking_budget` 不生效。
## 请求参数
-- `thinking_budget`(int,启用所需):`` 之后允许的最大 token 数。
-- `think_stop_sentence`(string,可选):CPU 侧会将该字符串编码为 token ids,并在预算边界附近强制输出。
+- `thinking_budget`(int,启用所需):`` 之后允许的最大 decode 阶段 token 数。
+- `think_stop_sentence`(string,可选):按字面串编码的自定义终止文案,并在预算边界附近强制输出。
## 算子级限制 vs LogitsProcessor
@@ -35,21 +38,27 @@ FastDeploy 当前有两种思考长度控制方式:
- 适合“只限制思考长度”的简单场景。
- **`ThinkingBudgetLogitsProcessor`**(`logits_processors_args.thinking_budget`):
- 由每步 Python 侧 logits 处理实现。
- - 支持更灵活的行为,例如 `think_stop_sentence`(在结束前插入自定义话术)。
+ - 支持更灵活的行为,例如 `think_stop_sentence`。
- 相比算子级限制,在高并发下通常有更高开销。
可按以下原则选择:
- 仅需限制思考长度:优先用 `reasoning_max_tokens`。
-- 需要更灵活控制(如插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`。
+- 需要更灵活控制(如在 `` 前插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`。
## 建议实践
当前实现中,`reasoning_max_tokens` 与 `thinking_budget` 不是互斥关系。
同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。
-- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`,不要传 `thinking_budget`。
-- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过 `logits_processors_args` 传 `thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置 `reasoning_max_tokens`。
+- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`,
+ 不要传 `thinking_budget`。
+- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。
+ 服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过
+ `logits_processors_args` 传 `thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置
+ `reasoning_max_tokens`。
+- `thinking_budget` 本身不依赖 `enable_thinking=true`。
+- 如果 ERNIE 的 chat template 已经在 prompt 里拼入 ``,`thinking_budget` 也应正常生效,不要求模型在 decode 阶段再次输出 ``。
- 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。
## 在线使用
diff --git a/fastdeploy/input/ernie4_5_processor.py b/fastdeploy/input/ernie4_5_processor.py
index 127f7aeccab..6559477c119 100644
--- a/fastdeploy/input/ernie4_5_processor.py
+++ b/fastdeploy/input/ernie4_5_processor.py
@@ -107,6 +107,11 @@ def process_request_dict(self, request, max_model_len=None):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
+ logits_processors_args = self._prepare_think_stop_sentence(
+ request.get("logits_processors_args") or {}, max_model_len
+ )
+ request["logits_processors_args"] = logits_processors_args
+
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if request.get("prompt"):
@@ -143,6 +148,10 @@ def process_request_dict(self, request, max_model_len=None):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
+ logits_processors_args = self._update_thinking_prompt_state(
+ request["prompt_token_ids"], request.get("logits_processors_args") or {}
+ )
+ request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request["prompt_token_ids"])
if request.get("max_tokens") is None:
request["max_tokens"] = max(1, max_tokens)
diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py
index 7633c2ecf45..f42b35a8eef 100644
--- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py
+++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py
@@ -210,6 +210,11 @@ def process_request_dict(self, request, max_model_len=None):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
+ logits_processors_args = self._prepare_think_stop_sentence(
+ request.get("logits_processors_args") or {}, max_model_len
+ )
+ request["logits_processors_args"] = logits_processors_args
+
if request.get("prompt_token_ids"):
messages = request.get("messages")
if messages:
@@ -257,6 +262,10 @@ def process_request_dict(self, request, max_model_len=None):
# 截断超过长度限制的prompt
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
+ logits_processors_args = self._update_thinking_prompt_state(
+ request["prompt_token_ids"], request.get("logits_processors_args") or {}
+ )
+ request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request["prompt_token_ids"])
if request.get("max_tokens") is None:
diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py
index fe0e16e1f49..5c974734324 100644
--- a/fastdeploy/input/text_processor.py
+++ b/fastdeploy/input/text_processor.py
@@ -92,6 +92,9 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False):
"""
Encode text into token ids with a small LRU cache.
"""
+ if not hasattr(self, "_tokenize_cache"):
+ self._tokenize_cache = OrderedDict()
+ self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key)
if cached is not None:
@@ -107,6 +110,25 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False):
self._tokenize_cache.popitem(last=False)
return token_ids
+ def _encode_literal_text_with_cache(self, text):
+ if not hasattr(self, "_tokenize_cache"):
+ self._tokenize_cache = OrderedDict()
+ self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
+ key = ("literal_text", text)
+ cached = self._tokenize_cache.get(key)
+ if cached is not None:
+ self._tokenize_cache.move_to_end(key)
+ return cached
+ token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
+ if hasattr(token_ids, "tolist"):
+ token_ids = token_ids.tolist()
+ elif not isinstance(token_ids, list):
+ token_ids = list(token_ids)
+ self._tokenize_cache[key] = token_ids
+ if len(self._tokenize_cache) > self._tokenize_cache_capacity:
+ self._tokenize_cache.popitem(last=False)
+ return token_ids
+
def messages2ids(self, messages):
"""
Convert multi-turn messages into ID sequences.
@@ -119,6 +141,77 @@ def messages2ids(self, messages):
"""
raise NotImplementedError
+ def _get_think_token_ids(self):
+ think_token_ids = getattr(self, "_think_token_ids", None)
+ if think_token_ids is not None:
+ return think_token_ids
+ tokenizer = getattr(self, "tokenizer", None)
+ vocab = tokenizer.get_vocab() if tokenizer is not None else {}
+ think_start_id = vocab.get("", -1)
+ think_end_id = vocab.get("", -1)
+ self._think_token_ids = (think_start_id, think_end_id)
+ return self._think_token_ids
+
+ def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
+ if not isinstance(logits_processors_args, dict):
+ return logits_processors_args
+ think_stop_sentence = logits_processors_args.get("think_stop_sentence")
+ if isinstance(think_stop_sentence, str) and think_stop_sentence:
+ sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
+ logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
+ logits_processors_args.pop("think_stop_sentence", None)
+ return logits_processors_args
+
+ def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
+ if not isinstance(logits_processors_args, dict):
+ return logits_processors_args
+ thinking_budget = logits_processors_args.get("thinking_budget")
+ if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
+ return logits_processors_args
+ if logits_processors_args.get("think_prompt_checked"):
+ return logits_processors_args
+ if prompt_token_ids is None:
+ return logits_processors_args
+ token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
+ if token_len == 0:
+ return logits_processors_args
+ think_start_id, think_end_id = self._get_think_token_ids()
+ if think_start_id < 0 or think_end_id < 0:
+ return logits_processors_args
+
+ if hasattr(prompt_token_ids, "tolist"):
+ token_list = prompt_token_ids.tolist()
+ else:
+ token_list = list(prompt_token_ids)
+
+ started = False
+ ended = False
+ tokens_after_start = 0
+ last_token_id = None
+ in_thinking = False
+ for token_id in token_list:
+ if token_id == think_start_id:
+ started = True
+ ended = False
+ in_thinking = True
+ elif token_id == think_end_id and in_thinking:
+ ended = True
+ in_thinking = False
+ if started and token_list:
+ # Align with operator-level reasoning_max_tokens: prompt-side tokens
+ # inside do not consume thinking budget.
+ last_token_id = int(token_list[-1])
+
+ logits_processors_args["think_prompt_checked"] = True
+ logits_processors_args["think_prompt_started"] = started
+ logits_processors_args["think_prompt_ended"] = ended
+ logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
+ if last_token_id is not None:
+ logits_processors_args["think_prompt_last_token_id"] = last_token_id
+ else:
+ logits_processors_args.pop("think_prompt_last_token_id", None)
+ return logits_processors_args
+
def ids2tokens(self, token_id, task_id=None):
"""
token ids to strings
@@ -207,65 +300,6 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
- self._think_token_ids = None
-
- def _get_think_token_ids(self):
- if self._think_token_ids is not None:
- return self._think_token_ids
- vocab = self.tokenizer.get_vocab()
- think_start_id = vocab.get("", -1)
- think_end_id = vocab.get("", -1)
- self._think_token_ids = (think_start_id, think_end_id)
- return self._think_token_ids
-
- def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
- if not isinstance(logits_processors_args, dict):
- return logits_processors_args
- thinking_budget = logits_processors_args.get("thinking_budget")
- if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
- return logits_processors_args
- if logits_processors_args.get("think_prompt_checked"):
- return logits_processors_args
- if prompt_token_ids is None:
- return logits_processors_args
- token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
- if token_len == 0:
- return logits_processors_args
- think_start_id, think_end_id = self._get_think_token_ids()
- if think_start_id < 0 or think_end_id < 0:
- return logits_processors_args
-
- if hasattr(prompt_token_ids, "tolist"):
- token_list = prompt_token_ids.tolist()
- else:
- token_list = list(prompt_token_ids)
-
- started = think_start_id in token_list
- ended = False
- tokens_after_start = 0
- last_token_id = None
- if started:
- start_pos = token_list.index(think_start_id)
- tokens_after = token_list[start_pos + 1 :]
- if think_end_id in tokens_after:
- end_pos = tokens_after.index(think_end_id)
- tokens_after_start = end_pos + 1
- ended = True
- else:
- tokens_after_start = len(tokens_after)
- if token_list:
- last_token_id = int(token_list[-1])
-
- logits_processors_args["think_prompt_checked"] = True
- logits_processors_args["think_prompt_started"] = started
- logits_processors_args["think_prompt_ended"] = ended
- logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
- if last_token_id is not None:
- logits_processors_args["think_prompt_last_token_id"] = last_token_id
- else:
- logits_processors_args.pop("think_prompt_last_token_id", None)
- return logits_processors_args
-
def process_request_dict(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
@@ -292,14 +326,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
- logits_processors_args = request.get("logits_processors_args") or {}
- think_stop_sentence = logits_processors_args.get("think_stop_sentence")
- if isinstance(think_stop_sentence, str) and think_stop_sentence:
- newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
- sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
- logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
- logits_processors_args.pop("think_stop_sentence", None)
- request["logits_processors_args"] = logits_processors_args
+ logits_processors_args = self._prepare_think_stop_sentence(
+ request.get("logits_processors_args") or {}, max_model_len
+ )
+ request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
diff --git a/fastdeploy/input/v1/ernie4_5_processor.py b/fastdeploy/input/v1/ernie4_5_processor.py
index 31741d22ac2..f6545dc068a 100644
--- a/fastdeploy/input/v1/ernie4_5_processor.py
+++ b/fastdeploy/input/v1/ernie4_5_processor.py
@@ -107,6 +107,11 @@ def process_request(self, request, max_model_len=None, **kwargs):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
+ logits_processors_args = self._prepare_think_stop_sentence(
+ request.get("logits_processors_args") or {}, max_model_len
+ )
+ request["logits_processors_args"] = logits_processors_args
+
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
@@ -123,7 +128,8 @@ def process_request(self, request, max_model_len=None, **kwargs):
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
request.prompt_token_ids = token_ids
data_processor_logger.debug(
- f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}"
+ f"request_ids: {request.request_id}, prompt: {prompt}, "
+ f"tokens: {tokens}, token_ids: {token_ids}"
)
elif request.messages is not None:
task = request.to_dict()
@@ -145,6 +151,10 @@ def process_request(self, request, max_model_len=None, **kwargs):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
+ logits_processors_args = self._update_thinking_prompt_state(
+ request.prompt_token_ids, request.get("logits_processors_args") or {}
+ )
+ request["logits_processors_args"] = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if request.get("max_tokens") is None:
request.set("max_tokens", max(1, max_tokens))
@@ -201,6 +211,11 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
+ logits_processors_args = self._prepare_think_stop_sentence(
+ getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
+ )
+ request.sampling_params.logits_processors_args = logits_processors_args
+
# processing prompt_token_ids
if not request.prompt_token_ids:
if request.prompt:
@@ -241,6 +256,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
+ logits_processors_args = self._update_thinking_prompt_state(
+ request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
+ )
+ request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None:
request.sampling_params.max_tokens = max(1, max_tokens)
diff --git a/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py
index 700b5f51bbc..bae80b60a96 100644
--- a/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py
+++ b/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py
@@ -216,6 +216,11 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
+ logits_processors_args = self._prepare_think_stop_sentence(
+ getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
+ )
+ request.sampling_params.logits_processors_args = logits_processors_args
+
if request.prompt_token_ids:
messages = request.messages
if messages:
@@ -267,6 +272,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
# 截断超过长度限制的prompt
if max_model_len is not None and len(request.prompt_token_ids) > max_model_len:
request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1]
+ logits_processors_args = self._update_thinking_prompt_state(
+ request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {}
+ )
+ request.sampling_params.logits_processors_args = logits_processors_args
max_tokens = max_model_len - len(request.prompt_token_ids)
if getattr(request.sampling_params, "max_tokens", None) is None:
diff --git a/fastdeploy/input/v1/text_processor.py b/fastdeploy/input/v1/text_processor.py
index 88610c007b6..f83a0e0f12e 100644
--- a/fastdeploy/input/v1/text_processor.py
+++ b/fastdeploy/input/v1/text_processor.py
@@ -115,6 +115,9 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False):
"""
Encode text into token ids with a small LRU cache.
"""
+ if not hasattr(self, "_tokenize_cache"):
+ self._tokenize_cache = OrderedDict()
+ self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
key = (text, bool(add_special_tokens))
cached = self._tokenize_cache.get(key)
if cached is not None:
@@ -130,6 +133,25 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False):
self._tokenize_cache.popitem(last=False)
return token_ids
+ def _encode_literal_text_with_cache(self, text):
+ if not hasattr(self, "_tokenize_cache"):
+ self._tokenize_cache = OrderedDict()
+ self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128)
+ key = ("literal_text", text)
+ cached = self._tokenize_cache.get(key)
+ if cached is not None:
+ self._tokenize_cache.move_to_end(key)
+ return cached
+ token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text))
+ if hasattr(token_ids, "tolist"):
+ token_ids = token_ids.tolist()
+ elif not isinstance(token_ids, list):
+ token_ids = list(token_ids)
+ self._tokenize_cache[key] = token_ids
+ if len(self._tokenize_cache) > self._tokenize_cache_capacity:
+ self._tokenize_cache.popitem(last=False)
+ return token_ids
+
def messages2ids(self, messages):
"""
Convert multi-turn messages into ID sequences.
@@ -142,6 +164,77 @@ def messages2ids(self, messages):
"""
raise NotImplementedError
+ def _get_think_token_ids(self):
+ think_token_ids = getattr(self, "_think_token_ids", None)
+ if think_token_ids is not None:
+ return think_token_ids
+ tokenizer = getattr(self, "tokenizer", None)
+ vocab = tokenizer.get_vocab() if tokenizer is not None else {}
+ think_start_id = vocab.get("", -1)
+ think_end_id = vocab.get("", -1)
+ self._think_token_ids = (think_start_id, think_end_id)
+ return self._think_token_ids
+
+ def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None):
+ if not isinstance(logits_processors_args, dict):
+ return logits_processors_args
+ think_stop_sentence = logits_processors_args.get("think_stop_sentence")
+ if isinstance(think_stop_sentence, str) and think_stop_sentence:
+ sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence)
+ logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids
+ logits_processors_args.pop("think_stop_sentence", None)
+ return logits_processors_args
+
+ def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
+ if not isinstance(logits_processors_args, dict):
+ return logits_processors_args
+ thinking_budget = logits_processors_args.get("thinking_budget")
+ if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
+ return logits_processors_args
+ if logits_processors_args.get("think_prompt_checked"):
+ return logits_processors_args
+ if prompt_token_ids is None:
+ return logits_processors_args
+ token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
+ if token_len == 0:
+ return logits_processors_args
+ think_start_id, think_end_id = self._get_think_token_ids()
+ if think_start_id < 0 or think_end_id < 0:
+ return logits_processors_args
+
+ if hasattr(prompt_token_ids, "tolist"):
+ token_list = prompt_token_ids.tolist()
+ else:
+ token_list = list(prompt_token_ids)
+
+ started = False
+ ended = False
+ tokens_after_start = 0
+ last_token_id = None
+ in_thinking = False
+ for token_id in token_list:
+ if token_id == think_start_id:
+ started = True
+ ended = False
+ in_thinking = True
+ elif token_id == think_end_id and in_thinking:
+ ended = True
+ in_thinking = False
+ if started and token_list:
+ # Align with operator-level reasoning_max_tokens: prompt-side tokens
+ # inside do not consume thinking budget.
+ last_token_id = int(token_list[-1])
+
+ logits_processors_args["think_prompt_checked"] = True
+ logits_processors_args["think_prompt_started"] = started
+ logits_processors_args["think_prompt_ended"] = ended
+ logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
+ if last_token_id is not None:
+ logits_processors_args["think_prompt_last_token_id"] = last_token_id
+ else:
+ logits_processors_args.pop("think_prompt_last_token_id", None)
+ return logits_processors_args
+
def ids2tokens(self, token_id, task_id=None):
"""
token ids to strings
@@ -221,65 +314,6 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
self.tokenizer.pad_token_id = self.pad_token_id
- self._think_token_ids = None
-
- def _get_think_token_ids(self):
- if self._think_token_ids is not None:
- return self._think_token_ids
- vocab = self.tokenizer.get_vocab()
- think_start_id = vocab.get("", -1)
- think_end_id = vocab.get("", -1)
- self._think_token_ids = (think_start_id, think_end_id)
- return self._think_token_ids
-
- def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args):
- if not isinstance(logits_processors_args, dict):
- return logits_processors_args
- thinking_budget = logits_processors_args.get("thinking_budget")
- if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0:
- return logits_processors_args
- if logits_processors_args.get("think_prompt_checked"):
- return logits_processors_args
- if prompt_token_ids is None:
- return logits_processors_args
- token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids)
- if token_len == 0:
- return logits_processors_args
- think_start_id, think_end_id = self._get_think_token_ids()
- if think_start_id < 0 or think_end_id < 0:
- return logits_processors_args
-
- if hasattr(prompt_token_ids, "tolist"):
- token_list = prompt_token_ids.tolist()
- else:
- token_list = list(prompt_token_ids)
-
- started = think_start_id in token_list
- ended = False
- tokens_after_start = 0
- last_token_id = None
- if started:
- start_pos = token_list.index(think_start_id)
- tokens_after = token_list[start_pos + 1 :]
- if think_end_id in tokens_after:
- end_pos = tokens_after.index(think_end_id)
- tokens_after_start = end_pos + 1
- ended = True
- else:
- tokens_after_start = len(tokens_after)
- if token_list:
- last_token_id = int(token_list[-1])
-
- logits_processors_args["think_prompt_checked"] = True
- logits_processors_args["think_prompt_started"] = started
- logits_processors_args["think_prompt_ended"] = ended
- logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start
- if last_token_id is not None:
- logits_processors_args["think_prompt_last_token_id"] = last_token_id
- else:
- logits_processors_args.pop("think_prompt_last_token_id", None)
- return logits_processors_args
-
def process_request(self, request, max_model_len=None, **kwargs):
"""
Preprocess the request
@@ -306,14 +340,10 @@ def process_request(self, request, max_model_len=None, **kwargs):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request["bad_words_token_ids"] = bad_words_token_ids
- logits_processors_args = request.get("logits_processors_args") or {}
- think_stop_sentence = logits_processors_args.get("think_stop_sentence")
- if isinstance(think_stop_sentence, str) and think_stop_sentence:
- newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
- sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
- logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
- logits_processors_args.pop("think_stop_sentence", None)
- request["logits_processors_args"] = logits_processors_args
+ logits_processors_args = self._prepare_think_stop_sentence(
+ request.get("logits_processors_args") or {}, max_model_len
+ )
+ request["logits_processors_args"] = logits_processors_args
# processing prompt_token_ids
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
@@ -411,14 +441,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs):
bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids)
request.sampling_params.bad_words_token_ids = bad_words_token_ids
- logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {}
- think_stop_sentence = logits_processors_args.get("think_stop_sentence")
- if isinstance(think_stop_sentence, str) and think_stop_sentence:
- newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False)
- sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False)
- logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids
- logits_processors_args.pop("think_stop_sentence", None)
- request.sampling_params.logits_processors_args = logits_processors_args
+ logits_processors_args = self._prepare_think_stop_sentence(
+ getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len
+ )
+ request.sampling_params.logits_processors_args = logits_processors_args
# processing prompt_token_ids
if not request.prompt_token_ids:
diff --git a/fastdeploy/model_executor/logits_processor/thinking_budget.py b/fastdeploy/model_executor/logits_processor/thinking_budget.py
index 891a8ddd25a..b8493c1756d 100644
--- a/fastdeploy/model_executor/logits_processor/thinking_budget.py
+++ b/fastdeploy/model_executor/logits_processor/thinking_budget.py
@@ -42,14 +42,15 @@ class _ThinkingState:
class ThinkingBudgetLogitsProcessor(LogitsProcessor):
"""Limit the number of tokens generated in the thinking phase.
- The processor tracks per-request thinking state and forces a newline token
- when the budget is reached, followed by the thinking end token on the next step.
+ The processor tracks per-request thinking state and forces the thinking end
+ token when the budget is reached. If a stop sentence is configured, the
+ processor emits the stop sentence first and then the thinking end token.
Request-specific configuration is provided via logits_processors_args:
{"thinking_budget": }
- Requires model_config to provide think_start_id, think_end_id, and line_break_id.
- If any of these are missing or invalid (-1), the processor will be disabled.
+ Requires model_config to provide think_start_id and think_end_id. If any of
+ these are missing or invalid (-1), the processor will be disabled.
"""
def __init__(self, fd_config: FDConfig) -> None:
@@ -61,20 +62,33 @@ def __init__(self, fd_config: FDConfig) -> None:
self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1
self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1
self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1
- self._enabled = (
- self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0
- )
+ self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0
if not self._enabled:
logger.warning(
"ThinkingBudgetLogitsProcessor disabled: missing token ids "
- f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). "
- "Ensure model vocab contains , tokens and line_break_id is configured."
+ f"(think_start={think_start_id}, think_end={think_end_id}). "
+ "Ensure model vocab contains and tokens."
)
self._states: Dict[str, _ThinkingState] = {}
self._active_req_ids: list[str] = []
self._active_budgets: list[int] = []
self._active_slots: list[int] = []
+ def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]:
+ started = False
+ ended = False
+ in_thinking = False
+ for token_id in prompt_slice:
+ if token_id == self.think_start_token_id:
+ started = True
+ ended = False
+ in_thinking = True
+ elif token_id == self.think_end_token_id and in_thinking:
+ ended = True
+ in_thinking = False
+ last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None
+ return started, ended, 0, last_token_id
+
def update_state(self, share_inputs: dict) -> None:
if not self._enabled:
return
@@ -82,6 +96,7 @@ def update_state(self, share_inputs: dict) -> None:
req_ids = share_inputs["req_ids"]
logits_processors_args = share_inputs["logits_processors_args"]
prompt_ids = share_inputs.get("prompt_ids")
+ token_ids_all = share_inputs.get("token_ids_all")
prompt_lens = share_inputs.get("prompt_lens")
pre_ids = share_inputs.get("pre_ids")
next_tokens = share_inputs.get("next_tokens")
@@ -153,6 +168,8 @@ def update_state(self, share_inputs: dict) -> None:
for idx, slot_id in enumerate(candidate_slots):
next_token_by_slot[slot_id] = int(next_sel[idx])
+ prompt_source = prompt_ids if prompt_ids is not None else token_ids_all
+
for idx, slot_id in enumerate(candidate_slots):
req_id = candidate_req_ids[idx]
logit_proc_args = candidate_args[idx]
@@ -190,25 +207,28 @@ def update_state(self, share_inputs: dict) -> None:
state.current_step_idx = current_step_idx
if not state.started and not state.prompt_checked:
- if prompt_ids is not None and prompt_lens is not None:
+ if prompt_source is not None and prompt_lens is not None:
if prompt_lens_np is not None:
prompt_len = int(prompt_lens_np[slot_id])
else:
prompt_len = int(prompt_lens[slot_id])
- prompt_slice = prompt_ids[slot_id, :prompt_len]
- prompt_slice = prompt_slice.numpy().tolist()
- if self.think_start_token_id in prompt_slice:
+ prompt_slice = prompt_source[slot_id, :prompt_len]
+ if hasattr(prompt_slice, "numpy"):
+ prompt_slice = prompt_slice.numpy().tolist()
+ elif hasattr(prompt_slice, "tolist"):
+ prompt_slice = prompt_slice.tolist()
+ else:
+ prompt_slice = list(prompt_slice)
+ if prompt_ids is None:
+ prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0]
+ prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = (
+ self._scan_prompt_state(prompt_slice)
+ )
+ if prompt_started:
state.started = True
- start_pos = prompt_slice.index(self.think_start_token_id)
- tokens_after = prompt_slice[start_pos + 1 :]
- if self.think_end_token_id in tokens_after:
- end_pos = tokens_after.index(self.think_end_token_id)
- state.tokens_after_start = end_pos + 1
- state.ended = True
- else:
- state.tokens_after_start = len(tokens_after)
- if prompt_slice:
- state.last_token_id = int(prompt_slice[-1])
+ state.ended = prompt_ended
+ state.tokens_after_start = prompt_tokens_after_start
+ state.last_token_id = prompt_last_token_id
if current_step_idx is not None and state.last_step_idx is None:
state.last_step_idx = current_step_idx
state.prompt_checked = True
@@ -304,13 +324,6 @@ def apply(self, logits: paddle.Tensor) -> paddle.Tensor:
if state.tokens_after_start < budget:
continue
- if state.last_token_id != self.line_break_token_id:
- logits[slot_id, :] = -float("inf")
- logits[slot_id, self.line_break_token_id] = 0.0
- state.last_token_id = self.line_break_token_id
- state.last_step_idx = state.current_step_idx
- continue
-
logits[slot_id, :] = -float("inf")
logits[slot_id, self.think_end_token_id] = 0.0
state.last_token_id = self.think_end_token_id
diff --git a/tests/model_executor/test_thinking_budget.py b/tests/model_executor/test_thinking_budget.py
index 7307092d19f..8ba9319ff7d 100644
--- a/tests/model_executor/test_thinking_budget.py
+++ b/tests/model_executor/test_thinking_budget.py
@@ -20,7 +20,19 @@
from fastdeploy.engine import engine as engine_module
from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs
from fastdeploy.engine.sampling_params import SamplingParams
+from fastdeploy.input.ernie4_5_processor import (
+ Ernie4_5Processor as ErnieTextDataProcessor,
+)
+from fastdeploy.input.ernie4_5_vl_processor import (
+ Ernie4_5_VLProcessor as ErnieVLDataProcessor,
+)
from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor
+from fastdeploy.input.v1.ernie4_5_processor import (
+ Ernie4_5Processor as V1ErnieTextDataProcessor,
+)
+from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import (
+ Ernie4_5_VLProcessor as V1ErnieVLDataProcessor,
+)
from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor
from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor
from fastdeploy.scheduler import SchedulerConfig
@@ -52,6 +64,7 @@ def __init__(self, fd_config, max_num_seqs=4):
"req_ids": [f"req_{i}" for i in range(max_num_seqs)],
"logits_processors_args": [{} for _ in range(max_num_seqs)],
"prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10
+ "token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
"pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)),
"step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)),
@@ -73,6 +86,12 @@ def update_request_state(self, slot_id, req: MockRequest, pre_id=None, set_next_
self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
req.prompt_ids, dtype=paddle.int64
)
+ self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor(
+ np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64
+ )
+ self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor(
+ req.prompt_ids, dtype=paddle.int64
+ )
self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64)
if req.sampling_params.logits_processors_args:
self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args
@@ -152,7 +171,7 @@ def _get_initial_logits(self, batch_size):
return logits
def test_thinking_budget_not_reached(self):
- # Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase
+ # Scenario: Thinking budget is 5, and prompt-side tokens after do not consume budget.
req_id = "test_req_1"
prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5})
@@ -167,9 +186,9 @@ def test_thinking_budget_not_reached(self):
processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started)
self.assertFalse(processor._states[req_id].ended)
- self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START
+ self.assertEqual(processor._states[req_id].tokens_after_start, 0)
- # Step 2: Simulate one generation step (budget 5, generated 3 -> 4)
+ # Step 2: Simulate one generation step (budget 5, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits)
@@ -181,9 +200,9 @@ def test_thinking_budget_not_reached(self):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token
- self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0)
+ self.assertEqual(processor._states[req_id].tokens_after_start, 1)
- # Step 3: Simulate another generation step (budget 5, generated 4 -> 5)
+ # Step 3: Simulate another generation step (budget 5, generated 1 -> 2)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs) # Update state before apply
processed_logits = processor.apply(logits)
@@ -191,14 +210,14 @@ def test_thinking_budget_not_reached(self):
mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token
processor.update_state(mock_runner.share_inputs) # Update state after generating token
- self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0)
+ self.assertEqual(processor._states[req_id].tokens_after_start, 2)
# LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token
- def test_thinking_budget_reached_forces_newline(self):
- # Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end
+ def test_thinking_budget_reached_forces_think_end(self):
+ # Scenario: Budget is 3 and only decode-time tokens count toward the budget.
req_id = "test_req_2"
- prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start
+ prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3})
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -206,16 +225,15 @@ def test_thinking_budget_reached_forces_newline(self):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
- line_break_id = processor.line_break_token_id
think_end_id = processor.think_end_token_id
- # Step 1: Initial state update (1 token after start)
+ # Step 1: Initial state update (prompt-side tokens do not count)
processor.update_state(mock_runner.share_inputs)
- self.assertEqual(processor._states[req_id].tokens_after_start, 1)
+ self.assertEqual(processor._states[req_id].tokens_after_start, 0)
self.assertFalse(processor._states[req_id].ended)
self.assertEqual(processor._states[req_id].last_token_id, 3)
- # Step 2: Generate 2nd token (budget 3, generated 1 -> 2)
+ # Step 2: Generate 1st decode token (budget 3, generated 0 -> 1)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -223,10 +241,10 @@ def test_thinking_budget_reached_forces_newline(self):
self.assertEqual(next_token, 0) # Normal generation
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
- self.assertEqual(processor._states[req_id].tokens_after_start, 2)
+ self.assertEqual(processor._states[req_id].tokens_after_start, 1)
self.assertEqual(processor._states[req_id].last_token_id, 0)
- # Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE.
+ # Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -234,34 +252,20 @@ def test_thinking_budget_reached_forces_newline(self):
self.assertEqual(next_token, 0) # Normal generation
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
- self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met
+ self.assertEqual(processor._states[req_id].tokens_after_start, 2)
self.assertEqual(processor._states[req_id].last_token_id, 0)
- # Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID.
+ # Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3).
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
-
- # Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0
- # Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues
- other_logits = paddle.concat(
- [
- processed_logits[0, :line_break_id],
- processed_logits[0, line_break_id + 1 : VOCAB_SIZE],
- ],
- axis=0,
- )
- self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item())
- self.assertEqual(processed_logits[0, line_break_id].item(), 0.0)
-
next_token = mock_runner.generate_next_token(processed_logits)[0]
- self.assertEqual(next_token, line_break_id) # Forces NEW_LINE
+ self.assertEqual(next_token, 0)
mock_runner.update_request_state(0, mock_req, pre_id=next_token)
processor.update_state(mock_runner.share_inputs)
- self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment
- self.assertEqual(processor._states[req_id].last_token_id, line_break_id)
+ self.assertEqual(processor._states[req_id].tokens_after_start, 3)
- # Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
+ # Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -324,9 +328,9 @@ def test_thinking_budget_stop_sentence_forces_tokens(self):
self.assertEqual(next_token, stop_sentence_token_ids[0])
def test_thinking_budget_no_stop_sentence_defaults(self):
- # Scenario: No stop sentence, budget reached should force newline.
+ # Scenario: No stop sentence, budget reached should force thinking_end directly.
req_id = "test_req_no_stop_sentence"
- prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start
+ prompt_ids = [THINKING_START_TOKEN_ID, 42]
sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req = MockRequest(req_id, prompt_ids, sampling_params)
@@ -334,35 +338,42 @@ def test_thinking_budget_no_stop_sentence_defaults(self):
mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1])
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
- line_break_id = processor.line_break_token_id
processor.update_state(mock_runner.share_inputs)
logits = self._get_initial_logits(1)
processed_logits = processor.apply(logits)
next_token = mock_runner.generate_next_token(processed_logits)[0]
- self.assertEqual(next_token, line_break_id)
+ self.assertEqual(next_token, 0)
+ mock_runner.update_request_state(0, mock_req, pre_id=next_token)
+ processor.update_state(mock_runner.share_inputs)
+
+ logits = self._get_initial_logits(1)
+ processor.update_state(mock_runner.share_inputs)
+ processed_logits = processor.apply(logits)
+ next_token = mock_runner.generate_next_token(processed_logits)[0]
+ self.assertEqual(next_token, THINKING_END_TOKEN_ID)
def test_thinking_budget_uses_config_token_ids(self):
# Scenario: Processor should use token ids from model config.
self.fd_config.model_config.think_start_id = 123
self.fd_config.model_config.think_end_id = 124
- self.fd_config.model_config.line_break_id = 125
+ self.fd_config.model_config.line_break_id = -1
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertEqual(processor.think_start_token_id, 123)
self.assertEqual(processor.think_end_token_id, 124)
- self.assertEqual(processor.line_break_token_id, 125)
+ self.assertEqual(processor.line_break_token_id, -1)
self.assertTrue(processor._enabled)
def test_thinking_budget_disabled_when_token_ids_missing(self):
# Scenario: Processor should be disabled when token ids are not configured.
self.fd_config.model_config.think_start_id = -1
self.fd_config.model_config.think_end_id = -1
- self.fd_config.model_config.line_break_id = -1
+ self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID
processor = ThinkingBudgetLogitsProcessor(self.fd_config)
self.assertFalse(processor._enabled)
self.assertEqual(processor.think_start_token_id, -1)
self.assertEqual(processor.think_end_token_id, -1)
- self.assertEqual(processor.line_break_token_id, -1)
+ self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID)
# update_state and apply should be no-op when disabled
mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
@@ -493,6 +504,28 @@ def test_thinking_budget_prompt_state_from_args(self):
self.assertEqual(state.tokens_after_start, 2)
self.assertEqual(state.last_token_id, 99)
+ def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self):
+ req_id = "req_gpu_fallback"
+ mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1)
+ mock_runner.share_inputs["req_ids"][0] = req_id
+ mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3}
+ mock_runner.share_inputs["prompt_ids"] = None
+ mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor(
+ [1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64
+ )
+ mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64)
+ mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64)
+
+ processor = ThinkingBudgetLogitsProcessor(self.fd_config)
+ processor.update_state(mock_runner.share_inputs)
+
+ state = processor._states[req_id]
+ self.assertTrue(state.prompt_checked)
+ self.assertTrue(state.started)
+ self.assertFalse(state.ended)
+ self.assertEqual(state.tokens_after_start, 0)
+ self.assertEqual(state.last_token_id, 3)
+
def test_thinking_budget_not_configured(self):
# Scenario: Processor is active, but request does not provide thinking_budget
req_id = "test_req_3"
@@ -532,20 +565,7 @@ def test_thinking_budget_zero(self):
self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet
self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID)
- # Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE.
- logits = self._get_initial_logits(1)
- processor.update_state(mock_runner.share_inputs)
- processed_logits = processor.apply(logits)
-
- self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
- next_token = mock_runner.generate_next_token(processed_logits)[0]
- self.assertEqual(next_token, NEW_LINE_TOKEN_ID)
- mock_runner.update_request_state(0, mock_req, pre_id=next_token)
- processor.update_state(mock_runner.share_inputs)
- self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments
- self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID)
-
- # Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID.
+ # Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END.
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
processed_logits = processor.apply(logits)
@@ -572,7 +592,7 @@ def test_thinking_end_in_prompt(self):
processor.update_state(mock_runner.share_inputs)
self.assertTrue(processor._states[req_id].started)
self.assertTrue(processor._states[req_id].ended)
- self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start
+ self.assertEqual(processor._states[req_id].tokens_after_start, 0)
logits = self._get_initial_logits(1)
processor.update_state(mock_runner.share_inputs)
@@ -582,7 +602,7 @@ def test_thinking_end_in_prompt(self):
def test_multiple_requests(self):
# Scenario: Multiple requests with different thinking states
req_id_1 = "req_a"
- prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2
+ prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11]
sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2})
mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1)
@@ -592,7 +612,7 @@ def test_multiple_requests(self):
mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2)
req_id_3 = "req_c"
- prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1
+ prompt_ids_3 = [THINKING_START_TOKEN_ID, 30]
sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1})
mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3)
@@ -609,14 +629,14 @@ def test_multiple_requests(self):
# Verify initial states
self.assertTrue(processor._states[req_id_1].started)
self.assertFalse(processor._states[req_id_1].ended)
- self.assertEqual(processor._states[req_id_1].tokens_after_start, 2)
+ self.assertEqual(processor._states[req_id_1].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_1].last_token_id, 11)
self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2
self.assertTrue(processor._states[req_id_3].started)
self.assertFalse(processor._states[req_id_3].ended)
- self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
+ self.assertEqual(processor._states[req_id_3].tokens_after_start, 0)
self.assertEqual(processor._states[req_id_3].last_token_id, 30)
# Simulate logits for the batch
@@ -624,16 +644,14 @@ def test_multiple_requests(self):
processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply
processed_batch_logits = processor.apply(batch_logits)
- # Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE)
- self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0)
- self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID)
+ # Req 1: prompt-side content does not consume budget, so first step is normal generation.
+ self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: No thinking budget, normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
- # Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE)
- self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0)
- self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID)
+ # Req 3: prompt-side content does not consume budget, so first step is normal generation.
+ self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0)
# Simulate generating next tokens and updating state
next_tokens = mock_runner.generate_next_token(processed_batch_logits)
@@ -643,20 +661,25 @@ def test_multiple_requests(self):
processor.update_state(mock_runner.share_inputs)
# Verify updated states for next step
- self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID)
- self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID)
+ self.assertEqual(processor._states[req_id_1].last_token_id, 0)
+ self.assertEqual(processor._states[req_id_3].last_token_id, 0)
+ self.assertEqual(processor._states[req_id_1].tokens_after_start, 1)
+ self.assertEqual(processor._states[req_id_3].tokens_after_start, 1)
+ self.assertFalse(processor._states[req_id_1].ended)
+ self.assertFalse(processor._states[req_id_3].ended)
batch_logits = self._get_initial_logits(3)
processor.update_state(mock_runner.share_inputs)
processed_batch_logits = processor.apply(batch_logits)
- # Req 1: last token was NEW_LINE. Should force THINKING_END
- self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID)
+ # Req 1: budget 2, tokens_after_start 1. Still normal generation.
+ self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0)
# Req 2: Still normal generation
self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0)
- # Req 3: last token was NEW_LINE. Should force THINKING_END
+ # Req 3: budget 1 reached after one generated token, should now force THINKING_END.
+ self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0)
self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID)
@@ -724,7 +747,7 @@ def test_update_thinking_prompt_state_from_text_processor(self):
self.assertTrue(updated["think_prompt_checked"])
self.assertTrue(updated["think_prompt_started"])
self.assertTrue(updated["think_prompt_ended"])
- self.assertEqual(updated["think_prompt_tokens_after_start"], 2)
+ self.assertEqual(updated["think_prompt_tokens_after_start"], 0)
self.assertEqual(updated["think_prompt_last_token_id"], 3)
def test_v1_process_request_missing_logits_processors_args(self):
@@ -833,6 +856,58 @@ def _text2ids(text, max_model_len=None, add_special_tokens=False):
self.assertEqual(processor.encode_with_cache("iter"), [41, 42])
self.assertNotIn(("np", False), processor._tokenize_cache)
+ def test_text_encode_with_cache_lazy_init(self):
+ processor = TextDataProcessor.__new__(TextDataProcessor)
+ call_counter = {"count": 0}
+
+ def _text2ids(text, max_model_len=None, add_special_tokens=False):
+ call_counter["count"] += 1
+ return np.array([51, 52], dtype=np.int64)
+
+ processor.text2ids = _text2ids
+
+ self.assertFalse(hasattr(processor, "_tokenize_cache"))
+ self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
+ self.assertTrue(hasattr(processor, "_tokenize_cache"))
+ self.assertEqual(processor.encode_with_cache("lazy"), [51, 52])
+ self.assertEqual(call_counter["count"], 1)
+
+ def test_v1_encode_with_cache_lazy_init(self):
+ processor = V1TextDataProcessor.__new__(V1TextDataProcessor)
+ call_counter = {"count": 0}
+
+ def _text2ids(text, max_model_len=None, add_special_tokens=False):
+ call_counter["count"] += 1
+ return np.array([61, 62], dtype=np.int64)
+
+ processor.text2ids = _text2ids
+
+ self.assertFalse(hasattr(processor, "_tokenize_cache"))
+ self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
+ self.assertTrue(hasattr(processor, "_tokenize_cache"))
+ self.assertEqual(processor.encode_with_cache("lazy"), [61, 62])
+ self.assertEqual(call_counter["count"], 1)
+
+ def test_ernie_encode_literal_text_with_cache(self):
+ processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
+ processor.tokenizer = SimpleNamespace(
+ tokenize=lambda text: ["token_a", "token_b"],
+ convert_tokens_to_ids=lambda tokens: [71, 72],
+ )
+
+ self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
+ self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72])
+
+ def test_v1_ernie_encode_literal_text_with_cache(self):
+ processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
+ processor.tokenizer = SimpleNamespace(
+ tokenize=lambda text: ["token_c", "token_d"],
+ convert_tokens_to_ids=lambda tokens: [81, 82],
+ )
+
+ self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
+ self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82])
+
def test_text_update_thinking_prompt_state_branches(self):
processor = TextDataProcessor.__new__(TextDataProcessor)
processor._think_token_ids = None
@@ -868,7 +943,7 @@ def test_text_update_thinking_prompt_state_branches(self):
)
self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"])
- self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
+ self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支
@@ -891,7 +966,7 @@ def test_v1_update_thinking_prompt_state_branches(self):
)
self.assertTrue(with_start_no_end["think_prompt_started"])
self.assertFalse(with_start_no_end["think_prompt_ended"])
- self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2)
+ self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0)
self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3)
# 命中 _get_think_token_ids 的缓存分支
@@ -903,7 +978,7 @@ def test_text_process_request_dict_think_stop_sentence(self):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
- processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202]
+ processor._encode_literal_text_with_cache = lambda text: [201, 202]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -924,7 +999,7 @@ def test_text_process_request_dict_think_stop_sentence(self):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(
processed["logits_processors_args"].get("think_stop_sentence_token_ids"),
- [23, 201, 202],
+ [201, 202],
)
self.assertNotIn("think_stop_sentence", processed["logits_processors_args"])
@@ -934,7 +1009,7 @@ def test_v1_process_request_think_stop_sentence(self):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
- processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302]
+ processor._encode_literal_text_with_cache = lambda text: [301, 302]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -955,7 +1030,7 @@ def test_v1_process_request_think_stop_sentence(self):
processed = processor.process_request(request, max_model_len=16)
self.assertEqual(
processed.logits_processors_args.get("think_stop_sentence_token_ids"),
- [23, 301, 302],
+ [301, 302],
)
self.assertNotIn("think_stop_sentence", processed.logits_processors_args)
@@ -965,7 +1040,7 @@ def test_v1_process_request_dict_think_stop_sentence(self):
processor.eos_token_ids = [1]
processor.update_stop_seq = lambda *args, **kwargs: None
processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
- processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402]
+ processor._encode_literal_text_with_cache = lambda text: [401, 402]
processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args
processor.reasoning_parser = None
@@ -992,10 +1067,173 @@ def test_v1_process_request_dict_think_stop_sentence(self):
processed = processor.process_request_dict(request, max_model_len=16)
self.assertEqual(
processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"),
- [23, 401, 402],
+ [401, 402],
)
self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args)
+ def test_ernie_process_request_dict_prepares_thinking_budget_args(self):
+ processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor)
+ processor._apply_default_parameters = lambda request: request
+ processor.eos_token_ids = [1]
+ processor.update_stop_seq = lambda *args, **kwargs: None
+ processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
+ processor._encode_literal_text_with_cache = lambda text: [501, 502]
+ processor.tokenizer = DummyTokenizerForTextProcessor()
+ processor.reasoning_parser = None
+
+ request = {
+ "request_id": "req_ernie_text",
+ "eos_token_ids": [1],
+ "prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2],
+ "prompt": None,
+ "messages": None,
+ "logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
+ "bad_words": None,
+ "bad_words_token_ids": None,
+ "max_tokens": 1,
+ "temperature": 1.0,
+ "top_p": 0.9,
+ "response_max_tokens": None,
+ "enable_thinking": True,
+ }
+ with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
+ processed = processor.process_request_dict(request, max_model_len=16)
+
+ self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502])
+ self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
+ self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
+ self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
+
+ def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self):
+ processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)
+ processor._apply_default_parameters = lambda request: request
+ processor.eos_token_ids = [1]
+ processor.update_stop_seq = lambda *args, **kwargs: None
+ processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
+ processor._encode_literal_text_with_cache = lambda text: [601, 602]
+ processor.tokenizer = DummyTokenizerForTextProcessor()
+ processor.reasoning_parser = None
+
+ request = DummyRequestV1(
+ request_id="req_v1_ernie_text",
+ eos_token_ids=[1],
+ prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2],
+ prompt=None,
+ messages=None,
+ chat_template_kwargs=None,
+ enable_thinking=True,
+ sampling_params=SimpleNamespace(
+ bad_words=None,
+ bad_words_token_ids=None,
+ max_tokens=1,
+ temperature=1.0,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ frequency_penalty=0.0,
+ presence_penalty=0.0,
+ response_max_tokens=None,
+ n=1,
+ logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
+ ),
+ )
+ with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None):
+ processed = processor.process_request_dict(request, max_model_len=16)
+
+ self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602])
+ self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
+ self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
+ self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
+
+ def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
+ processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor)
+ processor._apply_default_parameters = lambda request: request
+ processor.eos_token_ids = [1]
+ processor.update_stop_seq = lambda *args, **kwargs: None
+ processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
+ processor._encode_literal_text_with_cache = lambda text: [701, 702]
+ processor.tokenizer = DummyTokenizerForTextProcessor()
+ processor.reasoning_parser = None
+ processor._check_mm_limits = lambda *args, **kwargs: None
+ processor.append_completion_tokens = lambda *args, **kwargs: None
+ processor.pack_outputs = lambda outs: outs
+ processor.ernie4_5_processor = SimpleNamespace(
+ request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
+ )
+
+ request = {
+ "request_id": "req_ernie_vl",
+ "eos_token_ids": [1],
+ "messages": [{"role": "user", "content": "hi"}],
+ "bad_words": None,
+ "bad_words_token_ids": None,
+ "logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"},
+ "max_tokens": 1,
+ "top_p": 0.9,
+ "response_max_tokens": None,
+ }
+ with patch(
+ "fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
+ lambda *args, **kwargs: None,
+ ):
+ processed = processor.process_request_dict(request, max_model_len=16)
+
+ self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702])
+ self.assertTrue(processed["logits_processors_args"]["think_prompt_started"])
+ self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"])
+ self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0)
+
+ def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self):
+ processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor)
+ processor._apply_default_parameters = lambda request: request
+ processor.eos_token_ids = [1]
+ processor.update_stop_seq = lambda *args, **kwargs: None
+ processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids
+ processor._encode_literal_text_with_cache = lambda text: [801, 802]
+ processor.tokenizer = DummyTokenizerForTextProcessor()
+ processor.reasoning_parser = None
+ processor._check_mm_limits = lambda *args, **kwargs: None
+ processor.append_completion_tokens = lambda *args, **kwargs: None
+ processor.pack_outputs = lambda outs: outs
+ processor.ernie4_5_processor = SimpleNamespace(
+ request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)}
+ )
+
+ request = DummyRequestV1(
+ request_id="req_v1_ernie_vl",
+ eos_token_ids=[1],
+ prompt_token_ids=None,
+ prompt=None,
+ messages=[{"role": "user", "content": "hi"}],
+ chat_template_kwargs=None,
+ enable_thinking=True,
+ completion_token_ids=None,
+ multimodal_data=None,
+ sampling_params=SimpleNamespace(
+ bad_words=None,
+ bad_words_token_ids=None,
+ max_tokens=1,
+ temperature=1.0,
+ top_p=0.9,
+ repetition_penalty=1.0,
+ frequency_penalty=0.0,
+ presence_penalty=0.0,
+ response_max_tokens=None,
+ reasoning_max_tokens=None,
+ n=1,
+ logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"},
+ ),
+ )
+ with patch(
+ "fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids",
+ lambda *args, **kwargs: None,
+ ):
+ processed = processor.process_request_dict(request, max_model_len=16)
+
+ self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802])
+ self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"])
+ self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"])
+ self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0)
+
if __name__ == "__main__":
unittest.main()