From 6450f719e8b1b852f57d2ee3d0bbbc0342c105bc Mon Sep 17 00:00:00 2001 From: danielp Date: Mon, 22 Dec 2025 10:44:24 -0300 Subject: [PATCH] feat(gemini): add thought signature preservation for thinking models --- src/strands/event_loop/streaming.py | 4 ++ src/strands/models/gemini.py | 81 +++++++++++++++++-------- src/strands/types/content.py | 4 +- src/strands/types/tools.py | 2 + tests/strands/models/test_gemini.py | 91 ++++++++++++++++++++--------- 5 files changed, 128 insertions(+), 54 deletions(-) diff --git a/src/strands/event_loop/streaming.py b/src/strands/event_loop/streaming.py index 954633807..b157f740e 100644 --- a/src/strands/event_loop/streaming.py +++ b/src/strands/event_loop/streaming.py @@ -186,6 +186,8 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]: current_tool_use["toolUseId"] = tool_use_data["toolUseId"] current_tool_use["name"] = tool_use_data["name"] current_tool_use["input"] = "" + if "reasoningSignature" in tool_use_data: + current_tool_use["reasoningSignature"] = tool_use_data["reasoningSignature"] return current_tool_use @@ -286,6 +288,8 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]: name=tool_use_name, input=current_tool_use["input"], ) + if "reasoningSignature" in current_tool_use: + tool_use["reasoningSignature"] = current_tool_use["reasoningSignature"] content.append({"toolUse": tool_use}) state["current_tool_use"] = {} diff --git a/src/strands/models/gemini.py b/src/strands/models/gemini.py index 6a6535999..c5074799b 100644 --- a/src/strands/models/gemini.py +++ b/src/strands/models/gemini.py @@ -3,6 +3,7 @@ - Docs: https://ai.google.dev/api """ +import base64 import json import logging import mimetypes @@ -135,7 +136,9 @@ def _get_client(self) -> genai.Client: return genai.Client(**self.client_args) def _format_request_content_part( - self, content: ContentBlock, tool_use_id_to_name: dict[str, str] + self, + content: ContentBlock, + tool_use_id_to_name: dict[str, str], ) -> genai.types.Part: """Format content block into a Gemini part instance. @@ -173,7 +176,7 @@ def _format_request_content_part( return genai.types.Part( text=content["reasoningContent"]["reasoningText"]["text"], thought=True, - thought_signature=thought_signature.encode("utf-8") if thought_signature else None, + thought_signature=base64.b64decode(thought_signature) if thought_signature else None, ) if "text" in content: @@ -202,14 +205,18 @@ def _format_request_content_part( ) if "toolUse" in content: - tool_use_id_to_name[content["toolUse"]["toolUseId"]] = content["toolUse"]["name"] + tool_use_id = content["toolUse"]["toolUseId"] + tool_use_id_to_name[tool_use_id] = content["toolUse"]["name"] + + reasoning_signature = content["toolUse"].get("reasoningSignature") return genai.types.Part( function_call=genai.types.FunctionCall( args=content["toolUse"]["input"], - id=content["toolUse"]["toolUseId"], + id=tool_use_id, name=content["toolUse"]["name"], ), + thought_signature=base64.b64decode(reasoning_signature) if reasoning_signature else None, ) raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @@ -259,20 +266,27 @@ def _format_request_tools(self, tool_specs: list[ToolSpec] | None) -> list[genai Return: Gemini tool list. """ - tools = [ - genai.types.Tool( - function_declarations=[ - genai.types.FunctionDeclaration( - description=tool_spec["description"], - name=tool_spec["name"], - parameters_json_schema=tool_spec["inputSchema"]["json"], - ) - for tool_spec in tool_specs or [] - ], - ), - ] + tools = [] + + # Only add function declarations tool if there are tool specs + if tool_specs: + tools.append( + genai.types.Tool( + function_declarations=[ + genai.types.FunctionDeclaration( + description=tool_spec["description"], + name=tool_spec["name"], + parameters_json_schema=tool_spec["inputSchema"]["json"], + ) + for tool_spec in tool_specs + ], + ), + ) + + # Add any Gemini-specific tools if self.config.get("gemini_tools"): tools.extend(self.config["gemini_tools"]) + return tools def _format_request_config( @@ -293,11 +307,19 @@ def _format_request_config( Returns: Gemini request config. """ - return genai.types.GenerateContentConfig( - system_instruction=system_prompt, - tools=self._format_request_tools(tool_specs), + tools = self._format_request_tools(tool_specs) + + # Build config kwargs, only including tools if there are any + config_kwargs = { + "system_instruction": system_prompt, **(params or {}), - ) + } + + # Only include tools parameter if there are actual tools to pass + if tools: + config_kwargs["tools"] = tools + + return genai.types.GenerateContentConfig(**config_kwargs) def _format_request( self, @@ -349,13 +371,18 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: # Use Gemini's provided ID or generate one if missing tool_use_id = function_call.id or f"tooluse_{secrets.token_urlsafe(16)}" + tool_use_start: dict[str, Any] = { + "name": function_call.name, + "toolUseId": tool_use_id, + } + if event["data"].thought_signature: + tool_use_start["reasoningSignature"] = base64.b64encode( + event["data"].thought_signature + ).decode("ascii") return { "contentBlockStart": { "start": { - "toolUse": { - "name": function_call.name, - "toolUseId": tool_use_id, - }, + "toolUse": tool_use_start, # type: ignore[typeddict-item] }, }, } @@ -379,7 +406,11 @@ def _format_chunk(self, event: dict[str, Any]) -> StreamEvent: "reasoningContent": { "text": event["data"].text, **( - {"signature": event["data"].thought_signature.decode("utf-8")} + { + "signature": base64.b64encode(event["data"].thought_signature).decode( + "ascii" + ) + } if event["data"].thought_signature else {} ), diff --git a/src/strands/types/content.py b/src/strands/types/content.py index d75dbb87f..2b0714bee 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -8,7 +8,7 @@ from typing import Literal -from typing_extensions import TypedDict +from typing_extensions import NotRequired, TypedDict from .citations import CitationsContentBlock from .media import DocumentContent, ImageContent, VideoContent @@ -129,10 +129,12 @@ class ContentBlockStartToolUse(TypedDict): Attributes: name: The name of the tool that the model is requesting to use. toolUseId: The ID for the tool request. + reasoningSignature: Token that ties the model's reasoning to this tool call. """ name: str toolUseId: str + reasoningSignature: NotRequired[str] class ContentBlockStart(TypedDict, total=False): diff --git a/src/strands/types/tools.py b/src/strands/types/tools.py index 6fc0d703c..088c83bdb 100644 --- a/src/strands/types/tools.py +++ b/src/strands/types/tools.py @@ -58,11 +58,13 @@ class ToolUse(TypedDict): Can be any JSON-serializable type. name: The name of the tool to invoke. toolUseId: A unique identifier for this specific tool use request. + reasoningSignature: Token that ties the model's reasoning to this tool call. """ input: Any name: str toolUseId: str + reasoningSignature: NotRequired[str] class ToolResultContent(TypedDict, total=False): diff --git a/tests/strands/models/test_gemini.py b/tests/strands/models/test_gemini.py index d62c5a7c8..a9112332a 100644 --- a/tests/strands/models/test_gemini.py +++ b/tests/strands/models/test_gemini.py @@ -84,7 +84,7 @@ async def test_stream_request_default(gemini_client, model, messages, model_id): await anext(model.stream(messages)) exp_request = { - "config": {"tools": [{"function_declarations": []}]}, + "config": {}, "contents": [{"parts": [{"text": "test"}], "role": "user"}], "model": model_id, } @@ -99,7 +99,6 @@ async def test_stream_request_with_params(gemini_client, model, messages, model_ exp_request = { "config": { - "tools": [{"function_declarations": []}], "temperature": 1, }, "contents": [{"parts": [{"text": "test"}], "role": "user"}], @@ -113,7 +112,7 @@ async def test_stream_request_with_system_prompt(gemini_client, model, messages, await anext(model.stream(messages, system_prompt=system_prompt)) exp_request = { - "config": {"system_instruction": system_prompt, "tools": [{"function_declarations": []}]}, + "config": {"system_instruction": system_prompt}, "contents": [{"parts": [{"text": "test"}], "role": "user"}], "model": model_id, } @@ -146,9 +145,7 @@ async def test_stream_request_with_document(content, formatted_part, gemini_clie await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [{"parts": [formatted_part], "role": "user"}], "model": model_id, } @@ -173,9 +170,7 @@ async def test_stream_request_with_image(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -203,7 +198,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id): { "reasoningContent": { "reasoningText": { - "signature": "abc", + "signature": "YWJj", # base64 of "abc" "text": "reasoning_text", }, }, @@ -214,9 +209,7 @@ async def test_stream_request_with_reasoning(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -260,6 +253,7 @@ async def test_stream_request_with_tool_spec(gemini_client, model, model_id, too @pytest.mark.asyncio async def test_stream_request_with_tool_use(gemini_client, model, model_id): + """Test toolUse with reasoningSignature is sent as function_call with thought_signature.""" messages = [ { "role": "assistant", @@ -269,6 +263,7 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id): "toolUseId": "c1", "name": "calculator", "input": {"expression": "2+2"}, + "reasoningSignature": "YWJj", # base64 of "abc" }, }, ], @@ -277,9 +272,49 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], + "config": {}, + "contents": [ + { + "parts": [ + { + "function_call": { + "args": {"expression": "2+2"}, + "id": "c1", + "name": "calculator", + }, + "thought_signature": "YWJj", + }, + ], + "role": "model", + }, + ], + "model": model_id, + } + gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request) + + +@pytest.mark.asyncio +async def test_stream_request_with_tool_use_no_reasoning_signature(gemini_client, model, model_id): + """Test toolUse without reasoningSignature is sent as function_call without thought_signature.""" + messages = [ + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "c1", + "name": "calculator", + "input": {"expression": "2+2"}, + # No reasoningSignature + }, + }, + ], }, + ] + await anext(model.stream(messages)) + + exp_request = { + "config": {}, "contents": [ { "parts": [ @@ -289,6 +324,7 @@ async def test_stream_request_with_tool_use(gemini_client, model, model_id): "id": "c1", "name": "calculator", }, + # thought_signature omitted when None (Gemini SDK behavior) }, ], "role": "model", @@ -327,9 +363,7 @@ async def test_stream_request_with_tool_results(gemini_client, model, model_id): await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -391,9 +425,7 @@ async def test_stream_request_with_tool_results_preserving_name(gemini_client, m await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [ { "parts": [ @@ -403,6 +435,7 @@ async def test_stream_request_with_tool_results_preserving_name(gemini_client, m "id": "t1", "name": "tool_1", }, + # thought_signature omitted when None (Gemini SDK behavior) }, ], "role": "model", @@ -436,9 +469,7 @@ async def test_stream_request_with_empty_content(gemini_client, model, model_id) await anext(model.stream(messages)) exp_request = { - "config": { - "tools": [{"function_declarations": []}], - }, + "config": {}, "contents": [{"parts": [], "role": "user"}], "model": model_id, } @@ -560,10 +591,11 @@ async def test_stream_response_reasoning(gemini_client, model, messages, agenera ) tru_chunks = await alist(model.stream(messages)) + # signature is base64 encoded: b"abc" -> "YWJj" exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "abc", "text": "test reason"}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "YWJj", "text": "test reason"}}}}, {"contentBlockStop": {}}, {"messageStop": {"stopReason": "end_turn"}}, {"metadata": {"usage": {"inputTokens": 1, "outputTokens": 2, "totalTokens": 3}, "metrics": {"latencyMs": 0}}}, @@ -622,7 +654,11 @@ async def test_stream_response_reasoning_and_text(gemini_client, model, messages exp_chunks = [ {"messageStart": {"role": "assistant"}}, {"contentBlockStart": {"start": {}}}, - {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "sig1", "text": "thinking about math"}}}}, + { + "contentBlockDelta": { + "delta": {"reasoningContent": {"signature": "c2lnMQ==", "text": "thinking about math"}} + } + }, {"contentBlockStop": {}}, {"contentBlockStart": {"start": {}}}, {"contentBlockDelta": {"delta": {"text": "2 + 2 = 4"}}}, @@ -754,7 +790,6 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath exp_request = { "config": { - "tools": [{"function_declarations": []}], "response_mime_type": "application/json", "response_schema": weather_output.model_json_schema(), }, @@ -806,10 +841,10 @@ async def test_stream_request_with_gemini_tools(gemini_client, messages, model_i await anext(model.stream(messages)) + # When only gemini_tools are provided (no tool_specs), only gemini_tools are included exp_request = { "config": { "tools": [ - {"function_declarations": []}, {"google_search": {}}, ] },