From be0064c24a8af10714226a633a7cd18716066b30 Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Thu, 19 Mar 2026 15:21:46 +0800 Subject: [PATCH 1/3] [Bugfix] Align thinking_budget behavior with ERNIE reasoning flow --- docs/features/thinking_budget.md | 40 +- docs/zh/features/thinking_budget.md | 31 +- fastdeploy/input/ernie4_5_processor.py | 18 + .../ernie4_5_vl_processor.py | 9 + fastdeploy/input/text_processor.py | 176 ++++---- fastdeploy/input/v1/ernie4_5_processor.py | 21 +- .../ernie4_5_vl_processor.py | 9 + fastdeploy/input/v1/text_processor.py | 176 ++++---- .../logits_processor/thinking_budget.py | 73 ++-- tests/model_executor/test_thinking_budget.py | 404 ++++++++++++++---- 10 files changed, 670 insertions(+), 287 deletions(-) diff --git a/docs/features/thinking_budget.md b/docs/features/thinking_budget.md index 2dbaf30e367..79581497031 100644 --- a/docs/features/thinking_budget.md +++ b/docs/features/thinking_budget.md @@ -2,7 +2,9 @@ ## Overview -`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the ` ... ` segment. When the budget is reached, it forces a line break token and then the `` token to terminate the thinking section. +`ThinkingBudgetLogitsProcessor` limits the number of tokens generated inside the ` ... ` +segment. When the budget is reached, it terminates thinking by forcing ``. If +`think_stop_sentence` is configured, it forces the custom sentence first and then ``. ## When to Use @@ -11,19 +13,22 @@ ## How It Works -1. **CPU precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section. +1. **Request-side precompute (DataProcessor)**: when a request includes `thinking_budget`, the prompt token ids are scanned to determine whether thinking has started, whether it already ended, and how many tokens are already inside the thinking section. 2. **Per-step update**: during decoding, the processor tracks `last_token_id` and `tokens_after_start`. -3. **Budget enforcement**: once the budget is reached, it forces a line break and then the thinking end token. +3. **Budget enforcement**: once the budget is reached, it forces `` directly. If `think_stop_sentence` + is configured, it forces that sentence first and then ``. ## Requirements -- The model must provide valid token ids for `think_start_id`, `think_end_id`, and `line_break_id` (via `ModelConfig`). -- If any of these ids are invalid, the processor is disabled and `thinking_budget` will not take effect. +- The model must provide valid token ids for `think_start_id` and `think_end_id` (via `ModelConfig`). +- If either of these ids is invalid, the processor is disabled and `thinking_budget` will not take effect. ## Request Parameters -- `thinking_budget` (int, required to enable): maximum number of tokens after `` before forced termination. -- `think_stop_sentence` (string, optional): a stop sentence that will be tokenized on the CPU side and enforced near the budget boundary. +- `thinking_budget` (int, required to enable): maximum number of decode-time tokens after `` before forced + termination. +- `think_stop_sentence` (string, optional): a literal custom sentence that will be tokenized on the request side + and enforced near the budget boundary. ## Operator-Level vs LogitsProcessor @@ -41,16 +46,25 @@ FastDeploy has two ways to limit thinking length: In short: - If you only need a hard cap on thinking length, prefer `reasoning_max_tokens`. -- If you need custom behavior (for example, injecting custom sentence tokens), use `ThinkingBudgetLogitsProcessor`. +- If you need custom behavior (for example, inserting a custom sentence before ``), use + `ThinkingBudgetLogitsProcessor`. ## Practical guidance `reasoning_max_tokens` and `thinking_budget` are not mutually exclusive in current implementation. If both are configured for the same request, both constraints can take effect, and whichever triggers first will end the thinking phase. -- To use **operator-level-only** behavior: this is request-level config only. Set `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`. -- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`, and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave `reasoning_max_tokens` unset. -- Avoid enabling both for strict custom sentence insertion requirements, because operator-level termination may cut the custom sentence path earlier. +- To use **operator-level-only** behavior: this is request-level config only. Set + `enable_thinking=true` and `reasoning_max_tokens` in request, and do not set `thinking_budget`. +- To use **logits-processor-only** behavior (especially with `think_stop_sentence`): this requires + service-level + request-level config. Start service with `--logits-processors ThinkingBudgetLogitsProcessor`, + and set `thinking_budget` (and optional `think_stop_sentence`) in `logits_processors_args`; leave + `reasoning_max_tokens` unset. +- `thinking_budget` itself does not require `enable_thinking=true`. +- If an ERNIE chat template already appends `` in the prompt, `thinking_budget` should still take effect; it + does not require the model to emit another `` during decoding. +- Avoid enabling both for strict custom sentence insertion requirements, because operator-level + termination may cut the custom sentence path earlier. ## Online Usage @@ -120,4 +134,6 @@ print(outputs[0].outputs.text) ## Performance Note -This processor runs `update_state` and `apply` on every decode step. If you only need a hard thinking-length cap and care most about throughput, consider the operator-level reasoning-length controls instead of per-step logits processing. +This processor runs `update_state` and `apply` on every decode step. If you only need a hard +thinking-length cap and care most about throughput, consider the operator-level reasoning-length +controls instead of per-step logits processing. diff --git a/docs/zh/features/thinking_budget.md b/docs/zh/features/thinking_budget.md index bf614a24c0b..70e58d0c789 100644 --- a/docs/zh/features/thinking_budget.md +++ b/docs/zh/features/thinking_budget.md @@ -2,7 +2,9 @@ ## 概述 -`ThinkingBudgetLogitsProcessor` 用于限制 ` ... ` 区间的生成长度。当预算达到阈值时,会强制生成换行符 token,再强制生成 ``,从而结束思考段。 +`ThinkingBudgetLogitsProcessor` 用于限制 ` ... ` 区间的生成长度。当预算达到阈值时, +会直接强制生成 `` 来结束思考段;如果配置了 `think_stop_sentence`,则会先强制输出该自定义 +文案,再输出 ``。 ## 适用场景 @@ -11,19 +13,20 @@ ## 工作原理 -1. **CPU 侧预计算(DataProcessor)**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。 +1. **请求侧预计算(DataProcessor)**:当请求中包含 `thinking_budget`,会基于 prompt 的 token ids 计算是否已进入思考段、是否已结束,以及已有的思考长度。 2. **每步更新**:解码过程中跟踪 `last_token_id` 与 `tokens_after_start`。 -3. **预算约束**:达到预算后,依次强制换行符与思考结束 token。 +3. **预算约束**:达到预算后,默认直接强制 ``;如果配置了 `think_stop_sentence`,则先逐 token + 强制输出该文案,再输出 ``。 ## 前置要求 -- 模型需提供有效的 `think_start_id`、`think_end_id`、`line_break_id`(来自 `ModelConfig`)。 -- 若任意 id 无效,处理器会禁用,`thinking_budget` 不生效。 +- 模型需提供有效的 `think_start_id`、`think_end_id`(来自 `ModelConfig`)。 +- 若其中任意 id 无效,处理器会禁用,`thinking_budget` 不生效。 ## 请求参数 -- `thinking_budget`(int,启用所需):`` 之后允许的最大 token 数。 -- `think_stop_sentence`(string,可选):CPU 侧会将该字符串编码为 token ids,并在预算边界附近强制输出。 +- `thinking_budget`(int,启用所需):`` 之后允许的最大 decode 阶段 token 数。 +- `think_stop_sentence`(string,可选):按字面串编码的自定义终止文案,并在预算边界附近强制输出。 ## 算子级限制 vs LogitsProcessor @@ -35,21 +38,27 @@ FastDeploy 当前有两种思考长度控制方式: - 适合“只限制思考长度”的简单场景。 - **`ThinkingBudgetLogitsProcessor`**(`logits_processors_args.thinking_budget`): - 由每步 Python 侧 logits 处理实现。 - - 支持更灵活的行为,例如 `think_stop_sentence`(在结束前插入自定义话术)。 + - 支持更灵活的行为,例如 `think_stop_sentence`。 - 相比算子级限制,在高并发下通常有更高开销。 可按以下原则选择: - 仅需限制思考长度:优先用 `reasoning_max_tokens`。 -- 需要更灵活控制(如插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`。 +- 需要更灵活控制(如在 `` 前插入自定义话术):使用 `ThinkingBudgetLogitsProcessor`。 ## 建议实践 当前实现中,`reasoning_max_tokens` 与 `thinking_budget` 不是互斥关系。 同一请求如果同时配置,两套约束都可能生效,谁先触发就先结束思考段。 -- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`,不要传 `thinking_budget`。 -- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过 `logits_processors_args` 传 `thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置 `reasoning_max_tokens`。 +- **只用算子级限制**:这是请求级配置。仅在请求中设置 `enable_thinking=true` + `reasoning_max_tokens`, + 不要传 `thinking_budget`。 +- **只用 LogitsProcessor**(尤其要用 `think_stop_sentence`):这是“服务启动 + 请求参数”两级配置。 + 服务启动时必须加 `--logits-processors ThinkingBudgetLogitsProcessor`,并在请求里通过 + `logits_processors_args` 传 `thinking_budget`(以及可选的 `think_stop_sentence`);同时不要设置 + `reasoning_max_tokens`。 +- `thinking_budget` 本身不依赖 `enable_thinking=true`。 +- 如果 ERNIE 的 chat template 已经在 prompt 里拼入 ``,`thinking_budget` 也应正常生效,不要求模型在 decode 阶段再次输出 ``。 - 如果业务要求“必须完整插入自定义话术”,不建议与算子级限制同时开启,否则可能被算子级提前截断。 ## 在线使用 diff --git a/fastdeploy/input/ernie4_5_processor.py b/fastdeploy/input/ernie4_5_processor.py index c89b2ae2fb9..26ca88ef482 100644 --- a/fastdeploy/input/ernie4_5_processor.py +++ b/fastdeploy/input/ernie4_5_processor.py @@ -107,6 +107,11 @@ def process_request(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args + # processing prompt_token_ids if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt is not None: @@ -137,6 +142,10 @@ def process_request(self, request, max_model_len=None, **kwargs): # truncate prompts that exceed the length limit if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] + logits_processors_args = self._update_thinking_prompt_state( + request.prompt_token_ids, request.get("logits_processors_args") or {} + ) + request["logits_processors_args"] = logits_processors_args max_tokens = max_model_len - len(request.prompt_token_ids) if request.get("max_tokens") is None: request.set("max_tokens", max(1, max_tokens)) @@ -189,6 +198,11 @@ def process_request_dict(self, request, max_model_len=None): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args + # processing prompt_token_ids if not request.get("prompt_token_ids"): if request.get("prompt"): @@ -225,6 +239,10 @@ def process_request_dict(self, request, max_model_len=None): # truncate prompts that exceed the length limit if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] + logits_processors_args = self._update_thinking_prompt_state( + request["prompt_token_ids"], request.get("logits_processors_args") or {} + ) + request["logits_processors_args"] = logits_processors_args max_tokens = max_model_len - len(request["prompt_token_ids"]) if request.get("max_tokens") is None: request["max_tokens"] = max(1, max_tokens) diff --git a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 318cf3c41fb..4a2b73c4f2c 100644 --- a/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -221,6 +221,11 @@ def process_request_dict(self, request, max_model_len=None): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args + if request.get("prompt_token_ids"): messages = request.get("messages") if messages: @@ -268,6 +273,10 @@ def process_request_dict(self, request, max_model_len=None): # 截断超过长度限制的prompt if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len: request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1] + logits_processors_args = self._update_thinking_prompt_state( + request["prompt_token_ids"], request.get("logits_processors_args") or {} + ) + request["logits_processors_args"] = logits_processors_args max_tokens = max_model_len - len(request["prompt_token_ids"]) if request.get("max_tokens") is None: diff --git a/fastdeploy/input/text_processor.py b/fastdeploy/input/text_processor.py index 05d24697c90..43c01516aa5 100644 --- a/fastdeploy/input/text_processor.py +++ b/fastdeploy/input/text_processor.py @@ -120,6 +120,9 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False): """ Encode text into token ids with a small LRU cache. """ + if not hasattr(self, "_tokenize_cache"): + self._tokenize_cache = OrderedDict() + self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128) key = (text, bool(add_special_tokens)) cached = self._tokenize_cache.get(key) if cached is not None: @@ -135,6 +138,25 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False): self._tokenize_cache.popitem(last=False) return token_ids + def _encode_literal_text_with_cache(self, text): + if not hasattr(self, "_tokenize_cache"): + self._tokenize_cache = OrderedDict() + self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128) + key = ("literal_text", text) + cached = self._tokenize_cache.get(key) + if cached is not None: + self._tokenize_cache.move_to_end(key) + return cached + token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + elif not isinstance(token_ids, list): + token_ids = list(token_ids) + self._tokenize_cache[key] = token_ids + if len(self._tokenize_cache) > self._tokenize_cache_capacity: + self._tokenize_cache.popitem(last=False) + return token_ids + def messages2ids(self, messages): """ Convert multi-turn messages into ID sequences. @@ -147,6 +169,77 @@ def messages2ids(self, messages): """ raise NotImplementedError + def _get_think_token_ids(self): + think_token_ids = getattr(self, "_think_token_ids", None) + if think_token_ids is not None: + return think_token_ids + tokenizer = getattr(self, "tokenizer", None) + vocab = tokenizer.get_vocab() if tokenizer is not None else {} + think_start_id = vocab.get("", -1) + think_end_id = vocab.get("", -1) + self._think_token_ids = (think_start_id, think_end_id) + return self._think_token_ids + + def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None): + if not isinstance(logits_processors_args, dict): + return logits_processors_args + think_stop_sentence = logits_processors_args.get("think_stop_sentence") + if isinstance(think_stop_sentence, str) and think_stop_sentence: + sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence) + logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids + logits_processors_args.pop("think_stop_sentence", None) + return logits_processors_args + + def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args): + if not isinstance(logits_processors_args, dict): + return logits_processors_args + thinking_budget = logits_processors_args.get("thinking_budget") + if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0: + return logits_processors_args + if logits_processors_args.get("think_prompt_checked"): + return logits_processors_args + if prompt_token_ids is None: + return logits_processors_args + token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids) + if token_len == 0: + return logits_processors_args + think_start_id, think_end_id = self._get_think_token_ids() + if think_start_id < 0 or think_end_id < 0: + return logits_processors_args + + if hasattr(prompt_token_ids, "tolist"): + token_list = prompt_token_ids.tolist() + else: + token_list = list(prompt_token_ids) + + started = False + ended = False + tokens_after_start = 0 + last_token_id = None + in_thinking = False + for token_id in token_list: + if token_id == think_start_id: + started = True + ended = False + in_thinking = True + elif token_id == think_end_id and in_thinking: + ended = True + in_thinking = False + if started and token_list: + # Align with operator-level reasoning_max_tokens: prompt-side tokens + # inside do not consume thinking budget. + last_token_id = int(token_list[-1]) + + logits_processors_args["think_prompt_checked"] = True + logits_processors_args["think_prompt_started"] = started + logits_processors_args["think_prompt_ended"] = ended + logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start + if last_token_id is not None: + logits_processors_args["think_prompt_last_token_id"] = last_token_id + else: + logits_processors_args.pop("think_prompt_last_token_id", None) + return logits_processors_args + def ids2tokens(self, token_id, task_id=None): """ token ids to strings @@ -235,65 +328,6 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob self.reasoning_parser = reasoning_parser_obj(self.tokenizer) self.tokenizer.pad_token_id = self.pad_token_id - self._think_token_ids = None - - def _get_think_token_ids(self): - if self._think_token_ids is not None: - return self._think_token_ids - vocab = self.tokenizer.get_vocab() - think_start_id = vocab.get("", -1) - think_end_id = vocab.get("", -1) - self._think_token_ids = (think_start_id, think_end_id) - return self._think_token_ids - - def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args): - if not isinstance(logits_processors_args, dict): - return logits_processors_args - thinking_budget = logits_processors_args.get("thinking_budget") - if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0: - return logits_processors_args - if logits_processors_args.get("think_prompt_checked"): - return logits_processors_args - if prompt_token_ids is None: - return logits_processors_args - token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids) - if token_len == 0: - return logits_processors_args - think_start_id, think_end_id = self._get_think_token_ids() - if think_start_id < 0 or think_end_id < 0: - return logits_processors_args - - if hasattr(prompt_token_ids, "tolist"): - token_list = prompt_token_ids.tolist() - else: - token_list = list(prompt_token_ids) - - started = think_start_id in token_list - ended = False - tokens_after_start = 0 - last_token_id = None - if started: - start_pos = token_list.index(think_start_id) - tokens_after = token_list[start_pos + 1 :] - if think_end_id in tokens_after: - end_pos = tokens_after.index(think_end_id) - tokens_after_start = end_pos + 1 - ended = True - else: - tokens_after_start = len(tokens_after) - if token_list: - last_token_id = int(token_list[-1]) - - logits_processors_args["think_prompt_checked"] = True - logits_processors_args["think_prompt_started"] = started - logits_processors_args["think_prompt_ended"] = ended - logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start - if last_token_id is not None: - logits_processors_args["think_prompt_last_token_id"] = last_token_id - else: - logits_processors_args.pop("think_prompt_last_token_id", None) - return logits_processors_args - def process_request(self, request, max_model_len=None, **kwargs): """ Preprocess the request @@ -320,14 +354,10 @@ def process_request(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids - logits_processors_args = request.get("logits_processors_args") or {} - think_stop_sentence = logits_processors_args.get("think_stop_sentence") - if isinstance(think_stop_sentence, str) and think_stop_sentence: - newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) - sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False) - logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids - logits_processors_args.pop("think_stop_sentence", None) - request["logits_processors_args"] = logits_processors_args + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args # processing prompt_token_ids if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: @@ -425,14 +455,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids - logits_processors_args = request.get("logits_processors_args") or {} - think_stop_sentence = logits_processors_args.get("think_stop_sentence") - if isinstance(think_stop_sentence, str) and think_stop_sentence: - newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) - sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False) - logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids - logits_processors_args.pop("think_stop_sentence", None) - request["logits_processors_args"] = logits_processors_args + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args # processing prompt_token_ids if not request.get("prompt_token_ids"): diff --git a/fastdeploy/input/v1/ernie4_5_processor.py b/fastdeploy/input/v1/ernie4_5_processor.py index 3b4aee7588a..6b99ce37be0 100644 --- a/fastdeploy/input/v1/ernie4_5_processor.py +++ b/fastdeploy/input/v1/ernie4_5_processor.py @@ -107,6 +107,11 @@ def process_request(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args + # processing prompt_token_ids if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: if request.prompt is not None: @@ -123,7 +128,8 @@ def process_request(self, request, max_model_len=None, **kwargs): token_ids = self.tokenizer.convert_tokens_to_ids(tokens) request.prompt_token_ids = token_ids data_processor_logger.debug( - f"request_ids: {request.request_id}, prompt: {prompt}, tokens: {tokens}, token_ids: {token_ids}" + f"request_ids: {request.request_id}, prompt: {prompt}, " + f"tokens: {tokens}, token_ids: {token_ids}" ) elif request.messages is not None: task = request.to_dict() @@ -145,6 +151,10 @@ def process_request(self, request, max_model_len=None, **kwargs): # truncate prompts that exceed the length limit if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] + logits_processors_args = self._update_thinking_prompt_state( + request.prompt_token_ids, request.get("logits_processors_args") or {} + ) + request["logits_processors_args"] = logits_processors_args max_tokens = max_model_len - len(request.prompt_token_ids) if request.get("max_tokens") is None: request.set("max_tokens", max(1, max_tokens)) @@ -199,6 +209,11 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request.sampling_params.bad_words_token_ids = bad_words_token_ids + logits_processors_args = self._prepare_think_stop_sentence( + getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len + ) + request.sampling_params.logits_processors_args = logits_processors_args + # processing prompt_token_ids if not request.prompt_token_ids: if request.prompt: @@ -239,6 +254,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): # truncate prompts that exceed the length limit if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] + logits_processors_args = self._update_thinking_prompt_state( + request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {} + ) + request.sampling_params.logits_processors_args = logits_processors_args max_tokens = max_model_len - len(request.prompt_token_ids) if getattr(request.sampling_params, "max_tokens", None) is None: request.sampling_params.max_tokens = max(1, max_tokens) diff --git a/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py b/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py index 3fc704b4b1e..f1499e3d505 100644 --- a/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py +++ b/fastdeploy/input/v1/ernie4_5_vl_processor/ernie4_5_vl_processor.py @@ -216,6 +216,11 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request.sampling_params.bad_words_token_ids = bad_words_token_ids + logits_processors_args = self._prepare_think_stop_sentence( + getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len + ) + request.sampling_params.logits_processors_args = logits_processors_args + if request.prompt_token_ids: messages = request.messages if messages: @@ -267,6 +272,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): # 截断超过长度限制的prompt if max_model_len is not None and len(request.prompt_token_ids) > max_model_len: request.prompt_token_ids = request.prompt_token_ids[: max_model_len - 1] + logits_processors_args = self._update_thinking_prompt_state( + request.prompt_token_ids, getattr(request.sampling_params, "logits_processors_args", None) or {} + ) + request.sampling_params.logits_processors_args = logits_processors_args max_tokens = max_model_len - len(request.prompt_token_ids) if getattr(request.sampling_params, "max_tokens", None) is None: diff --git a/fastdeploy/input/v1/text_processor.py b/fastdeploy/input/v1/text_processor.py index 4b164a12861..93eca6bff97 100644 --- a/fastdeploy/input/v1/text_processor.py +++ b/fastdeploy/input/v1/text_processor.py @@ -115,6 +115,9 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False): """ Encode text into token ids with a small LRU cache. """ + if not hasattr(self, "_tokenize_cache"): + self._tokenize_cache = OrderedDict() + self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128) key = (text, bool(add_special_tokens)) cached = self._tokenize_cache.get(key) if cached is not None: @@ -130,6 +133,25 @@ def encode_with_cache(self, text, max_model_len=None, add_special_tokens=False): self._tokenize_cache.popitem(last=False) return token_ids + def _encode_literal_text_with_cache(self, text): + if not hasattr(self, "_tokenize_cache"): + self._tokenize_cache = OrderedDict() + self._tokenize_cache_capacity = getattr(self, "_tokenize_cache_capacity", 128) + key = ("literal_text", text) + cached = self._tokenize_cache.get(key) + if cached is not None: + self._tokenize_cache.move_to_end(key) + return cached + token_ids = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(text)) + if hasattr(token_ids, "tolist"): + token_ids = token_ids.tolist() + elif not isinstance(token_ids, list): + token_ids = list(token_ids) + self._tokenize_cache[key] = token_ids + if len(self._tokenize_cache) > self._tokenize_cache_capacity: + self._tokenize_cache.popitem(last=False) + return token_ids + def messages2ids(self, messages): """ Convert multi-turn messages into ID sequences. @@ -142,6 +164,77 @@ def messages2ids(self, messages): """ raise NotImplementedError + def _get_think_token_ids(self): + think_token_ids = getattr(self, "_think_token_ids", None) + if think_token_ids is not None: + return think_token_ids + tokenizer = getattr(self, "tokenizer", None) + vocab = tokenizer.get_vocab() if tokenizer is not None else {} + think_start_id = vocab.get("", -1) + think_end_id = vocab.get("", -1) + self._think_token_ids = (think_start_id, think_end_id) + return self._think_token_ids + + def _prepare_think_stop_sentence(self, logits_processors_args, max_model_len=None): + if not isinstance(logits_processors_args, dict): + return logits_processors_args + think_stop_sentence = logits_processors_args.get("think_stop_sentence") + if isinstance(think_stop_sentence, str) and think_stop_sentence: + sentence_token_ids = self._encode_literal_text_with_cache(think_stop_sentence) + logits_processors_args["think_stop_sentence_token_ids"] = sentence_token_ids + logits_processors_args.pop("think_stop_sentence", None) + return logits_processors_args + + def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args): + if not isinstance(logits_processors_args, dict): + return logits_processors_args + thinking_budget = logits_processors_args.get("thinking_budget") + if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0: + return logits_processors_args + if logits_processors_args.get("think_prompt_checked"): + return logits_processors_args + if prompt_token_ids is None: + return logits_processors_args + token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids) + if token_len == 0: + return logits_processors_args + think_start_id, think_end_id = self._get_think_token_ids() + if think_start_id < 0 or think_end_id < 0: + return logits_processors_args + + if hasattr(prompt_token_ids, "tolist"): + token_list = prompt_token_ids.tolist() + else: + token_list = list(prompt_token_ids) + + started = False + ended = False + tokens_after_start = 0 + last_token_id = None + in_thinking = False + for token_id in token_list: + if token_id == think_start_id: + started = True + ended = False + in_thinking = True + elif token_id == think_end_id and in_thinking: + ended = True + in_thinking = False + if started and token_list: + # Align with operator-level reasoning_max_tokens: prompt-side tokens + # inside do not consume thinking budget. + last_token_id = int(token_list[-1]) + + logits_processors_args["think_prompt_checked"] = True + logits_processors_args["think_prompt_started"] = started + logits_processors_args["think_prompt_ended"] = ended + logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start + if last_token_id is not None: + logits_processors_args["think_prompt_last_token_id"] = last_token_id + else: + logits_processors_args.pop("think_prompt_last_token_id", None) + return logits_processors_args + def ids2tokens(self, token_id, task_id=None): """ token ids to strings @@ -221,65 +314,6 @@ def __init__(self, model_name_or_path, reasoning_parser_obj=None, tool_parser_ob self.reasoning_parser = reasoning_parser_obj(self.tokenizer) self.tokenizer.pad_token_id = self.pad_token_id - self._think_token_ids = None - - def _get_think_token_ids(self): - if self._think_token_ids is not None: - return self._think_token_ids - vocab = self.tokenizer.get_vocab() - think_start_id = vocab.get("", -1) - think_end_id = vocab.get("", -1) - self._think_token_ids = (think_start_id, think_end_id) - return self._think_token_ids - - def _update_thinking_prompt_state(self, prompt_token_ids, logits_processors_args): - if not isinstance(logits_processors_args, dict): - return logits_processors_args - thinking_budget = logits_processors_args.get("thinking_budget") - if thinking_budget is None or not isinstance(thinking_budget, int) or thinking_budget < 0: - return logits_processors_args - if logits_processors_args.get("think_prompt_checked"): - return logits_processors_args - if prompt_token_ids is None: - return logits_processors_args - token_len = getattr(prompt_token_ids, "size", None) or len(prompt_token_ids) - if token_len == 0: - return logits_processors_args - think_start_id, think_end_id = self._get_think_token_ids() - if think_start_id < 0 or think_end_id < 0: - return logits_processors_args - - if hasattr(prompt_token_ids, "tolist"): - token_list = prompt_token_ids.tolist() - else: - token_list = list(prompt_token_ids) - - started = think_start_id in token_list - ended = False - tokens_after_start = 0 - last_token_id = None - if started: - start_pos = token_list.index(think_start_id) - tokens_after = token_list[start_pos + 1 :] - if think_end_id in tokens_after: - end_pos = tokens_after.index(think_end_id) - tokens_after_start = end_pos + 1 - ended = True - else: - tokens_after_start = len(tokens_after) - if token_list: - last_token_id = int(token_list[-1]) - - logits_processors_args["think_prompt_checked"] = True - logits_processors_args["think_prompt_started"] = started - logits_processors_args["think_prompt_ended"] = ended - logits_processors_args["think_prompt_tokens_after_start"] = tokens_after_start - if last_token_id is not None: - logits_processors_args["think_prompt_last_token_id"] = last_token_id - else: - logits_processors_args.pop("think_prompt_last_token_id", None) - return logits_processors_args - def process_request(self, request, max_model_len=None, **kwargs): """ Preprocess the request @@ -306,14 +340,10 @@ def process_request(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request["bad_words_token_ids"] = bad_words_token_ids - logits_processors_args = request.get("logits_processors_args") or {} - think_stop_sentence = logits_processors_args.get("think_stop_sentence") - if isinstance(think_stop_sentence, str) and think_stop_sentence: - newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) - sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False) - logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids - logits_processors_args.pop("think_stop_sentence", None) - request["logits_processors_args"] = logits_processors_args + logits_processors_args = self._prepare_think_stop_sentence( + request.get("logits_processors_args") or {}, max_model_len + ) + request["logits_processors_args"] = logits_processors_args # processing prompt_token_ids if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0: @@ -409,14 +439,10 @@ def process_request_dict(self, request, max_model_len=None, **kwargs): bad_words_token_ids = self.update_bad_words(bad_words, bad_words_token_ids) request.sampling_params.bad_words_token_ids = bad_words_token_ids - logits_processors_args = getattr(request.sampling_params, "logits_processors_args", None) or {} - think_stop_sentence = logits_processors_args.get("think_stop_sentence") - if isinstance(think_stop_sentence, str) and think_stop_sentence: - newline_token_ids = self.encode_with_cache("\n", max_model_len, add_special_tokens=False) - sentence_token_ids = self.encode_with_cache(think_stop_sentence, max_model_len, add_special_tokens=False) - logits_processors_args["think_stop_sentence_token_ids"] = newline_token_ids + sentence_token_ids - logits_processors_args.pop("think_stop_sentence", None) - request.sampling_params.logits_processors_args = logits_processors_args + logits_processors_args = self._prepare_think_stop_sentence( + getattr(request.sampling_params, "logits_processors_args", None) or {}, max_model_len + ) + request.sampling_params.logits_processors_args = logits_processors_args # processing prompt_token_ids if not request.prompt_token_ids: diff --git a/fastdeploy/model_executor/logits_processor/thinking_budget.py b/fastdeploy/model_executor/logits_processor/thinking_budget.py index 891a8ddd25a..b8493c1756d 100644 --- a/fastdeploy/model_executor/logits_processor/thinking_budget.py +++ b/fastdeploy/model_executor/logits_processor/thinking_budget.py @@ -42,14 +42,15 @@ class _ThinkingState: class ThinkingBudgetLogitsProcessor(LogitsProcessor): """Limit the number of tokens generated in the thinking phase. - The processor tracks per-request thinking state and forces a newline token - when the budget is reached, followed by the thinking end token on the next step. + The processor tracks per-request thinking state and forces the thinking end + token when the budget is reached. If a stop sentence is configured, the + processor emits the stop sentence first and then the thinking end token. Request-specific configuration is provided via logits_processors_args: {"thinking_budget": } - Requires model_config to provide think_start_id, think_end_id, and line_break_id. - If any of these are missing or invalid (-1), the processor will be disabled. + Requires model_config to provide think_start_id and think_end_id. If any of + these are missing or invalid (-1), the processor will be disabled. """ def __init__(self, fd_config: FDConfig) -> None: @@ -61,20 +62,33 @@ def __init__(self, fd_config: FDConfig) -> None: self.think_start_token_id = think_start_id if isinstance(think_start_id, int) and think_start_id >= 0 else -1 self.think_end_token_id = think_end_id if isinstance(think_end_id, int) and think_end_id >= 0 else -1 self.line_break_token_id = line_break_id if isinstance(line_break_id, int) and line_break_id >= 0 else -1 - self._enabled = ( - self.think_start_token_id >= 0 and self.think_end_token_id >= 0 and self.line_break_token_id >= 0 - ) + self._enabled = self.think_start_token_id >= 0 and self.think_end_token_id >= 0 if not self._enabled: logger.warning( "ThinkingBudgetLogitsProcessor disabled: missing token ids " - f"(think_start={think_start_id}, think_end={think_end_id}, line_break={line_break_id}). " - "Ensure model vocab contains , tokens and line_break_id is configured." + f"(think_start={think_start_id}, think_end={think_end_id}). " + "Ensure model vocab contains and tokens." ) self._states: Dict[str, _ThinkingState] = {} self._active_req_ids: list[str] = [] self._active_budgets: list[int] = [] self._active_slots: list[int] = [] + def _scan_prompt_state(self, prompt_slice: list[int]) -> tuple[bool, bool, int, Optional[int]]: + started = False + ended = False + in_thinking = False + for token_id in prompt_slice: + if token_id == self.think_start_token_id: + started = True + ended = False + in_thinking = True + elif token_id == self.think_end_token_id and in_thinking: + ended = True + in_thinking = False + last_token_id = int(prompt_slice[-1]) if started and prompt_slice else None + return started, ended, 0, last_token_id + def update_state(self, share_inputs: dict) -> None: if not self._enabled: return @@ -82,6 +96,7 @@ def update_state(self, share_inputs: dict) -> None: req_ids = share_inputs["req_ids"] logits_processors_args = share_inputs["logits_processors_args"] prompt_ids = share_inputs.get("prompt_ids") + token_ids_all = share_inputs.get("token_ids_all") prompt_lens = share_inputs.get("prompt_lens") pre_ids = share_inputs.get("pre_ids") next_tokens = share_inputs.get("next_tokens") @@ -153,6 +168,8 @@ def update_state(self, share_inputs: dict) -> None: for idx, slot_id in enumerate(candidate_slots): next_token_by_slot[slot_id] = int(next_sel[idx]) + prompt_source = prompt_ids if prompt_ids is not None else token_ids_all + for idx, slot_id in enumerate(candidate_slots): req_id = candidate_req_ids[idx] logit_proc_args = candidate_args[idx] @@ -190,25 +207,28 @@ def update_state(self, share_inputs: dict) -> None: state.current_step_idx = current_step_idx if not state.started and not state.prompt_checked: - if prompt_ids is not None and prompt_lens is not None: + if prompt_source is not None and prompt_lens is not None: if prompt_lens_np is not None: prompt_len = int(prompt_lens_np[slot_id]) else: prompt_len = int(prompt_lens[slot_id]) - prompt_slice = prompt_ids[slot_id, :prompt_len] - prompt_slice = prompt_slice.numpy().tolist() - if self.think_start_token_id in prompt_slice: + prompt_slice = prompt_source[slot_id, :prompt_len] + if hasattr(prompt_slice, "numpy"): + prompt_slice = prompt_slice.numpy().tolist() + elif hasattr(prompt_slice, "tolist"): + prompt_slice = prompt_slice.tolist() + else: + prompt_slice = list(prompt_slice) + if prompt_ids is None: + prompt_slice = [int(token_id) for token_id in prompt_slice if int(token_id) >= 0] + prompt_started, prompt_ended, prompt_tokens_after_start, prompt_last_token_id = ( + self._scan_prompt_state(prompt_slice) + ) + if prompt_started: state.started = True - start_pos = prompt_slice.index(self.think_start_token_id) - tokens_after = prompt_slice[start_pos + 1 :] - if self.think_end_token_id in tokens_after: - end_pos = tokens_after.index(self.think_end_token_id) - state.tokens_after_start = end_pos + 1 - state.ended = True - else: - state.tokens_after_start = len(tokens_after) - if prompt_slice: - state.last_token_id = int(prompt_slice[-1]) + state.ended = prompt_ended + state.tokens_after_start = prompt_tokens_after_start + state.last_token_id = prompt_last_token_id if current_step_idx is not None and state.last_step_idx is None: state.last_step_idx = current_step_idx state.prompt_checked = True @@ -304,13 +324,6 @@ def apply(self, logits: paddle.Tensor) -> paddle.Tensor: if state.tokens_after_start < budget: continue - if state.last_token_id != self.line_break_token_id: - logits[slot_id, :] = -float("inf") - logits[slot_id, self.line_break_token_id] = 0.0 - state.last_token_id = self.line_break_token_id - state.last_step_idx = state.current_step_idx - continue - logits[slot_id, :] = -float("inf") logits[slot_id, self.think_end_token_id] = 0.0 state.last_token_id = self.think_end_token_id diff --git a/tests/model_executor/test_thinking_budget.py b/tests/model_executor/test_thinking_budget.py index 8476ff2f675..c720162d0d1 100644 --- a/tests/model_executor/test_thinking_budget.py +++ b/tests/model_executor/test_thinking_budget.py @@ -20,7 +20,19 @@ from fastdeploy.engine import engine as engine_module from fastdeploy.engine.args_utils import EngineArgs # Import EngineArgs from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.input.ernie4_5_processor import ( + Ernie4_5Processor as ErnieTextDataProcessor, +) +from fastdeploy.input.ernie4_5_vl_processor import ( + Ernie4_5_VLProcessor as ErnieVLDataProcessor, +) from fastdeploy.input.text_processor import DataProcessor as TextDataProcessor +from fastdeploy.input.v1.ernie4_5_processor import ( + Ernie4_5Processor as V1ErnieTextDataProcessor, +) +from fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor import ( + Ernie4_5_VLProcessor as V1ErnieVLDataProcessor, +) from fastdeploy.input.v1.text_processor import DataProcessor as V1TextDataProcessor from fastdeploy.model_executor.logits_processor import ThinkingBudgetLogitsProcessor from fastdeploy.scheduler import SchedulerConfig @@ -52,6 +64,7 @@ def __init__(self, fd_config, max_num_seqs=4): "req_ids": [f"req_{i}" for i in range(max_num_seqs)], "logits_processors_args": [{} for _ in range(max_num_seqs)], "prompt_ids": paddle.to_tensor(np.zeros((max_num_seqs, 10), dtype=np.int64)), # Max prompt len 10 + "token_ids_all": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)), "prompt_lens": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)), "pre_ids": paddle.to_tensor(np.full((max_num_seqs, self.max_model_len), -1, dtype=np.int64)), "step_idx": paddle.to_tensor(np.zeros((max_num_seqs, 1), dtype=np.int64)), @@ -73,6 +86,12 @@ def update_request_state(self, slot_id, req: MockRequest, pre_id=None, set_next_ self.share_inputs["prompt_ids"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor( req.prompt_ids, dtype=paddle.int64 ) + self.share_inputs["token_ids_all"][slot_id, :] = paddle.to_tensor( + np.full((self.max_model_len,), -1, dtype=np.int64), dtype=paddle.int64 + ) + self.share_inputs["token_ids_all"][slot_id, : len(req.prompt_ids)] = paddle.to_tensor( + req.prompt_ids, dtype=paddle.int64 + ) self.share_inputs["prompt_lens"][slot_id] = paddle.to_tensor(len(req.prompt_ids), dtype=paddle.int64) if req.sampling_params.logits_processors_args: self.share_inputs["logits_processors_args"][slot_id] = req.sampling_params.logits_processors_args @@ -152,7 +171,7 @@ def _get_initial_logits(self, batch_size): return logits def test_thinking_budget_not_reached(self): - # Scenario: Thinking budget is 5, but only 3 tokens are generated in thinking phase + # Scenario: Thinking budget is 5, and prompt-side tokens after do not consume budget. req_id = "test_req_1" prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3, 4, 5] sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 5}) @@ -167,9 +186,9 @@ def test_thinking_budget_not_reached(self): processor.update_state(mock_runner.share_inputs) self.assertTrue(processor._states[req_id].started) self.assertFalse(processor._states[req_id].ended) - self.assertEqual(processor._states[req_id].tokens_after_start, 3) # (3, 4, 5) after THINKING_START + self.assertEqual(processor._states[req_id].tokens_after_start, 0) - # Step 2: Simulate one generation step (budget 5, generated 3 -> 4) + # Step 2: Simulate one generation step (budget 5, generated 0 -> 1) logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) # Update state before apply processed_logits = processor.apply(logits) @@ -181,9 +200,9 @@ def test_thinking_budget_not_reached(self): mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token processor.update_state(mock_runner.share_inputs) # Update state after generating token - self.assertEqual(processor._states[req_id].tokens_after_start, 4) # (3, 4, 5, 0) + self.assertEqual(processor._states[req_id].tokens_after_start, 1) - # Step 3: Simulate another generation step (budget 5, generated 4 -> 5) + # Step 3: Simulate another generation step (budget 5, generated 1 -> 2) logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) # Update state before apply processed_logits = processor.apply(logits) @@ -191,14 +210,14 @@ def test_thinking_budget_not_reached(self): mock_runner.update_request_state(0, mock_req, pre_id=0) # Update last token processor.update_state(mock_runner.share_inputs) # Update state after generating token - self.assertEqual(processor._states[req_id].tokens_after_start, 5) # (3, 4, 5, 0, 0) + self.assertEqual(processor._states[req_id].tokens_after_start, 2) # LogitsProcessor should still not restrict as NEW_LINE_TOKEN is not yet last token - def test_thinking_budget_reached_forces_newline(self): - # Scenario: Budget is 3, after 3 tokens, it should force newline, then thinking_end + def test_thinking_budget_reached_forces_think_end(self): + # Scenario: Budget is 3 and only decode-time tokens count toward the budget. req_id = "test_req_2" - prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] # Initial 1 token after start + prompt_ids = [1, 2, THINKING_START_TOKEN_ID, 3] sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 3}) mock_req = MockRequest(req_id, prompt_ids, sampling_params) @@ -206,16 +225,15 @@ def test_thinking_budget_reached_forces_newline(self): mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1]) processor = ThinkingBudgetLogitsProcessor(self.fd_config) - line_break_id = processor.line_break_token_id think_end_id = processor.think_end_token_id - # Step 1: Initial state update (1 token after start) + # Step 1: Initial state update (prompt-side tokens do not count) processor.update_state(mock_runner.share_inputs) - self.assertEqual(processor._states[req_id].tokens_after_start, 1) + self.assertEqual(processor._states[req_id].tokens_after_start, 0) self.assertFalse(processor._states[req_id].ended) self.assertEqual(processor._states[req_id].last_token_id, 3) - # Step 2: Generate 2nd token (budget 3, generated 1 -> 2) + # Step 2: Generate 1st decode token (budget 3, generated 0 -> 1) logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) processed_logits = processor.apply(logits) @@ -223,10 +241,10 @@ def test_thinking_budget_reached_forces_newline(self): self.assertEqual(next_token, 0) # Normal generation mock_runner.update_request_state(0, mock_req, pre_id=next_token) processor.update_state(mock_runner.share_inputs) - self.assertEqual(processor._states[req_id].tokens_after_start, 2) + self.assertEqual(processor._states[req_id].tokens_after_start, 1) self.assertEqual(processor._states[req_id].last_token_id, 0) - # Step 3: Generate 3rd token (budget 3, generated 2 -> 3). Next step should force NEW_LINE. + # Step 3: Generate 2nd decode token (budget 3, generated 1 -> 2) logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) processed_logits = processor.apply(logits) @@ -234,34 +252,20 @@ def test_thinking_budget_reached_forces_newline(self): self.assertEqual(next_token, 0) # Normal generation mock_runner.update_request_state(0, mock_req, pre_id=next_token) processor.update_state(mock_runner.share_inputs) - self.assertEqual(processor._states[req_id].tokens_after_start, 3) # Budget is now met + self.assertEqual(processor._states[req_id].tokens_after_start, 2) self.assertEqual(processor._states[req_id].last_token_id, 0) - # Step 4: Budget reached, last token not NEW_LINE. Should force NEW_LINE_TOKEN_ID. + # Step 4: Generate 3rd decode token (budget 3, generated 2 -> 3). logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) processed_logits = processor.apply(logits) - - # Verify all other logits are -inf, only NEW_LINE_TOKEN_ID is 0.0 - # Use <= for comparison because paddle.full creates -10.0 and comparison can have precision issues - other_logits = paddle.concat( - [ - processed_logits[0, :line_break_id], - processed_logits[0, line_break_id + 1 : VOCAB_SIZE], - ], - axis=0, - ) - self.assertTrue(paddle.all(other_logits <= -10.0).item() or paddle.all(other_logits == -float("inf")).item()) - self.assertEqual(processed_logits[0, line_break_id].item(), 0.0) - next_token = mock_runner.generate_next_token(processed_logits)[0] - self.assertEqual(next_token, line_break_id) # Forces NEW_LINE + self.assertEqual(next_token, 0) mock_runner.update_request_state(0, mock_req, pre_id=next_token) processor.update_state(mock_runner.share_inputs) - self.assertEqual(processor._states[req_id].tokens_after_start, 4) # Still increment - self.assertEqual(processor._states[req_id].last_token_id, line_break_id) + self.assertEqual(processor._states[req_id].tokens_after_start, 3) - # Step 5: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID. + # Step 5: Budget reached, should force THINKING_END_TOKEN_ID directly. logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) processed_logits = processor.apply(logits) @@ -324,9 +328,9 @@ def test_thinking_budget_stop_sentence_forces_tokens(self): self.assertEqual(next_token, stop_sentence_token_ids[0]) def test_thinking_budget_no_stop_sentence_defaults(self): - # Scenario: No stop sentence, budget reached should force newline. + # Scenario: No stop sentence, budget reached should force thinking_end directly. req_id = "test_req_no_stop_sentence" - prompt_ids = [THINKING_START_TOKEN_ID, 42] # 1 token after start + prompt_ids = [THINKING_START_TOKEN_ID, 42] sampling_params = SamplingParams(logits_processors_args={"thinking_budget": 1}) mock_req = MockRequest(req_id, prompt_ids, sampling_params) @@ -334,35 +338,42 @@ def test_thinking_budget_no_stop_sentence_defaults(self): mock_runner.update_request_state(0, mock_req, pre_id=prompt_ids[-1]) processor = ThinkingBudgetLogitsProcessor(self.fd_config) - line_break_id = processor.line_break_token_id processor.update_state(mock_runner.share_inputs) logits = self._get_initial_logits(1) processed_logits = processor.apply(logits) next_token = mock_runner.generate_next_token(processed_logits)[0] - self.assertEqual(next_token, line_break_id) + self.assertEqual(next_token, 0) + mock_runner.update_request_state(0, mock_req, pre_id=next_token) + processor.update_state(mock_runner.share_inputs) + + logits = self._get_initial_logits(1) + processor.update_state(mock_runner.share_inputs) + processed_logits = processor.apply(logits) + next_token = mock_runner.generate_next_token(processed_logits)[0] + self.assertEqual(next_token, THINKING_END_TOKEN_ID) def test_thinking_budget_uses_config_token_ids(self): # Scenario: Processor should use token ids from model config. self.fd_config.model_config.think_start_id = 123 self.fd_config.model_config.think_end_id = 124 - self.fd_config.model_config.line_break_id = 125 + self.fd_config.model_config.line_break_id = -1 processor = ThinkingBudgetLogitsProcessor(self.fd_config) self.assertEqual(processor.think_start_token_id, 123) self.assertEqual(processor.think_end_token_id, 124) - self.assertEqual(processor.line_break_token_id, 125) + self.assertEqual(processor.line_break_token_id, -1) self.assertTrue(processor._enabled) def test_thinking_budget_disabled_when_token_ids_missing(self): # Scenario: Processor should be disabled when token ids are not configured. self.fd_config.model_config.think_start_id = -1 self.fd_config.model_config.think_end_id = -1 - self.fd_config.model_config.line_break_id = -1 + self.fd_config.model_config.line_break_id = NEW_LINE_TOKEN_ID processor = ThinkingBudgetLogitsProcessor(self.fd_config) self.assertFalse(processor._enabled) self.assertEqual(processor.think_start_token_id, -1) self.assertEqual(processor.think_end_token_id, -1) - self.assertEqual(processor.line_break_token_id, -1) + self.assertEqual(processor.line_break_token_id, NEW_LINE_TOKEN_ID) # update_state and apply should be no-op when disabled mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1) @@ -493,6 +504,28 @@ def test_thinking_budget_prompt_state_from_args(self): self.assertEqual(state.tokens_after_start, 2) self.assertEqual(state.last_token_id, 99) + def test_thinking_budget_prompt_state_from_token_ids_all_fallback(self): + req_id = "req_gpu_fallback" + mock_runner = MockModelRunner(self.fd_config, max_num_seqs=1) + mock_runner.share_inputs["req_ids"][0] = req_id + mock_runner.share_inputs["logits_processors_args"][0] = {"thinking_budget": 3} + mock_runner.share_inputs["prompt_ids"] = None + mock_runner.share_inputs["token_ids_all"][0, :4] = paddle.to_tensor( + [1, THINKING_START_TOKEN_ID, 2, 3], dtype=paddle.int64 + ) + mock_runner.share_inputs["prompt_lens"][0, 0] = paddle.to_tensor(4, dtype=paddle.int64) + mock_runner.share_inputs["next_tokens"][0, 0] = paddle.to_tensor(-1, dtype=paddle.int64) + + processor = ThinkingBudgetLogitsProcessor(self.fd_config) + processor.update_state(mock_runner.share_inputs) + + state = processor._states[req_id] + self.assertTrue(state.prompt_checked) + self.assertTrue(state.started) + self.assertFalse(state.ended) + self.assertEqual(state.tokens_after_start, 0) + self.assertEqual(state.last_token_id, 3) + def test_thinking_budget_not_configured(self): # Scenario: Processor is active, but request does not provide thinking_budget req_id = "test_req_3" @@ -532,20 +565,7 @@ def test_thinking_budget_zero(self): self.assertEqual(processor._states[req_id].tokens_after_start, 0) # No tokens after start yet self.assertEqual(processor._states[req_id].last_token_id, THINKING_START_TOKEN_ID) - # Step 1: Budget 0 reached, last token is THINKING_START. Should force NEW_LINE. - logits = self._get_initial_logits(1) - processor.update_state(mock_runner.share_inputs) - processed_logits = processor.apply(logits) - - self.assertEqual(processed_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0) - next_token = mock_runner.generate_next_token(processed_logits)[0] - self.assertEqual(next_token, NEW_LINE_TOKEN_ID) - mock_runner.update_request_state(0, mock_req, pre_id=next_token) - processor.update_state(mock_runner.share_inputs) - self.assertEqual(processor._states[req_id].tokens_after_start, 1) # Still increments - self.assertEqual(processor._states[req_id].last_token_id, NEW_LINE_TOKEN_ID) - - # Step 2: Last token is NEW_LINE. Should force THINKING_END_TOKEN_ID. + # Step 1: Budget 0 reached, last token is THINKING_START. Should force THINKING_END. logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) processed_logits = processor.apply(logits) @@ -572,7 +592,7 @@ def test_thinking_end_in_prompt(self): processor.update_state(mock_runner.share_inputs) self.assertTrue(processor._states[req_id].started) self.assertTrue(processor._states[req_id].ended) - self.assertEqual(processor._states[req_id].tokens_after_start, 2) # Tokens 2, THINKING_END after start + self.assertEqual(processor._states[req_id].tokens_after_start, 0) logits = self._get_initial_logits(1) processor.update_state(mock_runner.share_inputs) @@ -582,7 +602,7 @@ def test_thinking_end_in_prompt(self): def test_multiple_requests(self): # Scenario: Multiple requests with different thinking states req_id_1 = "req_a" - prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] # budget 2, last_token=11, tokens_after_start=2 + prompt_ids_1 = [THINKING_START_TOKEN_ID, 10, 11] sampling_params_1 = SamplingParams(logits_processors_args={"thinking_budget": 2}) mock_req_1 = MockRequest(req_id_1, prompt_ids_1, sampling_params_1) @@ -592,7 +612,7 @@ def test_multiple_requests(self): mock_req_2 = MockRequest(req_id_2, prompt_ids_2, sampling_params_2) req_id_3 = "req_c" - prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] # budget 1, last_token=30, tokens_after_start=1 + prompt_ids_3 = [THINKING_START_TOKEN_ID, 30] sampling_params_3 = SamplingParams(logits_processors_args={"thinking_budget": 1}) mock_req_3 = MockRequest(req_id_3, prompt_ids_3, sampling_params_3) @@ -609,14 +629,14 @@ def test_multiple_requests(self): # Verify initial states self.assertTrue(processor._states[req_id_1].started) self.assertFalse(processor._states[req_id_1].ended) - self.assertEqual(processor._states[req_id_1].tokens_after_start, 2) + self.assertEqual(processor._states[req_id_1].tokens_after_start, 0) self.assertEqual(processor._states[req_id_1].last_token_id, 11) self.assertNotIn(req_id_2, processor._states) # No budget specified for req_2 self.assertTrue(processor._states[req_id_3].started) self.assertFalse(processor._states[req_id_3].ended) - self.assertEqual(processor._states[req_id_3].tokens_after_start, 1) + self.assertEqual(processor._states[req_id_3].tokens_after_start, 0) self.assertEqual(processor._states[req_id_3].last_token_id, 30) # Simulate logits for the batch @@ -624,16 +644,14 @@ def test_multiple_requests(self): processor.update_state(mock_runner.share_inputs) # Ensure state is updated before apply processed_batch_logits = processor.apply(batch_logits) - # Req 1: budget 2, tokens_after_start 2. Should force NEW_LINE (last_token_id is 11, not NEW_LINE) - self.assertEqual(processed_batch_logits[0, NEW_LINE_TOKEN_ID].item(), 0.0) - self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), NEW_LINE_TOKEN_ID) + # Req 1: prompt-side content does not consume budget, so first step is normal generation. + self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0) # Req 2: No thinking budget, normal generation self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0) - # Req 3: budget 1, tokens_after_start 1. Should force NEW_LINE (last_token_id is 30, not NEW_LINE) - self.assertEqual(processed_batch_logits[2, NEW_LINE_TOKEN_ID].item(), 0.0) - self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), NEW_LINE_TOKEN_ID) + # Req 3: prompt-side content does not consume budget, so first step is normal generation. + self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), 0) # Simulate generating next tokens and updating state next_tokens = mock_runner.generate_next_token(processed_batch_logits) @@ -643,20 +661,25 @@ def test_multiple_requests(self): processor.update_state(mock_runner.share_inputs) # Verify updated states for next step - self.assertEqual(processor._states[req_id_1].last_token_id, NEW_LINE_TOKEN_ID) - self.assertEqual(processor._states[req_id_3].last_token_id, NEW_LINE_TOKEN_ID) + self.assertEqual(processor._states[req_id_1].last_token_id, 0) + self.assertEqual(processor._states[req_id_3].last_token_id, 0) + self.assertEqual(processor._states[req_id_1].tokens_after_start, 1) + self.assertEqual(processor._states[req_id_3].tokens_after_start, 1) + self.assertFalse(processor._states[req_id_1].ended) + self.assertFalse(processor._states[req_id_3].ended) batch_logits = self._get_initial_logits(3) processor.update_state(mock_runner.share_inputs) processed_batch_logits = processor.apply(batch_logits) - # Req 1: last token was NEW_LINE. Should force THINKING_END - self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), THINKING_END_TOKEN_ID) + # Req 1: budget 2, tokens_after_start 1. Still normal generation. + self.assertEqual(paddle.argmax(processed_batch_logits[0], axis=-1).item(), 0) # Req 2: Still normal generation self.assertEqual(paddle.argmax(processed_batch_logits[1], axis=-1).item(), 0) - # Req 3: last token was NEW_LINE. Should force THINKING_END + # Req 3: budget 1 reached after one generated token, should now force THINKING_END. + self.assertEqual(processed_batch_logits[2, THINKING_END_TOKEN_ID].item(), 0.0) self.assertEqual(paddle.argmax(processed_batch_logits[2], axis=-1).item(), THINKING_END_TOKEN_ID) @@ -724,7 +747,7 @@ def test_update_thinking_prompt_state_from_text_processor(self): self.assertTrue(updated["think_prompt_checked"]) self.assertTrue(updated["think_prompt_started"]) self.assertTrue(updated["think_prompt_ended"]) - self.assertEqual(updated["think_prompt_tokens_after_start"], 2) + self.assertEqual(updated["think_prompt_tokens_after_start"], 0) self.assertEqual(updated["think_prompt_last_token_id"], 3) def test_v1_process_request_missing_logits_processors_args(self): @@ -833,6 +856,58 @@ def _text2ids(text, max_model_len=None, add_special_tokens=False): self.assertEqual(processor.encode_with_cache("iter"), [41, 42]) self.assertNotIn(("np", False), processor._tokenize_cache) + def test_text_encode_with_cache_lazy_init(self): + processor = TextDataProcessor.__new__(TextDataProcessor) + call_counter = {"count": 0} + + def _text2ids(text, max_model_len=None, add_special_tokens=False): + call_counter["count"] += 1 + return np.array([51, 52], dtype=np.int64) + + processor.text2ids = _text2ids + + self.assertFalse(hasattr(processor, "_tokenize_cache")) + self.assertEqual(processor.encode_with_cache("lazy"), [51, 52]) + self.assertTrue(hasattr(processor, "_tokenize_cache")) + self.assertEqual(processor.encode_with_cache("lazy"), [51, 52]) + self.assertEqual(call_counter["count"], 1) + + def test_v1_encode_with_cache_lazy_init(self): + processor = V1TextDataProcessor.__new__(V1TextDataProcessor) + call_counter = {"count": 0} + + def _text2ids(text, max_model_len=None, add_special_tokens=False): + call_counter["count"] += 1 + return np.array([61, 62], dtype=np.int64) + + processor.text2ids = _text2ids + + self.assertFalse(hasattr(processor, "_tokenize_cache")) + self.assertEqual(processor.encode_with_cache("lazy"), [61, 62]) + self.assertTrue(hasattr(processor, "_tokenize_cache")) + self.assertEqual(processor.encode_with_cache("lazy"), [61, 62]) + self.assertEqual(call_counter["count"], 1) + + def test_ernie_encode_literal_text_with_cache(self): + processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor) + processor.tokenizer = SimpleNamespace( + tokenize=lambda text: ["token_a", "token_b"], + convert_tokens_to_ids=lambda tokens: [71, 72], + ) + + self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72]) + self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [71, 72]) + + def test_v1_ernie_encode_literal_text_with_cache(self): + processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor) + processor.tokenizer = SimpleNamespace( + tokenize=lambda text: ["token_c", "token_d"], + convert_tokens_to_ids=lambda tokens: [81, 82], + ) + + self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82]) + self.assertEqual(processor._encode_literal_text_with_cache("fallback"), [81, 82]) + def test_text_update_thinking_prompt_state_branches(self): processor = TextDataProcessor.__new__(TextDataProcessor) processor._think_token_ids = None @@ -868,7 +943,7 @@ def test_text_update_thinking_prompt_state_branches(self): ) self.assertTrue(with_start_no_end["think_prompt_started"]) self.assertFalse(with_start_no_end["think_prompt_ended"]) - self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2) + self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0) self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3) # 命中 _get_think_token_ids 的缓存分支 @@ -891,7 +966,7 @@ def test_v1_update_thinking_prompt_state_branches(self): ) self.assertTrue(with_start_no_end["think_prompt_started"]) self.assertFalse(with_start_no_end["think_prompt_ended"]) - self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 2) + self.assertEqual(with_start_no_end["think_prompt_tokens_after_start"], 0) self.assertEqual(with_start_no_end["think_prompt_last_token_id"], 3) # 命中 _get_think_token_ids 的缓存分支 @@ -903,7 +978,7 @@ def test_text_process_request_think_stop_sentence(self): processor.eos_token_ids = [1] processor.update_stop_seq = lambda *args, **kwargs: None processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids - processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [101, 102] + processor._encode_literal_text_with_cache = lambda text: [101, 102] processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor.reasoning_parser = None @@ -924,7 +999,7 @@ def test_text_process_request_think_stop_sentence(self): processed = processor.process_request(request, max_model_len=16) self.assertEqual( processed.logits_processors_args.get("think_stop_sentence_token_ids"), - [23, 101, 102], + [101, 102], ) self.assertNotIn("think_stop_sentence", processed.logits_processors_args) @@ -934,7 +1009,7 @@ def test_text_process_request_dict_think_stop_sentence(self): processor.eos_token_ids = [1] processor.update_stop_seq = lambda *args, **kwargs: None processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids - processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [201, 202] + processor._encode_literal_text_with_cache = lambda text: [201, 202] processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor.reasoning_parser = None @@ -955,7 +1030,7 @@ def test_text_process_request_dict_think_stop_sentence(self): processed = processor.process_request_dict(request, max_model_len=16) self.assertEqual( processed["logits_processors_args"].get("think_stop_sentence_token_ids"), - [23, 201, 202], + [201, 202], ) self.assertNotIn("think_stop_sentence", processed["logits_processors_args"]) @@ -965,7 +1040,7 @@ def test_v1_process_request_think_stop_sentence(self): processor.eos_token_ids = [1] processor.update_stop_seq = lambda *args, **kwargs: None processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids - processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [301, 302] + processor._encode_literal_text_with_cache = lambda text: [301, 302] processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor.reasoning_parser = None @@ -986,7 +1061,7 @@ def test_v1_process_request_think_stop_sentence(self): processed = processor.process_request(request, max_model_len=16) self.assertEqual( processed.logits_processors_args.get("think_stop_sentence_token_ids"), - [23, 301, 302], + [301, 302], ) self.assertNotIn("think_stop_sentence", processed.logits_processors_args) @@ -996,7 +1071,7 @@ def test_v1_process_request_dict_think_stop_sentence(self): processor.eos_token_ids = [1] processor.update_stop_seq = lambda *args, **kwargs: None processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids - processor.encode_with_cache = lambda text, *args, **kwargs: [23] if text == "\n" else [401, 402] + processor._encode_literal_text_with_cache = lambda text: [401, 402] processor._update_thinking_prompt_state = lambda prompt_token_ids, args: args processor.reasoning_parser = None @@ -1023,10 +1098,173 @@ def test_v1_process_request_dict_think_stop_sentence(self): processed = processor.process_request_dict(request, max_model_len=16) self.assertEqual( processed.sampling_params.logits_processors_args.get("think_stop_sentence_token_ids"), - [23, 401, 402], + [401, 402], ) self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args) + def test_ernie_process_request_prepares_thinking_budget_args(self): + processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor) + processor._apply_default_parameters = lambda request: request + processor.eos_token_ids = [1] + processor.update_stop_seq = lambda *args, **kwargs: None + processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids + processor._encode_literal_text_with_cache = lambda text: [501, 502] + processor.tokenizer = DummyTokenizerForTextProcessor() + processor.reasoning_parser = None + + request = DummyRequestV1( + request_id="req_ernie_text", + eos_token_ids=[1], + prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2], + prompt=None, + messages=None, + logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"}, + bad_words=None, + bad_words_token_ids=None, + max_tokens=1, + temperature=1.0, + top_p=0.9, + response_max_tokens=None, + enable_thinking=True, + ) + with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None): + processed = processor.process_request(request, max_model_len=16) + + self.assertEqual(processed.logits_processors_args["think_stop_sentence_token_ids"], [501, 502]) + self.assertTrue(processed.logits_processors_args["think_prompt_started"]) + self.assertFalse(processed.logits_processors_args["think_prompt_ended"]) + self.assertEqual(processed.logits_processors_args["think_prompt_tokens_after_start"], 0) + + def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self): + processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor) + processor._apply_default_parameters = lambda request: request + processor.eos_token_ids = [1] + processor.update_stop_seq = lambda *args, **kwargs: None + processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids + processor._encode_literal_text_with_cache = lambda text: [601, 602] + processor.tokenizer = DummyTokenizerForTextProcessor() + processor.reasoning_parser = None + + request = DummyRequestV1( + request_id="req_v1_ernie_text", + eos_token_ids=[1], + prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2], + prompt=None, + messages=None, + chat_template_kwargs=None, + enable_thinking=True, + sampling_params=SimpleNamespace( + bad_words=None, + bad_words_token_ids=None, + max_tokens=1, + temperature=1.0, + top_p=0.9, + repetition_penalty=1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + response_max_tokens=None, + n=1, + logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"}, + ), + ) + with patch("fastdeploy.input.v1.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None): + processed = processor.process_request_dict(request, max_model_len=16) + + self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [601, 602]) + self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"]) + self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"]) + self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0) + + def test_ernie_vl_process_request_dict_prepares_thinking_budget_args(self): + processor = ErnieVLDataProcessor.__new__(ErnieVLDataProcessor) + processor._apply_default_parameters = lambda request: request + processor.eos_token_ids = [1] + processor.update_stop_seq = lambda *args, **kwargs: None + processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids + processor._encode_literal_text_with_cache = lambda text: [701, 702] + processor.tokenizer = DummyTokenizerForTextProcessor() + processor.reasoning_parser = None + processor._check_mm_limits = lambda *args, **kwargs: None + processor.append_completion_tokens = lambda *args, **kwargs: None + processor.pack_outputs = lambda outs: outs + processor.ernie4_5_processor = SimpleNamespace( + request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)} + ) + + request = { + "request_id": "req_ernie_vl", + "eos_token_ids": [1], + "messages": [{"role": "user", "content": "hi"}], + "bad_words": None, + "bad_words_token_ids": None, + "logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"}, + "max_tokens": 1, + "top_p": 0.9, + "response_max_tokens": None, + } + with patch( + "fastdeploy.input.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids", + lambda *args, **kwargs: None, + ): + processed = processor.process_request_dict(request, max_model_len=16) + + self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [701, 702]) + self.assertTrue(processed["logits_processors_args"]["think_prompt_started"]) + self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"]) + self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0) + + def test_v1_ernie_vl_process_request_dict_prepares_thinking_budget_args(self): + processor = V1ErnieVLDataProcessor.__new__(V1ErnieVLDataProcessor) + processor._apply_default_parameters = lambda request: request + processor.eos_token_ids = [1] + processor.update_stop_seq = lambda *args, **kwargs: None + processor.update_bad_words = lambda bad_words, bad_words_token_ids: bad_words_token_ids + processor._encode_literal_text_with_cache = lambda text: [801, 802] + processor.tokenizer = DummyTokenizerForTextProcessor() + processor.reasoning_parser = None + processor._check_mm_limits = lambda *args, **kwargs: None + processor.append_completion_tokens = lambda *args, **kwargs: None + processor.pack_outputs = lambda outs: outs + processor.ernie4_5_processor = SimpleNamespace( + request2ids=lambda request: {"input_ids": np.array([1, THINKING_START_TOKEN_ID, 2], dtype=np.int64)} + ) + + request = DummyRequestV1( + request_id="req_v1_ernie_vl", + eos_token_ids=[1], + prompt_token_ids=None, + prompt=None, + messages=[{"role": "user", "content": "hi"}], + chat_template_kwargs=None, + enable_thinking=True, + completion_token_ids=None, + multimodal_data=None, + sampling_params=SimpleNamespace( + bad_words=None, + bad_words_token_ids=None, + max_tokens=1, + temperature=1.0, + top_p=0.9, + repetition_penalty=1.0, + frequency_penalty=0.0, + presence_penalty=0.0, + response_max_tokens=None, + reasoning_max_tokens=None, + n=1, + logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"}, + ), + ) + with patch( + "fastdeploy.input.v1.ernie4_5_vl_processor.ernie4_5_vl_processor.process_stop_token_ids", + lambda *args, **kwargs: None, + ): + processed = processor.process_request_dict(request, max_model_len=16) + + self.assertEqual(processed.sampling_params.logits_processors_args["think_stop_sentence_token_ids"], [801, 802]) + self.assertTrue(processed.sampling_params.logits_processors_args["think_prompt_started"]) + self.assertFalse(processed.sampling_params.logits_processors_args["think_prompt_ended"]) + self.assertEqual(processed.sampling_params.logits_processors_args["think_prompt_tokens_after_start"], 0) + if __name__ == "__main__": unittest.main() From 5de1cb5b11e54335dd53790797a0f96eda0d8a29 Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Fri, 20 Mar 2026 10:56:14 +0800 Subject: [PATCH 2/3] [Docs] Fix thinking_budget markdown formatting --- docs/features/thinking_budget.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/thinking_budget.md b/docs/features/thinking_budget.md index 79581497031..c236cb4fcb4 100644 --- a/docs/features/thinking_budget.md +++ b/docs/features/thinking_budget.md @@ -26,7 +26,7 @@ segment. When the budget is reached, it terminates thinking by forcing ` ## Request Parameters - `thinking_budget` (int, required to enable): maximum number of decode-time tokens after `` before forced - termination. + termination. - `think_stop_sentence` (string, optional): a literal custom sentence that will be tokenized on the request side and enforced near the budget boundary. From 9387e742fc8d19c5a66b4ebf2fee9f4045e2e6aa Mon Sep 17 00:00:00 2001 From: jackyYang6 Date: Fri, 20 Mar 2026 13:23:29 +0800 Subject: [PATCH 3/3] [Test] Align ernie thinking budget test with process_request_dict --- tests/model_executor/test_thinking_budget.py | 42 ++++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/model_executor/test_thinking_budget.py b/tests/model_executor/test_thinking_budget.py index be574dcbfa7..8ba9319ff7d 100644 --- a/tests/model_executor/test_thinking_budget.py +++ b/tests/model_executor/test_thinking_budget.py @@ -1071,7 +1071,7 @@ def test_v1_process_request_dict_think_stop_sentence(self): ) self.assertNotIn("think_stop_sentence", processed.sampling_params.logits_processors_args) - def test_ernie_process_request_prepares_thinking_budget_args(self): + def test_ernie_process_request_dict_prepares_thinking_budget_args(self): processor = ErnieTextDataProcessor.__new__(ErnieTextDataProcessor) processor._apply_default_parameters = lambda request: request processor.eos_token_ids = [1] @@ -1081,28 +1081,28 @@ def test_ernie_process_request_prepares_thinking_budget_args(self): processor.tokenizer = DummyTokenizerForTextProcessor() processor.reasoning_parser = None - request = DummyRequestV1( - request_id="req_ernie_text", - eos_token_ids=[1], - prompt_token_ids=[1, THINKING_START_TOKEN_ID, 2], - prompt=None, - messages=None, - logits_processors_args={"thinking_budget": 20, "think_stop_sentence": "done"}, - bad_words=None, - bad_words_token_ids=None, - max_tokens=1, - temperature=1.0, - top_p=0.9, - response_max_tokens=None, - enable_thinking=True, - ) + request = { + "request_id": "req_ernie_text", + "eos_token_ids": [1], + "prompt_token_ids": [1, THINKING_START_TOKEN_ID, 2], + "prompt": None, + "messages": None, + "logits_processors_args": {"thinking_budget": 20, "think_stop_sentence": "done"}, + "bad_words": None, + "bad_words_token_ids": None, + "max_tokens": 1, + "temperature": 1.0, + "top_p": 0.9, + "response_max_tokens": None, + "enable_thinking": True, + } with patch("fastdeploy.input.ernie4_5_processor.process_stop_token_ids", lambda *args, **kwargs: None): - processed = processor.process_request(request, max_model_len=16) + processed = processor.process_request_dict(request, max_model_len=16) - self.assertEqual(processed.logits_processors_args["think_stop_sentence_token_ids"], [501, 502]) - self.assertTrue(processed.logits_processors_args["think_prompt_started"]) - self.assertFalse(processed.logits_processors_args["think_prompt_ended"]) - self.assertEqual(processed.logits_processors_args["think_prompt_tokens_after_start"], 0) + self.assertEqual(processed["logits_processors_args"]["think_stop_sentence_token_ids"], [501, 502]) + self.assertTrue(processed["logits_processors_args"]["think_prompt_started"]) + self.assertFalse(processed["logits_processors_args"]["think_prompt_ended"]) + self.assertEqual(processed["logits_processors_args"]["think_prompt_tokens_after_start"], 0) def test_v1_ernie_process_request_dict_prepares_thinking_budget_args(self): processor = V1ErnieTextDataProcessor.__new__(V1ErnieTextDataProcessor)