Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ def __init__(
**kwargs,
)
self.func = func
self._instance = None # Store the instance for bound methods
self.input_model = self._resolve_input_model(input_model)
self.approval_mode = approval_mode or "never_require"
if max_invocations is not None and max_invocations < 1:
Expand All @@ -630,8 +631,42 @@ def __init__(
@property
def declaration_only(self) -> bool:
"""Indicate whether the function is declaration only (i.e., has no implementation)."""
# Check for explicit _declaration_only attribute first (used in tests)
if hasattr(self, "_declaration_only") and self._declaration_only:
return True
return self.func is None

def __get__(self, obj: Any, objtype: type | None = None) -> "AIFunction[ArgsT, ReturnT]":
"""Implement the descriptor protocol to support bound methods.

When an AIFunction is accessed as an attribute of a class instance,
this method is called to bind the instance to the function.

Args:
obj: The instance that owns the descriptor, or None for class access.
objtype: The type that owns the descriptor.

Returns:
A new AIFunction with the instance bound to the wrapped function.
"""
if obj is None:
# Accessed from the class, not an instance
return self

# Check if the wrapped function is a method (has 'self' parameter)
if self.func is not None:
sig = inspect.signature(self.func)
params = list(sig.parameters.keys())
if params and params[0] in {"self", "cls"}:
# Create a new AIFunction with the bound method
import copy

bound_func = copy.copy(self)
bound_func._instance = obj
return bound_func

return self

def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]:
"""Resolve the input model for the function."""
if input_model is None:
Expand All @@ -646,7 +681,7 @@ def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | No

def __call__(self, *args: Any, **kwargs: Any) -> ReturnT | Awaitable[ReturnT]:
"""Call the wrapped function with the provided arguments."""
if self.func is None:
if self.declaration_only:
raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.")
if self.max_invocations is not None and self.invocation_count >= self.max_invocations:
raise ToolException(
Expand All @@ -662,7 +697,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnT | Awaitable[ReturnT]:
)
self.invocation_count += 1
try:
return self.func(*args, **kwargs)
# If we have a bound instance, call the function with self
if self._instance is not None:
return self.func(self._instance, *args, **kwargs)
return self.func(*args, **kwargs) # type:ignore[misc]
except Exception:
self.invocation_exception_count += 1
raise
Expand Down Expand Up @@ -858,6 +896,12 @@ def _parse_annotation(annotation: Any) -> Any:

def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[BaseModel]:
"""Create a Pydantic model from a function's signature."""
# Unwrap AIFunction objects to get the underlying function
from agent_framework._tools import AIFunction

if isinstance(func, AIFunction):
func = func.func # type: ignore[assignment]

