diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 05c3af191..a76017e75 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -118,6 +118,7 @@ def __init__( state: AgentState | dict | None = None, hooks: list[HookProvider] | None = None, session_manager: SessionManager | None = None, + structured_output_prompt: str | None = None, tool_executor: ToolExecutor | None = None, retry_strategy: ModelRetryStrategy | None = None, ): @@ -168,6 +169,11 @@ def __init__( Defaults to None. session_manager: Manager for handling agent sessions including conversation history and state. If provided, enables session-based persistence and state management. + structured_output_prompt: Custom prompt message used when forcing structured output. + When using structured output, if the model doesn't automatically use the output tool, + the agent sends a follow-up message to request structured formatting. This parameter + allows customizing that message. + Defaults to "You must format the previous response as structured output." tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.). retry_strategy: Strategy for retrying model calls on throttling or other transient errors. Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s. @@ -181,6 +187,7 @@ def __init__( # initializing self._system_prompt for backwards compatibility self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) self._default_structured_output_model = structured_output_model + self._structured_output_prompt = structured_output_prompt self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME self.description = description @@ -338,6 +345,7 @@ def __call__( *, invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -356,6 +364,7 @@ def __call__( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -369,7 +378,11 @@ def __call__( """ return run_async( lambda: self.invoke_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + prompt, + invocation_state=invocation_state, + structured_output_model=structured_output_model, + structured_output_prompt=structured_output_prompt, + **kwargs, ) ) @@ -379,6 +392,7 @@ async def invoke_async( *, invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AgentResult: """Process a natural language prompt through the agent's event loop. @@ -397,6 +411,7 @@ async def invoke_async( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass through the event loop.[Deprecating] Returns: @@ -408,7 +423,11 @@ async def invoke_async( - state: The final state of the event loop """ events = self.stream_async( - prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs + prompt, + invocation_state=invocation_state, + structured_output_model=structured_output_model, + structured_output_prompt=structured_output_prompt, + **kwargs, ) async for event in events: _ = event @@ -542,6 +561,7 @@ async def stream_async( *, invocation_state: dict[str, Any] | None = None, structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, **kwargs: Any, ) -> AsyncIterator[Any]: """Process a natural language prompt and yield events as an async iterator. @@ -560,6 +580,7 @@ async def stream_async( - None: Use existing conversation history invocation_state: Additional parameters to pass through the event loop. structured_output_model: Pydantic model type(s) for structured output (overrides agent default). + structured_output_prompt: Custom prompt for forcing structured output (overrides agent default). **kwargs: Additional parameters to pass to the event loop.[Deprecating] Yields: @@ -617,7 +638,7 @@ async def stream_async( with trace_api.use_span(self.trace_span): try: - events = self._run_loop(messages, merged_state, structured_output_model) + events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt) async for event in events: event.prepare(invocation_state=merged_state) @@ -645,6 +666,7 @@ async def _run_loop( messages: Messages, invocation_state: dict[str, Any], structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, ) -> AsyncGenerator[TypedEvent, None]: """Execute the agent's event loop with the given message and parameters. @@ -652,6 +674,7 @@ async def _run_loop( messages: The input messages to add to the conversation. invocation_state: Additional parameters to pass to the event loop. structured_output_model: Optional Pydantic model type for structured output. + structured_output_prompt: Optional custom prompt for forcing structured output. Yields: Events from the event loop cycle. @@ -668,7 +691,8 @@ async def _run_loop( await self._append_messages(*messages) structured_output_context = StructuredOutputContext( - structured_output_model or self._default_structured_output_model + structured_output_model or self._default_structured_output_model, + structured_output_prompt=structured_output_prompt or self._structured_output_prompt, ) # Execute the event loop cycle with retry logic for context limits diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 9fe645f80..3113ddb79 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -220,7 +220,7 @@ async def event_loop_cycle( structured_output_context.set_forced_mode() logger.debug("Forcing structured output tool") await agent._append_messages( - {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} + {"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]} ) events = recurse_event_loop( diff --git a/src/strands/tools/structured_output/__init__.py b/src/strands/tools/structured_output/__init__.py index 777d5d846..a3a12d000 100644 --- a/src/strands/tools/structured_output/__init__.py +++ b/src/strands/tools/structured_output/__init__.py @@ -1,5 +1,6 @@ """Structured output tools for the Strands Agents framework.""" +from ._structured_output_context import DEFAULT_STRUCTURED_OUTPUT_PROMPT from .structured_output_utils import convert_pydantic_to_tool_spec -__all__ = ["convert_pydantic_to_tool_spec"] +__all__ = ["convert_pydantic_to_tool_spec", "DEFAULT_STRUCTURED_OUTPUT_PROMPT"] diff --git a/src/strands/tools/structured_output/_structured_output_context.py b/src/strands/tools/structured_output/_structured_output_context.py index 2f8dd8ca0..9a5190d9d 100644 --- a/src/strands/tools/structured_output/_structured_output_context.py +++ b/src/strands/tools/structured_output/_structured_output_context.py @@ -13,15 +13,23 @@ logger = logging.getLogger(__name__) +DEFAULT_STRUCTURED_OUTPUT_PROMPT = "You must format the previous response as structured output." + class StructuredOutputContext: """Per-invocation context for structured output execution.""" - def __init__(self, structured_output_model: type[BaseModel] | None = None): + def __init__( + self, + structured_output_model: type[BaseModel] | None = None, + structured_output_prompt: str | None = None, + ): """Initialize a new structured output context. Args: structured_output_model: Optional Pydantic model type for structured output. + structured_output_prompt: Optional custom prompt message to use when forcing structured output. + Defaults to "You must format the previous response as structured output." """ self.results: dict[str, BaseModel] = {} self.structured_output_model: type[BaseModel] | None = structured_output_model @@ -31,6 +39,7 @@ def __init__(self, structured_output_model: type[BaseModel] | None = None): self.tool_choice: ToolChoice | None = None self.stop_loop: bool = False self.expected_tool_name: str | None = None + self.structured_output_prompt: str = structured_output_prompt or DEFAULT_STRUCTURED_OUTPUT_PROMPT if structured_output_model: self.structured_output_tool = StructuredOutputTool(structured_output_model) diff --git a/tests/strands/agent/test_agent_structured_output.py b/tests/strands/agent/test_agent_structured_output.py index 7341c714e..6ab112048 100644 --- a/tests/strands/agent/test_agent_structured_output.py +++ b/tests/strands/agent/test_agent_structured_output.py @@ -411,3 +411,160 @@ async def mock_product_cycle(*args, **kwargs): mock_event_loop.side_effect = mock_product_cycle result2 = agent("Get product", structured_output_model=product_model) assert result2.structured_output is pm + + +class TestAgentStructuredOutputPrompt: + """Test Agent structured_output_prompt functionality.""" + + def test_agent_init_with_structured_output_prompt(self, user_model): + """Test that Agent can be initialized with a structured_output_prompt.""" + custom_prompt = "Please format your response using the schema." + agent = Agent(structured_output_model=user_model, structured_output_prompt=custom_prompt) + + assert agent._structured_output_prompt == custom_prompt + + def test_agent_init_without_structured_output_prompt(self): + """Test that Agent can be initialized without structured_output_prompt.""" + agent = Agent() + + assert agent._structured_output_prompt is None + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_default_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.__call__ uses default structured_output_prompt when not specified.""" + custom_prompt = "Use the output schema to format your response." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context is not None + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default structured_output_prompt + agent = Agent( + model=mock_model, + structured_output_model=user_model, + structured_output_prompt=custom_prompt, + ) + agent("Get user info") + + mock_event_loop.assert_called_once() + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_override_default_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test that invocation-level structured_output_prompt overrides default.""" + default_prompt = "Default prompt for structured output." + override_prompt = "Override prompt for this specific call." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + # Should use override_prompt, not the default + assert structured_output_context.structured_output_prompt == override_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent with default prompt, but override at call time + agent = Agent( + model=mock_model, + structured_output_model=user_model, + structured_output_prompt=default_prompt, + ) + agent("Get user info", structured_output_prompt=override_prompt) + + mock_event_loop.assert_called_once() + + @patch("strands.agent.agent.event_loop_cycle") + def test_agent_call_with_invocation_prompt_no_default(self, mock_event_loop, user_model, mock_model, mock_metrics): + """Test that invocation-level prompt works when no default is set.""" + invocation_prompt = "Format as structured output now." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == invocation_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + # Create agent without default prompt + agent = Agent(model=mock_model, structured_output_model=user_model) + agent("Get user info", structured_output_prompt=invocation_prompt) + + mock_event_loop.assert_called_once() + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_invoke_async_with_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.invoke_async with structured_output_prompt.""" + custom_prompt = "Async prompt for structured output." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model, structured_output_model=user_model) + await agent.invoke_async("Get user", structured_output_prompt=custom_prompt) + + mock_event_loop.assert_called_once() + + @pytest.mark.asyncio + @patch("strands.agent.agent.event_loop_cycle") + async def test_agent_stream_async_with_structured_output_prompt( + self, mock_event_loop, user_model, mock_model, mock_metrics + ): + """Test Agent.stream_async with structured_output_prompt.""" + custom_prompt = "Stream async prompt for structured output." + + async def mock_cycle(*args, **kwargs): + structured_output_context = kwargs.get("structured_output_context") + assert structured_output_context.structured_output_prompt == custom_prompt + + yield EventLoopStopEvent( + stop_reason="end_turn", + message={"role": "assistant", "content": [{"text": "Response"}]}, + metrics=mock_metrics, + request_state={}, + ) + + mock_event_loop.side_effect = mock_cycle + + agent = Agent(model=mock_model, structured_output_model=user_model) + async for _ in agent.stream_async("Get user", structured_output_prompt=custom_prompt): + pass + + mock_event_loop.assert_called_once() diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 23b7f3433..6f75d6083 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -8,7 +8,10 @@ from strands.event_loop.event_loop import event_loop_cycle, recurse_event_loop from strands.telemetry.metrics import EventLoopMetrics from strands.tools.registry import ToolRegistry -from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output._structured_output_context import ( + DEFAULT_STRUCTURED_OUTPUT_PROMPT, + StructuredOutputContext, +) from strands.types._events import EventLoopStopEvent, StructuredOutputEvent @@ -190,6 +193,8 @@ async def test_event_loop_forces_structured_output_on_end_turn( mock_agent._append_messages.assert_called_once() args = mock_agent._append_messages.call_args[0][0] assert args["role"] == "user" + # Should use the default prompt + assert args["content"][0]["text"] == DEFAULT_STRUCTURED_OUTPUT_PROMPT # Should have called recurse_event_loop with the context mock_recurse.assert_called_once() @@ -197,6 +202,55 @@ async def test_event_loop_forces_structured_output_on_end_turn( assert call_kwargs["structured_output_context"] == structured_output_context +@pytest.mark.asyncio +async def test_event_loop_forces_structured_output_with_custom_prompt(mock_agent, agenerator, alist): + """Test that event loop uses custom prompt when forcing structured output.""" + custom_prompt = "Please format your response as structured data using the output schema." + structured_output_context = StructuredOutputContext( + structured_output_model=UserModel, + structured_output_prompt=custom_prompt, + ) + + # First call returns end_turn without using structured output tool + mock_agent.model.stream.side_effect = [ + agenerator( + [ + {"contentBlockDelta": {"delta": {"text": "Here is the user info"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + ] + + # Mock recurse_event_loop to return final result + with patch("strands.event_loop.event_loop.recurse_event_loop") as mock_recurse: + mock_stop_event = Mock() + mock_stop_event.stop = ( + "end_turn", + {"role": "assistant", "content": [{"text": "Done"}]}, + mock_agent.event_loop_metrics, + {}, + None, + UserModel(name="John", age=30, email="john@example.com"), + ) + mock_stop_event.__getitem__ = lambda self, key: {"stop": self.stop}[key] + + mock_recurse.return_value = agenerator([mock_stop_event]) + + stream = event_loop_cycle( + agent=mock_agent, + invocation_state={}, + structured_output_context=structured_output_context, + ) + await alist(stream) + + # Should have appended a message with the custom prompt + mock_agent._append_messages.assert_called_once() + args = mock_agent._append_messages.call_args[0][0] + assert args["role"] == "user" + assert args["content"][0]["text"] == custom_prompt + + @pytest.mark.asyncio async def test_structured_output_tool_execution_extracts_result( mock_agent, structured_output_context, agenerator, alist diff --git a/tests/strands/tools/structured_output/test_structured_output_context.py b/tests/strands/tools/structured_output/test_structured_output_context.py index 0f1c7ffff..6d75852d1 100644 --- a/tests/strands/tools/structured_output/test_structured_output_context.py +++ b/tests/strands/tools/structured_output/test_structured_output_context.py @@ -2,7 +2,10 @@ from pydantic import BaseModel, Field -from strands.tools.structured_output._structured_output_context import StructuredOutputContext +from strands.tools.structured_output._structured_output_context import ( + DEFAULT_STRUCTURED_OUTPUT_PROMPT, + StructuredOutputContext, +) from strands.tools.structured_output.structured_output_tool import StructuredOutputTool @@ -35,6 +38,7 @@ def test_initialization_with_structured_output_model(self): assert context.forced_mode is False assert context.tool_choice is None assert context.stop_loop is False + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT def test_initialization_without_structured_output_model(self): """Test initialization without a structured output model.""" @@ -47,6 +51,31 @@ def test_initialization_without_structured_output_model(self): assert context.forced_mode is False assert context.tool_choice is None assert context.stop_loop is False + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT + + def test_initialization_with_custom_prompt(self): + """Test initialization with a custom structured output prompt.""" + custom_prompt = "Please format your response using the output schema." + context = StructuredOutputContext( + structured_output_model=SampleModel, + structured_output_prompt=custom_prompt, + ) + + assert context.structured_output_model == SampleModel + assert context.structured_output_prompt == custom_prompt + + def test_initialization_with_none_prompt_uses_default(self): + """Test that None prompt falls back to default.""" + context = StructuredOutputContext( + structured_output_model=SampleModel, + structured_output_prompt=None, + ) + + assert context.structured_output_prompt == DEFAULT_STRUCTURED_OUTPUT_PROMPT + + def test_default_prompt_constant_value(self): + """Test the default prompt constant has expected value.""" + assert DEFAULT_STRUCTURED_OUTPUT_PROMPT == "You must format the previous response as structured output." def test_is_enabled_property(self): """Test the is_enabled property."""