Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
state: AgentState | dict | None = None,
hooks: list[HookProvider] | None = None,
session_manager: SessionManager | None = None,
structured_output_prompt: str | None = None,
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | None = None,
):
Expand Down Expand Up @@ -168,6 +169,11 @@ def __init__(
Defaults to None.
session_manager: Manager for handling agent sessions including conversation history and state.
If provided, enables session-based persistence and state management.
structured_output_prompt: Custom prompt message used when forcing structured output.
When using structured output, if the model doesn't automatically use the output tool,
the agent sends a follow-up message to request structured formatting. This parameter
allows customizing that message.
Defaults to "You must format the previous response as structured output."
tool_executor: Definition of tool execution strategy (e.g., sequential, concurrent, etc.).
retry_strategy: Strategy for retrying model calls on throttling or other transient errors.
Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s.
Expand All @@ -181,6 +187,7 @@ def __init__(
# initializing self._system_prompt for backwards compatibility
self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt)
self._default_structured_output_model = structured_output_model
self._structured_output_prompt = structured_output_prompt
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
self.description = description
Expand Down Expand Up @@ -338,6 +345,7 @@ def __call__(
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
**kwargs: Any,
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand All @@ -356,6 +364,7 @@ def __call__(
- None: Use existing conversation history
invocation_state: Additional parameters to pass through the event loop.
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
structured_output_prompt: Custom prompt for forcing structured output (overrides agent default).
**kwargs: Additional parameters to pass through the event loop.[Deprecating]

Returns:
Expand All @@ -369,7 +378,11 @@ def __call__(
"""
return run_async(
lambda: self.invoke_async(
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
prompt,
invocation_state=invocation_state,
structured_output_model=structured_output_model,
structured_output_prompt=structured_output_prompt,
**kwargs,
)
)

Expand All @@ -379,6 +392,7 @@ async def invoke_async(
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
**kwargs: Any,
) -> AgentResult:
"""Process a natural language prompt through the agent's event loop.
Expand All @@ -397,6 +411,7 @@ async def invoke_async(
- None: Use existing conversation history
invocation_state: Additional parameters to pass through the event loop.
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
structured_output_prompt: Custom prompt for forcing structured output (overrides agent default).
**kwargs: Additional parameters to pass through the event loop.[Deprecating]

Returns:
Expand All @@ -408,7 +423,11 @@ async def invoke_async(
- state: The final state of the event loop
"""
events = self.stream_async(
prompt, invocation_state=invocation_state, structured_output_model=structured_output_model, **kwargs
prompt,
invocation_state=invocation_state,
structured_output_model=structured_output_model,
structured_output_prompt=structured_output_prompt,
**kwargs,
)
async for event in events:
_ = event
Expand Down Expand Up @@ -542,6 +561,7 @@ async def stream_async(
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Process a natural language prompt and yield events as an async iterator.
Expand All @@ -560,6 +580,7 @@ async def stream_async(
- None: Use existing conversation history
invocation_state: Additional parameters to pass through the event loop.
structured_output_model: Pydantic model type(s) for structured output (overrides agent default).
structured_output_prompt: Custom prompt for forcing structured output (overrides agent default).
**kwargs: Additional parameters to pass to the event loop.[Deprecating]

Yields:
Expand Down Expand Up @@ -617,7 +638,7 @@ async def stream_async(

with trace_api.use_span(self.trace_span):
try:
events = self._run_loop(messages, merged_state, structured_output_model)
events = self._run_loop(messages, merged_state, structured_output_model, structured_output_prompt)

async for event in events:
event.prepare(invocation_state=merged_state)
Expand Down Expand Up @@ -645,13 +666,15 @@ async def _run_loop(
messages: Messages,
invocation_state: dict[str, Any],
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
) -> AsyncGenerator[TypedEvent, None]:
"""Execute the agent's event loop with the given message and parameters.

Args:
messages: The input messages to add to the conversation.
invocation_state: Additional parameters to pass to the event loop.
structured_output_model: Optional Pydantic model type for structured output.
structured_output_prompt: Optional custom prompt for forcing structured output.

Yields:
Events from the event loop cycle.
Expand All @@ -668,7 +691,8 @@ async def _run_loop(
await self._append_messages(*messages)

structured_output_context = StructuredOutputContext(
structured_output_model or self._default_structured_output_model
structured_output_model or self._default_structured_output_model,
structured_output_prompt=structured_output_prompt or self._structured_output_prompt,
)

# Execute the event loop cycle with retry logic for context limits
Expand Down
2 changes: 1 addition & 1 deletion src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ async def event_loop_cycle(
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
await agent._append_messages(
{"role": "user", "content": [{"text": "You must format the previous response as structured output."}]}
{"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]}
)

events = recurse_event_loop(
Expand Down
3 changes: 2 additions & 1 deletion src/strands/tools/structured_output/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Structured output tools for the Strands Agents framework."""

from ._structured_output_context import DEFAULT_STRUCTURED_OUTPUT_PROMPT
from .structured_output_utils import convert_pydantic_to_tool_spec

__all__ = ["convert_pydantic_to_tool_spec"]
__all__ = ["convert_pydantic_to_tool_spec", "DEFAULT_STRUCTURED_OUTPUT_PROMPT"]
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,23 @@

logger = logging.getLogger(__name__)

DEFAULT_STRUCTURED_OUTPUT_PROMPT = "You must format the previous response as structured output."

Comment on lines +16 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the user override this instead of us having to plumb this all the way through from the agent initialization?


class StructuredOutputContext:
"""Per-invocation context for structured output execution."""

def __init__(self, structured_output_model: type[BaseModel] | None = None):
def __init__(
self,
structured_output_model: type[BaseModel] | None = None,
structured_output_prompt: str | None = None,
):
"""Initialize a new structured output context.

Args:
structured_output_model: Optional Pydantic model type for structured output.
structured_output_prompt: Optional custom prompt message to use when forcing structured output.
Defaults to "You must format the previous response as structured output."
"""
self.results: dict[str, BaseModel] = {}
self.structured_output_model: type[BaseModel] | None = structured_output_model
Expand All @@ -31,6 +39,7 @@ def __init__(self, structured_output_model: type[BaseModel] | None = None):
self.tool_choice: ToolChoice | None = None
self.stop_loop: bool = False
self.expected_tool_name: str | None = None
self.structured_output_prompt: str = structured_output_prompt or DEFAULT_STRUCTURED_OUTPUT_PROMPT

if structured_output_model:
self.structured_output_tool = StructuredOutputTool(structured_output_model)
Expand Down
157 changes: 157 additions & 0 deletions tests/strands/agent/test_agent_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,160 @@ async def mock_product_cycle(*args, **kwargs):
mock_event_loop.side_effect = mock_product_cycle
result2 = agent("Get product", structured_output_model=product_model)
assert result2.structured_output is pm


class TestAgentStructuredOutputPrompt:
"""Test Agent structured_output_prompt functionality."""

def test_agent_init_with_structured_output_prompt(self, user_model):
"""Test that Agent can be initialized with a structured_output_prompt."""
custom_prompt = "Please format your response using the schema."
agent = Agent(structured_output_model=user_model, structured_output_prompt=custom_prompt)

assert agent._structured_output_prompt == custom_prompt

def test_agent_init_without_structured_output_prompt(self):
"""Test that Agent can be initialized without structured_output_prompt."""
agent = Agent()

assert agent._structured_output_prompt is None

@patch("strands.agent.agent.event_loop_cycle")
def test_agent_call_with_default_structured_output_prompt(
self, mock_event_loop, user_model, mock_model, mock_metrics
):
"""Test Agent.__call__ uses default structured_output_prompt when not specified."""
custom_prompt = "Use the output schema to format your response."

async def mock_cycle(*args, **kwargs):
structured_output_context = kwargs.get("structured_output_context")
assert structured_output_context is not None
assert structured_output_context.structured_output_prompt == custom_prompt

yield EventLoopStopEvent(
stop_reason="end_turn",
message={"role": "assistant", "content": [{"text": "Response"}]},
metrics=mock_metrics,
request_state={},
)

mock_event_loop.side_effect = mock_cycle

# Create agent with default structured_output_prompt
agent = Agent(
model=mock_model,
structured_output_model=user_model,
structured_output_prompt=custom_prompt,
)
agent("Get user info")

mock_event_loop.assert_called_once()

@patch("strands.agent.agent.event_loop_cycle")
def test_agent_call_override_default_structured_output_prompt(
self, mock_event_loop, user_model, mock_model, mock_metrics
):
"""Test that invocation-level structured_output_prompt overrides default."""
default_prompt = "Default prompt for structured output."
override_prompt = "Override prompt for this specific call."

async def mock_cycle(*args, **kwargs):
structured_output_context = kwargs.get("structured_output_context")
# Should use override_prompt, not the default
assert structured_output_context.structured_output_prompt == override_prompt

yield EventLoopStopEvent(
stop_reason="end_turn",
message={"role": "assistant", "content": [{"text": "Response"}]},
metrics=mock_metrics,
request_state={},
)

mock_event_loop.side_effect = mock_cycle

# Create agent with default prompt, but override at call time
agent = Agent(
model=mock_model,
structured_output_model=user_model,
structured_output_prompt=default_prompt,
)
agent("Get user info", structured_output_prompt=override_prompt)

mock_event_loop.assert_called_once()

@patch("strands.agent.agent.event_loop_cycle")
def test_agent_call_with_invocation_prompt_no_default(self, mock_event_loop, user_model, mock_model, mock_metrics):
"""Test that invocation-level prompt works when no default is set."""
invocation_prompt = "Format as structured output now."

async def mock_cycle(*args, **kwargs):
structured_output_context = kwargs.get("structured_output_context")
assert structured_output_context.structured_output_prompt == invocation_prompt

yield EventLoopStopEvent(
stop_reason="end_turn",
message={"role": "assistant", "content": [{"text": "Response"}]},
metrics=mock_metrics,
request_state={},
)

mock_event_loop.side_effect = mock_cycle

# Create agent without default prompt
agent = Agent(model=mock_model, structured_output_model=user_model)
agent("Get user info", structured_output_prompt=invocation_prompt)

mock_event_loop.assert_called_once()

@pytest.mark.asyncio
@patch("strands.agent.agent.event_loop_cycle")
async def test_agent_invoke_async_with_structured_output_prompt(
self, mock_event_loop, user_model, mock_model, mock_metrics
):
"""Test Agent.invoke_async with structured_output_prompt."""
custom_prompt = "Async prompt for structured output."

async def mock_cycle(*args, **kwargs):
structured_output_context = kwargs.get("structured_output_context")
assert structured_output_context.structured_output_prompt == custom_prompt

yield EventLoopStopEvent(
stop_reason="end_turn",
message={"role": "assistant", "content": [{"text": "Response"}]},
metrics=mock_metrics,
request_state={},
)

mock_event_loop.side_effect = mock_cycle

agent = Agent(model=mock_model, structured_output_model=user_model)
await agent.invoke_async("Get user", structured_output_prompt=custom_prompt)

mock_event_loop.assert_called_once()

@pytest.mark.asyncio
@patch("strands.agent.agent.event_loop_cycle")
async def test_agent_stream_async_with_structured_output_prompt(
self, mock_event_loop, user_model, mock_model, mock_metrics
):
"""Test Agent.stream_async with structured_output_prompt."""
custom_prompt = "Stream async prompt for structured output."

async def mock_cycle(*args, **kwargs):
structured_output_context = kwargs.get("structured_output_context")
assert structured_output_context.structured_output_prompt == custom_prompt

yield EventLoopStopEvent(
stop_reason="end_turn",
message={"role": "assistant", "content": [{"text": "Response"}]},
metrics=mock_metrics,
request_state={},
)

mock_event_loop.side_effect = mock_cycle

agent = Agent(model=mock_model, structured_output_model=user_model)
async for _ in agent.stream_async("Get user", structured_output_prompt=custom_prompt):
pass

mock_event_loop.assert_called_once()
Loading
Loading