sig = inspect.signature(func)
fields = {
pname: (
Expand Down
35 changes: 18 additions & 17 deletions python/packages/core/tests/core/test_function_invocation_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1305,26 +1305,20 @@ def success_func(arg1: str) -> str:
assert success_result.result == "Success value1"


async def test_declaration_only_tool_not_executed(chat_client_base: ChatClientProtocol):
"""Test that declaration_only tools are not executed."""
exec_counter = 0

@ai_function(name="declaration_func")
def declaration_func_inner(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Result {arg1}"

# Create a new AIFunction with declaration_only set
async def test_declaration_only_tool(chat_client_base: ChatClientProtocol):
"""Test that declaration_only tools without implementation (func=None) are not executed."""
from agent_framework import AIFunction

# Create a truly declaration-only function with no implementation
declaration_func = AIFunction(
name="declaration_func",
func=declaration_func_inner,
additional_properties={"declaration_only": True},
func=None,
description="A declaration-only function for testing",
input_model={"type": "object", "properties": {"arg1": {"type": "string"}}, "required": ["arg1"]},
)
# Set declaration_only on the instance
object.__setattr__(declaration_func, "_declaration_only", True)

# Verify it's marked as declaration_only
assert declaration_func.declaration_only is True

chat_client_base.run_responses = [
ChatResponse(
Expand All @@ -1338,8 +1332,6 @@ def declaration_func_inner(arg1: str) -> str:

response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[declaration_func])

# Function should NOT be executed
assert exec_counter == 0
# Should have the function call in messages but not a result
function_calls = [
content
Expand All @@ -1349,6 +1341,15 @@ def declaration_func_inner(arg1: str) -> str:
]
assert len(function_calls) >= 1

# Should not have a function result
function_results = [
content
for msg in response.messages
for content in msg.contents
if isinstance(content, FunctionResultContent) and content.call_id == "1"
]
assert len(function_results) == 0


async def test_multiple_function_calls_parallel_execution(chat_client_base: ChatClientProtocol):
"""Test that multiple function calls are executed in parallel."""
Expand Down
130 changes: 130 additions & 0 deletions python/packages/core/tests/core/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,136 @@ async def async_test_tool(x: int, y: int) -> int:
assert (await async_test_tool(1, 2)) == 3


def test_ai_function_decorator_in_class():
"""Test the ai_function decorator."""

class my_tools:
@ai_function(name="test_tool", description="A test tool")
def test_tool(self, x: int, y: int) -> int:
"""A simple function that adds two numbers."""
return x + y

test_tool = my_tools().test_tool

assert isinstance(test_tool, ToolProtocol)
assert isinstance(test_tool, AIFunction)
assert test_tool.name == "test_tool"
assert test_tool.description == "A test tool"
assert test_tool.parameters() == {
"properties": {"x": {"title": "X", "type": "integer"}, "y": {"title": "Y", "type": "integer"}},
"required": ["x", "y"],
"title": "test_tool_input",
"type": "object",
}
assert test_tool(1, 2) == 3


async def test_ai_function_decorator_shared_state():
"""Test that decorated methods maintain shared state across multiple calls and tool usage."""

class StatefulCounter:
"""A class that maintains a counter and provides decorated methods to interact with it."""

def __init__(self, initial_value: int = 0):
self.counter = initial_value
self.operation_log: list[str] = []

@ai_function(name="increment", description="Increment the counter")
def increment(self, amount: int) -> str:
"""Increment the counter by the given amount."""
self.counter += amount
self.operation_log.append(f"increment({amount})")
return f"Counter incremented by {amount}. New value: {self.counter}"

@ai_function(name="get_value", description="Get the current counter value")
def get_value(self) -> str:
"""Get the current counter value."""
self.operation_log.append("get_value()")
return f"Current counter value: {self.counter}"

@ai_function(name="multiply", description="Multiply the counter")
def multiply(self, factor: int) -> str:
"""Multiply the counter by the given factor."""
self.counter *= factor
self.operation_log.append(f"multiply({factor})")
return f"Counter multiplied by {factor}. New value: {self.counter}"

# Create a single instance with shared state
counter_instance = StatefulCounter(initial_value=10)

# Get the decorated methods - these will be used by different "agents" or tools
increment_tool = counter_instance.increment
get_value_tool = counter_instance.get_value
multiply_tool = counter_instance.multiply

# Verify they are AIFunction instances
assert isinstance(increment_tool, AIFunction)
assert isinstance(get_value_tool, AIFunction)
assert isinstance(multiply_tool, AIFunction)

# Tool 1 (increment) is used
result1 = increment_tool(5)
assert result1 == "Counter incremented by 5. New value: 15"
assert counter_instance.counter == 15

# Tool 2 (get_value) sees the state change from tool 1
result2 = get_value_tool()
assert result2 == "Current counter value: 15"
assert counter_instance.counter == 15

# Tool 3 (multiply) modifies the shared state
result3 = multiply_tool(3)
assert result3 == "Counter multiplied by 3. New value: 45"
assert counter_instance.counter == 45

# Tool 2 (get_value) sees the state change from tool 3
result4 = get_value_tool()
assert result4 == "Current counter value: 45"
assert counter_instance.counter == 45

# Tool 1 (increment) sees the current state and modifies it
result5 = increment_tool(10)
assert result5 == "Counter incremented by 10. New value: 55"
assert counter_instance.counter == 55

# Verify the operation log shows all operations in order
assert counter_instance.operation_log == [
"increment(5)",
"get_value()",
"multiply(3)",
"get_value()",
"increment(10)",
]

# Verify the parameters don't include 'self'
assert increment_tool.parameters() == {
"properties": {"amount": {"title": "Amount", "type": "integer"}},
"required": ["amount"],
"title": "increment_input",
"type": "object",
}
assert multiply_tool.parameters() == {
"properties": {"factor": {"title": "Factor", "type": "integer"}},
"required": ["factor"],
"title": "multiply_input",
"type": "object",
}
assert get_value_tool.parameters() == {
"properties": {},
"title": "get_value_input",
"type": "object",
}

# Test with invoke method as well (simulating agent execution)
result6 = await increment_tool.invoke(amount=5)
assert result6 == "Counter incremented by 5. New value: 60"
assert counter_instance.counter == 60

result7 = await get_value_tool.invoke()
assert result7 == "Current counter value: 60"
assert counter_instance.counter == 60


async def test_ai_function_invoke_telemetry_enabled(span_exporter: InMemorySpanExporter):
"""Test the ai_function invoke method with telemetry enabled."""

Expand Down