Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 220 additions & 84 deletions src/strands/tools/_caller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Support direct tool calls through agent.
"""Direct tool call support.

This module provides the _DirectToolCall and _ToolCaller classes that enable direct tool invocation through the
agent.tool interface, including synchronous execution and streaming methods.

Example:
```
Expand All @@ -7,10 +10,15 @@
```
"""

import asyncio
import contextvars
import json
import logging
import queue
import random
import weakref
from collections.abc import Callable
from collections.abc import AsyncIterator, Iterator
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any

from .._async import run_async
Expand All @@ -24,19 +32,34 @@
from ..agent import Agent
from ..experimental.bidi.agent import BidiAgent

logger = logging.getLogger(__name__)

class _ToolCaller:
"""Call tool as a function."""
# Sentinel to signal end of stream
_STREAM_END = object()

def __init__(self, agent: "Agent | BidiAgent") -> None:
"""Initialize instance.

class _DirectToolCall:
"""Callable wrapper for a single tool that provides streaming methods.

This class enables three execution modes for direct tool calls:
1. Synchronous: ``result = agent.tool.my_tool(x=5)``
2. Sync streaming: ``for event in agent.tool.my_tool.stream(x=5)``
3. Async streaming: ``async for event in agent.tool.my_tool.stream_async(x=5)``

Streaming methods do not acquire the invocation lock, do not record to message
history, and do not apply conversation management. They are designed for
observability and real-time progress monitoring.
"""

def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None:
"""Initialize direct tool call.

Args:
agent: Agent reference that will accept tool results.
agent: Agent reference that owns the tools.
tool_name: Name of the tool to execute.
"""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent_ref = weakref.ref(agent)
self._tool_name = tool_name

@property
def _agent(self) -> "Agent | BidiAgent":
Expand All @@ -46,104 +69,181 @@ def _agent(self) -> "Agent | BidiAgent":
raise ReferenceError("Agent has been garbage collected")
return agent

def __getattr__(self, name: str) -> Callable[..., Any]:
"""Call tool as a function.
def _prepare_tool_use(self, **kwargs: Any) -> tuple[ToolUse, list[ToolResult], dict[str, Any]]:
"""Prepare tool use request, results list, and invocation state.

Args:
**kwargs: Tool parameters.

Returns:
Tuple of (tool_use, tool_results, invocation_state).

Raises:
AttributeError: If tool doesn't exist.
"""
normalized_name = self._find_normalized_tool_name(self._tool_name)
tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
return tool_use, [], kwargs

This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
def __call__(
self,
user_message_override: str | None = None,
record_direct_tool_call: bool | None = None,
**kwargs: Any,
) -> ToolResult:
"""Synchronous tool execution (existing behavior - backward compatible).

This method enables the method-style interface (e.g., ``agent.tool.tool_name(param="value")``).
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').

Args:
name: The name of the attribute (tool) being accessed.
user_message_override: Optional custom message to record.
record_direct_tool_call: Whether to record in message history.
**kwargs: Tool parameters.

Returns:
A function that when called will execute the named tool.
ToolResult from execution.

Raises:
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
AttributeError: If tool doesn't exist.
RuntimeError: If called during interrupt.
ConcurrencyException: If invocation lock cannot be acquired.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

should_lock = should_record_direct_tool_call

def caller(
user_message_override: str | None = None,
record_direct_tool_call: bool | None = None,
**kwargs: Any,
) -> Any:
"""Call a tool directly by name.

Args:
user_message_override: Optional custom message to record instead of default
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
attribute if provided.
**kwargs: Keyword arguments to pass to the tool.

Returns:
The result returned by the tool.

Raises:
AttributeError: If the tool doesn't exist.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

if record_direct_tool_call is not None:
should_record_direct_tool_call = record_direct_tool_call
else:
should_record_direct_tool_call = self._agent.record_direct_tool_call

should_lock = should_record_direct_tool_call

from ..agent import Agent # Locally imported to avoid circular reference

acquired_lock = (
should_lock
and isinstance(self._agent, Agent)
and self._agent._invocation_lock.acquire_lock(blocking=False)
from ..agent import Agent # Locally imported to avoid circular reference

acquired_lock = (
should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False)
)
if should_lock and not acquired_lock:
raise ConcurrencyException(
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
)
if should_lock and not acquired_lock:
raise ConcurrencyException(
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
)

try:
normalized_name = self._find_normalized_tool_name(name)
try:
tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs)

# Create unique tool ID and set up the tool request
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
tool_use: ToolUse = {
"toolUseId": tool_id,
"name": normalized_name,
"input": kwargs.copy(),
}
tool_results: list[ToolResult] = []
invocation_state = kwargs
async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")

async def acall() -> ToolResult:
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")
tool_result = tool_results[0]

tool_result = tool_results[0]
if should_record_direct_tool_call:
await self._record_tool_execution(tool_use, tool_result, user_message_override)

