Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 291 additions & 5 deletions python/packages/core/tests/core/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from unittest.mock import AsyncMock, MagicMock
from uuid import uuid4

import pytest
from pytest import raises

from agent_framework import (
Expand All @@ -17,16 +18,19 @@
ChatClientProtocol,
ChatMessage,
ChatMessageStore,
ChatOptions,
ChatResponse,
Content,
Context,
ContextProvider,
HostedCodeInterpreterTool,
Role,
ToolProtocol,
tool,
)
from agent_framework._agents import _merge_options, _sanitize_agent_name
from agent_framework._mcp import MCPTool
from agent_framework.exceptions import AgentExecutionException
from agent_framework.exceptions import AgentExecutionException, AgentInitializationError


def test_agent_thread_type(agent_thread: AgentThread) -> None:
Expand Down Expand Up @@ -631,8 +635,6 @@ async def test_chat_agent_tool_choice_agent_level_used_when_run_level_not_specif
chat_client_base: Any, tool_tool: Any
) -> None:
"""Verify that agent-level tool_choice is used when run() doesn't specify one."""
from agent_framework import ChatOptions

captured_options: list[ChatOptions] = []

original_inner = chat_client_base._inner_get_response
Expand Down Expand Up @@ -664,8 +666,6 @@ async def capturing_inner(

async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level(chat_client_base: Any, tool_tool: Any) -> None:
"""Verify that tool_choice=None at run() uses agent-level default."""
from agent_framework import ChatOptions

captured_options: list[ChatOptions] = []

original_inner = chat_client_base._inner_get_response
Expand All @@ -691,3 +691,289 @@ async def capturing_inner(
# Verify the client received tool_choice="auto" from agent-level
assert len(captured_options) >= 1
assert captured_options[0]["tool_choice"] == "auto"


# region Test _merge_options


def test_merge_options_basic():
"""Test _merge_options merges two dicts with override precedence."""
base = {"key1": "value1", "key2": "value2"}
override = {"key2": "new_value2", "key3": "value3"}

result = _merge_options(base, override)

assert result["key1"] == "value1"
assert result["key2"] == "new_value2"
assert result["key3"] == "value3"


def test_merge_options_none_values_ignored():
"""Test _merge_options ignores None values in override."""
base = {"key1": "value1"}
override = {"key1": None, "key2": "value2"}

result = _merge_options(base, override)

assert result["key1"] == "value1" # None didn't override
assert result["key2"] == "value2"


def test_merge_options_tools_combined():
"""Test _merge_options combines tool lists without duplicates."""

class MockTool:
def __init__(self, name):
self.name = name

tool1 = MockTool("tool1")
tool2 = MockTool("tool2")
tool3 = MockTool("tool1") # Duplicate name

base = {"tools": [tool1]}
override = {"tools": [tool2, tool3]}

result = _merge_options(base, override)

# Should have tool1 and tool2, but not duplicate tool3
assert len(result["tools"]) == 2
tool_names = [t.name for t in result["tools"]]
assert "tool1" in tool_names
assert "tool2" in tool_names


def test_merge_options_logit_bias_merged():
"""Test _merge_options merges logit_bias dicts."""
base = {"logit_bias": {"token1": 1.0}}
override = {"logit_bias": {"token2": 2.0}}

result = _merge_options(base, override)

assert result["logit_bias"]["token1"] == 1.0
assert result["logit_bias"]["token2"] == 2.0


def test_merge_options_metadata_merged():
"""Test _merge_options merges metadata dicts."""
base = {"metadata": {"key1": "value1"}}
override = {"metadata": {"key2": "value2"}}

result = _merge_options(base, override)

assert result["metadata"]["key1"] == "value1"
assert result["metadata"]["key2"] == "value2"


def test_merge_options_instructions_concatenated():
"""Test _merge_options concatenates instructions."""
base = {"instructions": "First instruction."}
override = {"instructions": "Second instruction."}

result = _merge_options(base, override)

assert "First instruction." in result["instructions"]
assert "Second instruction." in result["instructions"]
assert "\n" in result["instructions"]


# endregion


# region Test _sanitize_agent_name


def test_sanitize_agent_name_none():
"""Test _sanitize_agent_name returns None for None input."""
assert _sanitize_agent_name(None) is None


def test_sanitize_agent_name_valid():
"""Test _sanitize_agent_name returns valid names unchanged."""
assert _sanitize_agent_name("valid_name") == "valid_name"
assert _sanitize_agent_name("ValidName123") == "ValidName123"


def test_sanitize_agent_name_replaces_invalid_chars():
"""Test _sanitize_agent_name replaces invalid characters."""
result = _sanitize_agent_name("Agent Name!")
# Should replace spaces and special chars with underscores
assert " " not in result
assert "!" not in result


# endregion


# region Test AgentProtocol.get_new_thread and deserialize_thread


@pytest.mark.asyncio
async def test_agent_get_new_thread(chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol):
"""Test that get_new_thread returns a new AgentThread."""
agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool])

