diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 47b7b55b36..591d255490 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -5,7 +5,7 @@ import sys from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack -from copy import copy +from copy import deepcopy from itertools import chain from typing import Any, ClassVar, Literal, Protocol, TypeVar, cast, runtime_checkable from uuid import uuid4 @@ -1245,7 +1245,7 @@ async def _prepare_thread_and_messages( Raises: AgentExecutionException: If the conversation IDs on the thread and agent don't match. """ - chat_options = copy(self.chat_options) if self.chat_options else ChatOptions() + chat_options = deepcopy(self.chat_options) if self.chat_options else ChatOptions() thread = thread or self.get_new_thread() if thread.service_thread_id and thread.context_provider: await thread.context_provider.thread_created(thread.service_thread_id) diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 34ade839bf..12297fe82a 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -3173,6 +3173,40 @@ def __init__( self.top_p = top_p self.user = user + def __deepcopy__(self, memo: dict[int, Any]) -> "ChatOptions": + """Create a runtime-safe copy without deep-copying tool instances.""" + clone = type(self).__new__(type(self)) + memo[id(self)] = clone + for key, value in self.__dict__.items(): + if key == "_tools": + setattr(clone, key, list(value) if value is not None else None) + continue + if key in {"logit_bias", "metadata", "additional_properties"}: + setattr(clone, key, self._safe_deepcopy_mapping(value, memo)) + continue + setattr(clone, key, self._safe_deepcopy_value(value, memo)) + return clone + + @staticmethod + def _safe_deepcopy_mapping( + value: MutableMapping[str, Any] | None, memo: dict[int, Any] + ) -> MutableMapping[str, Any] | None: + """Deep copy helper that falls back to a shallow copy for problematic mappings.""" + if value is None: + return None + try: + return deepcopy(value, memo) # type: ignore[arg-type] + except Exception: + return dict(value) + + @staticmethod + def _safe_deepcopy_value(value: Any, memo: dict[int, Any]) -> Any: + """Deep copy helper that avoids failing on non-copyable instances.""" + try: + return deepcopy(value, memo) + except Exception: + return value + @property def tools(self) -> list[ToolProtocol | MutableMapping[str, Any]] | None: """Return the tools that are specified.""" diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 7d36debf1c..77d5911865 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -115,6 +115,26 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl assert result_messages[1].text == "Test" +async def test_prepare_thread_does_not_mutate_agent_chat_options(chat_client: ChatClientProtocol) -> None: + tool = HostedCodeInterpreterTool() + agent = ChatAgent(chat_client=chat_client, tools=[tool]) + + assert agent.chat_options.tools is not None + base_tools = agent.chat_options.tools + thread = agent.get_new_thread() + + _, prepared_chat_options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] + thread=thread, + input_messages=[ChatMessage(role=Role.USER, text="Test")], + ) + + assert prepared_chat_options.tools is not None + assert base_tools is not prepared_chat_options.tools + + prepared_chat_options.tools.append(HostedCodeInterpreterTool()) # type: ignore[arg-type] + assert len(agent.chat_options.tools) == 1 + + async def test_chat_client_agent_update_thread_id(chat_client_base: ChatClientProtocol) -> None: mock_response = ChatResponse( messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])], diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 266aaad3f6..5a0ec5a773 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3,6 +3,7 @@ import pytest from agent_framework import ( + ChatAgent, ChatClientProtocol, ChatMessage, ChatOptions, @@ -127,6 +128,148 @@ def ai_func(arg1: str) -> str: assert exec_counter == 1 +async def test_function_invocation_inside_aiohttp_server(chat_client_base: ChatClientProtocol): + import aiohttp + from aiohttp import web + + exec_counter = 0 + + @ai_function(name="start_todo_investigation") + def ai_func(user_query: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Investigated {user_query}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[ + FunctionCallContent( + call_id="1", + name="start_todo_investigation", + arguments='{"user_query": "issue"}', + ) + ], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + agent = ChatAgent(chat_client=chat_client_base, tools=[ai_func]) + + async def handler(request: web.Request) -> web.Response: + thread = agent.get_new_thread() + result = await agent.run("Fix issue", thread=thread) + return web.Response(text=result.text or "") + + app = web.Application() + app.add_routes([web.post("/run", handler)]) + + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + try: + port = site._server.sockets[0].getsockname()[1] + async with aiohttp.ClientSession() as session, session.post(f"http://127.0.0.1:{port}/run") as response: + assert response.status == 200 + await response.text() + finally: + await runner.cleanup() + + assert exec_counter == 1 + + +async def test_function_invocation_in_threaded_aiohttp_app(chat_client_base: ChatClientProtocol): + import asyncio + import threading + from queue import Queue + + import aiohttp + from aiohttp import web + + exec_counter = 0 + + @ai_function(name="start_threaded_investigation") + def ai_func(user_query: str) -> str: + nonlocal exec_counter + exec_counter += 1 + return f"Threaded {user_query}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=ChatMessage( + role="assistant", + contents=[ + FunctionCallContent( + call_id="thread-1", + name="start_threaded_investigation", + arguments='{"user_query": "issue"}', + ) + ], + ) + ), + ChatResponse(messages=ChatMessage(role="assistant", text="done")), + ] + + agent = ChatAgent(chat_client=chat_client_base, tools=[ai_func]) + + ready_event = threading.Event() + port_queue: Queue[int] = Queue() + shutdown_queue: Queue[tuple[asyncio.AbstractEventLoop, asyncio.Event]] = Queue() + + async def init_app() -> web.Application: + async def handler(request: web.Request) -> web.Response: + thread = agent.get_new_thread() + result = await agent.run("Fix issue", thread=thread) + return web.Response(text=result.text or "") + + app = web.Application() + app.add_routes([web.post("/run", handler)]) + return app + + def server_thread() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async def runner_main() -> None: + app = await init_app() + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, "127.0.0.1", 0) + await site.start() + shutdown_event = asyncio.Event() + shutdown_queue.put((loop, shutdown_event)) + port = site._server.sockets[0].getsockname()[1] + port_queue.put(port) + ready_event.set() + try: + await shutdown_event.wait() + finally: + await runner.cleanup() + + try: + loop.run_until_complete(runner_main()) + finally: + loop.close() + + thread = threading.Thread(target=server_thread, daemon=True) + thread.start() + ready_event.wait(timeout=5) + assert ready_event.is_set() + loop_ref, shutdown_event = shutdown_queue.get(timeout=2) + port = port_queue.get(timeout=2) + + async with aiohttp.ClientSession() as session, session.post(f"http://127.0.0.1:{port}/run") as response: + assert response.status == 200 + await response.text() + + loop_ref.call_soon_threadsafe(shutdown_event.set) + thread.join(timeout=5) + assert exec_counter == 1 + + @pytest.mark.parametrize( "approval_required,num_functions", [