-
Notifications
You must be signed in to change notification settings - Fork 773
feat: add streaming to direct tool calls #1955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,7 @@ | ||
| """Support direct tool calls through agent. | ||
| """Direct tool call support. | ||
|
|
||
| This module provides the _DirectToolCall and _ToolCaller classes that enable direct tool invocation through the | ||
| agent.tool interface, including synchronous execution and streaming methods. | ||
|
|
||
| Example: | ||
| ``` | ||
|
|
@@ -7,10 +10,15 @@ | |
| ``` | ||
| """ | ||
|
|
||
| import asyncio | ||
| import contextvars | ||
| import json | ||
| import logging | ||
| import queue | ||
| import random | ||
| import weakref | ||
| from collections.abc import Callable | ||
| from collections.abc import AsyncIterator, Iterator | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from .._async import run_async | ||
|
|
@@ -24,19 +32,34 @@ | |
| from ..agent import Agent | ||
| from ..experimental.bidi.agent import BidiAgent | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| class _ToolCaller: | ||
| """Call tool as a function.""" | ||
| # Sentinel to signal end of stream | ||
| _STREAM_END = object() | ||
|
|
||
| def __init__(self, agent: "Agent | BidiAgent") -> None: | ||
| """Initialize instance. | ||
|
|
||
| class _DirectToolCall: | ||
| """Callable wrapper for a single tool that provides streaming methods. | ||
|
|
||
| This class enables three execution modes for direct tool calls: | ||
| 1. Synchronous: ``result = agent.tool.my_tool(x=5)`` | ||
| 2. Sync streaming: ``for event in agent.tool.my_tool.stream(x=5)`` | ||
| 3. Async streaming: ``async for event in agent.tool.my_tool.stream_async(x=5)`` | ||
|
|
||
| Streaming methods do not acquire the invocation lock, do not record to message | ||
| history, and do not apply conversation management. They are designed for | ||
| observability and real-time progress monitoring. | ||
| """ | ||
|
|
||
| def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None: | ||
| """Initialize direct tool call. | ||
|
|
||
| Args: | ||
| agent: Agent reference that will accept tool results. | ||
| agent: Agent reference that owns the tools. | ||
| tool_name: Name of the tool to execute. | ||
| """ | ||
| # WARNING: Do not add any other member variables or methods as this could result in a name conflict with | ||
| # agent tools and thus break their execution. | ||
| self._agent_ref = weakref.ref(agent) | ||
| self._tool_name = tool_name | ||
|
|
||
| @property | ||
| def _agent(self) -> "Agent | BidiAgent": | ||
|
|
@@ -46,104 +69,181 @@ def _agent(self) -> "Agent | BidiAgent": | |
| raise ReferenceError("Agent has been garbage collected") | ||
| return agent | ||
|
|
||
| def __getattr__(self, name: str) -> Callable[..., Any]: | ||
| """Call tool as a function. | ||
| def _prepare_tool_use(self, **kwargs: Any) -> tuple[ToolUse, list[ToolResult], dict[str, Any]]: | ||
| """Prepare tool use request, results list, and invocation state. | ||
|
|
||
| Args: | ||
| **kwargs: Tool parameters. | ||
|
|
||
| Returns: | ||
| Tuple of (tool_use, tool_results, invocation_state). | ||
|
|
||
| Raises: | ||
| AttributeError: If tool doesn't exist. | ||
| """ | ||
| normalized_name = self._find_normalized_tool_name(self._tool_name) | ||
| tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}" | ||
| tool_use: ToolUse = { | ||
| "toolUseId": tool_id, | ||
| "name": normalized_name, | ||
| "input": kwargs.copy(), | ||
| } | ||
| return tool_use, [], kwargs | ||
|
|
||
| This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). | ||
| def __call__( | ||
| self, | ||
| user_message_override: str | None = None, | ||
| record_direct_tool_call: bool | None = None, | ||
| **kwargs: Any, | ||
| ) -> ToolResult: | ||
| """Synchronous tool execution (existing behavior - backward compatible). | ||
|
|
||
| This method enables the method-style interface (e.g., ``agent.tool.tool_name(param="value")``). | ||
| It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). | ||
|
|
||
| Args: | ||
| name: The name of the attribute (tool) being accessed. | ||
| user_message_override: Optional custom message to record. | ||
| record_direct_tool_call: Whether to record in message history. | ||
| **kwargs: Tool parameters. | ||
|
|
||
| Returns: | ||
| A function that when called will execute the named tool. | ||
| ToolResult from execution. | ||
|
|
||
| Raises: | ||
| AttributeError: If no tool with the given name exists or if multiple tools match the given name. | ||
| AttributeError: If tool doesn't exist. | ||
| RuntimeError: If called during interrupt. | ||
| ConcurrencyException: If invocation lock cannot be acquired. | ||
| """ | ||
| if self._agent._interrupt_state.activated: | ||
| raise RuntimeError("cannot directly call tool during interrupt") | ||
|
|
||
| if record_direct_tool_call is not None: | ||
| should_record_direct_tool_call = record_direct_tool_call | ||
| else: | ||
| should_record_direct_tool_call = self._agent.record_direct_tool_call | ||
|
|
||
| should_lock = should_record_direct_tool_call | ||
|
|
||
| def caller( | ||
| user_message_override: str | None = None, | ||
| record_direct_tool_call: bool | None = None, | ||
| **kwargs: Any, | ||
| ) -> Any: | ||
| """Call a tool directly by name. | ||
|
|
||
| Args: | ||
| user_message_override: Optional custom message to record instead of default | ||
| record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class | ||
| attribute if provided. | ||
| **kwargs: Keyword arguments to pass to the tool. | ||
|
|
||
| Returns: | ||
| The result returned by the tool. | ||
|
|
||
| Raises: | ||
| AttributeError: If the tool doesn't exist. | ||
| """ | ||
| if self._agent._interrupt_state.activated: | ||
| raise RuntimeError("cannot directly call tool during interrupt") | ||
|
|
||
| if record_direct_tool_call is not None: | ||
| should_record_direct_tool_call = record_direct_tool_call | ||
| else: | ||
| should_record_direct_tool_call = self._agent.record_direct_tool_call | ||
|
|
||
| should_lock = should_record_direct_tool_call | ||
|
|
||
| from ..agent import Agent # Locally imported to avoid circular reference | ||
|
|
||
| acquired_lock = ( | ||
| should_lock | ||
| and isinstance(self._agent, Agent) | ||
| and self._agent._invocation_lock.acquire_lock(blocking=False) | ||
| from ..agent import Agent # Locally imported to avoid circular reference | ||
|
|
||
| acquired_lock = ( | ||
| should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False) | ||
| ) | ||
| if should_lock and not acquired_lock: | ||
| raise ConcurrencyException( | ||
| "Direct tool call cannot be made while the agent is in the middle of an invocation. " | ||
| "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." | ||
| ) | ||
| if should_lock and not acquired_lock: | ||
| raise ConcurrencyException( | ||
| "Direct tool call cannot be made while the agent is in the middle of an invocation. " | ||
| "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." | ||
| ) | ||
|
|
||
| try: | ||
| normalized_name = self._find_normalized_tool_name(name) | ||
| try: | ||
| tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs) | ||
|
|
||
| # Create unique tool ID and set up the tool request | ||
| tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" | ||
| tool_use: ToolUse = { | ||
| "toolUseId": tool_id, | ||
| "name": normalized_name, | ||
| "input": kwargs.copy(), | ||
| } | ||
| tool_results: list[ToolResult] = [] | ||
| invocation_state = kwargs | ||
| async def acall() -> ToolResult: | ||
| async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): | ||
| if isinstance(event, ToolInterruptEvent): | ||
| self._agent._interrupt_state.deactivate() | ||
| raise RuntimeError("cannot raise interrupt in direct tool call") | ||
|
|
||
| async def acall() -> ToolResult: | ||
| async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): | ||
| if isinstance(event, ToolInterruptEvent): | ||
| self._agent._interrupt_state.deactivate() | ||
| raise RuntimeError("cannot raise interrupt in direct tool call") | ||
| tool_result = tool_results[0] | ||
|
|
||
| tool_result = tool_results[0] | ||
| if should_record_direct_tool_call: | ||
| await self._record_tool_execution(tool_use, tool_result, user_message_override) | ||
|
|
||
| if should_record_direct_tool_call: | ||
| # Create a record of this tool execution in the message history | ||
| await self._record_tool_execution(tool_use, tool_result, user_message_override) | ||
| return tool_result | ||
|
|
||
| return tool_result | ||
| tool_result = run_async(acall) | ||
|
|
||
| tool_result = run_async(acall) | ||
| # TODO: https://github.com/strands-agents/sdk-python/issues/1311 | ||
| if isinstance(self._agent, Agent): | ||
| self._agent.conversation_manager.apply_management(self._agent) | ||
|
|
||
| # TODO: https://github.com/strands-agents/sdk-python/issues/1311 | ||
| if isinstance(self._agent, Agent): | ||
| self._agent.conversation_manager.apply_management(self._agent) | ||
| return tool_result | ||
|
|
||
| return tool_result | ||
| finally: | ||
| if acquired_lock and isinstance(self._agent, Agent): | ||
| self._agent._invocation_lock.release() | ||
|
|
||
| def stream(self, **kwargs: Any) -> Iterator[Any]: | ||
| """Synchronous streaming of tool execution events. | ||
|
|
||
| Bridges async-to-sync streaming using a background thread and queue, yielding | ||
| events in real-time as they are produced by the tool. | ||
|
|
||
| This method does not acquire the invocation lock, does not record to message | ||
| history, and does not apply conversation management. | ||
|
|
||
| Args: | ||
| **kwargs: Tool parameters. | ||
|
|
||
| Yields: | ||
| Tool execution events in real-time. | ||
|
|
||
| Raises: | ||
| AttributeError: If tool doesn't exist. | ||
| RuntimeError: If called during interrupt. | ||
| """ | ||
| # Fast-fail before spinning up a thread; stream_async also checks but this avoids unnecessary overhead | ||
| if self._agent._interrupt_state.activated: | ||
| raise RuntimeError("cannot directly call tool during interrupt") | ||
|
|
||
| event_queue: queue.Queue[Any] = queue.Queue() | ||
|
|
||
| async def _produce() -> None: | ||
| try: | ||
| async for event in self.stream_async(**kwargs): | ||
| event_queue.put(event) | ||
| except BaseException: | ||
| # Re-raise to propagate via future.result(); the sentinel must still be placed | ||
| # on the queue so the main thread unblocks before checking the future | ||
| raise | ||
| finally: | ||
| if acquired_lock and isinstance(self._agent, Agent): | ||
| self._agent._invocation_lock.release() | ||
| event_queue.put(_STREAM_END) | ||
|
|
||
| context = contextvars.copy_context() | ||
| with ThreadPoolExecutor(max_workers=1) as executor: | ||
| future = executor.submit(context.run, asyncio.run, _produce()) | ||
|
|
||
| while True: | ||
| item = event_queue.get() | ||
| if item is _STREAM_END: | ||
| break | ||
| yield item | ||
|
|
||
| return caller | ||
| # Propagates any exception from the producer thread | ||
| future.result() | ||
|
|
||
| async def stream_async(self, **kwargs: Any) -> AsyncIterator[Any]: | ||
| """Asynchronous streaming of tool execution events. | ||
|
|
||
| Yields events directly from tool execution without recording to message | ||
| history. Designed for observability and real-time progress monitoring. | ||
|
|
||
| This method does not acquire the invocation lock, does not record to message | ||
| history, and does not apply conversation management. It can be used concurrently | ||
| with agent invocations. | ||
|
|
||
| Args: | ||
| **kwargs: Tool parameters. | ||
|
|
||
| Yields: | ||
| Tool execution events from ToolExecutor._stream(). | ||
|
|
||
| Raises: | ||
| AttributeError: If tool doesn't exist. | ||
| RuntimeError: If called during interrupt. | ||
| """ | ||
| if self._agent._interrupt_state.activated: | ||
| raise RuntimeError("cannot directly call tool during interrupt") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Issue: Suggestion: Either add the same concurrency guard from
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is by design, streaming is meant for observability only, it doesn't touch message history or agent state, so there's no shared mutation to guard against. Made this explicit in the docstrings now so it's clearly an intentional choice rather than an oversight. |
||
|
|
||
| tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs) | ||
|
|
||
| logger.debug("tool_name=<%s>, streaming=<True> | executing tool stream", tool_use["name"]) | ||
|
|
||
| async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): | ||
| if isinstance(event, ToolInterruptEvent): | ||
| self._agent._interrupt_state.deactivate() | ||
| raise RuntimeError("cannot raise interrupt in direct tool call") | ||
| yield event | ||
|
|
||
| def _find_normalized_tool_name(self, name: str) -> str: | ||
| """Lookup the tool represented by name, replacing characters with underscores as necessary.""" | ||
|
|
@@ -246,3 +346,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di | |
|
|
||
| properties = tool_spec["inputSchema"]["json"]["properties"] | ||
| return {k: v for k, v in input_params.items() if k in properties} | ||
|
|
||
|
|
||
| class _ToolCaller: | ||
| """Call tool as a function.""" | ||
|
|
||
| def __init__(self, agent: "Agent | BidiAgent") -> None: | ||
| """Initialize instance. | ||
|
|
||
| Args: | ||
| agent: Agent reference that will accept tool results. | ||
| """ | ||
| # WARNING: Do not add any other member variables or methods as this could result in a name conflict with | ||
| # agent tools and thus break their execution. | ||
| self._agent_ref = weakref.ref(agent) | ||
|
|
||
| @property | ||
| def _agent(self) -> "Agent | BidiAgent": | ||
| """Return the agent, raising ReferenceError if it has been garbage collected.""" | ||
| agent = self._agent_ref() | ||
| if agent is None: | ||
| raise ReferenceError("Agent has been garbage collected") | ||
| return agent | ||
|
|
||
| def __getattr__(self, name: str) -> _DirectToolCall: | ||
| """Return direct tool call with streaming methods. | ||
|
|
||
| This method enables the tool calling interface by returning a callable | ||
| object that provides both synchronous execution and streaming methods. | ||
|
|
||
| Args: | ||
| name: Tool name. | ||
|
|
||
| Returns: | ||
| Direct tool call instance. | ||
| """ | ||
| return _DirectToolCall(self._agent, name) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: The
__call__method appliesconversation_manager.apply_management()after execution (line 134), butstream_async()does not. While this may be intentional since streaming doesn't record to history, if a user mixes streaming and non-streaming calls, the conversation management behavior could become inconsistent.Suggestion: Document this behavioral difference explicitly, or keep the two paths symmetric.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional, streaming doesn't add anything to message history, so there's nothing for the conversation manager to act on. Added a note about this in both
stream()andstream_async()docstrings to make it clear.