From 0b900ff8a8c0c2b27f391ebe1225feb42b1bb159 Mon Sep 17 00:00:00 2001 From: imoc Date: Fri, 1 Aug 2025 22:47:35 +0800 Subject: [PATCH] Inject to responses when forced thinking in chat template Adds logic to prepend '' to the first streamed chunk and all final generations if the chat template ends with 'think'. Adjusts token and offset accounting to remain consistent when the tag is injected. --- endpoints/OAI/types/chat_completion.py | 2 +- endpoints/OAI/utils/chat_completion.py | 79 +++++++++++++++++++------- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..d3b3a629 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -57,7 +57,7 @@ class ChatCompletionStreamChoice(BaseModel): class ChatCompletionRequest(CommonCompletionRequest): messages: List[ChatCompletionMessage] prompt_template: Optional[str] = None - add_generation_prompt: Optional[bool] = True + add_generation_prompt: Optional[bool] = None template_vars: Optional[dict] = Field( default={}, validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..cec4a37c 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -32,6 +32,29 @@ from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +def should_add_generation_prompt(data: ChatCompletionRequest) -> bool: + """ + Determines if a generation prompt should be added based on the request. + - Explicitly follows `data.add_generation_prompt` if set. + - Defaults to `False` if the last message is from the assistant to avoid double prompts. + - Defaults to `True` otherwise. + """ + if data.add_generation_prompt is not None: + return data.add_generation_prompt + if data.messages and data.messages[-1].role == "assistant": + return False + return True + + +def preprocess_stream_chunk(data: dict, inject_thinking: bool, is_first_chunk: bool): + """Prepends '' to the first chunk of a stream if needed.""" + if inject_thinking and is_first_chunk: + updated = data.copy() + updated["text"] = "" + updated.get("text", "") + return updated + return data + + def _create_response( request_id: str, generations: List[dict], model_name: Optional[str] ): @@ -54,10 +77,10 @@ def _create_response( logprobs = unwrap(generation.get("logprobs"), []) collected_token_probs = [] - for index, token in enumerate(token_probs.keys()): + for i, token in enumerate(token_probs.keys()): top_logprobs = [ - ChatCompletionLogprob(token=token, logprob=logprob) - for token, logprob in logprobs[index].items() + ChatCompletionLogprob(token=t, logprob=lp) + for t, lp in logprobs[i].items() ] collected_token_probs.append( @@ -258,7 +281,7 @@ async def apply_chat_template(data: ChatCompletionRequest): try: data.template_vars.update( { - "add_generation_prompt": data.add_generation_prompt, + "add_generation_prompt": should_add_generation_prompt(data), "tools": tools, "functions": data.functions, } @@ -324,6 +347,8 @@ async def stream_generate_chat_completion( try: logger.info(f"Received chat completion streaming request {request.state.id}") + inject_thinking = "" in prompt[-11:] and should_add_generation_prompt(data) + for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) request_id = _parse_gen_request_id(data.n, request.state.id, idx) @@ -342,8 +367,8 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) - # Text accumulation for tool calls - current_generation_text = "" + # Text accumulation for tool calls(?) + seen_first_chunk_indices = set() # Consumer loop while True: @@ -353,30 +378,36 @@ async def stream_generate_chat_completion( generation = await gen_queue.get() # Handle options if a tool model is present - if tool_start: - if "stop_str" in generation: - generations = await generate_tool_calls( - prompt, - embeddings, - data, - [generation], - request, - ) - - # Only one generation present in this case - generation = generations[0] - elif "text" in generation: - current_generation_text += generation["text"] + if tool_start and "stop_str" in generation: + generations = await generate_tool_calls( + prompt, + embeddings, + data, + [generation], + request, + ) + # Only one generation present in this case + generation = generations[0] # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation + index = generation.get("index", 0) + is_first_for_this_index = index not in seen_first_chunk_indices + + processed_generation = preprocess_stream_chunk( + generation, inject_thinking, is_first_for_this_index + ) + response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, processed_generation, model_path.name ) yield response.model_dump_json() + if is_first_for_this_index: + seen_first_chunk_indices.add(index) + # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): # Send a usage chunk @@ -442,6 +473,12 @@ async def generate_chat_completion( prompt, embeddings, data, generations, request ) + # Prepend "" after generation and tool calls are complete. + if "" in prompt[-11:] and should_add_generation_prompt(data): + for gen in generations: + if "text" in gen: + gen["text"] = "" + gen["text"] + response = _create_response(request.state.id, generations, model_path.name) logger.info(f"Finished chat completion request {request.state.id}")