From 65827d156d15b512d568a7816221a3ed3dde48da Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 17 Nov 2025 10:40:34 +0100 Subject: [PATCH 1/3] Fixes Python: @ai_function doesn't properly handle 'self' param Fixes #1343 --- .../packages/core/agent_framework/_tools.py | 35 +++++ python/packages/core/tests/core/test_tools.py | 130 ++++++++++++++++++ 2 files changed, 165 insertions(+) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 6edd258e15..a786fbe4fa 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -614,6 +614,7 @@ def __init__( **kwargs, ) self.func = func + self._instance = None # Store the instance for bound methods self.input_model = self._resolve_input_model(input_model) self.approval_mode = approval_mode or "never_require" if max_invocations is not None and max_invocations < 1: @@ -632,6 +633,37 @@ def declaration_only(self) -> bool: """Indicate whether the function is declaration only (i.e., has no implementation).""" return self.func is None + def __get__(self, obj: Any, objtype: type | None = None) -> "AIFunction[ArgsT, ReturnT]": + """Implement the descriptor protocol to support bound methods. + + When an AIFunction is accessed as an attribute of a class instance, + this method is called to bind the instance to the function. + + Args: + obj: The instance that owns the descriptor, or None for class access. + objtype: The type that owns the descriptor. + + Returns: + A new AIFunction with the instance bound to the wrapped function. + """ + if obj is None: + # Accessed from the class, not an instance + return self + + # Check if the wrapped function is a method (has 'self' parameter) + if self.func is not None: + sig = inspect.signature(self.func) + params = list(sig.parameters.keys()) + if params and params[0] in {"self", "cls"}: + # Create a new AIFunction with the bound method + import copy + + bound_func = copy.copy(self) + bound_func._instance = obj + return bound_func + + return self + def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]: """Resolve the input model for the function.""" if input_model is None: @@ -662,6 +694,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnT | Awaitable[ReturnT]: ) self.invocation_count += 1 try: + # If we have a bound instance, call the function with self + if self._instance is not None: + return self.func(self._instance, *args, **kwargs) return self.func(*args, **kwargs) except Exception: self.invocation_exception_count += 1 diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index acd9157363..47b129cc60 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -104,6 +104,136 @@ async def async_test_tool(x: int, y: int) -> int: assert (await async_test_tool(1, 2)) == 3 +def test_ai_function_decorator_in_class(): + """Test the ai_function decorator.""" + + class my_tools: + @ai_function(name="test_tool", description="A test tool") + def test_tool(self, x: int, y: int) -> int: + """A simple function that adds two numbers.""" + return x + y + + test_tool = my_tools().test_tool + + assert isinstance(test_tool, ToolProtocol) + assert isinstance(test_tool, AIFunction) + assert test_tool.name == "test_tool" + assert test_tool.description == "A test tool" + assert test_tool.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}, "y": {"title": "Y", "type": "integer"}}, + "required": ["x", "y"], + "title": "test_tool_input", + "type": "object", + } + assert test_tool(1, 2) == 3 + + +async def test_ai_function_decorator_shared_state(): + """Test that decorated methods maintain shared state across multiple calls and tool usage.""" + + class StatefulCounter: + """A class that maintains a counter and provides decorated methods to interact with it.""" + + def __init__(self, initial_value: int = 0): + self.counter = initial_value + self.operation_log: list[str] = [] + + @ai_function(name="increment", description="Increment the counter") + def increment(self, amount: int) -> str: + """Increment the counter by the given amount.""" + self.counter += amount + self.operation_log.append(f"increment({amount})") + return f"Counter incremented by {amount}. New value: {self.counter}" + + @ai_function(name="get_value", description="Get the current counter value") + def get_value(self) -> str: + """Get the current counter value.""" + self.operation_log.append("get_value()") + return f"Current counter value: {self.counter}" + + @ai_function(name="multiply", description="Multiply the counter") + def multiply(self, factor: int) -> str: + """Multiply the counter by the given factor.""" + self.counter *= factor + self.operation_log.append(f"multiply({factor})") + return f"Counter multiplied by {factor}. New value: {self.counter}" + + # Create a single instance with shared state + counter_instance = StatefulCounter(initial_value=10) + + # Get the decorated methods - these will be used by different "agents" or tools + increment_tool = counter_instance.increment + get_value_tool = counter_instance.get_value + multiply_tool = counter_instance.multiply + + # Verify they are AIFunction instances + assert isinstance(increment_tool, AIFunction) + assert isinstance(get_value_tool, AIFunction) + assert isinstance(multiply_tool, AIFunction) + + # Tool 1 (increment) is used + result1 = increment_tool(5) + assert result1 == "Counter incremented by 5. New value: 15" + assert counter_instance.counter == 15 + + # Tool 2 (get_value) sees the state change from tool 1 + result2 = get_value_tool() + assert result2 == "Current counter value: 15" + assert counter_instance.counter == 15 + + # Tool 3 (multiply) modifies the shared state + result3 = multiply_tool(3) + assert result3 == "Counter multiplied by 3. New value: 45" + assert counter_instance.counter == 45 + + # Tool 2 (get_value) sees the state change from tool 3 + result4 = get_value_tool() + assert result4 == "Current counter value: 45" + assert counter_instance.counter == 45 + + # Tool 1 (increment) sees the current state and modifies it + result5 = increment_tool(10) + assert result5 == "Counter incremented by 10. New value: 55" + assert counter_instance.counter == 55 + + # Verify the operation log shows all operations in order + assert counter_instance.operation_log == [ + "increment(5)", + "get_value()", + "multiply(3)", + "get_value()", + "increment(10)", + ] + + # Verify the parameters don't include 'self' + assert increment_tool.parameters() == { + "properties": {"amount": {"title": "Amount", "type": "integer"}}, + "required": ["amount"], + "title": "increment_input", + "type": "object", + } + assert multiply_tool.parameters() == { + "properties": {"factor": {"title": "Factor", "type": "integer"}}, + "required": ["factor"], + "title": "multiply_input", + "type": "object", + } + assert get_value_tool.parameters() == { + "properties": {}, + "title": "get_value_input", + "type": "object", + } + + # Test with invoke method as well (simulating agent execution) + result6 = await increment_tool.invoke(amount=5) + assert result6 == "Counter incremented by 5. New value: 60" + assert counter_instance.counter == 60 + + result7 = await get_value_tool.invoke() + assert result7 == "Current counter value: 60" + assert counter_instance.counter == 60 + + async def test_ai_function_invoke_telemetry_enabled(span_exporter: InMemorySpanExporter): """Test the ai_function invoke method with telemetry enabled.""" From c62fa3eea39e9c4c4bb36b31c539a5bc3d917367 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 17 Nov 2025 10:51:29 +0100 Subject: [PATCH 2/3] fix for declaration only funcs --- .../packages/core/agent_framework/_tools.py | 11 +++++- .../core/test_function_invocation_logic.py | 35 ++++++++++--------- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index a786fbe4fa..6b20f8086a 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -631,6 +631,9 @@ def __init__( @property def declaration_only(self) -> bool: """Indicate whether the function is declaration only (i.e., has no implementation).""" + # Check for explicit _declaration_only attribute first (used in tests) + if hasattr(self, "_declaration_only") and self._declaration_only: + return True return self.func is None def __get__(self, obj: Any, objtype: type | None = None) -> "AIFunction[ArgsT, ReturnT]": @@ -678,7 +681,7 @@ def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | No def __call__(self, *args: Any, **kwargs: Any) -> ReturnT | Awaitable[ReturnT]: """Call the wrapped function with the provided arguments.""" - if self.func is None: + if self.declaration_only: raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.") if self.max_invocations is not None and self.invocation_count >= self.max_invocations: raise ToolException( @@ -893,6 +896,12 @@ def _parse_annotation(annotation: Any) -> Any: def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[BaseModel]: """Create a Pydantic model from a function's signature.""" + # Unwrap AIFunction objects to get the underlying function + from agent_framework._tools import AIFunction + + if isinstance(func, AIFunction): + func = func.func # type: ignore[assignment] + sig = inspect.signature(func) fields = { pname: ( diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 77b95d98a2..266aaad3f6 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -1305,26 +1305,20 @@ def success_func(arg1: str) -> str: assert success_result.result == "Success value1" -async def test_declaration_only_tool_not_executed(chat_client_base: ChatClientProtocol): - """Test that declaration_only tools are not executed.""" - exec_counter = 0 - - @ai_function(name="declaration_func") - def declaration_func_inner(arg1: str) -> str: - nonlocal exec_counter - exec_counter += 1 - return f"Result {arg1}" - - # Create a new AIFunction with declaration_only set +async def test_declaration_only_tool(chat_client_base: ChatClientProtocol): + """Test that declaration_only tools without implementation (func=None) are not executed.""" from agent_framework import AIFunction + # Create a truly declaration-only function with no implementation declaration_func = AIFunction( name="declaration_func", - func=declaration_func_inner, - additional_properties={"declaration_only": True}, + func=None, + description="A declaration-only function for testing", + input_model={"type": "object", "properties": {"arg1": {"type": "string"}}, "required": ["arg1"]}, ) - # Set declaration_only on the instance - object.__setattr__(declaration_func, "_declaration_only", True) + + # Verify it's marked as declaration_only + assert declaration_func.declaration_only is True chat_client_base.run_responses = [ ChatResponse( @@ -1338,8 +1332,6 @@ def declaration_func_inner(arg1: str) -> str: response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[declaration_func]) - # Function should NOT be executed - assert exec_counter == 0 # Should have the function call in messages but not a result function_calls = [ content @@ -1349,6 +1341,15 @@ def declaration_func_inner(arg1: str) -> str: ] assert len(function_calls) >= 1 + # Should not have a function result + function_results = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, FunctionResultContent) and content.call_id == "1" + ] + assert len(function_results) == 0 + async def test_multiple_function_calls_parallel_execution(chat_client_base: ChatClientProtocol): """Test that multiple function calls are executed in parallel.""" From e07dc751f81c750d161f00271da612dbf854da18 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 17 Nov 2025 10:56:47 +0100 Subject: [PATCH 3/3] fix mypy --- python/packages/core/agent_framework/_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 6b20f8086a..17accb066f 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -700,7 +700,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnT | Awaitable[ReturnT]: # If we have a bound instance, call the function with self if self._instance is not None: return self.func(self._instance, *args, **kwargs) - return self.func(*args, **kwargs) + return self.func(*args, **kwargs) # type:ignore[misc] except Exception: self.invocation_exception_count += 1 raise