Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from copy import copy
from copy import deepcopy
from itertools import chain
from typing import Any, ClassVar, Literal, Protocol, TypeVar, cast, runtime_checkable
from uuid import uuid4
Expand Down Expand Up @@ -1245,7 +1245,7 @@ async def _prepare_thread_and_messages(
Raises:
AgentExecutionException: If the conversation IDs on the thread and agent don't match.
"""
chat_options = copy(self.chat_options) if self.chat_options else ChatOptions()
chat_options = deepcopy(self.chat_options) if self.chat_options else ChatOptions()
thread = thread or self.get_new_thread()
if thread.service_thread_id and thread.context_provider:
await thread.context_provider.thread_created(thread.service_thread_id)
Expand Down
34 changes: 34 additions & 0 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3173,6 +3173,40 @@ def __init__(
self.top_p = top_p
self.user = user

def __deepcopy__(self, memo: dict[int, Any]) -> "ChatOptions":
"""Create a runtime-safe copy without deep-copying tool instances."""
clone = type(self).__new__(type(self))
memo[id(self)] = clone
for key, value in self.__dict__.items():
if key == "_tools":
setattr(clone, key, list(value) if value is not None else None)
continue
if key in {"logit_bias", "metadata", "additional_properties"}:
setattr(clone, key, self._safe_deepcopy_mapping(value, memo))
continue
setattr(clone, key, self._safe_deepcopy_value(value, memo))
return clone

@staticmethod
def _safe_deepcopy_mapping(
value: MutableMapping[str, Any] | None, memo: dict[int, Any]
) -> MutableMapping[str, Any] | None:
"""Deep copy helper that falls back to a shallow copy for problematic mappings."""
if value is None:
return None
try:
return deepcopy(value, memo) # type: ignore[arg-type]
except Exception:
return dict(value)

@staticmethod
def _safe_deepcopy_value(value: Any, memo: dict[int, Any]) -> Any:
"""Deep copy helper that avoids failing on non-copyable instances."""
try:
return deepcopy(value, memo)
except Exception:
return value

@property
def tools(self) -> list[ToolProtocol | MutableMapping[str, Any]] | None:
"""Return the tools that are specified."""
Expand Down
20 changes: 20 additions & 0 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,26 @@ async def test_chat_client_agent_prepare_thread_and_messages(chat_client: ChatCl
assert result_messages[1].text == "Test"


async def test_prepare_thread_does_not_mutate_agent_chat_options(chat_client: ChatClientProtocol) -> None:
tool = HostedCodeInterpreterTool()
agent = ChatAgent(chat_client=chat_client, tools=[tool])

assert agent.chat_options.tools is not None
base_tools = agent.chat_options.tools
thread = agent.get_new_thread()

_, prepared_chat_options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=thread,
input_messages=[ChatMessage(role=Role.USER, text="Test")],
)

assert prepared_chat_options.tools is not None
assert base_tools is not prepared_chat_options.tools

prepared_chat_options.tools.append(HostedCodeInterpreterTool()) # type: ignore[arg-type]
assert len(agent.chat_options.tools) == 1


async def test_chat_client_agent_update_thread_id(chat_client_base: ChatClientProtocol) -> None:
mock_response = ChatResponse(
messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])],
Expand Down
143 changes: 143 additions & 0 deletions python/packages/core/tests/core/test_function_invocation_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from agent_framework import (
ChatAgent,
ChatClientProtocol,
ChatMessage,
ChatOptions,
Expand Down Expand Up @@ -127,6 +128,148 @@ def ai_func(arg1: str) -> str:
assert exec_counter == 1


async def test_function_invocation_inside_aiohttp_server(chat_client_base: ChatClientProtocol):
import aiohttp
from aiohttp import web

exec_counter = 0

@ai_function(name="start_todo_investigation")
def ai_func(user_query: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Investigated {user_query}"

chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[
FunctionCallContent(
call_id="1",
name="start_todo_investigation",
arguments='{"user_query": "issue"}',
)
],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]

agent = ChatAgent(chat_client=chat_client_base, tools=[ai_func])

async def handler(request: web.Request) -> web.Response:
thread = agent.get_new_thread()
result = await agent.run("Fix issue", thread=thread)
return web.Response(text=result.text or "")

app = web.Application()
app.add_routes([web.post("/run", handler)])

runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
try:
port = site._server.sockets[0].getsockname()[1]
async with aiohttp.ClientSession() as session, session.post(f"http://127.0.0.1:{port}/run") as response:
assert response.status == 200
await response.text()
finally:
await runner.cleanup()

assert exec_counter == 1


async def test_function_invocation_in_threaded_aiohttp_app(chat_client_base: ChatClientProtocol):
import asyncio
import threading
from queue import Queue

import aiohttp
from aiohttp import web

exec_counter = 0

@ai_function(name="start_threaded_investigation")
def ai_func(user_query: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Threaded {user_query}"

chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[
FunctionCallContent(
call_id="thread-1",
name="start_threaded_investigation",
arguments='{"user_query": "issue"}',
)
],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]

agent = ChatAgent(chat_client=chat_client_base, tools=[ai_func])

ready_event = threading.Event()
port_queue: Queue[int] = Queue()
shutdown_queue: Queue[tuple[asyncio.AbstractEventLoop, asyncio.Event]] = Queue()

async def init_app() -> web.Application:
async def handler(request: web.Request) -> web.Response:
thread = agent.get_new_thread()
result = await agent.run("Fix issue", thread=thread)
return web.Response(text=result.text or "")

app = web.Application()
app.add_routes([web.post("/run", handler)])
return app

def server_thread() -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

async def runner_main() -> None:
app = await init_app()
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
shutdown_event = asyncio.Event()
shutdown_queue.put((loop, shutdown_event))
port = site._server.sockets[0].getsockname()[1]
port_queue.put(port)
ready_event.set()
try:
await shutdown_event.wait()
finally:
await runner.cleanup()

try:
loop.run_until_complete(runner_main())
finally:
loop.close()

thread = threading.Thread(target=server_thread, daemon=True)
thread.start()
ready_event.wait(timeout=5)
assert ready_event.is_set()
loop_ref, shutdown_event = shutdown_queue.get(timeout=2)
port = port_queue.get(timeout=2)

async with aiohttp.ClientSession() as session, session.post(f"http://127.0.0.1:{port}/run") as response:
assert response.status == 200
await response.text()

loop_ref.call_soon_threadsafe(shutdown_event.set)
thread.join(timeout=5)
assert exec_counter == 1


@pytest.mark.parametrize(
"approval_required,num_functions",
[
Expand Down