From 989df363b783ffc7a7c93098c312bc8c0ee112b2 Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Tue, 27 Jan 2026 14:51:40 -0800 Subject: [PATCH 1/4] Add core types and agents unit tests to improve coverage (#3356) --- .../packages/core/tests/core/test_agents.py | 170 +++++++++++ python/packages/core/tests/core/test_types.py | 276 ++++++++++++++++++ 2 files changed, 446 insertions(+) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index a331f6f75c..703aac8997 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 +import pytest from pytest import raises from agent_framework import ( @@ -695,3 +696,172 @@ async def capturing_inner( # Verify the client received tool_choice="auto" from agent-level assert len(captured_options) >= 1 assert captured_options[0]["tool_choice"] == "auto" + + +# region Test _merge_options + + +def test_merge_options_basic(): + """Test _merge_options merges two dicts with override precedence.""" + from agent_framework._agents import _merge_options + + base = {"key1": "value1", "key2": "value2"} + override = {"key2": "new_value2", "key3": "value3"} + + result = _merge_options(base, override) + + assert result["key1"] == "value1" + assert result["key2"] == "new_value2" + assert result["key3"] == "value3" + + +def test_merge_options_none_values_ignored(): + """Test _merge_options ignores None values in override.""" + from agent_framework._agents import _merge_options + + base = {"key1": "value1"} + override = {"key1": None, "key2": "value2"} + + result = _merge_options(base, override) + + assert result["key1"] == "value1" # None didn't override + assert result["key2"] == "value2" + + +def test_merge_options_tools_combined(): + """Test _merge_options combines tool lists without duplicates.""" + from agent_framework._agents import _merge_options + + class MockTool: + def __init__(self, name): + self.name = name + + tool1 = MockTool("tool1") + tool2 = MockTool("tool2") + tool3 = MockTool("tool1") # Duplicate name + + base = {"tools": [tool1]} + override = {"tools": [tool2, tool3]} + + result = _merge_options(base, override) + + # Should have tool1 and tool2, but not duplicate tool3 + assert len(result["tools"]) == 2 + tool_names = [t.name for t in result["tools"]] + assert "tool1" in tool_names + assert "tool2" in tool_names + + +def test_merge_options_logit_bias_merged(): + """Test _merge_options merges logit_bias dicts.""" + from agent_framework._agents import _merge_options + + base = {"logit_bias": {"token1": 1.0}} + override = {"logit_bias": {"token2": 2.0}} + + result = _merge_options(base, override) + + assert result["logit_bias"]["token1"] == 1.0 + assert result["logit_bias"]["token2"] == 2.0 + + +def test_merge_options_metadata_merged(): + """Test _merge_options merges metadata dicts.""" + from agent_framework._agents import _merge_options + + base = {"metadata": {"key1": "value1"}} + override = {"metadata": {"key2": "value2"}} + + result = _merge_options(base, override) + + assert result["metadata"]["key1"] == "value1" + assert result["metadata"]["key2"] == "value2" + + +def test_merge_options_instructions_concatenated(): + """Test _merge_options concatenates instructions.""" + from agent_framework._agents import _merge_options + + base = {"instructions": "First instruction."} + override = {"instructions": "Second instruction."} + + result = _merge_options(base, override) + + assert "First instruction." in result["instructions"] + assert "Second instruction." in result["instructions"] + assert "\n" in result["instructions"] + + +# endregion + + +# region Test _sanitize_agent_name + + +def test_sanitize_agent_name_none(): + """Test _sanitize_agent_name returns None for None input.""" + from agent_framework._agents import _sanitize_agent_name + + assert _sanitize_agent_name(None) is None + + +def test_sanitize_agent_name_valid(): + """Test _sanitize_agent_name returns valid names unchanged.""" + from agent_framework._agents import _sanitize_agent_name + + assert _sanitize_agent_name("valid_name") == "valid_name" + assert _sanitize_agent_name("ValidName123") == "ValidName123" + + +def test_sanitize_agent_name_replaces_invalid_chars(): + """Test _sanitize_agent_name replaces invalid characters.""" + from agent_framework._agents import _sanitize_agent_name + + result = _sanitize_agent_name("Agent Name!") + # Should replace spaces and special chars with underscores + assert " " not in result + assert "!" not in result + + +# endregion + + +# region Test AgentProtocol.get_new_thread + + +@pytest.mark.asyncio +async def test_agent_get_new_thread(chat_client_base, ai_function_tool): + """Test that get_new_thread returns a new AgentThread.""" + agent = ChatAgent(chat_client=chat_client_base, tools=[ai_function_tool]) + + thread = agent.get_new_thread() + + assert thread is not None + assert isinstance(thread, AgentThread) + + +# endregion + + +# region Test ChatAgent initialization edge cases + + +@pytest.mark.asyncio +async def test_chat_agent_raises_with_both_conversation_id_and_store(): + """Test ChatAgent raises error with both conversation_id and chat_message_store_factory.""" + from unittest.mock import MagicMock + + from agent_framework.exceptions import AgentInitializationError + + mock_client = MagicMock() + mock_store_factory = MagicMock() + + with pytest.raises(AgentInitializationError, match="Cannot specify both"): + ChatAgent( + chat_client=mock_client, + default_options={"conversation_id": "test_id"}, + chat_message_store_factory=mock_store_factory, + ) + + +# endregion diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 3e5317fdae..b9f6c645ad 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -2280,3 +2280,279 @@ def __init__(self): # endregion + + +# region Test Content._add_usage_content + + +def test_content_add_usage_content(): + """Test adding two usage content instances combines their usage details.""" + usage1 = Content( + type="usage", + usage_details={"input_token_count": 100, "output_token_count": 50}, + raw_representation="raw1", + ) + usage2 = Content( + type="usage", + usage_details={"input_token_count": 200, "output_token_count": 100}, + raw_representation="raw2", + ) + + result = usage1 + usage2 + + assert result.type == "usage" + assert result.usage_details["input_token_count"] == 300 + assert result.usage_details["output_token_count"] == 150 + # Raw representations should be combined + assert isinstance(result.raw_representation, list) + assert "raw1" in result.raw_representation + assert "raw2" in result.raw_representation + + +def test_content_add_usage_content_with_none_raw_representation(): + """Test adding usage content when one has None raw_representation.""" + usage1 = Content( + type="usage", + usage_details={"input_token_count": 100}, + raw_representation=None, + ) + usage2 = Content( + type="usage", + usage_details={"output_token_count": 50}, + raw_representation="raw2", + ) + + result = usage1 + usage2 + + assert result.raw_representation == "raw2" + + +def test_content_add_usage_content_non_integer_values(): + """Test adding usage content with non-integer values.""" + usage1 = Content( + type="usage", + usage_details={"model": "gpt-4", "count": 10}, + ) + usage2 = Content( + type="usage", + usage_details={"model": "gpt-3.5", "count": 20}, + ) + + result = usage1 + usage2 + + # Non-integer "model" should take first non-None value + assert result.usage_details["model"] == "gpt-4" + # Integer "count" should be summed + assert result.usage_details["count"] == 30 + + +# endregion + + +# region Test Content.has_top_level_media_type + + +def test_content_has_top_level_media_type(): + """Test has_top_level_media_type returns correct boolean.""" + image = Content(type="uri", uri="https://example.com/image.png", media_type="image/png") + + assert image.has_top_level_media_type("image") is True + assert image.has_top_level_media_type("IMAGE") is True # Case insensitive + assert image.has_top_level_media_type("audio") is False + + +def test_content_has_top_level_media_type_no_slash(): + """Test has_top_level_media_type when media_type has no slash.""" + content = Content(type="data", media_type="text") + + assert content.has_top_level_media_type("text") is True + + +def test_content_has_top_level_media_type_raises_without_media_type(): + """Test has_top_level_media_type raises ContentError when no media_type.""" + content = Content(type="text", text="hello") + + with raises(ContentError, match="no media_type found"): + content.has_top_level_media_type("text") + + +# endregion + + +# region Test Content.parse_arguments + + +def test_content_parse_arguments_none(): + """Test parse_arguments returns None when arguments is None.""" + content = Content(type="function_call", call_id="1", name="test", arguments=None) + + assert content.parse_arguments() is None + + +def test_content_parse_arguments_empty_string(): + """Test parse_arguments returns empty dict for empty string.""" + content = Content(type="function_call", call_id="1", name="test", arguments="") + + assert content.parse_arguments() == {} + + +def test_content_parse_arguments_valid_json(): + """Test parse_arguments parses valid JSON string.""" + content = Content(type="function_call", call_id="1", name="test", arguments='{"key": "value"}') + + result = content.parse_arguments() + assert result == {"key": "value"} + + +def test_content_parse_arguments_non_dict_json(): + """Test parse_arguments wraps non-dict JSON in 'raw' key.""" + content = Content(type="function_call", call_id="1", name="test", arguments='"just a string"') + + result = content.parse_arguments() + # The JSON is parsed, and if it's not a dict, wrapped in 'raw' + assert result == {"raw": "just a string"} + + +def test_content_parse_arguments_invalid_json(): + """Test parse_arguments wraps invalid JSON in 'raw' key.""" + content = Content(type="function_call", call_id="1", name="test", arguments="not json at all") + + result = content.parse_arguments() + assert result == {"raw": "not json at all"} + + +def test_content_parse_arguments_dict_passthrough(): + """Test parse_arguments passes through dict arguments.""" + args = {"key": "value", "num": 42} + content = Content(type="function_call", call_id="1", name="test", arguments=args) + + result = content.parse_arguments() + assert result == args + + +# endregion + + +# region Test _get_data_bytes_as_str + + +def test_get_data_bytes_as_str_non_data_uri(): + """Test _get_data_bytes_as_str returns None for non-data URIs.""" + from agent_framework._types import _get_data_bytes_as_str + + content = Content(type="uri", uri="https://example.com/image.png") + assert _get_data_bytes_as_str(content) is None + + +def test_get_data_bytes_as_str_no_base64(): + """Test _get_data_bytes_as_str raises for non-base64 data URI.""" + from agent_framework._types import _get_data_bytes_as_str + + content = Content(type="uri", uri="data:text/plain,hello") + with raises(ContentError, match="base64 encoding"): + _get_data_bytes_as_str(content) + + +def test_get_data_bytes_as_str_valid(): + """Test _get_data_bytes_as_str extracts base64 data.""" + from agent_framework._types import _get_data_bytes_as_str + + data = base64.b64encode(b"hello").decode() + content = Content(type="uri", uri=f"data:text/plain;base64,{data}") + result = _get_data_bytes_as_str(content) + assert result == data + + +# endregion + + +# region Test _get_data_bytes + + +def test_get_data_bytes_decodes_base64(): + """Test _get_data_bytes decodes base64 data correctly.""" + from agent_framework._types import _get_data_bytes + + original = b"hello world" + data = base64.b64encode(original).decode() + content = Content(type="uri", uri=f"data:text/plain;base64,{data}") + + result = _get_data_bytes(content) + assert result == original + + +def test_get_data_bytes_invalid_base64(): + """Test _get_data_bytes raises for invalid base64.""" + from agent_framework._types import _get_data_bytes + + content = Content(type="uri", uri="data:text/plain;base64,!!invalid!!") + with raises(ContentError, match="Failed to decode"): + _get_data_bytes(content) + + +# endregion + + +# region Test _parse_content_list + + +def test_parse_content_list_with_content_objects(): + """Test _parse_content_list passes through Content objects.""" + from agent_framework._types import _parse_content_list + + content = Content(type="text", text="hello") + result = _parse_content_list([content]) + + assert len(result) == 1 + assert result[0] is content + + +def test_parse_content_list_with_dicts(): + """Test _parse_content_list converts dicts to Content.""" + from agent_framework._types import _parse_content_list + + result = _parse_content_list([{"type": "text", "text": "hello"}]) + + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "hello" + + +def test_parse_content_list_with_mixed_valid_invalid(): + """Test _parse_content_list handles a mix of valid and Content objects.""" + from agent_framework._types import _parse_content_list + + content = Content(type="text", text="hello") + # Pass a mix of Content object and dict + result = _parse_content_list([content, {"type": "text", "text": "world"}]) + + assert len(result) == 2 + assert result[0].text == "hello" + assert result[1].text == "world" + + +# endregion + + +# region Test _validate_uri + + +def test_validate_uri_known_schema(): + """Test _validate_uri accepts known URI schemas.""" + from agent_framework._types import _validate_uri + + result = _validate_uri("https://example.com/file.txt", "text/plain") + assert result.get("uri") == "https://example.com/file.txt" + + +def test_validate_uri_data_uri(): + """Test _validate_uri handles data URIs.""" + from agent_framework._types import _validate_uri + + data = base64.b64encode(b"test").decode() + uri = f"data:text/plain;base64,{data}" + result = _validate_uri(uri, None) + assert "uri" in result + + +# endregion From eb41d7b75a44428cdcdc8d802f00bdeeedf3e10d Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Tue, 27 Jan 2026 15:27:08 -0800 Subject: [PATCH 2/4] Address PR comments: move imports to top, fix naming (schema->scheme) --- .../packages/core/tests/core/test_agents.py | 29 +------- python/packages/core/tests/core/test_types.py | 67 +++++-------------- 2 files changed, 18 insertions(+), 78 deletions(-) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 703aac8997..16f2181957 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -18,6 +18,7 @@ ChatClientProtocol, ChatMessage, ChatMessageStore, + ChatOptions, ChatResponse, Content, Context, @@ -26,8 +27,9 @@ Role, ai_function, ) +from agent_framework._agents import _merge_options, _sanitize_agent_name from agent_framework._mcp import MCPTool -from agent_framework.exceptions import AgentExecutionException +from agent_framework.exceptions import AgentExecutionException, AgentInitializationError def test_agent_thread_type(agent_thread: AgentThread) -> None: @@ -634,8 +636,6 @@ async def test_chat_agent_tool_choice_agent_level_used_when_run_level_not_specif chat_client_base: Any, ai_function_tool: Any ) -> None: """Verify that agent-level tool_choice is used when run() doesn't specify one.""" - from agent_framework import ChatOptions - captured_options: list[ChatOptions] = [] original_inner = chat_client_base._inner_get_response @@ -669,8 +669,6 @@ async def test_chat_agent_tool_choice_none_at_run_preserves_agent_level( chat_client_base: Any, ai_function_tool: Any ) -> None: """Verify that tool_choice=None at run() uses agent-level default.""" - from agent_framework import ChatOptions - captured_options: list[ChatOptions] = [] original_inner = chat_client_base._inner_get_response @@ -703,8 +701,6 @@ async def capturing_inner( def test_merge_options_basic(): """Test _merge_options merges two dicts with override precedence.""" - from agent_framework._agents import _merge_options - base = {"key1": "value1", "key2": "value2"} override = {"key2": "new_value2", "key3": "value3"} @@ -717,8 +713,6 @@ def test_merge_options_basic(): def test_merge_options_none_values_ignored(): """Test _merge_options ignores None values in override.""" - from agent_framework._agents import _merge_options - base = {"key1": "value1"} override = {"key1": None, "key2": "value2"} @@ -730,7 +724,6 @@ def test_merge_options_none_values_ignored(): def test_merge_options_tools_combined(): """Test _merge_options combines tool lists without duplicates.""" - from agent_framework._agents import _merge_options class MockTool: def __init__(self, name): @@ -754,8 +747,6 @@ def __init__(self, name): def test_merge_options_logit_bias_merged(): """Test _merge_options merges logit_bias dicts.""" - from agent_framework._agents import _merge_options - base = {"logit_bias": {"token1": 1.0}} override = {"logit_bias": {"token2": 2.0}} @@ -767,8 +758,6 @@ def test_merge_options_logit_bias_merged(): def test_merge_options_metadata_merged(): """Test _merge_options merges metadata dicts.""" - from agent_framework._agents import _merge_options - base = {"metadata": {"key1": "value1"}} override = {"metadata": {"key2": "value2"}} @@ -780,8 +769,6 @@ def test_merge_options_metadata_merged(): def test_merge_options_instructions_concatenated(): """Test _merge_options concatenates instructions.""" - from agent_framework._agents import _merge_options - base = {"instructions": "First instruction."} override = {"instructions": "Second instruction."} @@ -800,23 +787,17 @@ def test_merge_options_instructions_concatenated(): def test_sanitize_agent_name_none(): """Test _sanitize_agent_name returns None for None input.""" - from agent_framework._agents import _sanitize_agent_name - assert _sanitize_agent_name(None) is None def test_sanitize_agent_name_valid(): """Test _sanitize_agent_name returns valid names unchanged.""" - from agent_framework._agents import _sanitize_agent_name - assert _sanitize_agent_name("valid_name") == "valid_name" assert _sanitize_agent_name("ValidName123") == "ValidName123" def test_sanitize_agent_name_replaces_invalid_chars(): """Test _sanitize_agent_name replaces invalid characters.""" - from agent_framework._agents import _sanitize_agent_name - result = _sanitize_agent_name("Agent Name!") # Should replace spaces and special chars with underscores assert " " not in result @@ -849,10 +830,6 @@ async def test_agent_get_new_thread(chat_client_base, ai_function_tool): @pytest.mark.asyncio async def test_chat_agent_raises_with_both_conversation_id_and_store(): """Test ChatAgent raises error with both conversation_id and chat_message_store_factory.""" - from unittest.mock import MagicMock - - from agent_framework.exceptions import AgentInitializationError - mock_client = MagicMock() mock_store_factory = MagicMock() diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index b9f6c645ad..df3d9b55ab 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -2,11 +2,12 @@ import base64 from collections.abc import AsyncIterable +from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any +from typing import Any, Literal import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field, ValidationError from pytest import fixture, mark, raises from agent_framework import ( @@ -29,6 +30,14 @@ merge_chat_options, prepare_function_call_results, ) +from agent_framework._types import ( + _get_data_bytes, + _get_data_bytes_as_str, + _parse_content_list, + _validate_uri, + add_usage_details, + validate_tool_mode, +) from agent_framework.exceptions import ContentError @@ -439,8 +448,6 @@ def test_usage_details(): def test_usage_details_addition(): - from agent_framework._types import add_usage_details - usage1 = UsageDetails( input_token_count=5, output_token_count=10, @@ -478,8 +485,6 @@ def test_usage_details_additional_counts(): def test_usage_details_add_with_none_and_type_errors(): - from agent_framework._types import add_usage_details - u = UsageDetails(input_token_count=1) # add_usage_details with None returns the non-None value v = add_usage_details(u, None) @@ -665,9 +670,6 @@ def test_chat_response_with_format_init(): def test_chat_response_value_raises_on_invalid_schema(): """Test that value property raises ValidationError with field constraint details.""" - from typing import Literal - - from pydantic import Field, ValidationError class StrictSchema(BaseModel): id: Literal[5] @@ -689,9 +691,6 @@ class StrictSchema(BaseModel): def test_chat_response_try_parse_value_returns_none_on_invalid(): """Test that try_parse_value returns None on validation failure with Field constraints.""" - from typing import Literal - - from pydantic import Field class StrictSchema(BaseModel): id: Literal[5] @@ -707,7 +706,6 @@ class StrictSchema(BaseModel): def test_chat_response_try_parse_value_returns_value_on_success(): """Test that try_parse_value returns parsed value when all constraints pass.""" - from pydantic import Field class MySchema(BaseModel): name: str = Field(min_length=3) @@ -724,9 +722,6 @@ class MySchema(BaseModel): def test_agent_response_value_raises_on_invalid_schema(): """Test that AgentResponse.value property raises ValidationError with field constraint details.""" - from typing import Literal - - from pydantic import Field, ValidationError class StrictSchema(BaseModel): id: Literal[5] @@ -748,9 +743,6 @@ class StrictSchema(BaseModel): def test_agent_response_try_parse_value_returns_none_on_invalid(): """Test that AgentResponse.try_parse_value returns None on Field constraint failure.""" - from typing import Literal - - from pydantic import Field class StrictSchema(BaseModel): id: Literal[5] @@ -766,7 +758,6 @@ class StrictSchema(BaseModel): def test_agent_response_try_parse_value_returns_value_on_success(): """Test that AgentResponse.try_parse_value returns parsed value when all constraints pass.""" - from pydantic import Field class MySchema(BaseModel): name: str = Field(min_length=3) @@ -990,8 +981,6 @@ def test_chat_options_init() -> None: def test_chat_options_tool_choice_validation(): """Test validate_tool_mode utility function.""" - from agent_framework._types import validate_tool_mode - # Valid string values assert validate_tool_mode("auto") == {"mode": "auto"} assert validate_tool_mode("required") == {"mode": "required"} @@ -1019,8 +1008,6 @@ def test_chat_options_tool_choice_validation(): def test_chat_options_merge(ai_function_tool, ai_tool) -> None: """Test merge_chat_options utility function.""" - from agent_framework import merge_chat_options - options1: ChatOptions = { "model_id": "gpt-4o", "tools": [ai_function_tool], @@ -1502,8 +1489,6 @@ def test_comprehensive_to_dict_exclude_options(): def test_usage_details_iadd_edge_cases(): """Test UsageDetails addition with edge cases for better coverage.""" - from agent_framework._types import add_usage_details - # Test with None values u1 = UsageDetails(input_token_count=None, output_token_count=5, custom1=10) u2 = UsageDetails(input_token_count=3, output_token_count=None, custom2=20) @@ -2235,7 +2220,6 @@ def test_prepare_function_call_results_nested_pydantic_model(): def test_prepare_function_call_results_text_content_single(): """Test that objects with text attribute (like MCP TextContent) are properly handled.""" - from dataclasses import dataclass @dataclass class MockTextContent: @@ -2251,7 +2235,6 @@ class MockTextContent: def test_prepare_function_call_results_text_content_multiple(): """Test that multiple TextContent-like objects are serialized correctly.""" - from dataclasses import dataclass @dataclass class MockTextContent: @@ -2438,16 +2421,12 @@ def test_content_parse_arguments_dict_passthrough(): def test_get_data_bytes_as_str_non_data_uri(): """Test _get_data_bytes_as_str returns None for non-data URIs.""" - from agent_framework._types import _get_data_bytes_as_str - content = Content(type="uri", uri="https://example.com/image.png") assert _get_data_bytes_as_str(content) is None def test_get_data_bytes_as_str_no_base64(): """Test _get_data_bytes_as_str raises for non-base64 data URI.""" - from agent_framework._types import _get_data_bytes_as_str - content = Content(type="uri", uri="data:text/plain,hello") with raises(ContentError, match="base64 encoding"): _get_data_bytes_as_str(content) @@ -2455,8 +2434,6 @@ def test_get_data_bytes_as_str_no_base64(): def test_get_data_bytes_as_str_valid(): """Test _get_data_bytes_as_str extracts base64 data.""" - from agent_framework._types import _get_data_bytes_as_str - data = base64.b64encode(b"hello").decode() content = Content(type="uri", uri=f"data:text/plain;base64,{data}") result = _get_data_bytes_as_str(content) @@ -2471,8 +2448,6 @@ def test_get_data_bytes_as_str_valid(): def test_get_data_bytes_decodes_base64(): """Test _get_data_bytes decodes base64 data correctly.""" - from agent_framework._types import _get_data_bytes - original = b"hello world" data = base64.b64encode(original).decode() content = Content(type="uri", uri=f"data:text/plain;base64,{data}") @@ -2483,8 +2458,6 @@ def test_get_data_bytes_decodes_base64(): def test_get_data_bytes_invalid_base64(): """Test _get_data_bytes raises for invalid base64.""" - from agent_framework._types import _get_data_bytes - content = Content(type="uri", uri="data:text/plain;base64,!!invalid!!") with raises(ContentError, match="Failed to decode"): _get_data_bytes(content) @@ -2498,8 +2471,6 @@ def test_get_data_bytes_invalid_base64(): def test_parse_content_list_with_content_objects(): """Test _parse_content_list passes through Content objects.""" - from agent_framework._types import _parse_content_list - content = Content(type="text", text="hello") result = _parse_content_list([content]) @@ -2509,8 +2480,6 @@ def test_parse_content_list_with_content_objects(): def test_parse_content_list_with_dicts(): """Test _parse_content_list converts dicts to Content.""" - from agent_framework._types import _parse_content_list - result = _parse_content_list([{"type": "text", "text": "hello"}]) assert len(result) == 1 @@ -2518,10 +2487,8 @@ def test_parse_content_list_with_dicts(): assert result[0].text == "hello" -def test_parse_content_list_with_mixed_valid_invalid(): - """Test _parse_content_list handles a mix of valid and Content objects.""" - from agent_framework._types import _parse_content_list - +def test_parse_content_list_with_mixed_content_and_dict(): + """Test _parse_content_list handles a mix of Content objects and dicts.""" content = Content(type="text", text="hello") # Pass a mix of Content object and dict result = _parse_content_list([content, {"type": "text", "text": "world"}]) @@ -2537,18 +2504,14 @@ def test_parse_content_list_with_mixed_valid_invalid(): # region Test _validate_uri -def test_validate_uri_known_schema(): - """Test _validate_uri accepts known URI schemas.""" - from agent_framework._types import _validate_uri - +def test_validate_uri_known_scheme(): + """Test _validate_uri accepts known URI schemes.""" result = _validate_uri("https://example.com/file.txt", "text/plain") assert result.get("uri") == "https://example.com/file.txt" def test_validate_uri_data_uri(): """Test _validate_uri handles data URIs.""" - from agent_framework._types import _validate_uri - data = base64.b64encode(b"test").decode() uri = f"data:text/plain;base64,{data}" result = _validate_uri(uri, None) From 53a64defff7ce4d2e5520c348294bea3ddeb7fca Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Wed, 28 Jan 2026 12:03:47 -0800 Subject: [PATCH 3/4] fix failing test --- python/packages/core/tests/core/test_agents.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index d5587058bd..cee6f91934 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -807,9 +807,9 @@ def test_sanitize_agent_name_replaces_invalid_chars(): @pytest.mark.asyncio -async def test_agent_get_new_thread(chat_client_base, ai_function_tool): +async def test_agent_get_new_thread(chat_client_base, tool_tool): """Test that get_new_thread returns a new AgentThread.""" - agent = ChatAgent(chat_client=chat_client_base, tools=[ai_function_tool]) + agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool]) thread = agent.get_new_thread() From 4745a70bd383ea8f24252a302e55fd326243024c Mon Sep 17 00:00:00 2001 From: Giles Odigwe Date: Thu, 29 Jan 2026 11:37:10 -0800 Subject: [PATCH 4/4] improved agents coverage --- .../packages/core/tests/core/test_agents.py | 143 +++++++++++++++++- 1 file changed, 141 insertions(+), 2 deletions(-) diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index cee6f91934..1f4d1cadce 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -25,6 +25,7 @@ ContextProvider, HostedCodeInterpreterTool, Role, + ToolProtocol, tool, ) from agent_framework._agents import _merge_options, _sanitize_agent_name @@ -803,11 +804,11 @@ def test_sanitize_agent_name_replaces_invalid_chars(): # endregion -# region Test AgentProtocol.get_new_thread +# region Test AgentProtocol.get_new_thread and deserialize_thread @pytest.mark.asyncio -async def test_agent_get_new_thread(chat_client_base, tool_tool): +async def test_agent_get_new_thread(chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol): """Test that get_new_thread returns a new AgentThread.""" agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool]) @@ -817,6 +818,61 @@ async def test_agent_get_new_thread(chat_client_base, tool_tool): assert isinstance(thread, AgentThread) +@pytest.mark.asyncio +async def test_agent_get_new_thread_with_context_provider( + chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol +): + """Test that get_new_thread passes context_provider to the thread.""" + + class TestContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context() + + provider = TestContextProvider() + agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool], context_provider=provider) + + thread = agent.get_new_thread() + + assert thread is not None + assert thread.context_provider is provider + + +@pytest.mark.asyncio +async def test_agent_get_new_thread_with_service_thread_id( + chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol +): + """Test that get_new_thread passes kwargs like service_thread_id to the thread.""" + agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool]) + + thread = agent.get_new_thread(service_thread_id="test-thread-123") + + assert thread is not None + assert thread.service_thread_id == "test-thread-123" + + +@pytest.mark.asyncio +async def test_agent_deserialize_thread(chat_client_base: ChatClientProtocol, tool_tool: ToolProtocol): + """Test deserialize_thread restores a thread from serialized state.""" + agent = ChatAgent(chat_client=chat_client_base, tools=[tool_tool]) + + # Create serialized thread state with messages + serialized_state = { + "service_thread_id": None, + "chat_message_store_state": { + "messages": [{"role": "user", "text": "Hello"}], + }, + } + + thread = await agent.deserialize_thread(serialized_state) + + assert thread is not None + assert isinstance(thread, AgentThread) + assert thread.message_store is not None + messages = await thread.message_store.list_messages() + assert len(messages) == 1 + assert messages[0].text == "Hello" + + # endregion @@ -837,4 +893,87 @@ async def test_chat_agent_raises_with_both_conversation_id_and_store(): ) +def test_chat_agent_calls_update_agent_name_on_client(): + """Test that ChatAgent calls _update_agent_name_and_description on client if available.""" + mock_client = MagicMock() + mock_client._update_agent_name_and_description = MagicMock() + + ChatAgent( + chat_client=mock_client, + name="TestAgent", + description="Test description", + ) + + mock_client._update_agent_name_and_description.assert_called_once_with("TestAgent", "Test description") + + +@pytest.mark.asyncio +async def test_chat_agent_context_provider_adds_tools_when_agent_has_none(chat_client_base: ChatClientProtocol): + """Test that context provider tools are used when agent has no default tools.""" + + @tool + def context_tool(text: str) -> str: + """A tool provided by context.""" + return text + + class ToolContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context(tools=[context_tool]) + + provider = ToolContextProvider() + agent = ChatAgent(chat_client=chat_client_base, context_provider=provider) + + # Agent starts with empty tools list + assert agent.default_options.get("tools") == [] + + # Run the agent and verify context tools are added + _, options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] + thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")] + ) + + # The context tools should now be in the options + assert options.get("tools") is not None + assert len(options["tools"]) == 1 + + +@pytest.mark.asyncio +async def test_chat_agent_context_provider_adds_instructions_when_agent_has_none(chat_client_base: ChatClientProtocol): + """Test that context provider instructions are used when agent has no default instructions.""" + + class InstructionContextProvider(ContextProvider): + async def invoking(self, messages, **kwargs): + return Context(instructions="Context-provided instructions") + + provider = InstructionContextProvider() + agent = ChatAgent(chat_client=chat_client_base, context_provider=provider) + + # Verify agent has no default instructions + assert agent.default_options.get("instructions") is None + + # Run the agent and verify context instructions are available + _, options, _ = await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] + thread=None, input_messages=[ChatMessage(role=Role.USER, text="Hello")] + ) + + # The context instructions should now be in the options + assert options.get("instructions") == "Context-provided instructions" + + +@pytest.mark.asyncio +async def test_chat_agent_raises_on_conversation_id_mismatch(chat_client_base: ChatClientProtocol): + """Test that ChatAgent raises when thread and agent have different conversation IDs.""" + agent = ChatAgent( + chat_client=chat_client_base, + default_options={"conversation_id": "agent-conversation-id"}, + ) + + # Create a thread with a different service_thread_id + thread = AgentThread(service_thread_id="different-thread-id") + + with pytest.raises(AgentExecutionException, match="conversation_id set on the agent is different"): + await agent._prepare_thread_and_messages( # type: ignore[reportPrivateUsage] + thread=thread, input_messages=[ChatMessage(role=Role.USER, text="Hello")] + ) + + # endregion