From 09271558c59da92c8c6a27e47c069b5ec3c0a079 Mon Sep 17 00:00:00 2001 From: Emaan Khan Date: Sun, 22 Mar 2026 16:08:56 -0700 Subject: [PATCH] feat: add streaming to direct tool calls --- src/strands/tools/_caller.py | 304 +++++++++++----- tests/strands/tools/test_direct_tool_call.py | 356 +++++++++++++++++++ tests_integ/test_tool_streaming.py | 79 ++++ 3 files changed, 655 insertions(+), 84 deletions(-) create mode 100644 tests/strands/tools/test_direct_tool_call.py create mode 100644 tests_integ/test_tool_streaming.py diff --git a/src/strands/tools/_caller.py b/src/strands/tools/_caller.py index 0b5408f35..a70fd2e7a 100644 --- a/src/strands/tools/_caller.py +++ b/src/strands/tools/_caller.py @@ -1,4 +1,7 @@ -"""Support direct tool calls through agent. +"""Direct tool call support. + +This module provides the _DirectToolCall and _ToolCaller classes that enable direct tool invocation through the +agent.tool interface, including synchronous execution and streaming methods. Example: ``` @@ -7,10 +10,15 @@ ``` """ +import asyncio +import contextvars import json +import logging +import queue import random import weakref -from collections.abc import Callable +from collections.abc import AsyncIterator, Iterator +from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any from .._async import run_async @@ -24,19 +32,34 @@ from ..agent import Agent from ..experimental.bidi.agent import BidiAgent +logger = logging.getLogger(__name__) -class _ToolCaller: - """Call tool as a function.""" +# Sentinel to signal end of stream +_STREAM_END = object() - def __init__(self, agent: "Agent | BidiAgent") -> None: - """Initialize instance. + +class _DirectToolCall: + """Callable wrapper for a single tool that provides streaming methods. + + This class enables three execution modes for direct tool calls: + 1. Synchronous: ``result = agent.tool.my_tool(x=5)`` + 2. Sync streaming: ``for event in agent.tool.my_tool.stream(x=5)`` + 3. Async streaming: ``async for event in agent.tool.my_tool.stream_async(x=5)`` + + Streaming methods do not acquire the invocation lock, do not record to message + history, and do not apply conversation management. They are designed for + observability and real-time progress monitoring. + """ + + def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None: + """Initialize direct tool call. Args: - agent: Agent reference that will accept tool results. + agent: Agent reference that owns the tools. + tool_name: Name of the tool to execute. """ - # WARNING: Do not add any other member variables or methods as this could result in a name conflict with - # agent tools and thus break their execution. self._agent_ref = weakref.ref(agent) + self._tool_name = tool_name @property def _agent(self) -> "Agent | BidiAgent": @@ -46,104 +69,181 @@ def _agent(self) -> "Agent | BidiAgent": raise ReferenceError("Agent has been garbage collected") return agent - def __getattr__(self, name: str) -> Callable[..., Any]: - """Call tool as a function. + def _prepare_tool_use(self, **kwargs: Any) -> tuple[ToolUse, list[ToolResult], dict[str, Any]]: + """Prepare tool use request, results list, and invocation state. + + Args: + **kwargs: Tool parameters. + + Returns: + Tuple of (tool_use, tool_results, invocation_state). + + Raises: + AttributeError: If tool doesn't exist. + """ + normalized_name = self._find_normalized_tool_name(self._tool_name) + tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}" + tool_use: ToolUse = { + "toolUseId": tool_id, + "name": normalized_name, + "input": kwargs.copy(), + } + return tool_use, [], kwargs - This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). + def __call__( + self, + user_message_override: str | None = None, + record_direct_tool_call: bool | None = None, + **kwargs: Any, + ) -> ToolResult: + """Synchronous tool execution (existing behavior - backward compatible). + + This method enables the method-style interface (e.g., ``agent.tool.tool_name(param="value")``). It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing'). Args: - name: The name of the attribute (tool) being accessed. + user_message_override: Optional custom message to record. + record_direct_tool_call: Whether to record in message history. + **kwargs: Tool parameters. Returns: - A function that when called will execute the named tool. + ToolResult from execution. Raises: - AttributeError: If no tool with the given name exists or if multiple tools match the given name. + AttributeError: If tool doesn't exist. + RuntimeError: If called during interrupt. + ConcurrencyException: If invocation lock cannot be acquired. """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call + + should_lock = should_record_direct_tool_call - def caller( - user_message_override: str | None = None, - record_direct_tool_call: bool | None = None, - **kwargs: Any, - ) -> Any: - """Call a tool directly by name. - - Args: - user_message_override: Optional custom message to record instead of default - record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class - attribute if provided. - **kwargs: Keyword arguments to pass to the tool. - - Returns: - The result returned by the tool. - - Raises: - AttributeError: If the tool doesn't exist. - """ - if self._agent._interrupt_state.activated: - raise RuntimeError("cannot directly call tool during interrupt") - - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call - - should_lock = should_record_direct_tool_call - - from ..agent import Agent # Locally imported to avoid circular reference - - acquired_lock = ( - should_lock - and isinstance(self._agent, Agent) - and self._agent._invocation_lock.acquire_lock(blocking=False) + from ..agent import Agent # Locally imported to avoid circular reference + + acquired_lock = ( + should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False) + ) + if should_lock and not acquired_lock: + raise ConcurrencyException( + "Direct tool call cannot be made while the agent is in the middle of an invocation. " + "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." ) - if should_lock and not acquired_lock: - raise ConcurrencyException( - "Direct tool call cannot be made while the agent is in the middle of an invocation. " - "Set record_direct_tool_call=False to allow direct tool calls during agent invocation." - ) - try: - normalized_name = self._find_normalized_tool_name(name) + try: + tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs) - # Create unique tool ID and set up the tool request - tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}" - tool_use: ToolUse = { - "toolUseId": tool_id, - "name": normalized_name, - "input": kwargs.copy(), - } - tool_results: list[ToolResult] = [] - invocation_state = kwargs + async def acall() -> ToolResult: + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") - async def acall() -> ToolResult: - async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): - if isinstance(event, ToolInterruptEvent): - self._agent._interrupt_state.deactivate() - raise RuntimeError("cannot raise interrupt in direct tool call") + tool_result = tool_results[0] - tool_result = tool_results[0] + if should_record_direct_tool_call: + await self._record_tool_execution(tool_use, tool_result, user_message_override) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - await self._record_tool_execution(tool_use, tool_result, user_message_override) + return tool_result - return tool_result + tool_result = run_async(acall) - tool_result = run_async(acall) + # TODO: https://github.com/strands-agents/sdk-python/issues/1311 + if isinstance(self._agent, Agent): + self._agent.conversation_manager.apply_management(self._agent) - # TODO: https://github.com/strands-agents/sdk-python/issues/1311 - if isinstance(self._agent, Agent): - self._agent.conversation_manager.apply_management(self._agent) + return tool_result - return tool_result + finally: + if acquired_lock and isinstance(self._agent, Agent): + self._agent._invocation_lock.release() + + def stream(self, **kwargs: Any) -> Iterator[Any]: + """Synchronous streaming of tool execution events. + + Bridges async-to-sync streaming using a background thread and queue, yielding + events in real-time as they are produced by the tool. + + This method does not acquire the invocation lock, does not record to message + history, and does not apply conversation management. + Args: + **kwargs: Tool parameters. + + Yields: + Tool execution events in real-time. + + Raises: + AttributeError: If tool doesn't exist. + RuntimeError: If called during interrupt. + """ + # Fast-fail before spinning up a thread; stream_async also checks but this avoids unnecessary overhead + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + + event_queue: queue.Queue[Any] = queue.Queue() + + async def _produce() -> None: + try: + async for event in self.stream_async(**kwargs): + event_queue.put(event) + except BaseException: + # Re-raise to propagate via future.result(); the sentinel must still be placed + # on the queue so the main thread unblocks before checking the future + raise finally: - if acquired_lock and isinstance(self._agent, Agent): - self._agent._invocation_lock.release() + event_queue.put(_STREAM_END) + + context = contextvars.copy_context() + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(context.run, asyncio.run, _produce()) + + while True: + item = event_queue.get() + if item is _STREAM_END: + break + yield item - return caller + # Propagates any exception from the producer thread + future.result() + + async def stream_async(self, **kwargs: Any) -> AsyncIterator[Any]: + """Asynchronous streaming of tool execution events. + + Yields events directly from tool execution without recording to message + history. Designed for observability and real-time progress monitoring. + + This method does not acquire the invocation lock, does not record to message + history, and does not apply conversation management. It can be used concurrently + with agent invocations. + + Args: + **kwargs: Tool parameters. + + Yields: + Tool execution events from ToolExecutor._stream(). + + Raises: + AttributeError: If tool doesn't exist. + RuntimeError: If called during interrupt. + """ + if self._agent._interrupt_state.activated: + raise RuntimeError("cannot directly call tool during interrupt") + + tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs) + + logger.debug("tool_name=<%s>, streaming= | executing tool stream", tool_use["name"]) + + async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state): + if isinstance(event, ToolInterruptEvent): + self._agent._interrupt_state.deactivate() + raise RuntimeError("cannot raise interrupt in direct tool call") + yield event def _find_normalized_tool_name(self, name: str) -> str: """Lookup the tool represented by name, replacing characters with underscores as necessary.""" @@ -246,3 +346,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di properties = tool_spec["inputSchema"]["json"]["properties"] return {k: v for k, v in input_params.items() if k in properties} + + +class _ToolCaller: + """Call tool as a function.""" + + def __init__(self, agent: "Agent | BidiAgent") -> None: + """Initialize instance. + + Args: + agent: Agent reference that will accept tool results. + """ + # WARNING: Do not add any other member variables or methods as this could result in a name conflict with + # agent tools and thus break their execution. + self._agent_ref = weakref.ref(agent) + + @property + def _agent(self) -> "Agent | BidiAgent": + """Return the agent, raising ReferenceError if it has been garbage collected.""" + agent = self._agent_ref() + if agent is None: + raise ReferenceError("Agent has been garbage collected") + return agent + + def __getattr__(self, name: str) -> _DirectToolCall: + """Return direct tool call with streaming methods. + + This method enables the tool calling interface by returning a callable + object that provides both synchronous execution and streaming methods. + + Args: + name: Tool name. + + Returns: + Direct tool call instance. + """ + return _DirectToolCall(self._agent, name) diff --git a/tests/strands/tools/test_direct_tool_call.py b/tests/strands/tools/test_direct_tool_call.py new file mode 100644 index 000000000..3fed489f2 --- /dev/null +++ b/tests/strands/tools/test_direct_tool_call.py @@ -0,0 +1,356 @@ +"""Unit tests for _DirectToolCall.""" + +import gc +import unittest.mock +import weakref + +import pytest + +from strands import Agent, tool +from strands.types._events import ToolResultEvent + + +class TestDirectToolCall: + """Test _DirectToolCall class.""" + + def test_is_callable(self): + """Test direct tool call is callable.""" + + @tool + def test_tool(x: int) -> int: + return x * 2 + + agent = Agent(tools=[test_tool]) + direct_call = agent.tool.test_tool + + result = direct_call(x=5) + assert result["status"] == "success" + + def test_has_streaming_methods(self): + """Test direct tool call has stream and stream_async methods.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + direct_call = agent.tool.test_tool + + assert hasattr(direct_call, "stream") + assert hasattr(direct_call, "stream_async") + assert callable(direct_call.stream) + assert callable(direct_call.stream_async) + + def test_weakref_prevents_circular_reference(self): + """Test weakref prevents agent from leaking.""" + + @tool + def test_tool(x: int) -> int: + return x + + gc.disable() + try: + agent = Agent(tools=[test_tool]) + _ = agent.tool.test_tool + ref = weakref.ref(agent) + + del agent + + if ref() is not None: + gc.collect() + + assert ref() is None + finally: + gc.enable() + + def test_weakref_raises_on_deleted_agent(self): + """Test accessing _agent property raises ReferenceError when agent deleted.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + direct_call = agent.tool.test_tool + + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + _ = direct_call._agent + + def test_find_normalized_tool_name_with_underscores(self): + """Test tool name normalization replaces underscores with hyphens.""" + + @tool(name="my-tool") + def my_tool(x: int) -> int: + return x + + agent = Agent(tools=[my_tool]) + direct_call = agent.tool.my_tool + + result = direct_call(x=5) + assert result["status"] == "success" + + def test_find_normalized_tool_name_not_found(self): + """Test non-existent tool raises AttributeError.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + + with pytest.raises(AttributeError, match="Tool 'nonexistent' not found"): + _ = agent.tool.nonexistent(x=5) + + def test_find_normalized_tool_name_exact_match(self): + """Test _find_normalized_tool_name with exact match.""" + + @tool(name="exact_name") + def exact_tool(x: int) -> int: + return x + + agent = Agent(tools=[exact_tool]) + direct_call = agent.tool.exact_name + + normalized = direct_call._find_normalized_tool_name("exact_name") + assert normalized == "exact_name" + + +class TestDirectToolCallStreamAsync: + """Test _DirectToolCall.stream_async().""" + + @pytest.mark.asyncio + async def test_basic(self, alist): + """Test basic async streaming from direct tool call.""" + + @tool + def test_tool(x: int) -> int: + return x * 2 + + agent = Agent(tools=[test_tool]) + + events = await alist(agent.tool.test_tool.stream_async(x=5)) + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert "10" in result_events[0]["tool_result"]["content"][0]["text"] + + @pytest.mark.asyncio + async def test_with_error(self, alist): + """Test async streaming handles tool errors.""" + + @tool + def error_tool(should_fail: bool) -> str: + if should_fail: + raise ValueError("Test error") + return "success" + + agent = Agent(tools=[error_tool]) + + events = await alist(agent.tool.error_tool.stream_async(should_fail=True)) + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "error" + assert "Test error" in result_events[0]["tool_result"]["content"][0]["text"] + + @pytest.mark.asyncio + async def test_tool_not_found(self): + """Test stream_async with non-existent tool.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + + with pytest.raises(AttributeError, match="Tool 'nonexistent' not found"): + async for _ in agent.tool.nonexistent.stream_async(x=5): + pass + + @pytest.mark.asyncio + async def test_with_interrupt_state_raises(self): + """Test stream_async raises when interrupt state activated.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + agent._interrupt_state.activate() + + with pytest.raises(RuntimeError, match="cannot directly call tool during interrupt"): + async for _event in agent.tool.test_tool.stream_async(x=5): + pass + + @pytest.mark.asyncio + async def test_with_interrupt_from_tool(self): + """Test stream_async when tool raises interrupt.""" + + @tool(context=True) + def interrupt_tool(tool_context): + tool_context.interrupt("test", reason="test") + return "unreachable" + + agent = Agent(tools=[interrupt_tool]) + + with pytest.raises(RuntimeError, match="cannot raise interrupt in direct tool call"): + async for _ in agent.tool.interrupt_tool.stream_async(): + pass + + assert not agent._interrupt_state.activated + + @pytest.mark.asyncio + async def test_after_agent_deleted(self): + """Test stream_async raises when agent is garbage collected.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + direct_call = agent.tool.test_tool + + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + async for _ in direct_call.stream_async(x=5): + pass + + @pytest.mark.asyncio + async def test_with_normalized_name(self, alist): + """Test stream_async with underscore to hyphen normalization.""" + + @tool(name="my-tool") + def my_tool(x: int) -> int: + return x * 3 + + agent = Agent(tools=[my_tool]) + + events = await alist(agent.tool.my_tool.stream_async(x=4)) + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert "12" in result_events[0]["tool_result"]["content"][0]["text"] + + @pytest.mark.asyncio + async def test_generates_unique_tool_ids(self, alist): + """Test stream_async generates unique tool use IDs.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + + with unittest.mock.patch("strands.tools._caller.random.randint") as mock_randint: + mock_randint.side_effect = [111, 222, 333] + + tool_use_ids = [] + for _ in range(3): + events = await alist(agent.tool.test_tool.stream_async(x=1)) + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + tool_use_ids.append(result_events[0]["tool_result"]["toolUseId"]) + + assert len(set(tool_use_ids)) == 3 + assert "tooluse_test_tool_111" in tool_use_ids + assert "tooluse_test_tool_222" in tool_use_ids + assert "tooluse_test_tool_333" in tool_use_ids + + +class TestDirectToolCallStream: + """Test _DirectToolCall.stream() (sync streaming via thread+queue bridge).""" + + def test_basic(self): + """Test synchronous streaming from direct tool call.""" + + @tool + def test_tool(x: int) -> int: + return x * 2 + + agent = Agent(tools=[test_tool]) + + events = list(agent.tool.test_tool.stream(x=5)) + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "success" + assert "10" in result_events[0]["tool_result"]["content"][0]["text"] + + def test_with_error(self): + """Test sync streaming handles tool errors.""" + + @tool + def error_tool(should_fail: bool) -> str: + if should_fail: + raise ValueError("Sync error") + return "success" + + agent = Agent(tools=[error_tool]) + + events = list(agent.tool.error_tool.stream(should_fail=True)) + + result_events = [e for e in events if isinstance(e, ToolResultEvent)] + assert len(result_events) == 1 + assert result_events[0]["tool_result"]["status"] == "error" + assert "Sync error" in result_events[0]["tool_result"]["content"][0]["text"] + + def test_tool_not_found(self): + """Test stream with non-existent tool.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + + with pytest.raises(AttributeError, match="Tool 'nonexistent' not found"): + list(agent.tool.nonexistent.stream(x=5)) + + def test_with_interrupt_state_raises(self): + """Test stream raises when interrupt state activated.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + agent._interrupt_state.activate() + + with pytest.raises(RuntimeError, match="cannot directly call tool during interrupt"): + list(agent.tool.test_tool.stream(x=5)) + + def test_with_interrupt_from_tool(self): + """Test stream when tool raises interrupt.""" + + @tool(context=True) + def interrupt_tool(tool_context): + tool_context.interrupt("test", reason="test") + return "unreachable" + + agent = Agent(tools=[interrupt_tool]) + + with pytest.raises(RuntimeError, match="cannot raise interrupt in direct tool call"): + list(agent.tool.interrupt_tool.stream()) + + assert not agent._interrupt_state.activated + + def test_after_agent_deleted(self): + """Test stream raises when agent is garbage collected.""" + + @tool + def test_tool(x: int) -> int: + return x + + agent = Agent(tools=[test_tool]) + direct_call = agent.tool.test_tool + + del agent + gc.collect() + + with pytest.raises(ReferenceError, match="Agent has been garbage collected"): + list(direct_call.stream(x=5)) diff --git a/tests_integ/test_tool_streaming.py b/tests_integ/test_tool_streaming.py new file mode 100644 index 000000000..a64fe4eef --- /dev/null +++ b/tests_integ/test_tool_streaming.py @@ -0,0 +1,79 @@ +"""Integration tests for direct tool call streaming (Issue #1436).""" + +import pytest + +from strands import Agent, tool + + +@tool +def simple_tool(value: int) -> int: + """Simple tool for testing.""" + return value * 2 + + +class TestToolStreaming: + """Test tool streaming methods.""" + + @pytest.mark.asyncio + async def test_async_streaming(self): + """Test async streaming captures events.""" + agent = Agent(tools=[simple_tool]) + events = [] + + async for event in agent.tool.simple_tool.stream_async(value=5): + events.append(event) + + assert len(events) > 0 + assert any(e.get("type") == "tool_result" for e in events) + + def test_sync_streaming(self): + """Test sync streaming captures events.""" + agent = Agent(tools=[simple_tool]) + events = [] + + for event in agent.tool.simple_tool.stream(value=5): + events.append(event) + + assert len(events) > 0 + assert any(e.get("type") == "tool_result" for e in events) + + def test_backward_compatibility(self): + """Test existing sync API unchanged.""" + agent = Agent(tools=[simple_tool]) + result = agent.tool.simple_tool(value=5) + + assert result["status"] == "success" + assert result["content"][0]["text"] == "10" + + @pytest.mark.asyncio + async def test_tool_not_found(self): + """Test non-existent tool raises AttributeError.""" + agent = Agent(tools=[]) + + with pytest.raises(AttributeError, match="Tool 'fake' not found"): + async for _event in agent.tool.fake.stream_async(): + pass + + @pytest.mark.asyncio + async def test_tool_error_captured_in_result(self): + """Test tool errors are captured in tool_result events.""" + + @tool + def error_tool() -> str: + raise ValueError("Test error") + + agent = Agent(tools=[error_tool]) + events = [] + + async for event in agent.tool.error_tool.stream_async(): + events.append(event) + + # Should have at least one event + assert len(events) > 0 + + # Final event should be tool_result with error status + final_event = events[-1] + assert final_event.get("type") == "tool_result" + tool_result = final_event.get("tool_result", {}) + assert tool_result.get("status") == "error" + assert "Test error" in tool_result.get("content", [{}])[0].get("text", "")