if should_record_direct_tool_call:
# Create a record of this tool execution in the message history
await self._record_tool_execution(tool_use, tool_result, user_message_override)
return tool_result

return tool_result
tool_result = run_async(acall)

tool_result = run_async(acall)
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The __call__ method applies conversation_manager.apply_management() after execution (line 134), but stream_async() does not. While this may be intentional since streaming doesn't record to history, if a user mixes streaming and non-streaming calls, the conversation management behavior could become inconsistent.

Suggestion: Document this behavioral difference explicitly, or keep the two paths symmetric.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional, streaming doesn't add anything to message history, so there's nothing for the conversation manager to act on. Added a note about this in both stream() and stream_async() docstrings to make it clear.


# TODO: https://github.com/strands-agents/sdk-python/issues/1311
if isinstance(self._agent, Agent):
self._agent.conversation_manager.apply_management(self._agent)
return tool_result

return tool_result
finally:
if acquired_lock and isinstance(self._agent, Agent):
self._agent._invocation_lock.release()

def stream(self, **kwargs: Any) -> Iterator[Any]:
"""Synchronous streaming of tool execution events.

Bridges async-to-sync streaming using a background thread and queue, yielding
events in real-time as they are produced by the tool.

This method does not acquire the invocation lock, does not record to message
history, and does not apply conversation management.

Args:
**kwargs: Tool parameters.

Yields:
Tool execution events in real-time.

Raises:
AttributeError: If tool doesn't exist.
RuntimeError: If called during interrupt.
"""
# Fast-fail before spinning up a thread; stream_async also checks but this avoids unnecessary overhead
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")

event_queue: queue.Queue[Any] = queue.Queue()

async def _produce() -> None:
try:
async for event in self.stream_async(**kwargs):
event_queue.put(event)
except BaseException:
# Re-raise to propagate via future.result(); the sentinel must still be placed
# on the queue so the main thread unblocks before checking the future
raise
finally:
if acquired_lock and isinstance(self._agent, Agent):
self._agent._invocation_lock.release()
event_queue.put(_STREAM_END)

context = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(context.run, asyncio.run, _produce())

while True:
item = event_queue.get()
if item is _STREAM_END:
break
yield item

return caller
# Propagates any exception from the producer thread
future.result()

async def stream_async(self, **kwargs: Any) -> AsyncIterator[Any]:
"""Asynchronous streaming of tool execution events.

Yields events directly from tool execution without recording to message
history. Designed for observability and real-time progress monitoring.

This method does not acquire the invocation lock, does not record to message
history, and does not apply conversation management. It can be used concurrently
with agent invocations.

Args:
**kwargs: Tool parameters.

Yields:
Tool execution events from ToolExecutor._stream().

Raises:
AttributeError: If tool doesn't exist.
RuntimeError: If called during interrupt.
"""
if self._agent._interrupt_state.activated:
raise RuntimeError("cannot directly call tool during interrupt")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: stream_async() doesn't check the invocation lock or record_direct_tool_call, while __call__() does. This means streaming tool calls could execute concurrently with agent invocations even when record_direct_tool_call is True (the default), creating potential race conditions with shared agent state.

Suggestion: Either add the same concurrency guard from __call__(), or document why streaming calls intentionally skip it (e.g., because they don't record to message history). If the intent is that streaming never records, make this an explicit design decision in the docstring.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is by design, streaming is meant for observability only, it doesn't touch message history or agent state, so there's no shared mutation to guard against. Made this explicit in the docstrings now so it's clearly an intentional choice rather than an oversight.


tool_use, tool_results, invocation_state = self._prepare_tool_use(**kwargs)

logger.debug("tool_name=<%s>, streaming=<True> | executing tool stream", tool_use["name"])

async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
if isinstance(event, ToolInterruptEvent):
self._agent._interrupt_state.deactivate()
raise RuntimeError("cannot raise interrupt in direct tool call")
yield event

def _find_normalized_tool_name(self, name: str) -> str:
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
Expand Down Expand Up @@ -246,3 +346,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di

properties = tool_spec["inputSchema"]["json"]["properties"]
return {k: v for k, v in input_params.items() if k in properties}


class _ToolCaller:
"""Call tool as a function."""

def __init__(self, agent: "Agent | BidiAgent") -> None:
"""Initialize instance.

Args:
agent: Agent reference that will accept tool results.
"""
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
# agent tools and thus break their execution.
self._agent_ref = weakref.ref(agent)

@property
def _agent(self) -> "Agent | BidiAgent":
"""Return the agent, raising ReferenceError if it has been garbage collected."""
agent = self._agent_ref()
if agent is None:
raise ReferenceError("Agent has been garbage collected")
return agent

def __getattr__(self, name: str) -> _DirectToolCall:
"""Return direct tool call with streaming methods.

This method enables the tool calling interface by returning a callable
object that provides both synchronous execution and streaming methods.

Args:
name: Tool name.

Returns:
Direct tool call instance.
"""
return _DirectToolCall(self._agent, name)
Loading