thread = agent.get_new_thread()

assert thread is not None
assert isinstance(thread, AgentThread)


@pytest.mark.asyncio
async def test_agent_get_new_thread_with_context_provider(
chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol
):
"""Test that get_new_thread passes context_provider to the thread."""

class TestContextProvider(ContextProvider):
async def invoking(self, messages, **kwargs):
return Context()

provider = TestContextProvider()
agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool], context_provider=provider)

thread = agent.get_new_thread()

assert thread is not None
assert thread.context_provider is provider


@pytest.mark.asyncio
async def test_agent_get_new_thread_with_service_thread_id(
chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol
):
"""Test that get_new_thread passes kwargs like service_thread_id to the thread."""
agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool])

thread = agent.get_new_thread(service_thread_id="test-thread-123")

assert thread is not None
assert thread.service_thread_id == "test-thread-123"


@pytest.mark.asyncio
async def test_agent_deserialize_thread(chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol):
"""Test deserialize_thread restores a thread from serialized state."""
agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool])

# Create serialized thread state with messages
serialized_state = {
"service_thread_id": None,
"chat_message_store_state": {
"messages": [{"role": "user", "text": "Hello"}],
},
}

thread = await agent.deserialize_thread(serialized_state)

assert thread is not None
assert isinstance(thread, AgentThread)
assert thread.message_store is not None
messages = await thread.message_store.list_messages()
assert len(messages) == 1
assert messages[0].text == "Hello"


# endregion


# region Test ChatAgent initialization edge cases


@pytest.mark.asyncio
async def test_chat_agent_raises_with_both_conversation_id_and_store():
"""Test ChatAgent raises error with both conversation_id and chat_message_store_factory."""
mock_client = MagicMock()
mock_store_factory = MagicMock()

with pytest.raises(AgentInitializationError, match="Cannot specify both"):
ChatAgent(
chat_client=mock_client,
default_options={"conversation_id": "test_id"},
chat_message_store_factory=mock_store_factory,
)


def test_chat_agent_calls_update_agent_name_on_client():
"""Test that ChatAgent calls _update_agent_name_and_description on client if available."""
mock_client = MagicMock()
mock_client._update_agent_name_and_description = MagicMock()

ChatAgent(
chat_client=mock_client,
name="TestAgent",
description="Test description",
)

mock_client._update_agent_name_and_description.assert_called_once_with("TestAgent", "Test description")


@pytest.mark.asyncio
async def test_chat_agent_context_provider_adds_tools_when_agent_has_none(chat_client_base: ChatClientProtocol):
"""Test that context provider tools are used when agent has no default tools."""

@tool
def context_tool(text: str) -> str:
"""A tool provided by context."""
return text

class ToolContextProvider(ContextProvider):
async def invoking(self, messages, **kwargs):
return Context(tools=[context_tool])

provider = ToolContextProvider()
agent = ChatAgent(chat_client=chat_client_base, context_provider=provider)

# Agent starts with empty tools list
assert agent.default_options.get("tools") == []

# Run the agent and verify context tools are added
_, options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)

# The context tools should now be in the options
assert options.get("tools") is not None
assert len(options["tools"]) == 1


@pytest.mark.asyncio
async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none(chat_client_base: ChatClientProtocol):
"""Test that context provider instructions are used when agent has no default instructions."""

class InstructionContextProvider(ContextProvider):
async def invoking(self, messages, **kwargs):
return Context(instructions="Context-provided instructions")

provider = InstructionContextProvider()
agent = ChatAgent(chat_client=chat_client_base, context_provider=provider)

# Verify agent has no default instructions
assert agent.default_options.get("instructions") is None

# Run the agent and verify context instructions are available
_, options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)

# The context instructions should now be in the options
assert options.get("instructions") == "Context-provided instructions"


@pytest.mark.asyncio
async def test_chat_agent_raises_on_conversation_id_mismatch(chat_client_base: ChatClientProtocol):
"""Test that ChatAgent raises when thread and agent have different conversation IDs."""
agent = ChatAgent(
chat_client=chat_client_base,
default_options={"conversation_id": "agent-conversation-id"},
)

# Create a thread with a different service_thread_id
thread = AgentThread(service_thread_id="different-thread-id")

with pytest.raises(AgentExecutionException, match="conversation_id set on the agent is different"):
await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage]
thread=thread, input_messages=[ChatMessage(role=Role.USER, text="Hello")]
)


# endregion
Loading
Loading