From 12db37be8a8dbd1a0545540f2422dc9ce5c397df Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 16 Jan 2026 15:44:42 +0100 Subject: [PATCH 1/7] ported Content to a new model --- .../a2a/agent_framework_a2a/_agent.py | 19 +- python/packages/a2a/tests/test_a2a_agent.py | 32 +- .../ag-ui/agent_framework_ag_ui/_client.py | 43 +- .../_event_converters.py | 15 +- .../ag-ui/agent_framework_ag_ui/_events.py | 41 +- .../_message_adapters.py | 83 +- .../_orchestration/_helpers.py | 27 +- .../_orchestration/_state_manager.py | 4 +- .../agent_framework_ag_ui/_orchestrators.py | 35 +- .../ag-ui/agent_framework_ag_ui/_types.py | 3 +- .../agents/task_steps_agent.py | 7 +- .../packages/ag-ui/tests/test_ag_ui_client.py | 28 +- .../tests/test_agent_wrapper_comprehensive.py | 52 +- .../tests/test_backend_tool_rendering.py | 16 +- .../ag-ui/tests/test_document_writer_flow.py | 20 +- python/packages/ag-ui/tests/test_endpoint.py | 4 +- .../ag-ui/tests/test_events_comprehensive.py | 121 +- .../ag-ui/tests/test_human_in_the_loop.py | 22 +- .../ag-ui/tests/test_message_adapters.py | 66 +- .../ag-ui/tests/test_message_hygiene.py | 10 +- .../ag-ui/tests/test_orchestrators.py | 10 +- .../tests/test_orchestrators_coverage.py | 88 +- .../packages/ag-ui/tests/test_shared_state.py | 4 +- .../ag-ui/tests/test_state_manager.py | 4 +- .../ag-ui/tests/test_structured_output.py | 14 +- .../packages/ag-ui/tests/utils_test_ag_ui.py | 4 +- .../agent_framework_anthropic/_chat_client.py | 173 +- .../anthropic/tests/test_anthropic_client.py | 43 +- .../agent_framework_azure_ai/_chat_client.py | 183 +- .../agent_framework_azure_ai/_client.py | 3 +- .../tests/test_azure_ai_agent_client.py | 183 +- .../azure-ai/tests/test_azure_ai_client.py | 20 +- .../_durable_agent_state.py | 110 +- .../_entities.py | 4 +- .../packages/azurefunctions/tests/test_app.py | 4 +- .../azurefunctions/tests/test_entities.py | 10 +- .../agent_framework_bedrock/_chat_client.py | 88 +- .../bedrock/tests/test_bedrock_client.py | 10 +- .../bedrock/tests/test_bedrock_settings.py | 20 +- .../agent_framework_chatkit/_converter.py | 30 +- .../agent_framework_chatkit/_streaming.py | 4 +- .../packages/chatkit/tests/test_converter.py | 22 +- .../packages/chatkit/tests/test_streaming.py | 15 +- .../agent_framework_copilotstudio/_agent.py | 4 +- .../copilotstudio/tests/test_copilot_agent.py | 19 +- .../packages/core/agent_framework/_agents.py | 4 +- python/packages/core/agent_framework/_mcp.py | 102 +- .../packages/core/agent_framework/_tools.py | 196 +- .../packages/core/agent_framework/_types.py | 2982 ++++++----------- .../core/agent_framework/_workflows/_agent.py | 52 +- .../_workflows/_agent_executor.py | 16 +- .../_workflows/_orchestrator_helpers.py | 4 +- .../agent_framework/azure/_chat_client.py | 15 +- .../core/agent_framework/observability.py | 16 +- .../openai/_assistants_client.py | 49 +- .../agent_framework/openai/_chat_client.py | 63 +- .../openai/_responses_client.py | 207 +- .../azure/test_azure_assistants_client.py | 5 +- .../tests/azure/test_azure_chat_client.py | 15 +- .../azure/test_azure_responses_client.py | 6 +- python/packages/core/tests/core/conftest.py | 10 +- .../packages/core/tests/core/test_agents.py | 17 +- .../core/test_as_tool_kwargs_propagation.py | 8 +- .../core/test_function_invocation_logic.py | 444 +-- .../test_kwargs_propagation_to_ai_function.py | 21 +- python/packages/core/tests/core/test_mcp.py | 81 +- .../core/tests/core/test_middleware.py | 50 +- .../core/test_middleware_context_result.py | 14 +- .../tests/core/test_middleware_with_agent.py | 66 +- .../tests/core/test_middleware_with_chat.py | 6 +- python/packages/core/tests/core/test_tools.py | 205 +- python/packages/core/tests/core/test_types.py | 869 +++-- .../openai/test_openai_assistants_client.py | 37 +- .../tests/openai/test_openai_chat_client.py | 37 +- .../openai/test_openai_responses_client.py | 131 +- .../core/tests/test_observability_datetime.py | 4 +- .../tests/workflow/test_agent_executor.py | 4 +- .../test_agent_executor_tool_calls.py | 55 +- .../tests/workflow/test_full_conversation.py | 6 +- .../core/tests/workflow/test_group_chat.py | 8 +- .../core/tests/workflow/test_handoff.py | 13 +- .../core/tests/workflow/test_magentic.py | 8 +- .../core/tests/workflow/test_sequential.py | 4 +- .../core/tests/workflow/test_workflow.py | 4 +- .../tests/workflow/test_workflow_agent.py | 119 +- .../tests/workflow/test_workflow_kwargs.py | 4 +- .../agent_framework_declarative/_loader.py | 6 +- .../_workflows/_executors_agents.py | 19 +- .../agent_framework_devui/_conversations.py | 2 +- .../devui/agent_framework_devui/_executor.py | 32 +- .../devui/agent_framework_devui/_mapper.py | 54 +- .../frontend/src/types/agent-framework.ts | 14 +- .../devui/frontend/src/types/index.ts | 2 +- .../devui/tests/test_cleanup_hooks.py | 8 +- python/packages/devui/tests/test_discovery.py | 6 +- python/packages/devui/tests/test_execution.py | 6 +- python/packages/devui/tests/test_helpers.py | 32 +- python/packages/devui/tests/test_mapper.py | 39 +- .../devui/tests/test_multimodal_workflow.py | 12 +- .../agent_framework_lab_lightning/__init__.py | 5 +- .../lab/lightning/tests/test_lightning.py | 2 +- .../_message_utils.py | 4 +- .../lab/tau2/tests/test_message_utils.py | 64 +- .../lab/tau2/tests/test_sliding_window.py | 50 +- .../lab/tau2/tests/test_tau2_utils.py | 36 +- .../mem0/tests/test_mem0_context_provider.py | 7 +- .../agent_framework_ollama/_chat_client.py | 33 +- .../ollama/tests/test_ollama_chat_client.py | 29 +- .../tests/test_redis_chat_message_store.py | 4 +- .../agents/anthropic/anthropic_advanced.py | 2 +- .../agents/anthropic/anthropic_foundry.py | 2 +- .../agents/anthropic/anthropic_skills.py | 2 +- ...i_with_code_interpreter_file_generation.py | 13 +- .../azure_ai_with_azure_ai_search.py | 8 +- .../azure_ai_with_bing_grounding_citations.py | 10 +- .../agents/ollama/ollama_chat_multimodal.py | 6 +- .../openai_responses_client_reasoning.py | 2 +- .../multimodal_input/README.md | 4 +- .../multimodal_input/azure_chat_multimodal.py | 7 +- .../azure_responses_multimodal.py | 11 +- .../openai_chat_multimodal.py | 15 +- .../workflow_as_agent_reflection_pattern.py | 4 +- python/uv.lock | 16 +- 123 files changed, 3722 insertions(+), 4520 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 5f8c27b182..84cc8fb20c 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -31,11 +31,8 @@ AgentThread, BaseAgent, ChatMessage, - Contents, - DataContent, + Content, Role, - TextContent, - UriContent, normalize_messages, prepend_agent_framework_to_user_agent, ) @@ -362,19 +359,19 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage: metadata=cast(dict[str, Any], message.additional_properties), ) - def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]: - """Parse A2A Parts into Agent Framework Contents. + def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]: + """Parse A2A Parts into Agent Framework Content. Transforms A2A protocol Parts into framework-native Content objects, handling text, file (URI/bytes), and data parts with metadata preservation. """ - contents: list[Contents] = [] + contents: list[Content] = [] for part in parts: inner_part = part.root match inner_part.kind: case "text": contents.append( - TextContent( + Content.from_text( text=inner_part.text, additional_properties=inner_part.metadata, raw_representation=inner_part, @@ -383,7 +380,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]: case "file": if isinstance(inner_part.file, FileWithUri): contents.append( - UriContent( + Content.from_uri( uri=inner_part.file.uri, media_type=inner_part.file.mime_type or "", additional_properties=inner_part.metadata, @@ -392,7 +389,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]: ) elif isinstance(inner_part.file, FileWithBytes): contents.append( - DataContent( + Content.from_data( data=base64.b64decode(inner_part.file.bytes), media_type=inner_part.file.mime_type or "", additional_properties=inner_part.metadata, @@ -401,7 +398,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]: ) case "data": contents.append( - TextContent( + Content.from_text( text=json.dumps(inner_part.data), additional_properties=inner_part.metadata, raw_representation=inner_part, diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 5d77345b20..eca97b2ac6 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -24,12 +24,8 @@ AgentResponse, AgentResponseUpdate, ChatMessage, - DataContent, - ErrorContent, - HostedFileContent, + Content, Role, - TextContent, - UriContent, ) from agent_framework.a2a import A2AAgent from pytest import fixture, raises @@ -289,8 +285,8 @@ def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None: # Verify conversion assert len(contents) == 2 - assert isinstance(contents[0], TextContent) - assert isinstance(contents[1], TextContent) + assert contents[0].type == "text" + assert contents[1].type == "text" assert contents[0].text == "First part" assert contents[1].text == "Second part" @@ -299,7 +295,7 @@ def test_prepare_message_for_a2a_with_error_content(a2a_agent: A2AAgent) -> None """Test _prepare_message_for_a2a with ErrorContent.""" # Create ChatMessage with ErrorContent - error_content = ErrorContent(message="Test error message") + error_content = Content.from_error(message="Test error message") message = ChatMessage(role=Role.USER, contents=[error_content]) # Convert to A2A message @@ -314,7 +310,7 @@ def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None: """Test _prepare_message_for_a2a with UriContent.""" # Create ChatMessage with UriContent - uri_content = UriContent(uri="http://example.com/file.pdf", media_type="application/pdf") + uri_content = Content.from_uri(uri="http://example.com/file.pdf", media_type="application/pdf") message = ChatMessage(role=Role.USER, contents=[uri_content]) # Convert to A2A message @@ -330,7 +326,7 @@ def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None: """Test _prepare_message_for_a2a with DataContent.""" # Create ChatMessage with DataContent (base64 data URI) - data_content = DataContent(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain") + data_content = Content.from_uri(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain") message = ChatMessage(role=Role.USER, contents=[data_content]) # Convert to A2A message @@ -368,7 +364,7 @@ async def test_run_stream_with_message_response(a2a_agent: A2AAgent, mock_a2a_cl assert len(updates[0].contents) == 1 content = updates[0].contents[0] - assert isinstance(content, TextContent) + assert content.type == "text" assert content.text == "Streaming response from agent!" assert updates[0].response_id == "msg-stream-123" @@ -414,10 +410,10 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None: message = ChatMessage( role=Role.USER, contents=[ - TextContent(text="Here's the analysis:"), - DataContent(data=b"binary data", media_type="application/octet-stream"), - UriContent(uri="https://example.com/image.png", media_type="image/png"), - TextContent(text='{"structured": "data"}'), + Content.from_text(text="Here's the analysis:"), + Content.from_data(data=b"binary data", media_type="application/octet-stream"), + Content.from_uri(uri="https://example.com/image.png", media_type="image/png"), + Content.from_text(text='{"structured": "data"}'), ], ) @@ -445,7 +441,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None: assert len(contents) == 1 - assert isinstance(contents[0], TextContent) + assert contents[0].type == "text" assert contents[0].text == '{"key": "value", "number": 42}' assert contents[0].additional_properties == {"source": "test"} @@ -470,7 +466,7 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None: # Create message with hosted file content message = ChatMessage( role=Role.USER, - contents=[HostedFileContent(file_id="hosted://storage/document.pdf")], + contents=[Content.from_hosted_file(file_id="hosted://storage/document.pdf")], ) result = agent._prepare_message_for_a2a(message) # noqa: SLF001 @@ -507,7 +503,7 @@ def test_parse_contents_from_a2a_with_hosted_file_uri() -> None: assert len(contents) == 1 - assert isinstance(contents[0], UriContent) + assert contents[0].type == "uri" assert contents[0].uri == "hosted://storage/document.pdf" assert contents[0].media_type == "" # Converted None to empty string diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index e31036803c..3da099b3b8 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -17,12 +17,10 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - DataContent, - FunctionCallContent, + Content, + use_chat_middleware, + use_function_invocation, ) -from agent_framework._middleware import use_chat_middleware -from agent_framework._tools import use_function_invocation -from agent_framework._types import BaseContent, Contents from agent_framework.observability import use_instrumentation from ._event_converters import AGUIEventConverter @@ -53,26 +51,11 @@ logger: logging.Logger = logging.getLogger(__name__) -class ServerFunctionCallContent(BaseContent): - """Wrapper for server function calls to prevent client re-execution. - - All function calls from the remote server are server-side executions. - This wrapper prevents @use_function_invocation from trying to execute them again. - """ - - function_call_content: FunctionCallContent - - def __init__(self, function_call_content: FunctionCallContent) -> None: - """Initialize with the function call content.""" - super().__init__(type="server_function_call") - self.function_call_content = function_call_content - - -def _unwrap_server_function_call_contents(contents: MutableSequence[Contents | dict[str, Any]]) -> None: - """Replace ServerFunctionCallContent instances with their underlying call content.""" +def _unwrap_server_function_call_contents(contents: MutableSequence[Content | dict[str, Any]]) -> None: + """Replace server_function_call instances with their underlying call content.""" for idx, content in enumerate(contents): - if isinstance(content, ServerFunctionCallContent): - contents[idx] = content.function_call_content # type: ignore[assignment] + if content.type == "server_function_call": + contents[idx] = content.function_call # type: ignore[assignment] TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient[Any]]) @@ -93,7 +76,7 @@ def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseCha @wraps(original_get_streaming_response) async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: async for update in original_get_streaming_response(self, *args, **kwargs): - _unwrap_server_function_call_contents(cast(MutableSequence[Contents | dict[str, Any]], update.contents)) + _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], update.contents)) yield update chat_client.get_streaming_response = streaming_wrapper # type: ignore[assignment] @@ -105,9 +88,7 @@ async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse: response = await original_get_response(self, *args, **kwargs) if response.messages: for message in response.messages: - _unwrap_server_function_call_contents( - cast(MutableSequence[Contents | dict[str, Any]], message.contents) - ) + _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents)) return response chat_client.get_response = response_wrapper # type: ignore[assignment] @@ -289,7 +270,7 @@ def _extract_state_from_messages( last_message = messages[-1] for content in last_message.contents: - if isinstance(content, DataContent) and content.media_type == "application/json": + if isinstance(content, Content) and content.type == "data" and content.media_type == "application/json": try: uri = content.uri if uri.startswith("data:application/json;base64,"): @@ -433,7 +414,7 @@ async def _inner_get_streaming_response( ) # Distinguish client vs server tools for i, content in enumerate(update.contents): - if isinstance(content, FunctionCallContent): + if content.type == "function_call": logger.debug( f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" ) @@ -446,6 +427,6 @@ async def _inner_get_streaming_response( # Server tool - wrap so @use_function_invocation ignores it logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") self._register_server_tool_placeholder(content.name) - update.contents[i] = ServerFunctionCallContent(content) # type: ignore + update.contents[i] = Content(type="server_function_call", function_call=content) # type: ignore yield update diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py b/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py index 0f485739c9..bd2d989f2a 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py @@ -6,12 +6,9 @@ from agent_framework import ( ChatResponseUpdate, - ErrorContent, + Content, FinishReason, - FunctionCallContent, - FunctionResultContent, Role, - TextContent, ) @@ -117,7 +114,7 @@ def _handle_text_message_content(self, event: dict[str, Any]) -> ChatResponseUpd return ChatResponseUpdate( role=Role.ASSISTANT, message_id=self.current_message_id, - contents=[TextContent(text=delta)], + contents=[Content.from_text(text=delta)], ) def _handle_text_message_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None: @@ -133,7 +130,7 @@ def _handle_tool_call_start(self, event: dict[str, Any]) -> ChatResponseUpdate: return ChatResponseUpdate( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id=self.current_tool_call_id or "", name=self.current_tool_name or "", arguments="", @@ -149,7 +146,7 @@ def _handle_tool_call_args(self, event: dict[str, Any]) -> ChatResponseUpdate: return ChatResponseUpdate( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id=self.current_tool_call_id or "", name=self.current_tool_name or "", arguments=delta, @@ -170,7 +167,7 @@ def _handle_tool_call_result(self, event: dict[str, Any]) -> ChatResponseUpdate: return ChatResponseUpdate( role=Role.TOOL, contents=[ - FunctionResultContent( + Content.from_function_result( call_id=tool_call_id, result=result, ) @@ -197,7 +194,7 @@ def _handle_run_error(self, event: dict[str, Any]) -> ChatResponseUpdate: role=Role.ASSISTANT, finish_reason=FinishReason.CONTENT_FILTER, contents=[ - ErrorContent( + Content.from_error( message=error_message, error_code="RUN_ERROR", ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index ddf3ebba01..bff1780170 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -25,10 +25,7 @@ ) from agent_framework import ( AgentResponseUpdate, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, - TextContent, + Content, prepare_function_call_results, ) @@ -96,18 +93,20 @@ async def from_agent_run_update(self, update: AgentResponseUpdate) -> list[BaseE logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items") for idx, content in enumerate(update.contents): logger.info(f" Content {idx}: type={type(content).__name__}") - if isinstance(content, TextContent): - events.extend(self._handle_text_content(content)) - elif isinstance(content, FunctionCallContent): - events.extend(self._handle_function_call_content(content)) - elif isinstance(content, FunctionResultContent): - events.extend(self._handle_function_result_content(content)) - elif isinstance(content, FunctionApprovalRequestContent): - events.extend(self._handle_function_approval_request_content(content)) - + match content.type: + case "text": + events.extend(self._handle_text_content(content)) + case "function_call": + events.extend(self._handle_function_call_content(content)) + case "function_result": + events.extend(self._handle_function_result_content(content)) + case "function_approval_request": + events.extend(self._handle_function_approval_request_content(content)) + case _: + logger.warning(f" Unsupported content type: {content.type}, skipping.") return events - def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: + def _handle_text_content(self, content: Content) -> list[BaseEvent]: events: list[BaseEvent] = [] logger.info(f" TextContent found: length={len(content.text)}") logger.info( @@ -150,14 +149,14 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: events.append(event) return events - def _handle_function_call_content(self, content: FunctionCallContent) -> list[BaseEvent]: + def _handle_function_call_content(self, content: Content) -> list[BaseEvent]: events: list[BaseEvent] = [] if content.name: logger.debug(f"Tool call: {content.name} (call_id: {content.call_id})") if not content.name and not content.call_id and not self.current_tool_call_name: args_length = len(str(content.arguments)) if content.arguments else 0 - logger.warning(f"FunctionCallContent missing name and call_id. args_length={args_length}") + logger.warning(f"Content missing name and call_id. args_length={args_length}") tool_call_id = self._coalesce_tool_call_id(content) # Only emit ToolCallStartEvent once per tool call (when it's a new tool call) @@ -190,7 +189,7 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba return events - def _coalesce_tool_call_id(self, content: FunctionCallContent) -> str: + def _coalesce_tool_call_id(self, content: Content) -> str: if content.call_id: return content.call_id if self.current_tool_call_id: @@ -286,7 +285,7 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.pending_state_updates[state_key] = state_value return events - def _handle_function_result_content(self, content: FunctionResultContent) -> list[BaseEvent]: + def _handle_function_result_content(self, content: Content) -> list[BaseEvent]: events: list[BaseEvent] = [] if content.call_id: end_event = ToolCallEndEvent( @@ -367,7 +366,7 @@ def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: self.current_tool_call_name = None return events - def _emit_confirm_changes_tool_call(self, function_call: FunctionCallContent | None = None) -> list[BaseEvent]: + def _emit_confirm_changes_tool_call(self, function_call: Content | None = None) -> list[BaseEvent]: """Emit a confirm_changes tool call for Dojo UI compatibility. Args: @@ -419,7 +418,7 @@ def _emit_confirm_changes_tool_call(self, function_call: FunctionCallContent | N logger.info("Set flag to stop run after confirm_changes") return events - def _emit_function_approval_tool_call(self, function_call: FunctionCallContent) -> list[BaseEvent]: + def _emit_function_approval_tool_call(self, function_call: Content) -> list[BaseEvent]: """Emit a tool call that can drive UI approval for function requests.""" tool_call_name = "confirm_changes" if self.approval_tool_name and self.approval_tool_name != function_call.name: @@ -462,7 +461,7 @@ def _emit_function_approval_tool_call(self, function_call: FunctionCallContent) logger.info("Set flag to stop run after confirm_changes") return events - def _handle_function_approval_request_content(self, content: FunctionApprovalRequestContent) -> list[BaseEvent]: + def _handle_function_approval_request_content(self, content: Content) -> list[BaseEvent]: events: list[BaseEvent] = [] logger.info("=== FUNCTION APPROVAL REQUEST ===") logger.info(f" Function: {content.function_call.name}") diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 1ff858e9f5..b46a7cd288 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -8,11 +8,8 @@ from agent_framework import ( ChatMessage, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, prepare_function_call_results, ) @@ -40,11 +37,11 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: tool_ids = { str(content.call_id) for content in msg.contents or [] - if isinstance(content, FunctionCallContent) and content.call_id + if content.type == "function_call" and content.call_id } confirm_changes_call = None for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": + if content.type == "function_call" and content.name == "confirm_changes": confirm_changes_call = content break @@ -59,7 +56,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: approval_call_ids: set[str] = set() approval_accepted: bool | None = None for content in msg.contents or []: - if type(content) is FunctionApprovalResponseContent: + if content.type == "function_approval_response": if content.function_call and content.function_call.call_id: approval_call_ids.add(str(content.function_call.call_id)) if approval_accepted is None: @@ -79,7 +76,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: synthetic_result = ChatMessage( role="tool", contents=[ - FunctionResultContent( + Content.from_function_result( call_id=pending_confirm_changes_id, result="Confirmed" if approval_accepted else "Rejected", ) @@ -93,7 +90,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: if pending_confirm_changes_id: user_text = "" for content in msg.contents or []: - if isinstance(content, TextContent): + if content.type == "text": user_text = content.text break @@ -106,7 +103,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: synthetic_result = ChatMessage( role="tool", contents=[ - FunctionResultContent( + Content.from_function_result( call_id=pending_confirm_changes_id, result="Confirmed" if parsed.get("accepted") else "Rejected", ) @@ -130,7 +127,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: synthetic_result = ChatMessage( role="tool", contents=[ - FunctionResultContent( + Content.from_function_result( call_id=pending_call_id, result="Tool execution skipped - user provided follow-up message", ) @@ -149,7 +146,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: continue keep = False for content in msg.contents or []: - if isinstance(content, FunctionResultContent): + if content.type == "function_result" and content.call_id: call_id = str(content.call_id) if call_id in pending_tool_call_ids: keep = True @@ -175,7 +172,7 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: for idx, msg in enumerate(messages): role_value = get_role_value(msg) - if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): + if role_value == "tool" and msg.contents and msg.contents[0].type == "function_result": call_id = str(msg.contents[0].call_id) key: Any = (role_value, call_id) @@ -184,7 +181,7 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: existing_msg = unique_messages[existing_idx] existing_result = None - if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): + if existing_msg.contents and existing_msg.contents[0].type == "function_result": existing_result = existing_msg.contents[0].result new_result = msg.contents[0].result @@ -198,11 +195,9 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: seen_keys[key] = len(unique_messages) unique_messages.append(msg) - elif ( - role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents) - ): + elif role_value == "assistant" and msg.contents and any(c.type == "function_call" for c in msg.contents): tool_call_ids = tuple( - sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) + sorted(str(c.call_id) for c in msg.contents if c.type == "function_call" and c.call_id) ) key = (role_value, tool_call_ids) @@ -275,15 +270,14 @@ def _update_tool_call_arguments( function_payload_dict["arguments"] = modified_args return - def _find_matching_func_call(call_id: str) -> FunctionCallContent | None: + def _find_matching_func_call(call_id: str) -> Content | None: for prev_msg in result: role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) if role_val != "assistant": continue for content in prev_msg.contents or []: - if isinstance(content, FunctionCallContent): - if content.call_id == call_id and content.name != "confirm_changes": - return content + if content.type == "function_call" and content.call_id == call_id and content.name != "confirm_changes": + return content return None def _parse_arguments(arguments: Any) -> dict[str, Any] | None: @@ -301,9 +295,9 @@ def _resolve_approval_call_id(tool_call_id: str, parsed_payload: dict[str, Any] continue direct_call = None confirm_call = None - sibling_calls: list[FunctionCallContent] = [] + sibling_calls: list[Content] = [] for content in prev_msg.contents or []: - if not isinstance(content, FunctionCallContent): + if content.type != "function_call": continue if content.call_id == tool_call_id: direct_call = content @@ -407,7 +401,7 @@ def _filter_modified_args( if not ( (m.role.value if hasattr(m.role, "value") else str(m.role)) == "tool" and any( - isinstance(c, FunctionResultContent) and c.call_id == approval_call_id + c.type == "function_result" and c.call_id == approval_call_id for c in (m.contents or []) ) ) @@ -465,7 +459,7 @@ def _filter_modified_args( matching_func_call.arguments = updated_args _update_tool_call_arguments(messages, str(approval_call_id), merged_args) # Create a new FunctionCallContent with the modified arguments - func_call_for_approval = FunctionCallContent( + func_call_for_approval = Content.from_function_call( call_id=matching_func_call.call_id, name=matching_func_call.name, arguments=json.dumps(filtered_args), @@ -476,7 +470,7 @@ def _filter_modified_args( func_call_for_approval = matching_func_call # Create FunctionApprovalResponseContent for the agent framework - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( approved=accepted, id=str(approval_call_id), function_call=func_call_for_approval, @@ -491,7 +485,7 @@ def _filter_modified_args( # Keep the old behavior for backwards compatibility chat_msg = ChatMessage( role=Role.USER, - contents=[TextContent(text=approval_payload_text)], + contents=[Content.from_text(text=approval_payload_text)], additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")}, ) if "id" in msg: @@ -511,7 +505,7 @@ def _filter_modified_args( func_result = str(result_content) chat_msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id=str(tool_call_id), result=func_result)], + contents=[Content.from_function_result(call_id=str(tool_call_id), result=func_result)], ) if "id" in msg: chat_msg.message_id = msg["id"] @@ -527,21 +521,21 @@ def _filter_modified_args( chat_msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id=str(tool_call_id), result=result_content)], + contents=[Content.from_function_result(call_id=str(tool_call_id), result=result_content)], ) if "id" in msg: chat_msg.message_id = msg["id"] result.append(chat_msg) continue - # If assistant message includes tool calls, convert to FunctionCallContent(s) + # If assistant message includes tool calls, convert to Content.from_function_call(s) tool_calls = msg.get("tool_calls") or msg.get("toolCalls") if tool_calls: contents: list[Any] = [] # Include any assistant text content if present content_text = msg.get("content") if isinstance(content_text, str) and content_text: - contents.append(TextContent(text=content_text)) + contents.append(Content.from_text(text=content_text)) # Convert each tool call entry for tc in tool_calls: if not isinstance(tc, dict): @@ -558,7 +552,7 @@ def _filter_modified_args( arguments = func_dict.get("arguments") contents.append( - FunctionCallContent( + Content.from_function_call( call_id=call_id, name=name, arguments=arguments, @@ -580,14 +574,14 @@ def _filter_modified_args( approval_contents: list[Any] = [] for approval in msg["function_approvals"]: # Create FunctionCallContent with the modified arguments - func_call = FunctionCallContent( + func_call = Content.from_function_call( call_id=approval.get("call_id", ""), name=approval.get("name", ""), arguments=approval.get("arguments", {}), ) # Create the approval response - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( approved=approval.get("approved", True), id=approval.get("id", ""), function_call=func_call, @@ -599,9 +593,9 @@ def _filter_modified_args( # Regular text message content = msg.get("content", "") if isinstance(content, str): - chat_msg = ChatMessage(role=role, contents=[TextContent(text=content)]) + chat_msg = ChatMessage(role=role, contents=[Content.from_text(text=content)]) else: - chat_msg = ChatMessage(role=role, contents=[TextContent(text=str(content))]) + chat_msg = ChatMessage(role=role, contents=[Content.from_text(text=str(content))]) if "id" in msg: chat_msg.message_id = msg["id"] @@ -652,9 +646,9 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str tool_result_call_id: str | None = None for content in msg.contents: - if isinstance(content, TextContent): + if content.type == "text": content_text += content.text - elif isinstance(content, FunctionCallContent): + elif content.type == "function_call": tool_calls.append( { "id": content.call_id, @@ -665,7 +659,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str }, } ) - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": # Tool result content - extract call_id and result tool_result_call_id = content.call_id # Serialize result to string using core utility @@ -702,8 +696,13 @@ def extract_text_from_contents(contents: list[Any]) -> str: """ text_parts: list[str] = [] for content in contents: - if isinstance(content, TextContent): - text_parts.append(content.text) + if type_ := getattr(content, "type", None): + if type_ == "text_reasoning": + continue + if text := getattr(content, "text", None): + text_parts.append(text) + continue + # TODO (moonbox3): should this handle both text and text_reasoning? elif hasattr(content, "text"): text_parts.append(content.text) return "".join(text_parts) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py index ebf6ef6f57..b327192367 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -9,10 +9,7 @@ from ag_ui.core import StateSnapshotEvent from agent_framework import ( ChatMessage, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, - TextContent, + Content, ) from .._utils import get_role_value, safe_json_parse @@ -37,9 +34,9 @@ def pending_tool_call_ids(messages: list[ChatMessage]) -> set[str]: resolved_ids: set[str] = set() for msg in messages: for content in msg.contents: - if isinstance(content, FunctionCallContent) and content.call_id: + if content.type == "function_call" and content.call_id: pending_ids.add(str(content.call_id)) - elif isinstance(content, FunctionResultContent) and content.call_id: + elif content.type == "function_result" and content.call_id: resolved_ids.add(str(content.call_id)) return pending_ids - resolved_ids @@ -56,7 +53,7 @@ def is_state_context_message(message: ChatMessage) -> bool: if get_role_value(message) != "system": return False for content in message.contents: - if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + if content.type == "text" and content.text.startswith("Current state of the application:"): return True return False @@ -139,7 +136,7 @@ def tool_calls_match_state( if get_role_value(msg) != "assistant": continue for content in msg.contents: - if isinstance(content, FunctionCallContent) and content.name == tool_name: + if content.type == "function_call" and content.name == tool_name: tool_args = safe_json_parse(content.arguments) break if tool_args is not None: @@ -287,7 +284,7 @@ def collect_approved_state_snapshots( if get_role_value(msg) != "user": continue for content in msg.contents: - if type(content) is FunctionApprovalResponseContent: + if content.type == "function_approval_response": if not content.function_call or not content.approved: continue parsed_args = content.function_call.parse_arguments() @@ -319,7 +316,7 @@ def collect_approved_state_snapshots( return events -def latest_approval_response(messages: list[ChatMessage]) -> FunctionApprovalResponseContent | None: +def latest_approval_response(messages: list[ChatMessage]) -> Content | None: """Get the latest approval response from messages. Args: @@ -332,12 +329,12 @@ def latest_approval_response(messages: list[ChatMessage]) -> FunctionApprovalRes return None last_message = messages[-1] for content in last_message.contents: - if type(content) is FunctionApprovalResponseContent: + if content.type == "function_approval_response": return content return None -def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: +def approval_steps(approval: Content) -> list[Any]: """Extract steps from an approval response. Args: @@ -346,9 +343,7 @@ def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: Returns: List of steps, or empty list if none """ - state_args: Any | None = None - if approval.additional_properties: - state_args = approval.additional_properties.get("ag_ui_state_args") + state_args = approval.additional_properties.get("ag_ui_state_args", None) if isinstance(state_args, dict): steps = state_args.get("steps") if isinstance(steps, list): @@ -365,7 +360,7 @@ def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: def is_step_based_approval( - approval: FunctionApprovalResponseContent, + approval: Content, predict_state_config: dict[str, dict[str, str]] | None, ) -> bool: """Check if an approval is step-based. diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py index 7d8a23d84c..05cc55228d 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py @@ -6,7 +6,7 @@ from typing import Any from ag_ui.core import CustomEvent, EventType -from agent_framework import ChatMessage, TextContent +from agent_framework import ChatMessage, Content class StateManager: @@ -71,7 +71,7 @@ def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_ca return ChatMessage( role="system", contents=[ - TextContent( + Content.from_text( text=( "Current state of the application:\n" f"{state_json}\n\n" diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index b5566f0aec..6ac93810db 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -25,13 +25,11 @@ AgentProtocol, AgentThread, ChatAgent, - FunctionCallContent, - FunctionResultContent, - TextContent, + Content, + FunctionInvocationConfiguration, ) from agent_framework._middleware import extract_and_merge_function_middleware from agent_framework._tools import ( - FunctionInvocationConfiguration, _collect_approval_responses, # type: ignore _replace_approval_contents_with_results, # type: ignore _try_execute_function_calls, # type: ignore @@ -285,7 +283,7 @@ async def run( last_message = context.last_message if last_message: for content in last_message.contents: - if isinstance(content, TextContent): + if content.type == "text": tool_content_text = content.text break @@ -441,25 +439,24 @@ async def run( logger.info(f" Message {i}: role={role}, id={msg_id}") if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): - content_type = type(content).__name__ - if isinstance(content, TextContent): - logger.debug(" Content %s: %s - text_length=%s", j, content_type, len(content.text)) - elif isinstance(content, FunctionCallContent): + if content.type == "text": + logger.debug(" Content %s: %s - text_length=%s", j, content.type, len(content.text)) + elif content.type == "function_call": arg_length = len(str(content.arguments)) if content.arguments else 0 logger.debug( - " Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length + " Content %s: %s - %s args_length=%s", j, content.type, content.name, arg_length ) - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": result_preview = type(content.result).__name__ if content.result is not None else "None" logger.debug( " Content %s: %s - call_id=%s, result_type=%s", j, - content_type, + content.type, content.call_id, result_preview, ) else: - logger.debug(f" Content {j}: {content_type}") + logger.debug(f" Content {j}: {content.type}") pending_tool_calls: list[dict[str, Any]] = [] tool_calls_by_id: dict[str, dict[str, Any]] = {} @@ -536,16 +533,14 @@ async def _resolve_approval_responses( logger.error("Failed to execute approved tool calls; injecting error results.") approved_function_results = [] - normalized_results: list[FunctionResultContent] = [] + normalized_results: list[Content] = [] for idx, approval in enumerate(approved_responses): - if idx < len(approved_function_results) and isinstance( - approved_function_results[idx], FunctionResultContent - ): + if idx < len(approved_function_results) and approved_function_results[idx].type == "function_result": normalized_results.append(approved_function_results[idx]) continue call_id = approval.function_call.call_id or approval.id normalized_results.append( - FunctionResultContent(call_id=call_id, result="Error: Tool call invocation failed.") + Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") ) _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore @@ -661,8 +656,8 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap if all_updates is not None: all_updates.append(update) if event_bridge.current_message_id is None and update.contents: - has_tool_call = any(isinstance(content, FunctionCallContent) for content in update.contents) - has_text = any(isinstance(content, TextContent) for content in update.contents) + has_tool_call = any(content.type == "function_call" for content in update.contents) + has_text = any(content.type == "text" for content in update.contents) if has_tool_call and not has_text: tool_message_id = generate_event_id() event_bridge.current_message_id = tool_message_id diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index 226abae692..f88dceb78b 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -6,6 +6,7 @@ from typing import Any, TypedDict from agent_framework import ChatOptions +from pydantic import BaseModel, Field if sys.version_info >= (3, 13): from typing import TypeVar @@ -19,8 +20,6 @@ "RunMetadata", ] -from pydantic import BaseModel, Field - class PredictStateConfig(TypedDict): """Configuration for predictive state updates.""" diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py index 572df2720b..9a4acf4319 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py @@ -18,7 +18,7 @@ TextMessageStartEvent, ToolCallStartEvent, ) -from agent_framework import ChatAgent, ChatClientProtocol, ai_function +from agent_framework import ChatAgent, ChatClientProtocol, ChatMessage, Content, ai_function from agent_framework.ag_ui import AgentFrameworkAgent from pydantic import BaseModel, Field @@ -221,7 +221,6 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non chat_client = chat_agent.chat_client # type: ignore # Build messages for summary call - from agent_framework._types import ChatMessage, TextContent original_messages = input_data.get("messages", []) @@ -234,7 +233,7 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non messages.append( ChatMessage( role=msg.get("role", "user"), - contents=[TextContent(text=content_str)], + contents=[Content.from_text(text=content_str)], ) ) elif isinstance(msg, ChatMessage): @@ -245,7 +244,7 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non ChatMessage( role="user", contents=[ - TextContent( + Content.from_text( text="The steps have been successfully executed. Provide a brief one-sentence summary." ) ], diff --git a/python/packages/ag-ui/tests/test_ag_ui_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py index bc1cc6d711..b05810972e 100644 --- a/python/packages/ag-ui/tests/test_ag_ui_client.py +++ b/python/packages/ag-ui/tests/test_ag_ui_client.py @@ -11,14 +11,13 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - FunctionCallContent, + Content, Role, - TextContent, ai_function, ) from pytest import MonkeyPatch -from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent +from agent_framework_ag_ui._client import AGUIChatClient from agent_framework_ag_ui._http_service import AGUIHttpService @@ -96,13 +95,11 @@ async def test_extract_state_from_messages_with_state(self) -> None: state_json = json.dumps(state_data) state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8") - from agent_framework import DataContent - messages = [ ChatMessage(role="user", text="Hello"), ChatMessage( role="user", - contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")], + contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")], ), ] @@ -121,12 +118,10 @@ async def test_extract_state_invalid_json(self) -> None: invalid_json = "not valid json" state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8") - from agent_framework import DataContent - messages = [ ChatMessage( role="user", - contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")], + contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")], ), ] @@ -200,8 +195,8 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str first_content = updates[1].contents[0] second_content = updates[2].contents[0] - assert isinstance(first_content, TextContent) - assert isinstance(second_content, TextContent) + assert first_content.type == "text" + assert second_content.type == "text" assert first_content.text == "Hello" assert second_content.text == " world" @@ -294,13 +289,12 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str updates.append(update) function_calls = [ - content for update in updates for content in update.contents if isinstance(content, FunctionCallContent) + content for update in updates for content in update.contents if content.type == "function_call" ] assert function_calls assert function_calls[0].name == "get_time_zone" - assert not any( - isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents - ) + + assert not any(content.type == "server_function_call" for update in updates for content in update.contents) async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None: """Server tools should not trigger local function invocation even when client tools exist.""" @@ -343,13 +337,11 @@ async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None: state_json = json.dumps(state_data) state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8") - from agent_framework import DataContent - messages = [ ChatMessage(role="user", text="Hello"), ChatMessage( role="user", - contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")], + contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")], ), ] diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index f919c00a56..8b708b3ac7 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -9,7 +9,7 @@ from typing import Any import pytest -from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent +from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, Content from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) @@ -23,7 +23,7 @@ async def test_agent_initialization_basic(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent[ChatOptions]( chat_client=StreamingChatClientStub(stream_fn), @@ -45,7 +45,7 @@ async def test_agent_initialization_with_state_schema(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) state_schema: dict[str, dict[str, Any]] = {"document": {"type": "string"}} @@ -61,7 +61,7 @@ async def test_agent_initialization_with_predict_state_config(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}} @@ -77,7 +77,7 @@ async def test_agent_initialization_with_pydantic_state_schema(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) class MyState(BaseModel): document: str @@ -100,7 +100,7 @@ async def test_run_started_event_emission(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -124,7 +124,7 @@ async def test_predict_state_custom_event_emission(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) predict_config = { @@ -156,7 +156,7 @@ async def test_initial_state_snapshot_with_schema(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) state_schema = {"document": {"type": "string"}} @@ -186,7 +186,7 @@ async def test_state_initialization_object_type(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) state_schema: dict[str, dict[str, Any]] = {"recipe": {"type": "object", "properties": {}}} @@ -213,7 +213,7 @@ async def test_state_initialization_array_type(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) state_schema: dict[str, dict[str, Any]] = {"steps": {"type": "array", "items": {}}} @@ -240,7 +240,7 @@ async def test_run_finished_event_emission(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -262,7 +262,7 @@ async def test_tool_result_confirm_changes_accepted(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Document updated")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Document updated")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent( @@ -309,7 +309,7 @@ async def test_tool_result_confirm_changes_rejected(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="OK")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="OK")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -343,7 +343,7 @@ async def test_tool_result_function_approval_accepted(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="OK")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="OK")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -389,7 +389,7 @@ async def test_tool_result_function_approval_rejected(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="OK")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="OK")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -431,7 +431,7 @@ async def stream_fn( metadata = options.get("metadata") if metadata: thread_metadata.update(metadata) - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -462,7 +462,7 @@ async def stream_fn( metadata = options.get("metadata") if metadata: thread_metadata.update(metadata) - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent( @@ -492,7 +492,7 @@ async def test_no_messages_provided(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -516,7 +516,7 @@ async def test_message_end_event_emission(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Hello world")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello world")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent(agent=agent) @@ -602,7 +602,7 @@ async def test_suppressed_summary_with_document_state(): async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="Response")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Response")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) wrapper = AgentFrameworkAgent( @@ -693,7 +693,7 @@ async def stream_fn( async def test_function_approval_mode_executes_tool(): """Test that function approval with approval_mode='always_require' sends the correct messages.""" - from agent_framework import FunctionResultContent, ai_function + from agent_framework import ai_function from agent_framework.ag_ui import AgentFrameworkAgent messages_received: list[Any] = [] @@ -712,7 +712,7 @@ async def stream_fn( # Capture the messages received by the chat client messages_received.clear() messages_received.extend(messages) - yield ChatResponseUpdate(contents=[TextContent(text="Processing completed")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Processing completed")]) agent = ChatAgent( chat_client=StreamingChatClientStub(stream_fn), @@ -770,7 +770,7 @@ async def stream_fn( tool_result_found = False for msg in messages_received: for content in msg.contents: - if isinstance(content, FunctionResultContent): + if content.type == "function_result": tool_result_found = True assert content.call_id == "call_get_datetime_123" assert content.result == "2025/12/01 12:00:00" @@ -784,7 +784,7 @@ async def stream_fn( async def test_function_approval_mode_rejection(): """Test that function approval rejection creates a rejection response.""" - from agent_framework import FunctionResultContent, ai_function + from agent_framework import ai_function from agent_framework.ag_ui import AgentFrameworkAgent messages_received: list[Any] = [] @@ -803,7 +803,7 @@ async def stream_fn( # Capture the messages received by the chat client messages_received.clear() messages_received.extend(messages) - yield ChatResponseUpdate(contents=[TextContent(text="Operation cancelled")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="Operation cancelled")]) agent = ChatAgent( name="test_agent", @@ -855,7 +855,7 @@ async def stream_fn( rejection_found = False for msg in messages_received: for content in msg.contents: - if isinstance(content, FunctionResultContent): + if content.type == "function_result": rejection_found = True assert content.call_id == "call_delete_123" assert content.result == "Error: Tool call invocation was rejected by user." diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py index 446da23ff2..594d127532 100644 --- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py +++ b/python/packages/ag-ui/tests/test_backend_tool_rendering.py @@ -12,7 +12,7 @@ ToolCallResultEvent, ToolCallStartEvent, ) -from agent_framework import AgentResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent +from agent_framework import AgentResponseUpdate, Content from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -22,7 +22,7 @@ async def test_tool_call_flow(): bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") # Step 1: Tool call starts - tool_call = FunctionCallContent( + tool_call = Content.from_function_call( call_id="weather-123", name="get_weather", arguments={"location": "Seattle"}, @@ -44,7 +44,7 @@ async def test_tool_call_flow(): assert "Seattle" in args_event.delta # Step 2: Tool result comes back - tool_result = FunctionResultContent( + tool_result = Content.from_function_result( call_id="weather-123", result="Weather in Seattle: Rainy, 52°F", ) @@ -71,8 +71,8 @@ async def test_text_with_tool_call(): bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") # Agent says something then calls a tool - text_content = TextContent(text="Let me check the weather for you.") - tool_call = FunctionCallContent( + text_content = Content.from_text(text="Let me check the weather for you.") + tool_call = Content.from_function_call( call_id="weather-456", name="get_forecast", arguments={"location": "San Francisco", "days": 3}, @@ -102,9 +102,9 @@ async def test_multiple_tool_results(): # Multiple tool results results = [ - FunctionResultContent(call_id="tool-1", result="Result 1"), - FunctionResultContent(call_id="tool-2", result="Result 2"), - FunctionResultContent(call_id="tool-3", result="Result 3"), + Content.from_function_result(call_id="tool-1", result="Result 1"), + Content.from_function_result(call_id="tool-2", result="Result 2"), + Content.from_function_result(call_id="tool-3", result="Result 3"), ] update = AgentResponseUpdate(contents=results) diff --git a/python/packages/ag-ui/tests/test_document_writer_flow.py b/python/packages/ag-ui/tests/test_document_writer_flow.py index 2e5cec9f95..7e154682b4 100644 --- a/python/packages/ag-ui/tests/test_document_writer_flow.py +++ b/python/packages/ag-ui/tests/test_document_writer_flow.py @@ -3,7 +3,7 @@ """Tests for document writer predictive state flow with confirm_changes.""" from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent -from agent_framework import AgentResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent +from agent_framework import AgentResponseUpdate, Content from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -21,7 +21,7 @@ async def test_streaming_document_with_state_deltas(): ) # Simulate streaming tool call - first chunk with name - tool_call_start = FunctionCallContent( + tool_call_start = Content.from_function_call( call_id="call_123", name="write_document_local", arguments='{"document":"Once', @@ -34,7 +34,9 @@ async def test_streaming_document_with_state_deltas(): assert any(e.type == EventType.TOOL_CALL_ARGS for e in events1) # Second chunk - incomplete JSON, should try partial extraction - tool_call_chunk2 = FunctionCallContent(call_id="call_123", name="write_document_local", arguments=" upon a time") + tool_call_chunk2 = Content.from_function_call( + call_id="call_123", name="write_document_local", arguments=" upon a time" + ) update2 = AgentResponseUpdate(contents=[tool_call_chunk2]) events2 = await bridge.from_agent_run_update(update2) @@ -71,7 +73,7 @@ async def test_confirm_changes_emission(): bridge.pending_state_updates = {"document": "A short story"} # Tool result - tool_result = FunctionResultContent( + tool_result = Content.from_function_result( call_id="call_123", result="Document written.", ) @@ -115,7 +117,7 @@ async def test_text_suppression_before_confirm(): bridge.should_stop_after_confirm = True # Text content that should be suppressed - text = TextContent(text="I have written a story about pirates.") + text = Content.from_text(text="I have written a story about pirates.") update = AgentResponseUpdate(contents=[text]) events = await bridge.from_agent_run_update(update) @@ -146,7 +148,7 @@ async def test_no_confirm_for_non_predictive_tools(): # Different tool (not in predict_state_config) bridge.current_tool_call_name = "get_weather" - tool_result = FunctionResultContent( + tool_result = Content.from_function_result( call_id="call_456", result="Sunny, 72°F", ) @@ -175,7 +177,7 @@ async def test_state_delta_deduplication(): ) # First tool call with document - tool_call1 = FunctionCallContent( + tool_call1 = Content.from_function_call( call_id="call_1", name="write_document_local", arguments='{"document":"Same text"}', @@ -189,7 +191,7 @@ async def test_state_delta_deduplication(): # Second tool call with SAME document (shouldn't emit new delta) bridge.current_tool_call_name = "write_document_local" - tool_call2 = FunctionCallContent( + tool_call2 = Content.from_function_call( call_id="call_2", name="write_document_local", arguments='{"document":"Same text"}', # Identical content @@ -216,7 +218,7 @@ async def test_predict_state_config_multiple_fields(): ) # Tool call with both fields - tool_call = FunctionCallContent( + tool_call = Content.from_function_call( call_id="call_999", name="create_post", arguments='{"title":"My Post","body":"Post content"}', diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py index 59cb884c5c..e09bb32fce 100644 --- a/python/packages/ag-ui/tests/test_endpoint.py +++ b/python/packages/ag-ui/tests/test_endpoint.py @@ -6,7 +6,7 @@ import sys from pathlib import Path -from agent_framework import ChatAgent, ChatResponseUpdate, TextContent +from agent_framework import ChatAgent, ChatResponseUpdate, Content from fastapi import FastAPI, Header, HTTPException from fastapi.params import Depends from fastapi.testclient import TestClient @@ -20,7 +20,7 @@ def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub: """Create a typed chat client stub for endpoint tests.""" - updates = [ChatResponseUpdate(contents=[TextContent(text=response_text)])] + updates = [ChatResponseUpdate(contents=[Content.from_text(text=response_text)])] return StreamingChatClientStub(stream_from_updates(updates)) diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py index 295ba00372..75e923123f 100644 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ b/python/packages/ag-ui/tests/test_events_comprehensive.py @@ -6,10 +6,7 @@ from agent_framework import ( AgentResponseUpdate, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, - TextContent, + Content, ) @@ -19,7 +16,7 @@ async def test_basic_text_message_conversion(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentResponseUpdate(contents=[TextContent(text="Hello")]) + update = AgentResponseUpdate(contents=[Content.from_text(text="Hello")]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -35,8 +32,8 @@ async def test_text_message_streaming(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")]) - update2 = AgentResponseUpdate(contents=[TextContent(text="world")]) + update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")]) + update2 = AgentResponseUpdate(contents=[Content.from_text(text="world")]) events1 = await bridge.from_agent_run_update(update1) events2 = await bridge.from_agent_run_update(update2) @@ -61,7 +58,7 @@ async def test_skip_text_content_for_structured_outputs(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread", skip_text_content=True) - update = AgentResponseUpdate(contents=[TextContent(text='{"result": "data"}')]) + update = AgentResponseUpdate(contents=[Content.from_text(text='{"result": "data"}')]) events = await bridge.from_agent_run_update(update) # No events should be emitted @@ -74,9 +71,9 @@ async def test_skip_text_content_for_empty_text(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")]) - update2 = AgentResponseUpdate(contents=[TextContent(text="")]) # Empty chunk - update3 = AgentResponseUpdate(contents=[TextContent(text="world")]) + update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")]) + update2 = AgentResponseUpdate(contents=[Content.from_text(text="")]) # Empty chunk + update3 = AgentResponseUpdate(contents=[Content.from_text(text="world")]) events1 = await bridge.from_agent_run_update(update1) events2 = await bridge.from_agent_run_update(update2) @@ -105,7 +102,7 @@ async def test_tool_call_with_name(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")]) + update = AgentResponseUpdate(contents=[Content.from_function_call(name="search_web", call_id="call_123")]) events = await bridge.from_agent_run_update(update) assert len(events) == 1 @@ -121,15 +118,17 @@ async def test_tool_call_streaming_args(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # First chunk: name only - update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")]) + update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="search_web", call_id="call_123")]) events1 = await bridge.from_agent_run_update(update1) # Second chunk: arguments chunk 1 (name can be empty string for continuation) - update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='{"query": "')]) + update2 = AgentResponseUpdate( + contents=[Content.from_function_call(name="", call_id="call_123", arguments='{"query": "')] + ) events2 = await bridge.from_agent_run_update(update2) # Third chunk: arguments chunk 2 - update3 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='AI"}')]) + update3 = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="call_123", arguments='AI"}')]) events3 = await bridge.from_agent_run_update(update3) # First update: ToolCallStartEvent @@ -167,9 +166,11 @@ async def test_streaming_tool_call_no_duplicate_start_events(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # Simulate streaming tool call: first chunk has name, subsequent chunks have name="" - update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="get_weather", call_id="call_789")]) - update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_789", arguments='{"loc":')]) - update3 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_789", arguments='"SF"}')]) + update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="get_weather", call_id="call_789")]) + update2 = AgentResponseUpdate( + contents=[Content.from_function_call(name="", call_id="call_789", arguments='{"loc":')] + ) + update3 = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="call_789", arguments='"SF"}')]) events1 = await bridge.from_agent_run_update(update1) events2 = await bridge.from_agent_run_update(update2) @@ -193,7 +194,7 @@ async def test_tool_result_with_dict(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") result_data = {"status": "success", "count": 42} - update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=result_data)]) + update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=result_data)]) events = await bridge.from_agent_run_update(update) # Should emit ToolCallEndEvent + ToolCallResultEvent @@ -214,7 +215,7 @@ async def test_tool_result_with_string(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result="Search complete")]) + update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result="Search complete")]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -229,7 +230,7 @@ async def test_tool_result_with_none(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=None)]) + update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=None)]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -247,8 +248,8 @@ async def test_multiple_tool_results_in_sequence(): update = AgentResponseUpdate( contents=[ - FunctionResultContent(call_id="call_1", result="Result 1"), - FunctionResultContent(call_id="call_2", result="Result 2"), + Content.from_function_result(call_id="call_1", result="Result 1"), + Content.from_function_result(call_id="call_2", result="Result 2"), ] ) events = await bridge.from_agent_run_update(update) @@ -272,12 +273,12 @@ async def test_function_approval_request_basic(): require_confirmation=False, ) - func_call = FunctionCallContent( + func_call = Content.from_function_call( call_id="call_123", name="send_email", arguments={"to": "user@example.com", "subject": "Test"}, ) - approval = FunctionApprovalRequestContent( + approval = Content.from_function_approval_request( id="approval_001", function_call=func_call, ) @@ -312,8 +313,8 @@ async def test_empty_predict_state_config(): # Tool call with arguments update = AgentResponseUpdate( contents=[ - FunctionCallContent(name="write_doc", call_id="call_1", arguments='{"content": "test"}'), - FunctionResultContent(call_id="call_1", result="Done"), + Content.from_function_call(name="write_doc", call_id="call_1", arguments='{"content": "test"}'), + Content.from_function_result(call_id="call_1", result="Done"), ] ) events = await bridge.from_agent_run_update(update) @@ -347,8 +348,8 @@ async def test_tool_not_in_predict_state_config(): # Different tool name update = AgentResponseUpdate( contents=[ - FunctionCallContent(name="search_web", call_id="call_1", arguments='{"query": "AI"}'), - FunctionResultContent(call_id="call_1", result="Results"), + Content.from_function_call(name="search_web", call_id="call_1", arguments='{"query": "AI"}'), + Content.from_function_result(call_id="call_1", result="Results"), ] ) events = await bridge.from_agent_run_update(update) @@ -376,8 +377,8 @@ async def test_state_management_tracking(): # Streaming tool call update1 = AgentResponseUpdate( contents=[ - FunctionCallContent(name="write_doc", call_id="call_1"), - FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Hello"}'), + Content.from_function_call(name="write_doc", call_id="call_1"), + Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Hello"}'), ] ) await bridge.from_agent_run_update(update1) @@ -387,7 +388,7 @@ async def test_state_management_tracking(): assert bridge.pending_state_updates["document"] == "Hello" # Tool result should update current_state - update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) + update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]) await bridge.from_agent_run_update(update2) # current_state should be updated @@ -413,12 +414,12 @@ async def test_wildcard_tool_argument(): # Complete tool call with dict arguments update = AgentResponseUpdate( contents=[ - FunctionCallContent( + Content.from_function_call( name="create_recipe", call_id="call_1", arguments={"title": "Pasta", "ingredients": ["pasta", "sauce"]}, ), - FunctionResultContent(call_id="call_1", result="Created"), + Content.from_function_result(call_id="call_1", result="Created"), ] ) events = await bridge.from_agent_run_update(update) @@ -503,14 +504,14 @@ async def test_state_snapshot_after_tool_result(): # Tool call with streaming args update1 = AgentResponseUpdate( contents=[ - FunctionCallContent(name="write_doc", call_id="call_1"), - FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Test"}'), + Content.from_function_call(name="write_doc", call_id="call_1"), + Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Test"}'), ] ) await bridge.from_agent_run_update(update1) # Tool result should trigger StateSnapshotEvent - update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) + update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]) events = await bridge.from_agent_run_update(update2) # Should have: ToolCallEnd, ToolCallResult, StateSnapshot, ToolCallStart (confirm_changes), ToolCallArgs, ToolCallEnd @@ -526,12 +527,12 @@ async def test_message_id_persistence_across_chunks(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # First chunk - update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")]) + update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")]) events1 = await bridge.from_agent_run_update(update1) message_id = events1[0].message_id # Second chunk - update2 = AgentResponseUpdate(contents=[TextContent(text="world")]) + update2 = AgentResponseUpdate(contents=[Content.from_text(text="world")]) events2 = await bridge.from_agent_run_update(update2) # Should use same message_id @@ -546,14 +547,16 @@ async def test_tool_call_id_tracking(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # First chunk with name - update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="search", call_id="call_1")]) + update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="search", call_id="call_1")]) await bridge.from_agent_run_update(update1) assert bridge.current_tool_call_id == "call_1" assert bridge.current_tool_call_name == "search" # Second chunk with args but no name - update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_1", arguments='{"q":"AI"}')]) + update2 = AgentResponseUpdate( + contents=[Content.from_function_call(name="", call_id="call_1", arguments='{"q":"AI"}')] + ) events2 = await bridge.from_agent_run_update(update2) # Should still track same tool call @@ -576,8 +579,8 @@ async def test_tool_name_reset_after_result(): # Tool call update1 = AgentResponseUpdate( contents=[ - FunctionCallContent(name="write_doc", call_id="call_1"), - FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Test"}'), + Content.from_function_call(name="write_doc", call_id="call_1"), + Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Test"}'), ] ) await bridge.from_agent_run_update(update1) @@ -585,7 +588,7 @@ async def test_tool_name_reset_after_result(): assert bridge.current_tool_call_name == "write_doc" # Tool result with predictive state (should trigger confirm_changes and reset) - update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]) + update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]) await bridge.from_agent_run_update(update2) # Tool name should be reset @@ -604,9 +607,9 @@ async def test_function_approval_with_wildcard_argument(): }, ) - approval_content = FunctionApprovalRequestContent( + approval_content = Content.from_function_approval_request( id="approval_1", - function_call=FunctionCallContent( + function_call=Content.from_function_call( name="submit", call_id="call_1", arguments='{"key1": "value1", "key2": "value2"}' ), ) @@ -632,9 +635,11 @@ async def test_function_approval_missing_argument(): }, ) - approval_content = FunctionApprovalRequestContent( + approval_content = Content.from_function_approval_request( id="approval_1", - function_call=FunctionCallContent(name="process", call_id="call_1", arguments='{"other_field": "value"}'), + function_call=Content.from_function_call( + name="process", call_id="call_1", arguments='{"other_field": "value"}' + ), ) update = AgentResponseUpdate(contents=[approval_content]) @@ -654,8 +659,8 @@ async def test_empty_predict_state_config_no_deltas(): # Tool call with arguments update = AgentResponseUpdate( contents=[ - FunctionCallContent(name="search", call_id="call_1"), - FunctionCallContent(name="", call_id="call_1", arguments='{"query": "test"}'), + Content.from_function_call(name="search", call_id="call_1"), + Content.from_function_call(name="", call_id="call_1", arguments='{"query": "test"}'), ] ) events = await bridge.from_agent_run_update(update) @@ -678,8 +683,8 @@ async def test_tool_with_no_matching_config(): # Tool call for different tool update = AgentResponseUpdate( contents=[ - FunctionCallContent(name="search_web", call_id="call_1"), - FunctionCallContent(name="", call_id="call_1", arguments='{"query": "test"}'), + Content.from_function_call(name="search_web", call_id="call_1"), + Content.from_function_call(name="", call_id="call_1", arguments='{"query": "test"}'), ] ) events = await bridge.from_agent_run_update(update) @@ -696,7 +701,7 @@ async def test_tool_call_without_name_or_id(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") # This should not crash but log an error - update = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="", arguments='{"arg": "val"}')]) + update = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="", arguments='{"arg": "val"}')]) events = await bridge.from_agent_run_update(update) # Should emit ToolCallArgsEvent with generated ID @@ -717,7 +722,7 @@ async def test_state_delta_count_logging(): for i in range(15): update = AgentResponseUpdate( contents=[ - FunctionCallContent(name="", call_id="call_1", arguments=f'{{"text": "Content variation {i}"}}'), + Content.from_function_call(name="", call_id="call_1", arguments=f'{{"text": "Content variation {i}"}}'), ] ) # Set the tool name to match config @@ -737,7 +742,7 @@ async def test_tool_result_with_empty_list(): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=[])]) + update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=[])]) events = await bridge.from_agent_run_update(update) assert len(events) == 2 @@ -760,7 +765,7 @@ class MockTextContent: bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") update = AgentResponseUpdate( - contents=[FunctionResultContent(call_id="call_123", result=[MockTextContent("Hello from MCP tool!")])] + contents=[Content.from_function_result(call_id="call_123", result=[MockTextContent("Hello from MCP tool!")])] ) events = await bridge.from_agent_run_update(update) @@ -785,7 +790,7 @@ class MockTextContent: update = AgentResponseUpdate( contents=[ - FunctionResultContent( + Content.from_function_result( call_id="call_123", result=[MockTextContent("First result"), MockTextContent("Second result")], ) @@ -812,7 +817,7 @@ class MockModel(BaseModel): bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") update = AgentResponseUpdate( - contents=[FunctionResultContent(call_id="call_123", result=[MockModel(value=1), MockModel(value=2)])] + contents=[Content.from_function_result(call_id="call_123", result=[MockModel(value=1), MockModel(value=2)])] ) events = await bridge.from_agent_run_update(update) diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py index 00e64472b6..b643465e36 100644 --- a/python/packages/ag-ui/tests/test_human_in_the_loop.py +++ b/python/packages/ag-ui/tests/test_human_in_the_loop.py @@ -2,7 +2,7 @@ """Tests for human in the loop (function approval requests).""" -from agent_framework import AgentResponseUpdate, FunctionApprovalRequestContent, FunctionCallContent +from agent_framework import AgentResponseUpdate, Content from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -17,12 +17,12 @@ async def test_function_approval_request_emission(): ) # Create approval request - func_call = FunctionCallContent( + func_call = Content.from_function_call( call_id="call_123", name="send_email", arguments={"to": "user@example.com", "subject": "Test"}, ) - approval_request = FunctionApprovalRequestContent( + approval_request = Content.from_function_approval_request( id="approval_001", function_call=func_call, ) @@ -56,12 +56,12 @@ async def test_function_approval_request_with_confirm_changes(): require_confirmation=True, ) - func_call = FunctionCallContent( + func_call = Content.from_function_call( call_id="call_456", name="delete_file", arguments={"path": "/tmp/test.txt"}, ) - approval_request = FunctionApprovalRequestContent( + approval_request = Content.from_function_approval_request( id="approval_002", function_call=func_call, ) @@ -109,22 +109,22 @@ async def test_multiple_approval_requests(): require_confirmation=False, ) - func_call_1 = FunctionCallContent( + func_call_1 = Content.from_function_call( call_id="call_1", name="create_event", arguments={"title": "Meeting"}, ) - approval_1 = FunctionApprovalRequestContent( + approval_1 = Content.from_function_approval_request( id="approval_1", function_call=func_call_1, ) - func_call_2 = FunctionCallContent( + func_call_2 = Content.from_function_call( call_id="call_2", name="book_room", arguments={"room": "Conference A"}, ) - approval_2 = FunctionApprovalRequestContent( + approval_2 = Content.from_function_approval_request( id="approval_2", function_call=func_call_2, ) @@ -164,12 +164,12 @@ async def test_function_approval_request_sets_stop_flag(): assert bridge.should_stop_after_confirm is False - func_call = FunctionCallContent( + func_call = Content.from_function_call( call_id="call_stop_test", name="get_datetime", arguments={}, ) - approval_request = FunctionApprovalRequestContent( + approval_request = Content.from_function_approval_request( id="approval_stop_test", function_call=func_call, ) diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index 9173314a28..970a4fe76b 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -5,7 +5,7 @@ import json import pytest -from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework_ag_ui._message_adapters import ( agent_framework_messages_to_agui, @@ -24,7 +24,7 @@ def sample_agui_message(): @pytest.fixture def sample_agent_framework_message(): """Create a sample Agent Framework message.""" - return ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")], message_id="msg-123") + return ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")], message_id="msg-123") def test_agui_to_agent_framework_basic(sample_agui_message): @@ -89,7 +89,7 @@ def test_agui_tool_result_to_agent_framework(): assert message.role == Role.USER assert len(message.contents) == 1 - assert isinstance(message.contents[0], TextContent) + assert message.contents[0].type == "text" assert message.contents[0].text == '{"accepted": true, "steps": []}' assert message.additional_properties is not None @@ -141,7 +141,7 @@ def test_agui_tool_approval_updates_tool_call_arguments(): assert len(messages) == 2 assistant_msg = messages[0] - func_call = next(content for content in assistant_msg.contents if isinstance(content, FunctionCallContent)) + func_call = next(content for content in assistant_msg.contents if content.type == "function_call") assert func_call.arguments == { "steps": [ {"description": "Boil water", "status": "enabled"}, @@ -157,11 +157,9 @@ def test_agui_tool_approval_updates_tool_call_arguments(): ] } - from agent_framework import FunctionApprovalResponseContent - approval_msg = messages[1] approval_content = next( - content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + content for content in approval_msg.contents if content.type == "function_approval_response" ) assert approval_content.function_call.parse_arguments() == { "steps": [ @@ -211,12 +209,9 @@ def test_agui_tool_approval_from_confirm_changes_maps_to_function_call(): ] messages = agui_messages_to_agent_framework(messages_input) - - from agent_framework import FunctionApprovalResponseContent - approval_msg = messages[1] approval_content = next( - content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + content for content in approval_msg.contents if content.type == "function_approval_response" ) assert approval_content.function_call.call_id == "call_tool" @@ -259,12 +254,9 @@ def test_agui_tool_approval_from_confirm_changes_falls_back_to_sibling_call(): ] messages = agui_messages_to_agent_framework(messages_input) - - from agent_framework import FunctionApprovalResponseContent - approval_msg = messages[1] approval_content = next( - content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + content for content in approval_msg.contents if content.type == "function_approval_response" ) assert approval_content.function_call.call_id == "call_tool" @@ -315,12 +307,9 @@ def test_agui_tool_approval_from_generate_task_steps_maps_to_function_call(): ] messages = agui_messages_to_agent_framework(messages_input) - - from agent_framework import FunctionApprovalResponseContent - approval_msg = messages[1] approval_content = next( - content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + content for content in approval_msg.contents if content.type == "function_approval_response" ) assert approval_content.function_call.call_id == "call_tool" @@ -380,15 +369,14 @@ def test_agui_function_approvals(): assert msg.role == Role.USER assert len(msg.contents) == 2 - from agent_framework import FunctionApprovalResponseContent - - assert isinstance(msg.contents[0], FunctionApprovalResponseContent) + assert msg.contents[0].type == "function_approval_response" assert msg.contents[0].approved is True assert msg.contents[0].id == "approval-1" assert msg.contents[0].function_call.name == "search" assert msg.contents[0].function_call.call_id == "call-1" - assert isinstance(msg.contents[1], FunctionApprovalResponseContent) + assert msg.contents[1].type == "function_approval_response" + assert msg.contents[1].id == "approval-2" assert msg.contents[1].approved is False @@ -406,7 +394,7 @@ def test_agui_non_string_content(): assert len(messages) == 1 assert len(messages[0].contents) == 1 - assert isinstance(messages[0].contents[0], TextContent) + assert messages[0].contents[0].type == "text" assert "nested" in messages[0].contents[0].text @@ -440,9 +428,9 @@ def test_agui_with_tool_calls_to_agent_framework(): assert msg.role == Role.ASSISTANT assert msg.message_id == "msg-789" # First content is text, second is the function call - assert isinstance(msg.contents[0], TextContent) + assert msg.contents[0].type == "text" assert msg.contents[0].text == "Calling tool" - assert isinstance(msg.contents[1], FunctionCallContent) + assert msg.contents[1].type == "function_call" assert msg.contents[1].call_id == "call-123" assert msg.contents[1].name == "get_weather" assert msg.contents[1].arguments == {"location": "Seattle"} @@ -453,8 +441,8 @@ def test_agent_framework_to_agui_with_tool_calls(): msg = ChatMessage( role=Role.ASSISTANT, contents=[ - TextContent(text="Calling tool"), - FunctionCallContent(call_id="call-123", name="search", arguments={"query": "test"}), + Content.from_text(text="Calling tool"), + Content.from_function_call(call_id="call-123", name="search", arguments={"query": "test"}), ], message_id="msg-456", ) @@ -477,7 +465,7 @@ def test_agent_framework_to_agui_multiple_text_contents(): """Test concatenating multiple text contents.""" msg = ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="Part 1 "), TextContent(text="Part 2")], + contents=[Content.from_text(text="Part 1 "), Content.from_text(text="Part 2")], ) messages = agent_framework_messages_to_agui([msg]) @@ -488,7 +476,7 @@ def test_agent_framework_to_agui_multiple_text_contents(): def test_agent_framework_to_agui_no_message_id(): """Test message without message_id - should auto-generate ID.""" - msg = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]) + msg = ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]) messages = agent_framework_messages_to_agui([msg]) @@ -500,7 +488,7 @@ def test_agent_framework_to_agui_no_message_id(): def test_agent_framework_to_agui_system_role(): """Test system role conversion.""" - msg = ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System")]) + msg = ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System")]) messages = agent_framework_messages_to_agui([msg]) @@ -510,7 +498,7 @@ def test_agent_framework_to_agui_system_role(): def test_extract_text_from_contents(): """Test extracting text from contents list.""" - contents = [TextContent(text="Hello "), TextContent(text="World")] + contents = [Content.from_text(text="Hello "), Content.from_text(text="World")] result = extract_text_from_contents(contents) @@ -533,7 +521,7 @@ def __init__(self, text: str): def test_extract_text_from_custom_contents(): """Test extracting text from custom content objects.""" - contents = [CustomTextContent(text="Custom "), TextContent(text="Mixed")] + contents = [CustomTextContent(text="Custom "), Content.from_text(text="Mixed")] result = extract_text_from_contents(contents) @@ -547,7 +535,7 @@ def test_agent_framework_to_agui_function_result_dict(): """Test converting FunctionResultContent with dict result to AG-UI.""" msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id="call-123", result={"key": "value", "count": 42})], + contents=[Content.from_function_result(call_id="call-123", result={"key": "value", "count": 42})], message_id="msg-789", ) @@ -564,7 +552,7 @@ def test_agent_framework_to_agui_function_result_none(): """Test converting FunctionResultContent with None result to AG-UI.""" msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id="call-123", result=None)], + contents=[Content.from_function_result(call_id="call-123", result=None)], message_id="msg-789", ) @@ -580,7 +568,7 @@ def test_agent_framework_to_agui_function_result_string(): """Test converting FunctionResultContent with string result to AG-UI.""" msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id="call-123", result="plain text result")], + contents=[Content.from_function_result(call_id="call-123", result="plain text result")], message_id="msg-789", ) @@ -595,7 +583,7 @@ def test_agent_framework_to_agui_function_result_empty_list(): """Test converting FunctionResultContent with empty list result to AG-UI.""" msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id="call-123", result=[])], + contents=[Content.from_function_result(call_id="call-123", result=[])], message_id="msg-789", ) @@ -617,7 +605,7 @@ class MockTextContent: msg = ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id="call-123", result=[MockTextContent("Hello from MCP!")])], + contents=[Content.from_function_result(call_id="call-123", result=[MockTextContent("Hello from MCP!")])], message_id="msg-789", ) @@ -640,7 +628,7 @@ class MockTextContent: msg = ChatMessage( role=Role.TOOL, contents=[ - FunctionResultContent( + Content.from_function_result( call_id="call-123", result=[MockTextContent("First result"), MockTextContent("Second result")], ) diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py index 380ff438bd..ecc01de3cb 100644 --- a/python/packages/ag-ui/tests/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent +from agent_framework import ChatMessage, Content from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history @@ -10,7 +10,7 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None: ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( name="confirm_changes", call_id="call_confirm_123", arguments='{"changes": "test"}', @@ -19,7 +19,7 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None: ), ChatMessage( role="user", - contents=[TextContent(text='{"accepted": true}')], + contents=[Content.from_text(text='{"accepted": true}')], ), ] @@ -37,11 +37,11 @@ def test_deduplicate_messages_prefers_non_empty_tool_results() -> None: messages = [ ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call1", result="")], + contents=[Content.from_function_result(call_id="call1", result="")], ), ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call1", result="result data")], + contents=[Content.from_function_result(call_id="call1", result="result data")], ), ] diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 279ddedc82..c951246bfa 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -13,8 +13,8 @@ BaseChatClient, ChatAgent, ChatResponseUpdate, + Content, FunctionInvocationConfiguration, - TextContent, ai_function, ) @@ -79,11 +79,11 @@ async def mock_run_stream( if capture_messages is not None: capture_messages.extend(messages) yield AgentResponseUpdate( - contents=[TextContent(text="ok")], + contents=[Content.from_text(text="ok")], role="assistant", response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) raw_representation=ChatResponseUpdate( - contents=[TextContent(text="ok")], + contents=[Content.from_text(text="ok")], conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) ), @@ -253,7 +253,7 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None: if role_value != "system": continue for content in msg.contents or []: - if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + if content.type == "text" and content.text.startswith("Current state of the application:"): state_messages.append(content.text) assert state_messages assert "Vegetarian" in state_messages[0] @@ -302,6 +302,6 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None if role_value != "system": continue for content in msg.contents or []: - if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + if content.type == "text" and content.text.startswith("Current state of the application:"): state_messages.append(content.text) assert not state_messages diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 6c311d593a..d579c691b7 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -8,12 +8,7 @@ from types import SimpleNamespace from typing import Any -from agent_framework import ( - AgentResponseUpdate, - ChatMessage, - TextContent, - ai_function, -) +from agent_framework import AgentResponseUpdate, ChatMessage, Content, ai_function from pydantic import BaseModel from agent_framework_ag_ui._agent import AgentConfig @@ -48,14 +43,14 @@ async def test_human_in_the_loop_json_decode_error() -> None: messages = [ ChatMessage( role="tool", - contents=[TextContent(text="not valid json {")], + contents=[Content.from_text(text="not valid json {")], additional_properties={"is_tool_result": True}, ) ] agent = StubAgent( default_options={"tools": [approval_tool], "response_format": None}, - updates=[AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")], + updates=[AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")], ) context = TestExecutionContext( input_data=input_data, @@ -78,14 +73,14 @@ async def test_human_in_the_loop_json_decode_error() -> None: async def test_sanitize_tool_history_confirm_changes() -> None: """Test sanitize_tool_history logic for confirm_changes synthetic result.""" - from agent_framework import ChatMessage, FunctionCallContent, TextContent + from agent_framework import ChatMessage # Create messages that will trigger confirm_changes synthetic result injection messages = [ ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( name="confirm_changes", call_id="call_confirm_123", arguments='{"changes": "test"}', @@ -94,7 +89,7 @@ async def test_sanitize_tool_history_confirm_changes() -> None: ), ChatMessage( role="user", - contents=[TextContent(text='{"accepted": true}')], + contents=[Content.from_text(text='{"accepted": true}')], ), ] @@ -134,17 +129,17 @@ async def test_sanitize_tool_history_confirm_changes() -> None: async def test_sanitize_tool_history_orphaned_tool_result() -> None: """Test sanitize_tool_history removes orphaned tool results.""" - from agent_framework import ChatMessage, FunctionResultContent, TextContent + from agent_framework import ChatMessage # Tool result without preceding assistant tool call messages = [ ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="orphan_123", result="orphaned data")], + contents=[Content.from_function_result(call_id="orphan_123", result="orphaned data")], ), ChatMessage( role="user", - contents=[TextContent(text="Hello")], + contents=[Content.from_text(text="Hello")], ), ] @@ -214,20 +209,20 @@ async def test_orphaned_tool_result_sanitization() -> None: async def test_deduplicate_messages_empty_tool_results() -> None: """Test deduplicate_messages prefers non-empty tool results.""" - from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent + from agent_framework import ChatMessage messages = [ ChatMessage( role="assistant", - contents=[FunctionCallContent(name="test_tool", call_id="call_789", arguments="{}")], + contents=[Content.from_function_call(name="test_tool", call_id="call_789", arguments="{}")], ), ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call_789", result="")], + contents=[Content.from_function_result(call_id="call_789", result="")], ), ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call_789", result="real data")], + contents=[Content.from_function_result(call_id="call_789", result="real data")], ), ] @@ -259,20 +254,20 @@ async def test_deduplicate_messages_empty_tool_results() -> None: async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None: """Test deduplicate_messages removes duplicate assistant tool call messages.""" - from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent + from agent_framework import ChatMessage messages = [ ChatMessage( role="assistant", - contents=[FunctionCallContent(name="test_tool", call_id="call_abc", arguments="{}")], + contents=[Content.from_function_call(name="test_tool", call_id="call_abc", arguments="{}")], ), ChatMessage( role="assistant", - contents=[FunctionCallContent(name="test_tool", call_id="call_abc", arguments="{}")], + contents=[Content.from_function_call(name="test_tool", call_id="call_abc", arguments="{}")], ), ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call_abc", result="result")], + contents=[Content.from_function_result(call_id="call_abc", result="result")], ), ] @@ -303,20 +298,20 @@ async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None: async def test_deduplicate_messages_duplicate_system_messages() -> None: """Test that deduplication logic is invoked for system messages.""" - from agent_framework import ChatMessage, TextContent + from agent_framework import ChatMessage messages = [ ChatMessage( role="system", - contents=[TextContent(text="You are a helpful assistant.")], + contents=[Content.from_text(text="You are a helpful assistant.")], ), ChatMessage( role="system", - contents=[TextContent(text="You are a helpful assistant.")], + contents=[Content.from_text(text="You are a helpful assistant.")], ), ChatMessage( role="user", - contents=[TextContent(text="Hello")], + contents=[Content.from_text(text="Hello")], ), ] @@ -387,20 +382,20 @@ async def test_state_context_injection() -> None: async def test_state_context_injection_with_tool_calls_and_input_state() -> None: """Test state context is injected when state is provided, even with tool calls.""" - from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent + from agent_framework import ChatMessage messages = [ ChatMessage( role="assistant", - contents=[FunctionCallContent(name="get_weather", call_id="call_xyz", arguments="{}")], + contents=[Content.from_function_call(name="get_weather", call_id="call_xyz", arguments="{}")], ), ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call_xyz", result="sunny")], + contents=[Content.from_function_result(call_id="call_xyz", result="sunny")], ), ChatMessage( role="user", - contents=[TextContent(text="Thanks")], + contents=[Content.from_text(text="Thanks")], ), ] @@ -452,7 +447,7 @@ class RecipeState(BaseModel): default_options=DEFAULT_OPTIONS, updates=[ AgentResponseUpdate( - contents=[TextContent(text='{"ingredients": ["tomato"], "message": "Added tomato"}')], + contents=[Content.from_text(text='{"ingredients": ["tomato"], "message": "Added tomato"}')], role="assistant", ) ], @@ -641,13 +636,13 @@ async def test_all_messages_filtered_handling() -> None: async def test_confirm_changes_with_invalid_json_fallback() -> None: """Test confirm_changes with invalid JSON falls back to normal processing.""" - from agent_framework import ChatMessage, FunctionCallContent, TextContent + from agent_framework import ChatMessage messages = [ ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( name="confirm_changes", call_id="call_confirm_invalid", arguments='{"changes": "test"}', @@ -656,7 +651,7 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None: ), ChatMessage( role="user", - contents=[TextContent(text="invalid json {")], + contents=[Content.from_text(text="invalid json {")], ), ] @@ -688,19 +683,18 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None: async def test_confirm_changes_closes_active_message_before_finish() -> None: """Confirm-changes flow closes any active text message before run finishes.""" from ag_ui.core import TextMessageEndEvent, TextMessageStartEvent - from agent_framework import FunctionCallContent, FunctionResultContent updates = [ AgentResponseUpdate( contents=[ - FunctionCallContent( + Content.from_function_call( name="write_document_local", call_id="call_1", arguments='{"document": "Draft"}', ) ] ), - AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]), + AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]), ] orchestrator = DefaultOrchestrator() @@ -735,16 +729,16 @@ async def test_confirm_changes_closes_active_message_before_finish() -> None: async def test_tool_result_kept_when_call_id_matches() -> None: """Test tool result is kept when call_id matches pending tool calls.""" - from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent + from agent_framework import ChatMessage messages = [ ChatMessage( role="assistant", - contents=[FunctionCallContent(name="get_data", call_id="call_match", arguments="{}")], + contents=[Content.from_function_call(name="get_data", call_id="call_match", arguments="{}")], ), ChatMessage( role="tool", - contents=[FunctionResultContent(call_id="call_match", result="data")], + contents=[Content.from_function_result(call_id="call_match", result="data")], ), ] @@ -794,11 +788,11 @@ async def run_stream( **kwargs: Any, ) -> AsyncGenerator[AgentResponseUpdate, None]: self.messages_received = messages - yield AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant") + yield AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant") - from agent_framework import ChatMessage, TextContent + from agent_framework import ChatMessage - messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])] orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} @@ -820,9 +814,9 @@ async def run_stream( async def test_initial_state_snapshot_with_array_schema() -> None: """Test state initialization with array type schema.""" - from agent_framework import ChatMessage, TextContent + from agent_framework import ChatMessage - messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])] orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": [], "state": {}} @@ -851,9 +845,9 @@ async def test_response_format_skip_text_content() -> None: class OutputModel(BaseModel): result: str - from agent_framework import ChatMessage, TextContent + from agent_framework import ChatMessage - messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])] orchestrator = DefaultOrchestrator() input_data: dict[str, Any] = {"messages": []} diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py index 469f5f5ad8..4b3f5ebb23 100644 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ b/python/packages/ag-ui/tests/test_shared_state.py @@ -8,7 +8,7 @@ import pytest from ag_ui.core import StateSnapshotEvent -from agent_framework import ChatAgent, ChatResponseUpdate, TextContent +from agent_framework import ChatAgent, ChatResponseUpdate, Content from agent_framework_ag_ui._agent import AgentFrameworkAgent from agent_framework_ag_ui._events import AgentFrameworkEventBridge @@ -20,7 +20,7 @@ @pytest.fixture def mock_agent() -> ChatAgent: """Create a mock agent for testing.""" - updates = [ChatResponseUpdate(contents=[TextContent(text="Hello!")])] + updates = [ChatResponseUpdate(contents=[Content.from_text(text="Hello!")])] chat_client = StreamingChatClientStub(stream_from_updates(updates)) return ChatAgent(name="test_agent", instructions="Test agent", chat_client=chat_client) diff --git a/python/packages/ag-ui/tests/test_state_manager.py b/python/packages/ag-ui/tests/test_state_manager.py index bc0a7b6a19..47b2940978 100644 --- a/python/packages/ag-ui/tests/test_state_manager.py +++ b/python/packages/ag-ui/tests/test_state_manager.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. from ag_ui.core import CustomEvent, EventType -from agent_framework import ChatMessage, TextContent +from agent_framework import ChatMessage from agent_framework_ag_ui._events import AgentFrameworkEventBridge from agent_framework_ag_ui._orchestration._state_manager import StateManager @@ -47,5 +47,5 @@ def test_state_context_only_when_new_user_turn() -> None: message = state_manager.state_context_message(is_new_user_turn=True, conversation_has_tool_calls=False) assert isinstance(message, ChatMessage) - assert isinstance(message.contents[0], TextContent) + assert message.contents[0].type == "text" assert "Current state of the application" in message.contents[0].text diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py index b9a04353be..7c623f62d6 100644 --- a/python/packages/ag-ui/tests/test_structured_output.py +++ b/python/packages/ag-ui/tests/test_structured_output.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Any -from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent +from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, Content from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) @@ -43,7 +43,7 @@ async def stream_fn( messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: yield ChatResponseUpdate( - contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')] + contents=[Content.from_text(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')] ) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -86,7 +86,7 @@ async def stream_fn( {"id": "2", "description": "Step 2", "status": "pending"}, ] } - yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))]) + yield ChatResponseUpdate(contents=[Content.from_text(text=json.dumps(steps_data))]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.default_options = ChatOptions(response_format=StepsOutput) @@ -118,7 +118,7 @@ async def test_structured_output_with_no_schema_match(): from agent_framework.ag_ui import AgentFrameworkAgent updates = [ - ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]), + ChatResponseUpdate(contents=[Content.from_text(text='{"data": {"key": "value"}}')]), ] agent = ChatAgent( @@ -156,7 +156,7 @@ class DataOutput(BaseModel): async def stream_fn( messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')]) + yield ChatResponseUpdate(contents=[Content.from_text(text='{"data": {"key": "value"}, "info": "processed"}')]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.default_options = ChatOptions(response_format=DataOutput) @@ -185,7 +185,7 @@ async def test_no_structured_output_when_no_response_format(): """Test that structured output path is skipped when no response_format.""" from agent_framework.ag_ui import AgentFrameworkAgent - updates = [ChatResponseUpdate(contents=[TextContent(text="Regular text")])] + updates = [ChatResponseUpdate(contents=[Content.from_text(text="Regular text")])] agent = ChatAgent( name="test", @@ -216,7 +216,7 @@ async def stream_fn( messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"} - yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))]) + yield ChatResponseUpdate(contents=[Content.from_text(text=json.dumps(output_data))]) agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) agent.default_options = ChatOptions(response_format=RecipeOutput) diff --git a/python/packages/ag-ui/tests/utils_test_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py index c3fa590cd1..33f462257e 100644 --- a/python/packages/ag-ui/tests/utils_test_ag_ui.py +++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py @@ -16,7 +16,7 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - TextContent, + Content, ) from agent_framework._clients import TOptions_co @@ -91,7 +91,7 @@ def __init__( self.id = agent_id self.name = agent_name self.description = "stub agent" - self.updates = updates or [AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")] + self.updates = updates or [AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")] self.default_options: dict[str, Any] = ( default_options if isinstance(default_options, dict) else {"tools": None, "response_format": None} ) diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index c9223e614b..c95851a3c8 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -7,31 +7,19 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, AIFunction, - Annotations, + Annotation, BaseChatClient, ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, - CitationAnnotation, - CodeInterpreterToolCallContent, - CodeInterpreterToolResultContent, - Contents, - ErrorContent, + Content, FinishReason, - FunctionCallContent, - FunctionResultContent, HostedCodeInterpreterTool, - HostedFileContent, HostedMCPTool, HostedWebSearchTool, - MCPServerToolCallContent, - MCPServerToolResultContent, Role, - TextContent, - TextReasoningContent, TextSpanRegion, - UsageContent, UsageDetails, get_logger, prepare_function_call_results, @@ -653,9 +641,9 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons """ match event.type: case "message_start": - usage_details: list[UsageContent] = [] + usage_details: list[Content] = [] if event.message.usage and (details := self._parse_usage_from_anthropic(event.message.usage)): - usage_details.append(UsageContent(details=details)) + usage_details.append(Content.from_usage(usage_details=details)) return ChatResponseUpdate( response_id=event.message.id, @@ -672,7 +660,7 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons case "message_delta": usage = self._parse_usage_from_anthropic(event.usage) return ChatResponseUpdate( - contents=[UsageContent(details=usage, raw_representation=event.usage)] if usage else [], + contents=[Content.from_usage(usage_details=usage, raw_representation=event.usage)] if usage else [], finish_reason=FINISH_REASON_MAP.get(event.delta.stop_reason) if event.delta.stop_reason else None, raw_representation=event, ) @@ -702,24 +690,24 @@ def _parse_usage_from_anthropic(self, usage: BetaUsage | BetaMessageDeltaUsage | return None usage_details = UsageDetails(output_token_count=usage.output_tokens) if usage.input_tokens is not None: - usage_details.input_token_count = usage.input_tokens + usage_details["input_token_count"] = usage.input_tokens if usage.cache_creation_input_tokens is not None: - usage_details.additional_counts["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens + usage_details["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens if usage.cache_read_input_tokens is not None: - usage_details.additional_counts["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens + usage_details["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens return usage_details def _parse_contents_from_anthropic( self, content: Sequence[BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock], - ) -> list[Contents]: + ) -> list[Content]: """Parse contents from the Anthropic message.""" - contents: list[Contents] = [] + contents: list[Content] = [] for content_block in content: match content_block.type: case "text" | "text_delta": contents.append( - TextContent( + Content.from_text( text=content_block.text, raw_representation=content_block, annotations=self._parse_citations_from_anthropic(content_block), @@ -729,7 +717,7 @@ def _parse_contents_from_anthropic( self._last_call_id_name = (content_block.id, content_block.name) if content_block.type == "mcp_tool_use": contents.append( - MCPServerToolCallContent( + Content.from_mcp_server_tool_call( call_id=content_block.id, tool_name=content_block.name, server_name=None, @@ -739,10 +727,10 @@ def _parse_contents_from_anthropic( ) elif "code_execution" in (content_block.name or ""): contents.append( - CodeInterpreterToolCallContent( + Content.from_code_interpreter_tool_call( call_id=content_block.id, inputs=[ - TextContent( + Content.from_text( text=str(content_block.input), raw_representation=content_block, ) @@ -752,7 +740,7 @@ def _parse_contents_from_anthropic( ) else: contents.append( - FunctionCallContent( + Content.from_function_call( call_id=content_block.id, name=content_block.name, arguments=content_block.input, @@ -761,13 +749,13 @@ def _parse_contents_from_anthropic( ) case "mcp_tool_result": call_id, name = self._last_call_id_name or (None, None) - parsed_output: list[Contents] | None = None + parsed_output: list[Content] | None = None if content_block.content: if isinstance(content_block.content, list): parsed_output = self._parse_contents_from_anthropic(content_block.content) elif isinstance(content_block.content, (str, bytes)): parsed_output = [ - TextContent( + Content.from_text( text=str(content_block.content), raw_representation=content_block, ) @@ -775,28 +763,27 @@ def _parse_contents_from_anthropic( else: parsed_output = self._parse_contents_from_anthropic([content_block.content]) contents.append( - MCPServerToolResultContent( + Content.from_mcp_server_tool_result( call_id=content_block.tool_use_id, output=parsed_output, raw_representation=content_block, ) ) case "web_search_tool_result" | "web_fetch_tool_result": - call_id, name = self._last_call_id_name or (None, None) + call_id, _ = self._last_call_id_name or (None, None) contents.append( - FunctionResultContent( + Content.from_function_result( call_id=content_block.tool_use_id, - name=name if name and call_id == content_block.tool_use_id else "web_tool", result=content_block.content, raw_representation=content_block, ) ) case "code_execution_tool_result": - code_outputs: list[Contents] = [] + code_outputs: list[Content] = [] if content_block.content: if isinstance(content_block.content, BetaCodeExecutionToolResultError): code_outputs.append( - ErrorContent( + Content.from_error( message=content_block.content.error_code, raw_representation=content_block.content, ) @@ -804,41 +791,41 @@ def _parse_contents_from_anthropic( else: if content_block.content.stdout: code_outputs.append( - TextContent( + Content.from_text( text=content_block.content.stdout, raw_representation=content_block.content, ) ) if content_block.content.stderr: code_outputs.append( - ErrorContent( + Content.from_error( message=content_block.content.stderr, raw_representation=content_block.content, ) ) for code_file_content in content_block.content.content: code_outputs.append( - HostedFileContent( + Content.from_hosted_file( file_id=code_file_content.file_id, raw_representation=code_file_content, ) ) contents.append( - CodeInterpreterToolResultContent( + Content.from_code_interpreter_tool_result( call_id=content_block.tool_use_id, raw_representation=content_block, outputs=code_outputs, ) ) case "bash_code_execution_tool_result": - bash_outputs: list[Contents] = [] + bash_outputs: list[Content] = [] if content_block.content: if isinstance( content_block.content, BetaBashCodeExecutionToolResultError, ): bash_outputs.append( - ErrorContent( + Content.from_error( message=content_block.content.error_code, raw_representation=content_block.content, ) @@ -846,39 +833,38 @@ def _parse_contents_from_anthropic( else: if content_block.content.stdout: bash_outputs.append( - TextContent( + Content.from_text( text=content_block.content.stdout, raw_representation=content_block.content, ) ) if content_block.content.stderr: bash_outputs.append( - ErrorContent( + Content.from_error( message=content_block.content.stderr, raw_representation=content_block.content, ) ) for bash_file_content in content_block.content.content: contents.append( - HostedFileContent( + Content.from_hosted_file( file_id=bash_file_content.file_id, raw_representation=bash_file_content, ) ) contents.append( - FunctionResultContent( + Content.from_function_result( call_id=content_block.tool_use_id, - name=content_block.type, result=bash_outputs, raw_representation=content_block, ) ) case "text_editor_code_execution_tool_result": - text_editor_outputs: list[Contents] = [] + text_editor_outputs: list[Content] = [] match content_block.content.type: case "text_editor_code_execution_tool_result_error": text_editor_outputs.append( - ErrorContent( + Content.from_error( message=content_block.content.error_code and getattr(content_block.content, "error_message", ""), raw_representation=content_block.content, @@ -887,10 +873,12 @@ def _parse_contents_from_anthropic( case "text_editor_code_execution_view_result": annotations = ( [ - CitationAnnotation( + Annotation( + type="citation", raw_representation=content_block.content, annotated_regions=[ TextSpanRegion( + type="text_span", start_index=content_block.content.start_line, end_index=content_block.content.start_line + (content_block.content.num_lines or 0), @@ -903,7 +891,7 @@ def _parse_contents_from_anthropic( else None ) text_editor_outputs.append( - TextContent( + Content.from_text( text=content_block.content.content, annotations=annotations, raw_representation=content_block.content, @@ -911,10 +899,12 @@ def _parse_contents_from_anthropic( ) case "text_editor_code_execution_str_replace_result": old_annotation = ( - CitationAnnotation( + Annotation( + type="citation", raw_representation=content_block.content, annotated_regions=[ TextSpanRegion( + type="text_span", start_index=content_block.content.old_start or 0, end_index=( (content_block.content.old_start or 0) @@ -928,13 +918,15 @@ def _parse_contents_from_anthropic( else None ) new_annotation = ( - CitationAnnotation( + Annotation( + type="citation", raw_representation=content_block.content, snippet="\n".join(content_block.content.lines) if content_block.content.lines else None, annotated_regions=[ TextSpanRegion( + type="text_span", start_index=content_block.content.new_start or 0, end_index=( (content_block.content.new_start or 0) @@ -950,7 +942,7 @@ def _parse_contents_from_anthropic( annotations = [ann for ann in [old_annotation, new_annotation] if ann is not None] text_editor_outputs.append( - TextContent( + Content.from_text( text=( "\n".join(content_block.content.lines) if content_block.content.lines else "" ), @@ -960,15 +952,14 @@ def _parse_contents_from_anthropic( ) case "text_editor_code_execution_create_result": text_editor_outputs.append( - TextContent( + Content.from_text( text=f"File update: {content_block.content.is_file_update}", raw_representation=content_block.content, ) ) contents.append( - FunctionResultContent( + Content.from_function_result( call_id=content_block.tool_use_id, - name=content_block.type, result=text_editor_outputs, raw_representation=content_block, ) @@ -981,7 +972,7 @@ def _parse_contents_from_anthropic( # This matches OpenAI's behavior where streaming chunks have name="". call_id, _ = self._last_call_id_name if self._last_call_id_name else ("", "") contents.append( - FunctionCallContent( + Content.from_function_call( call_id=call_id, name="", arguments=content_block.partial_json, @@ -990,7 +981,7 @@ def _parse_contents_from_anthropic( ) case "thinking" | "thinking_delta": contents.append( - TextReasoningContent( + Content.from_text_reasoning( text=content_block.thinking, raw_representation=content_block, ) @@ -1001,65 +992,65 @@ def _parse_contents_from_anthropic( def _parse_citations_from_anthropic( self, content_block: BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock - ) -> list[Annotations] | None: - content_citations = getattr(content_block, "citations", None) - if not content_citations: + ) -> list[Annotation] | None: + content_blocks = getattr(content_block, "citations", None) + if not content_blocks: return None - annotations: list[Annotations] = [] - for citation in content_citations: - cit = CitationAnnotation(raw_representation=citation) + annotations: list[Annotation] = [] + for citation in content_blocks: + cit = Annotation(type="citation", raw_representation=citation) match citation.type: case "char_location": - cit.title = citation.title - cit.snippet = citation.cited_text + cit["title"] = citation.title + cit["snippet"] = citation.cited_text if citation.file_id: - cit.file_id = citation.file_id - if not cit.annotated_regions: - cit.annotated_regions = [] - cit.annotated_regions.append( + cit["file_id"] = citation.file_id + cit.setdefault("annotated_regions", []) + cit["annotated_regions"].append( TextSpanRegion( + type="text_span", start_index=citation.start_char_index, end_index=citation.end_char_index, ) ) case "page_location": - cit.title = citation.document_title - cit.snippet = citation.cited_text + cit["title"] = citation.document_title + cit["snippet"] = citation.cited_text if citation.file_id: - cit.file_id = citation.file_id - if not cit.annotated_regions: - cit.annotated_regions = [] - cit.annotated_regions.append( + cit["file_id"] = citation.file_id + cit.setdefault("annotated_regions", []) + cit["annotated_regions"].append( TextSpanRegion( + type="text_span", start_index=citation.start_page_number, end_index=citation.end_page_number, ) ) case "content_block_location": - cit.title = citation.document_title - cit.snippet = citation.cited_text + cit["title"] = citation.document_title + cit["snippet"] = citation.cited_text if citation.file_id: - cit.file_id = citation.file_id - if not cit.annotated_regions: - cit.annotated_regions = [] - cit.annotated_regions.append( + cit["file_id"] = citation.file_id + cit.setdefault("annotated_regions", []) + cit["annotated_regions"].append( TextSpanRegion( + type="text_span", start_index=citation.start_block_index, end_index=citation.end_block_index, ) ) case "web_search_result_location": - cit.title = citation.title - cit.snippet = citation.cited_text - cit.url = citation.url + cit["title"] = citation.title + cit["snippet"] = citation.cited_text + cit["url"] = citation.url case "search_result_location": - cit.title = citation.title - cit.snippet = citation.cited_text - cit.url = citation.source - if not cit.annotated_regions: - cit.annotated_regions = [] - cit.annotated_regions.append( + cit["title"] = citation.title + cit["snippet"] = citation.cited_text + cit["url"] = citation.source + cit.setdefault("annotated_regions", []) + cit["annotated_regions"].append( TextSpanRegion( + type="text_span", start_index=citation.start_block_index, end_index=citation.end_block_index, ) diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index 828d9916c2..4476c6b3b6 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -10,16 +10,12 @@ ChatMessage, ChatOptions, ChatResponseUpdate, - DataContent, + Content, FinishReason, - FunctionCallContent, - FunctionResultContent, HostedCodeInterpreterTool, HostedMCPTool, HostedWebSearchTool, Role, - TextContent, - TextReasoningContent, ai_function, ) from agent_framework.exceptions import ServiceInitializationError @@ -170,7 +166,7 @@ def test_prepare_message_for_anthropic_function_call(mock_anthropic_client: Magi message = ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_123", name="get_weather", arguments={"location": "San Francisco"}, @@ -194,9 +190,8 @@ def test_prepare_message_for_anthropic_function_result(mock_anthropic_client: Ma message = ChatMessage( role=Role.TOOL, contents=[ - FunctionResultContent( + Content.from_function_result( call_id="call_123", - name="get_weather", result="Sunny, 72°F", ) ], @@ -219,7 +214,7 @@ def test_prepare_message_for_anthropic_text_reasoning(mock_anthropic_client: Mag chat_client = create_test_anthropic_client(mock_anthropic_client) message = ChatMessage( role=Role.ASSISTANT, - contents=[TextReasoningContent(text="Let me think about this...")], + contents=[Content.from_text_reasoning(text="Let me think about this...")], ) result = chat_client._prepare_message_for_anthropic(message) @@ -507,12 +502,12 @@ def test_process_message_basic(mock_anthropic_client: MagicMock) -> None: assert len(response.messages) == 1 assert response.messages[0].role == Role.ASSISTANT assert len(response.messages[0].contents) == 1 - assert isinstance(response.messages[0].contents[0], TextContent) + assert response.messages[0].contents[0].type == "text" assert response.messages[0].contents[0].text == "Hello there!" assert response.finish_reason == FinishReason.STOP assert response.usage_details is not None - assert response.usage_details.input_token_count == 10 - assert response.usage_details.output_token_count == 5 + assert response.usage_details["input_token_count"] == 10 + assert response.usage_details["output_token_count"] == 5 def test_process_message_with_tool_use(mock_anthropic_client: MagicMock) -> None: @@ -536,7 +531,7 @@ def test_process_message_with_tool_use(mock_anthropic_client: MagicMock) -> None response = chat_client._process_message(mock_message) assert len(response.messages[0].contents) == 1 - assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[0].contents[0].type == "function_call" assert response.messages[0].contents[0].call_id == "call_123" assert response.messages[0].contents[0].name == "get_weather" assert response.finish_reason == FinishReason.TOOL_CALLS @@ -550,8 +545,8 @@ def test_parse_usage_from_anthropic_basic(mock_anthropic_client: MagicMock) -> N result = chat_client._parse_usage_from_anthropic(usage) assert result is not None - assert result.input_token_count == 10 - assert result.output_token_count == 5 + assert result["input_token_count"] == 10 + assert result["output_token_count"] == 5 def test_parse_usage_from_anthropic_none(mock_anthropic_client: MagicMock) -> None: @@ -571,7 +566,7 @@ def test_parse_contents_from_anthropic_text(mock_anthropic_client: MagicMock) -> result = chat_client._parse_contents_from_anthropic(content) assert len(result) == 1 - assert isinstance(result[0], TextContent) + assert result[0].type == "text" assert result[0].text == "Hello!" @@ -590,7 +585,7 @@ def test_parse_contents_from_anthropic_tool_use(mock_anthropic_client: MagicMock result = chat_client._parse_contents_from_anthropic(content) assert len(result) == 1 - assert isinstance(result[0], FunctionCallContent) + assert result[0].type == "function_call" assert result[0].call_id == "call_123" assert result[0].name == "get_weather" @@ -613,7 +608,7 @@ def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_a result = chat_client._parse_contents_from_anthropic([tool_use_content]) assert len(result) == 1 - assert isinstance(result[0], FunctionCallContent) + assert result[0].type == "function_call" assert result[0].call_id == "call_123" assert result[0].name == "get_weather" # Initial event has name @@ -624,7 +619,7 @@ def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_a result = chat_client._parse_contents_from_anthropic([delta_content_1]) assert len(result) == 1 - assert isinstance(result[0], FunctionCallContent) + assert result[0].type == "function_call" assert result[0].call_id == "call_123" assert result[0].name == "" # Delta events should have empty name assert result[0].arguments == '{"location":' @@ -636,7 +631,7 @@ def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_a result = chat_client._parse_contents_from_anthropic([delta_content_2]) assert len(result) == 1 - assert isinstance(result[0], FunctionCallContent) + assert result[0].type == "function_call" assert result[0].call_id == "call_123" assert result[0].name == "" # Still empty name for subsequent deltas assert result[0].arguments == '"San Francisco"}' @@ -771,9 +766,7 @@ async def test_anthropic_client_integration_function_calling() -> None: assert response is not None # Should contain function call - has_function_call = any( - isinstance(content, FunctionCallContent) for msg in response.messages for content in msg.contents - ) + has_function_call = any(content.type == "function_call" for msg in response.messages for content in msg.contents) assert has_function_call @@ -872,8 +865,8 @@ async def test_anthropic_client_integration_images() -> None: ChatMessage( role=Role.USER, contents=[ - TextContent(text="Describe this image"), - DataContent(media_type="image/jpeg", data=image_bytes), + Content.from_text(text="Describe this image"), + Content.from_data(media_type="image/jpeg", data=image_bytes), ], ), ] diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 6973517c14..e9dca3d1dd 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -9,6 +9,8 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, + AIFunction, + Annotation, BaseChatClient, ChatAgent, ChatMessage, @@ -16,23 +18,14 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - CitationAnnotation, - Contents, - ContextProvider, - DataContent, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, - HostedFileContent, + Content, + HostedCodeInterpreterTool, + HostedFileSearchTool, HostedMCPTool, - Middleware, + HostedWebSearchTool, Role, - TextContent, TextSpanRegion, ToolProtocol, - UriContent, - UsageContent, UsageDetails, get_logger, prepare_function_call_results, @@ -422,7 +415,7 @@ async def _create_agent_stream( self, agent_id: str, run_options: dict[str, Any], - required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None, + required_action_results: list[Content] | None, ) -> tuple[AsyncAgentRunStream[AsyncAgentEventHandler[Any]] | AsyncAgentEventHandler[Any], str]: """Create the agent stream for processing. @@ -506,9 +499,9 @@ async def _prepare_thread( def _extract_url_citations( self, message_delta_chunk: MessageDeltaChunk, azure_search_tool_calls: list[dict[str, Any]] - ) -> list[CitationAnnotation]: + ) -> list[Annotation]: """Extract URL citations from MessageDeltaChunk.""" - url_citations: list[CitationAnnotation] = [] + url_citations: list[Annotation] = [] # Process each content item in the delta to find citations for content in message_delta_chunk.delta.content: @@ -520,6 +513,7 @@ def _extract_url_citations( if annotation.start_index and annotation.end_index: annotated_regions = [ TextSpanRegion( + type="text_span", start_index=annotation.start_index, end_index=annotation.end_index, ) @@ -530,9 +524,10 @@ def _extract_url_citations( annotation.url_citation.url, azure_search_tool_calls ) - # Create CitationAnnotation with real URL - citation = CitationAnnotation( - title=getattr(annotation.url_citation, "title", None), + # Create Annotation with real URL + citation = Annotation( + type="citation", + title=annotation.url_citation.title, url=real_url, snippet=None, annotated_regions=annotated_regions, @@ -542,7 +537,7 @@ def _extract_url_citations( return url_citations - def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> list[HostedFileContent]: + def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> list[Content]: """Extract file references from MessageDeltaChunk annotations. Code interpreter generates files that are referenced via file path or file citation @@ -559,7 +554,7 @@ def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> Returns: List of HostedFileContent objects for any files referenced in annotations """ - file_contents: list[HostedFileContent] = [] + file_contents: list[Content] = [] for content in message_delta_chunk.delta.content: if isinstance(content, MessageDeltaTextContent) and content.text and content.text.annotations: @@ -570,14 +565,14 @@ def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> if file_path is not None: file_id = getattr(file_path, "file_id", None) if file_id: - file_contents.append(HostedFileContent(file_id=file_id)) + file_contents.append(Content.from_hosted_file(file_id=file_id)) elif isinstance(annotation, MessageDeltaTextFileCitationAnnotation): # Extract file_id from the file_citation annotation file_citation = getattr(annotation, "file_citation", None) if file_citation is not None: file_id = getattr(file_citation, "file_id", None) if file_id: - file_contents.append(HostedFileContent(file_id=file_id)) + file_contents.append(Content.from_hosted_file(file_id=file_id)) return file_contents @@ -644,9 +639,9 @@ async def _process_stream( file_contents = self._extract_file_path_contents(event_data) # Create contents with citations if any exist - citation_content: list[Contents] = [] + citation_content: list[Content] = [] if event_data.text or url_citations: - text_content_obj = TextContent(text=event_data.text or "") + text_content_obj = Content.from_text(text=event_data.text or "") if url_citations: text_content_obj.annotations = url_citations citation_content.append(text_content_obj) @@ -722,7 +717,7 @@ async def _process_stream( self._capture_azure_search_tool_calls(event_data, azure_search_tool_calls) if event_data.usage: - usage_content = UsageContent( + usage_content = Content.from_usage( UsageDetails( input_token_count=event_data.usage.prompt_tokens, output_token_count=event_data.usage.completion_tokens, @@ -757,19 +752,21 @@ async def _process_stream( tool_call.code_interpreter, RunStepDeltaCodeInterpreterDetailItemObject, ): - code_contents: list[Contents] = [] + code_contents: list[Content] = [] if tool_call.code_interpreter.input is not None: logger.debug(f"Code Interpreter Input: {tool_call.code_interpreter.input}") if tool_call.code_interpreter.outputs is not None: for output in tool_call.code_interpreter.outputs: if isinstance(output, RunStepDeltaCodeInterpreterLogOutput) and output.logs: - code_contents.append(TextContent(text=output.logs)) + code_contents.append(Content.from_text(text=output.logs)) if ( isinstance(output, RunStepDeltaCodeInterpreterImageOutput) and output.image is not None and output.image.file_id is not None ): - code_contents.append(HostedFileContent(file_id=output.image.file_id)) + code_contents.append( + Content.from_hosted_file(file_id=output.image.file_id) + ) yield ChatResponseUpdate( role=Role.ASSISTANT, contents=code_contents, @@ -822,12 +819,12 @@ def _capture_azure_search_tool_calls( except Exception as ex: logger.debug(f"Failed to capture Azure AI Search tool call: {ex}") - def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]: + def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id: str | None) -> list[Content]: """Parse function call contents from an Azure AI tool action event.""" if isinstance(event_data, ThreadRun) and event_data.required_action is not None: if isinstance(event_data.required_action, SubmitToolOutputsAction): return [ - FunctionCallContent( + Content.from_function_call( call_id=f'["{response_id}", "{tool.id}"]', name=tool.function.name, arguments=tool.function.arguments, @@ -837,9 +834,9 @@ def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id ] if isinstance(event_data.required_action, SubmitToolApprovalAction): return [ - FunctionApprovalRequestContent( + Content.from_function_approval_request( id=f'["{response_id}", "{tool.id}"]', - function_call=FunctionCallContent( + function_call=Content.from_function_call( call_id=f'["{response_id}", "{tool.id}"]', name=tool.name, arguments=tool.arguments, @@ -875,7 +872,7 @@ async def _prepare_options( messages: MutableSequence[ChatMessage], options: Mapping[str, Any], **kwargs: Any, - ) -> tuple[dict[str, Any], list[FunctionResultContent | FunctionApprovalResponseContent] | None]: + ) -> tuple[dict[str, Any], list[Content] | None]: agent_definition = await self._load_agent_definition_if_needed() # Build run_options from options dict, excluding specific keys @@ -1052,7 +1049,7 @@ def _prepare_messages( ) -> tuple[ list[ThreadMessageOptions] | None, list[str], - list[FunctionResultContent | FunctionApprovalResponseContent] | None, + list[Content] | None, ]: """Prepare messages for Azure AI Agents API. @@ -1064,28 +1061,34 @@ def _prepare_messages( Tuple of (additional_messages, instructions, required_action_results) """ instructions: list[str] = [] - required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None = None + required_action_results: list[Content] | None = None additional_messages: list[ThreadMessageOptions] | None = None for chat_message in messages: if chat_message.role.value in ["system", "developer"]: - for text_content in [content for content in chat_message.contents if isinstance(content, TextContent)]: + for text_content in [content for content in chat_message.contents if content.type == "text"]: instructions.append(text_content.text) continue message_contents: list[MessageInputContentBlock] = [] for content in chat_message.contents: - if isinstance(content, TextContent): - message_contents.append(MessageInputTextBlock(text=content.text)) - elif isinstance(content, (DataContent, UriContent)) and content.has_top_level_media_type("image"): - message_contents.append(MessageInputImageUrlBlock(image_url=MessageImageUrlParam(url=content.uri))) - elif isinstance(content, (FunctionResultContent, FunctionApprovalResponseContent)): - if required_action_results is None: - required_action_results = [] - required_action_results.append(content) - elif isinstance(content.raw_representation, MessageInputContentBlock): - message_contents.append(content.raw_representation) + match content.type: + case "text": + message_contents.append(MessageInputTextBlock(text=content.text)) + case "data" | "uri": + if content.has_top_level_media_type("image"): + message_contents.append( + MessageInputImageUrlBlock(image_url=MessageImageUrlParam(url=content.uri)) + ) + # Only images are supported. Other media types are ignored. + case "function_result" | "function_approval_response": + if required_action_results is None: + required_action_results = [] + required_action_results.append(content) + case _: + if isinstance(content.raw_representation, MessageInputContentBlock): + message_contents.append(content.raw_representation) if message_contents: if additional_messages is None: @@ -1099,9 +1102,85 @@ def _prepare_messages( return additional_messages, instructions, required_action_results + async def _prepare_tools_for_azure_ai( + self, tools: Sequence["ToolProtocol | MutableMapping[str, Any]"], run_options: dict[str, Any] | None = None + ) -> list[ToolDefinition | dict[str, Any]]: + """Prepare tool definitions for the Azure AI Agents API.""" + tool_definitions: list[ToolDefinition | dict[str, Any]] = [] + for tool in tools: + match tool: + case AIFunction(): + tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType] + case HostedWebSearchTool(): + additional_props = tool.additional_properties or {} + config_args: dict[str, Any] = {} + if count := additional_props.get("count"): + config_args["count"] = count + if freshness := additional_props.get("freshness"): + config_args["freshness"] = freshness + if market := additional_props.get("market"): + config_args["market"] = market + if set_lang := additional_props.get("set_lang"): + config_args["set_lang"] = set_lang + # Bing Grounding + connection_id = additional_props.get("connection_id") or os.getenv("BING_CONNECTION_ID") + # Custom Bing Search + custom_connection_id = additional_props.get("custom_connection_id") or os.getenv( + "BING_CUSTOM_CONNECTION_ID" + ) + custom_instance_name = additional_props.get("custom_instance_name") or os.getenv( + "BING_CUSTOM_INSTANCE_NAME" + ) + bing_search: BingGroundingTool | BingCustomSearchTool | None = None + if (connection_id) and not custom_connection_id and not custom_instance_name: + if connection_id: + conn_id = connection_id + else: + raise ServiceInitializationError("Parameter connection_id is not provided.") + bing_search = BingGroundingTool(connection_id=conn_id, **config_args) + if custom_connection_id and custom_instance_name: + bing_search = BingCustomSearchTool( + connection_id=custom_connection_id, + instance_name=custom_instance_name, + **config_args, + ) + if not bing_search: + raise ServiceInitializationError( + "Bing search tool requires either 'connection_id' for Bing Grounding " + "or both 'custom_connection_id' and 'custom_instance_name' for Custom Bing Search. " + "These can be provided via additional_properties or environment variables: " + "'BING_CONNECTION_ID', 'BING_CUSTOM_CONNECTION_ID', " + "'BING_CUSTOM_INSTANCE_NAME'" + ) + tool_definitions.extend(bing_search.definitions) + case HostedCodeInterpreterTool(): + tool_definitions.append(CodeInterpreterToolDefinition()) + case HostedMCPTool(): + mcp_tool = McpTool( + server_label=tool.name.replace(" ", "_"), + server_url=str(tool.url), + allowed_tools=list(tool.allowed_tools) if tool.allowed_tools else [], + ) + tool_definitions.extend(mcp_tool.definitions) + case HostedFileSearchTool(): + vector_stores = [inp for inp in tool.inputs or [] if inp.type == "hosted_vector_store"] + if vector_stores: + file_search = FileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) + tool_definitions.extend(file_search.definitions) + # Set tool_resources for file search to work properly with Azure AI + if run_options is not None and "tool_resources" not in run_options: + run_options["tool_resources"] = file_search.resources + case ToolDefinition(): + tool_definitions.append(tool) + case dict(): + tool_definitions.append(tool) + case _: + raise ServiceInitializationError(f"Unsupported tool type: {type(tool)}") + return tool_definitions + def _prepare_tool_outputs_for_azure_ai( self, - required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None, + required_action_results: list[Content] | None, ) -> tuple[str | None, list[ToolOutput] | None, list[ToolApproval] | None]: """Prepare function results and approvals for submission to the Azure AI API.""" run_id: str | None = None @@ -1115,9 +1194,7 @@ def _prepare_tool_outputs_for_azure_ai( # We need to extract the run ID and ensure that the Output/Approval we send back to Azure # is only the call ID. run_and_call_ids: list[str] = ( - json.loads(content.call_id) - if isinstance(content, FunctionResultContent) - else json.loads(content.id) + json.loads(content.call_id) if content.type == "function_result" else json.loads(content.id) ) if ( @@ -1132,13 +1209,13 @@ def _prepare_tool_outputs_for_azure_ai( run_id = run_and_call_ids[0] call_id = run_and_call_ids[1] - if isinstance(content, FunctionResultContent): + if content.type == "function_result": if tool_outputs is None: tool_outputs = [] tool_outputs.append( ToolOutput(tool_call_id=call_id, output=prepare_function_call_results(content.result)) ) - elif isinstance(content, FunctionApprovalResponseContent): + elif content.type == "function_approval_response": if tool_approvals is None: tool_approvals = [] tool_approvals.append(ToolApproval(tool_call_id=call_id, approve=content.approved)) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index 3dfb1766b2..a78379eb61 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -12,7 +12,6 @@ ContextProvider, HostedMCPTool, Middleware, - TextContent, ToolProtocol, get_logger, use_chat_middleware, @@ -477,7 +476,7 @@ def _prepare_messages_for_azure_ai( # System/developer messages are turned into instructions, since there is no such message roles in Azure AI. for message in messages: if message.role.value in ["system", "developer"]: - for text_content in [content for content in message.contents if isinstance(content, TextContent)]: + for text_content in [content for content in message.contents if content.type == "text"]: instructions_list.append(text_content.text) else: result.append(message) diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 21bedbf710..7b20caea7d 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -17,19 +17,12 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - CitationAnnotation, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, + Content, HostedCodeInterpreterTool, - HostedFileContent, HostedFileSearchTool, HostedMCPTool, - HostedVectorStoreContent, + HostedWebSearchTool, Role, - TextContent, - UriContent, ) from agent_framework._serialization import SerializationMixin from agent_framework.exceptions import ServiceInitializationError @@ -368,7 +361,7 @@ async def test_azure_ai_chat_client_prepare_options_with_image_content(mock_agen # Mock get_agent mock_agents_client.get_agent = AsyncMock(return_value=None) - image_content = UriContent(uri="https://example.com/image.jpg", media_type="image/jpeg") + image_content = Content.from_uri(uri="https://example.com/image.jpg", media_type="image/jpeg") messages = [ChatMessage(role=Role.USER, contents=[image_content])] run_options, _ = await chat_client._prepare_options(messages, {}) # type: ignore @@ -551,7 +544,7 @@ def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_basic(mock_agen result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore assert len(result) == 1 - assert isinstance(result[0], FunctionCallContent) + assert result[0].type == "function_call" assert result[0].name == "get_weather" assert result[0].call_id == '["response_123", "call_123"]' @@ -728,6 +721,121 @@ async def test_azure_ai_chat_client_prepare_options_mcp_with_headers(mock_agents assert mcp_resource["headers"] == headers +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding( + mock_agents_client: MagicMock, +) -> None: + """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Bing Grounding.""" + + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + web_search_tool = HostedWebSearchTool( + additional_properties={ + "connection_id": "test-connection-id", + "count": 5, + "freshness": "Day", + "market": "en-US", + "set_lang": "en", + } + ) + + # Mock BingGroundingTool + with patch("agent_framework_azure_ai._chat_client.BingGroundingTool") as mock_bing_grounding: + mock_bing_tool = MagicMock() + mock_bing_tool.definitions = [{"type": "bing_grounding"}] + mock_bing_grounding.return_value = mock_bing_tool + + result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore + + assert len(result) == 1 + assert result[0] == {"type": "bing_grounding"} + call_args = mock_bing_grounding.call_args[1] + assert call_args["count"] == 5 + assert call_args["freshness"] == "Day" + assert call_args["market"] == "en-US" + assert call_args["set_lang"] == "en" + assert "connection_id" in call_args + + +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding_with_connection_id( + mock_agents_client: MagicMock, +) -> None: + """Test _prepare_tools_... with HostedWebSearchTool using Bing Grounding with connection_id (no HTTP call).""" + + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + web_search_tool = HostedWebSearchTool( + additional_properties={ + "connection_id": "direct-connection-id", + "count": 3, + } + ) + + # Mock BingGroundingTool + with patch("agent_framework_azure_ai._chat_client.BingGroundingTool") as mock_bing_grounding: + mock_bing_tool = MagicMock() + mock_bing_tool.definitions = [{"type": "bing_grounding"}] + mock_bing_grounding.return_value = mock_bing_tool + + result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore + + assert len(result) == 1 + assert result[0] == {"type": "bing_grounding"} + mock_bing_grounding.assert_called_once_with(connection_id="direct-connection-id", count=3) + + +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_custom_bing( + mock_agents_client: MagicMock, +) -> None: + """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Custom Bing Search.""" + + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + web_search_tool = HostedWebSearchTool( + additional_properties={ + "custom_connection_id": "custom-connection-id", + "custom_instance_name": "custom-instance", + "count": 10, + } + ) + + # Mock BingCustomSearchTool + with patch("agent_framework_azure_ai._chat_client.BingCustomSearchTool") as mock_custom_bing: + mock_custom_tool = MagicMock() + mock_custom_tool.definitions = [{"type": "bing_custom_search"}] + mock_custom_bing.return_value = mock_custom_tool + + result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore + + assert len(result) == 1 + assert result[0] == {"type": "bing_custom_search"} + + +async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_file_search_with_vector_stores( + mock_agents_client: MagicMock, +) -> None: + """Test _prepare_tools_for_azure_ai with HostedFileSearchTool using vector stores.""" + + chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") + + vector_store_input = Content.from_hosted_vector_store(vector_store_id="vs-123") + file_search_tool = HostedFileSearchTool(inputs=[vector_store_input]) + + # Mock FileSearchTool + with patch("agent_framework_azure_ai._chat_client.FileSearchTool") as mock_file_search: + mock_file_tool = MagicMock() + mock_file_tool.definitions = [{"type": "file_search"}] + mock_file_tool.resources = {"vector_store_ids": ["vs-123"]} + mock_file_search.return_value = mock_file_tool + + run_options = {} + result = await chat_client._prepare_tools_for_azure_ai([file_search_tool], run_options) # type: ignore + + assert len(result) == 1 + assert result[0] == {"type": "file_search"} + assert run_options["tool_resources"] == {"vector_store_ids": ["vs-123"]} + mock_file_search.assert_called_once_with(vector_store_ids=["vs-123"]) + + async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals( mock_agents_client: MagicMock, ) -> None: @@ -741,9 +849,9 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals( chat_client._get_active_thread_run = AsyncMock(return_value=mock_thread_run) # type: ignore # Mock required action results with approval response that matches run ID - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id='["test-run-id", "test-call-id"]', - function_call=FunctionCallContent( + function_call=Content.from_function_call( call_id='["test-run-id", "test-call-id"]', name="test_function", arguments="{}" ), approved=True, @@ -839,7 +947,7 @@ async def test_azure_ai_chat_client_prepare_tool_outputs_for_azure_ai_function_r chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Test with simple result - function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result="Simple result") + function_result = Content.from_function_result(call_id='["run_123", "call_456"]', result="Simple result") run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore @@ -857,7 +965,7 @@ async def test_azure_ai_chat_client_convert_required_action_invalid_call_id(mock chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Invalid call_id format - should raise JSONDecodeError - function_result = FunctionResultContent(call_id="invalid_json", result="result") + function_result = Content.from_function_result(call_id="invalid_json", result="result") with pytest.raises(json.JSONDecodeError): chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore @@ -870,7 +978,7 @@ async def test_azure_ai_chat_client_convert_required_action_invalid_structure( chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Valid JSON but invalid structure (missing second element) - function_result = FunctionResultContent(call_id='["run_123"]', result="result") + function_result = Content.from_function_result(call_id='["run_123"]', result="result") run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore @@ -894,7 +1002,7 @@ def __init__(self, name: str, value: int): # Test with BaseModel result mock_result = MockResult(name="test", value=42) - function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result=mock_result) + function_result = Content.from_function_result(call_id='["run_123", "call_456"]', result=mock_result) run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore @@ -922,7 +1030,7 @@ def __init__(self, data: str): # Test with multiple results - mix of BaseModel and regular objects mock_basemodel = MockResult(data="model_data") results_list = [mock_basemodel, {"key": "value"}, "string_result"] - function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result=results_list) + function_result = Content.from_function_result(call_id='["run_123", "call_456"]', result=results_list) run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore @@ -948,9 +1056,11 @@ async def test_azure_ai_chat_client_convert_required_action_approval_response( chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent") # Test with approval response - need to provide required fields - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id='["run_123", "call_456"]', - function_call=FunctionCallContent(call_id='["run_123", "call_456"]', name="test_function", arguments="{}"), + function_call=Content.from_function_call( + call_id='["run_123", "call_456"]', name="test_function", arguments="{}" + ), approved=True, ) @@ -985,7 +1095,7 @@ async def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_approval_ result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore assert len(result) == 1 - assert isinstance(result[0], FunctionApprovalRequestContent) + assert result[0].type == "function_approval_request" assert result[0].id == '["response_123", "approval_call_123"]' assert result[0].function_call.name == "approve_action" assert result[0].function_call.call_id == '["response_123", "approval_call_123"]' @@ -1064,7 +1174,7 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_outputs( chat_client._get_active_thread_run = AsyncMock(return_value=mock_thread_run) # type: ignore # Mock required action results with matching run ID - function_result = FunctionResultContent(call_id='["test-run-id", "test-call-id"]', result="test result") + function_result = Content.from_function_result(call_id='["test-run-id", "test-call-id"]', result="test result") # Mock submit_tool_outputs_stream mock_handler = MagicMock() @@ -1115,14 +1225,13 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c # Verify results assert len(citations) == 1 citation = citations[0] - assert isinstance(citation, CitationAnnotation) - assert citation.url == "https://example.com/test" - assert citation.title == "Test Title" - assert citation.snippet is None - assert citation.annotated_regions is not None - assert len(citation.annotated_regions) == 1 - assert citation.annotated_regions[0].start_index == 10 - assert citation.annotated_regions[0].end_index == 20 + assert citation["url"] == "https://example.com/test" + assert citation["title"] == "Test Title" + assert citation["snippet"] is None + assert citation["annotated_regions"] is not None + assert len(citation["annotated_regions"]) == 1 + assert citation["annotated_regions"][0]["start_index"] == 10 + assert citation["annotated_regions"][0]["end_index"] == 20 def test_azure_ai_chat_client_extract_file_path_contents_with_file_path_annotation( @@ -1158,7 +1267,7 @@ def test_azure_ai_chat_client_extract_file_path_contents_with_file_path_annotati # Verify results assert len(file_contents) == 1 - assert isinstance(file_contents[0], HostedFileContent) + assert file_contents[0].type == "hosted_file" assert file_contents[0].file_id == "assistant-test-file-123" @@ -1195,7 +1304,7 @@ def test_azure_ai_chat_client_extract_file_path_contents_with_file_citation_anno # Verify results assert len(file_contents) == 1 - assert isinstance(file_contents[0], HostedFileContent) + assert file_contents[0].type == "hosted_file" assert file_contents[0].file_id == "cfile_test-citation-456" @@ -1305,7 +1414,7 @@ async def test_azure_ai_chat_client_streaming() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert any(word in full_message.lower() for word in ["sunny", "25"]) @@ -1331,7 +1440,7 @@ async def test_azure_ai_chat_client_streaming_tools() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert any(word in full_message.lower() for word in ["sunny", "25"]) @@ -1476,7 +1585,9 @@ async def test_azure_ai_chat_client_agent_file_search(): ) # 2. Create file search tool with uploaded resources - file_search_tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id=vector_store.id)]) + file_search_tool = HostedFileSearchTool( + inputs=[Content.from_hosted_vector_store(vector_store_id=vector_store.id)] + ) async with ChatAgent( chat_client=client, @@ -1795,7 +1906,7 @@ def test_azure_ai_chat_client_extract_url_citations_with_azure_search_enhanced_u # Verify real URL was used assert len(citations) == 1 citation = citations[0] - assert citation.url == "https://real-example.com/doc2" # doc_1 maps to index 1 + assert citation["url"] == "https://real-example.com/doc2" # doc_1 maps to index 1 def test_azure_ai_chat_client_init_with_auto_created_agents_client( diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index dad8f049fe..7e3f7e2db4 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -16,6 +16,7 @@ ChatMessage, ChatOptions, ChatResponse, + Content, HostedCodeInterpreterTool, HostedFileContent, HostedFileSearchTool, @@ -23,7 +24,6 @@ HostedVectorStoreContent, HostedWebSearchTool, Role, - TextContent, ) from agent_framework.exceptions import ServiceInitializationError from azure.ai.projects.aio import AIProjectClient @@ -298,9 +298,9 @@ async def test_prepare_messages_for_azure_ai_with_system_messages( client = create_test_azure_ai_client(mock_project_client) messages = [ - ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="You are a helpful assistant.")]), - ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="System response")]), + ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="You are a helpful assistant.")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="System response")]), ] result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore @@ -318,8 +318,8 @@ async def test_prepare_messages_for_azure_ai_no_system_messages( client = create_test_azure_ai_client(mock_project_client) messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Hi there!")]), ] result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore @@ -419,7 +419,7 @@ async def test_prepare_options_basic(mock_project_client: MagicMock) -> None: """Test prepare_options basic functionality.""" client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0") - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])] with ( patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), @@ -453,7 +453,7 @@ async def test_prepare_options_with_application_endpoint( agent_version="1", ) - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])] with ( patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), @@ -492,7 +492,7 @@ async def test_prepare_options_with_application_project_client( agent_version="1", ) - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])] with ( patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}), @@ -848,7 +848,7 @@ async def test_prepare_options_excludes_response_format( """Test that prepare_options excludes response_format, text, and text_format from final run options.""" client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0") - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])] chat_options: ChatOptions = {} with ( diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index c84a30fbb4..67d99e5be1 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -36,18 +36,8 @@ from agent_framework import ( AgentResponse, - BaseContent, ChatMessage, - DataContent, - ErrorContent, - FunctionCallContent, - FunctionResultContent, - HostedFileContent, - HostedVectorStoreContent, - TextContent, - TextReasoningContent, - UriContent, - UsageContent, + Content, UsageDetails, get_logger, ) @@ -290,25 +280,25 @@ def from_ai_content(content: Any) -> DurableAgentStateContent: The corresponding DurableAgentStateContent subclass instance """ # Map AI content type to appropriate DurableAgentStateContent subclass - if isinstance(content, DataContent): + if isinstance(content, Content) and content.type == "data": return DurableAgentStateDataContent.from_data_content(content) - if isinstance(content, ErrorContent): + if isinstance(content, Content) and content.type == "error": return DurableAgentStateErrorContent.from_error_content(content) - if isinstance(content, FunctionCallContent): + if isinstance(content, Content) and content.type == "function_call": return DurableAgentStateFunctionCallContent.from_function_call_content(content) - if isinstance(content, FunctionResultContent): + if isinstance(content, Content) and content.type == "function_result": return DurableAgentStateFunctionResultContent.from_function_result_content(content) - if isinstance(content, HostedFileContent): + if isinstance(content, Content) and content.type == "hosted_file": return DurableAgentStateHostedFileContent.from_hosted_file_content(content) - if isinstance(content, HostedVectorStoreContent): + if isinstance(content, Content) and content.type == "hosted_vector_store": return DurableAgentStateHostedVectorStoreContent.from_hosted_vector_store_content(content) - if isinstance(content, TextContent): + if isinstance(content, Content) and content.type == "text": return DurableAgentStateTextContent.from_text_content(content) - if isinstance(content, TextReasoningContent): + if isinstance(content, Content) and content.type == "text_reasoning": return DurableAgentStateTextReasoningContent.from_text_reasoning_content(content) - if isinstance(content, UriContent): + if isinstance(content, Content) and content.type == "uri": return DurableAgentStateUriContent.from_uri_content(content) - if isinstance(content, UsageContent): + if isinstance(content, Content) and content.type == "usage": return DurableAgentStateUsageContent.from_usage_content(content) return DurableAgentStateUnknownContent.from_unknown_content(content) @@ -868,11 +858,11 @@ def to_dict(self) -> dict[str, Any]: } @staticmethod - def from_data_content(content: DataContent) -> DurableAgentStateDataContent: + def from_data_content(content: Content) -> DurableAgentStateDataContent: return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) - def to_ai_content(self) -> DataContent: - return DataContent(uri=self.uri, media_type=self.media_type) + def to_ai_content(self) -> Content: + return Content.from_uri(uri=self.uri, media_type=self.media_type) class DurableAgentStateErrorContent(DurableAgentStateContent): @@ -907,13 +897,13 @@ def to_dict(self) -> dict[str, Any]: } @staticmethod - def from_error_content(content: ErrorContent) -> DurableAgentStateErrorContent: + def from_error_content(content: Content) -> DurableAgentStateErrorContent: return DurableAgentStateErrorContent( - message=content.message, error_code=content.error_code, details=content.details + message=content.message, error_code=content.error_code, details=content.error_details ) - def to_ai_content(self) -> ErrorContent: - return ErrorContent(message=self.message, error_code=self.error_code, details=self.details) + def to_ai_content(self) -> Content: + return Content.from_error(message=self.message, error_code=self.error_code, error_details=self.details) class DurableAgentStateFunctionCallContent(DurableAgentStateContent): @@ -949,7 +939,7 @@ def to_dict(self) -> dict[str, Any]: } @staticmethod - def from_function_call_content(content: FunctionCallContent) -> DurableAgentStateFunctionCallContent: + def from_function_call_content(content: Content) -> DurableAgentStateFunctionCallContent: # Ensure arguments is a dict; parse string if needed arguments: dict[str, Any] = {} if content.arguments: @@ -964,8 +954,8 @@ def from_function_call_content(content: FunctionCallContent) -> DurableAgentStat return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments) - def to_ai_content(self) -> FunctionCallContent: - return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments) + def to_ai_content(self) -> Content: + return Content.from_function_call(call_id=self.call_id, name=self.name, arguments=self.arguments) class DurableAgentStateFunctionResultContent(DurableAgentStateContent): @@ -997,11 +987,11 @@ def to_dict(self) -> dict[str, Any]: } @staticmethod - def from_function_result_content(content: FunctionResultContent) -> DurableAgentStateFunctionResultContent: + def from_function_result_content(content: Content) -> DurableAgentStateFunctionResultContent: return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) - def to_ai_content(self) -> FunctionResultContent: - return FunctionResultContent(call_id=self.call_id, result=self.result) + def to_ai_content(self) -> Content: + return Content.from_function_result(call_id=self.call_id, result=self.result) class DurableAgentStateHostedFileContent(DurableAgentStateContent): @@ -1025,11 +1015,11 @@ def to_dict(self) -> dict[str, Any]: return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.FILE_ID: self.file_id} @staticmethod - def from_hosted_file_content(content: HostedFileContent) -> DurableAgentStateHostedFileContent: + def from_hosted_file_content(content: Content) -> DurableAgentStateHostedFileContent: return DurableAgentStateHostedFileContent(file_id=content.file_id) - def to_ai_content(self) -> HostedFileContent: - return HostedFileContent(file_id=self.file_id) + def to_ai_content(self) -> Content: + return Content.from_hosted_file(file_id=self.file_id) class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent): @@ -1058,12 +1048,12 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_hosted_vector_store_content( - content: HostedVectorStoreContent, + content: Content, ) -> DurableAgentStateHostedVectorStoreContent: return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) - def to_ai_content(self) -> HostedVectorStoreContent: - return HostedVectorStoreContent(vector_store_id=self.vector_store_id) + def to_ai_content(self) -> Content: + return Content.from_hosted_vector_store(vector_store_id=self.vector_store_id) class DurableAgentStateTextContent(DurableAgentStateContent): @@ -1085,11 +1075,11 @@ def to_dict(self) -> dict[str, Any]: return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.TEXT: self.text} @staticmethod - def from_text_content(content: TextContent) -> DurableAgentStateTextContent: + def from_text_content(content: Content) -> DurableAgentStateTextContent: return DurableAgentStateTextContent(text=content.text) - def to_ai_content(self) -> TextContent: - return TextContent(text=self.text or "") + def to_ai_content(self) -> Content: + return Content.from_text(text=self.text or "") class DurableAgentStateTextReasoningContent(DurableAgentStateContent): @@ -1111,11 +1101,11 @@ def to_dict(self) -> dict[str, Any]: return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.TEXT: self.text} @staticmethod - def from_text_reasoning_content(content: TextReasoningContent) -> DurableAgentStateTextReasoningContent: + def from_text_reasoning_content(content: Content) -> DurableAgentStateTextReasoningContent: return DurableAgentStateTextReasoningContent(text=content.text) - def to_ai_content(self) -> TextReasoningContent: - return TextReasoningContent(text=self.text or "") + def to_ai_content(self) -> Content: + return Content.from_text_reasoning(text=self.text) class DurableAgentStateUriContent(DurableAgentStateContent): @@ -1146,11 +1136,11 @@ def to_dict(self) -> dict[str, Any]: } @staticmethod - def from_uri_content(content: UriContent) -> DurableAgentStateUriContent: + def from_uri_content(content: Content) -> DurableAgentStateUriContent: return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) - def to_ai_content(self) -> UriContent: - return UriContent(uri=self.uri, media_type=self.media_type) + def to_ai_content(self) -> Content: + return Content.from_uri(uri=self.uri, media_type=self.media_type) class DurableAgentStateUsage: @@ -1215,11 +1205,11 @@ def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None: def to_usage_details(self) -> UsageDetails: # Convert back to AI SDK UsageDetails - return UsageDetails( - input_token_count=self.input_token_count, - output_token_count=self.output_token_count, - total_token_count=self.total_token_count, - ) + return { + "input_token_count": self.input_token_count, + "output_token_count": self.output_token_count, + "total_token_count": self.total_token_count, + } class DurableAgentStateUsageContent(DurableAgentStateContent): @@ -1247,11 +1237,11 @@ def to_dict(self) -> dict[str, Any]: } @staticmethod - def from_usage_content(content: UsageContent) -> DurableAgentStateUsageContent: - return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details)) + def from_usage_content(content: Content) -> DurableAgentStateUsageContent: + return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.usage_details)) - def to_ai_content(self) -> UsageContent: - return UsageContent(details=self.usage.to_usage_details()) + def to_ai_content(self) -> Content: + return Content.from_usage(usage_details=self.usage.to_usage_details()) class DurableAgentStateUnknownContent(DurableAgentStateContent): @@ -1279,7 +1269,7 @@ def to_dict(self) -> dict[str, Any]: def from_unknown_content(content: Any) -> DurableAgentStateUnknownContent: return DurableAgentStateUnknownContent(content=content) - def to_ai_content(self) -> BaseContent: + def to_ai_content(self) -> Content: if not self.content: raise Exception("The content is missing and cannot be converted to valid AI content.") - return BaseContent(content=self.content) + return Content(type=self.type, additional_properties={"content": self.content}) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py index f757004cbb..ba6040b99b 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py @@ -18,7 +18,7 @@ AgentResponse, AgentResponseUpdate, ChatMessage, - ErrorContent, + Content, Role, get_logger, ) @@ -193,7 +193,7 @@ async def run( # Create error message error_message = ChatMessage( - role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)] + role=Role.ASSISTANT, contents=[Content.from_error(message=str(exc), error_code=type(exc).__name__)] ) error_response = AgentResponse(messages=[error_message]) diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index 37f70a04e1..29d614e729 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -10,7 +10,7 @@ import azure.durable_functions as df import azure.functions as func import pytest -from agent_framework import AgentResponse, ChatMessage, ErrorContent +from agent_framework import AgentResponse, ChatMessage from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER @@ -622,7 +622,7 @@ async def test_entity_handles_agent_error(self) -> None: assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] - assert isinstance(content, ErrorContent) + assert content.type == "error" assert "Agent error" in (content.message or "") assert content.error_code == "Exception" diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py index 63bc685afb..5d980f8610 100644 --- a/python/packages/azurefunctions/tests/test_entities.py +++ b/python/packages/azurefunctions/tests/test_entities.py @@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, ErrorContent, Role +from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, Role from pydantic import BaseModel from agent_framework_azurefunctions._durable_agent_state import ( @@ -608,7 +608,7 @@ async def test_run_agent_handles_agent_exception(self) -> None: assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] - assert isinstance(content, ErrorContent) + assert content.type == "error" assert "Agent failed" in (content.message or "") assert content.error_code == "Exception" @@ -627,7 +627,7 @@ async def test_run_agent_handles_value_error(self) -> None: assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] - assert isinstance(content, ErrorContent) + assert content.type == "error" assert content.error_code == "ValueError" assert "Invalid input" in str(content.message) @@ -646,7 +646,7 @@ async def test_run_agent_handles_timeout_error(self) -> None: assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] - assert isinstance(content, ErrorContent) + assert content.type == "error" assert content.error_code == "TimeoutError" def test_entity_function_handles_exception_in_operation(self) -> None: @@ -685,7 +685,7 @@ async def test_run_agent_preserves_message_on_error(self) -> None: assert isinstance(result, AgentResponse) assert len(result.messages) == 1 content = result.messages[0].contents[0] - assert isinstance(content, ErrorContent) + assert content.type == "error" class TestConversationHistory: diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index e9e1eeff96..60a4b829f2 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -16,14 +16,10 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - Contents, + Content, FinishReason, - FunctionCallContent, - FunctionResultContent, Role, - TextContent, ToolProtocol, - UsageContent, UsageDetails, get_logger, prepare_function_call_results, @@ -328,7 +324,7 @@ async def _inner_get_streaming_response( response = await self._inner_get_response(messages=messages, options=options, **kwargs) contents = list(response.messages[0].contents if response.messages else []) if response.usage_details: - contents.append(UsageContent(details=response.usage_details)) + contents.append(Content.from_usage(details=response.usage_details)) yield ChatResponseUpdate( response_id=response.response_id, contents=contents, @@ -472,37 +468,41 @@ def _convert_message_to_content_blocks(self, message: ChatMessage) -> list[dict[ blocks.append(block) return blocks - def _convert_content_to_bedrock_block(self, content: Contents) -> dict[str, Any] | None: - if isinstance(content, TextContent): - return {"text": content.text} - if isinstance(content, FunctionCallContent): - arguments = content.parse_arguments() or {} - return { - "toolUse": { - "toolUseId": content.call_id or self._generate_tool_call_id(), - "name": content.name, - "input": arguments, + def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any] | None: + match content.type: + case "text": + return {"text": content.text} + case "function_call": + arguments = content.parse_arguments() or {} + return { + "toolUse": { + "toolUseId": content.call_id or self._generate_tool_call_id(), + "name": content.name, + "input": arguments, + } } - } - if isinstance(content, FunctionResultContent): - tool_result_block = { - "toolResult": { - "toolUseId": content.call_id, - "content": self._convert_tool_result_to_blocks(content.result), - "status": "error" if content.exception else "success", + case "function_result": + tool_result_block = { + "toolResult": { + "toolUseId": content.call_id, + "content": self._convert_tool_result_to_blocks(content.result), + "status": "error" if content.exception else "success", + } } - } - if content.exception: - tool_result = tool_result_block["toolResult"] - existing_content = tool_result.get("content") - content_list: list[dict[str, Any]] - if isinstance(existing_content, list): - content_list = existing_content - else: - content_list = [] - tool_result["content"] = content_list - content_list.append({"text": str(content.exception)}) - return tool_result_block + if content.exception: + tool_result = tool_result_block["toolResult"] + existing_content = tool_result.get("content") + content_list: list[dict[str, Any]] + if isinstance(existing_content, list): + content_list = existing_content + else: + content_list = [] + tool_result["content"] = content_list + content_list.append({"text": str(content.exception)}) + return tool_result_block + case _: + # Bedrock does not support other content types at this time + pass return None def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]: @@ -531,7 +531,7 @@ def _normalize_tool_result_value(self, value: Any) -> dict[str, Any]: return {"text": value} if isinstance(value, (int, float, bool)) or value is None: return {"json": value} - if isinstance(value, TextContent) and getattr(value, "text", None): + if isinstance(value, Content) and value.type == "text": return {"text": value.text} if hasattr(value, "to_dict"): try: @@ -586,23 +586,23 @@ def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse: def _parse_usage(self, usage: dict[str, Any] | None) -> UsageDetails | None: if not usage: return None - details = UsageDetails() + details: UsageDetails = {} if (input_tokens := usage.get("inputTokens")) is not None: - details.input_token_count = input_tokens + details["input_token_count"] = input_tokens if (output_tokens := usage.get("outputTokens")) is not None: - details.output_token_count = output_tokens + details["output_token_count"] = output_tokens if (total_tokens := usage.get("totalTokens")) is not None: - details.additional_counts["bedrock.total_tokens"] = total_tokens + details["total_token_count"] = total_tokens return details def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, Any]]) -> list[Any]: contents: list[Any] = [] for block in content_blocks: if text_value := block.get("text"): - contents.append(TextContent(text=text_value, raw_representation=block)) + contents.append(Content.from_text(text=text_value, raw_representation=block)) continue if (json_value := block.get("json")) is not None: - contents.append(TextContent(text=json.dumps(json_value), raw_representation=block)) + contents.append(Content.from_text(text=json.dumps(json_value), raw_representation=block)) continue tool_use = block.get("toolUse") if isinstance(tool_use, MutableMapping): @@ -610,7 +610,7 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A if not tool_name: raise ServiceInvalidResponseError("Bedrock response missing required tool name in toolUse block.") contents.append( - FunctionCallContent( + Content.from_function_call( call_id=tool_use.get("toolUseId") or self._generate_tool_call_id(), name=tool_name, arguments=tool_use.get("input"), @@ -626,7 +626,7 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A exception = RuntimeError(f"Bedrock tool result status: {status}") result_value = self._convert_bedrock_tool_result_to_value(tool_result.get("content")) contents.append( - FunctionResultContent( + Content.from_function_result( call_id=tool_result.get("toolUseId") or self._generate_tool_call_id(), result=result_value, exception=exception, diff --git a/python/packages/bedrock/tests/test_bedrock_client.py b/python/packages/bedrock/tests/test_bedrock_client.py index 5842426483..704eb2138a 100644 --- a/python/packages/bedrock/tests/test_bedrock_client.py +++ b/python/packages/bedrock/tests/test_bedrock_client.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from agent_framework import ChatMessage, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework.exceptions import ServiceInitializationError from agent_framework_bedrock import BedrockChatClient @@ -42,8 +42,8 @@ def test_get_response_invokes_bedrock_runtime() -> None: ) messages = [ - ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="You are concise.")]), - ChatMessage(role=Role.USER, contents=[TextContent(text="hello")]), + ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="You are concise.")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="hello")]), ] response = asyncio.run(client.get_response(messages=messages, options={"max_tokens": 32})) @@ -53,7 +53,7 @@ def test_get_response_invokes_bedrock_runtime() -> None: assert payload["modelId"] == "amazon.titan-text" assert payload["messages"][0]["content"][0]["text"] == "hello" assert response.messages[0].contents[0].text == "Bedrock says hi" - assert response.usage_details and response.usage_details.input_token_count == 10 + assert response.usage_details and response.usage_details["input_token_count"] == 10 def test_build_request_requires_non_system_messages() -> None: @@ -63,7 +63,7 @@ def test_build_request_requires_non_system_messages() -> None: client=_StubBedrockRuntime(), ) - messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="Only system text")])] + messages = [ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="Only system text")])] with pytest.raises(ServiceInitializationError): client._prepare_options(messages, {}) diff --git a/python/packages/bedrock/tests/test_bedrock_settings.py b/python/packages/bedrock/tests/test_bedrock_settings.py index 1924c750c6..07898303de 100644 --- a/python/packages/bedrock/tests/test_bedrock_settings.py +++ b/python/packages/bedrock/tests/test_bedrock_settings.py @@ -9,10 +9,8 @@ AIFunction, ChatMessage, ChatOptions, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, ) from pydantic import BaseModel @@ -49,7 +47,7 @@ def test_build_request_includes_tool_config() -> None: "tools": [tool], "tool_choice": {"mode": "required", "required_function_name": "get_weather"}, } - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="hi")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="hi")])] request = client._prepare_options(messages, options) @@ -61,14 +59,16 @@ def test_build_request_serializes_tool_history() -> None: client = _build_client() options: ChatOptions = {} messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="how's weather?")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="how's weather?")]), ChatMessage( role=Role.ASSISTANT, - contents=[FunctionCallContent(call_id="call-1", name="get_weather", arguments='{"location": "SEA"}')], + contents=[ + Content.from_function_call(call_id="call-1", name="get_weather", arguments='{"location": "SEA"}') + ], ), ChatMessage( role=Role.TOOL, - contents=[FunctionResultContent(call_id="call-1", result={"answer": "72F"})], + contents=[Content.from_function_result(call_id="call-1", result={"answer": "72F"})], ), ] @@ -101,9 +101,9 @@ def test_process_response_parses_tool_use_and_result() -> None: chat_response = client._process_converse_response(response) contents = chat_response.messages[0].contents - assert isinstance(contents[0], FunctionCallContent) + assert contents[0].type == "function_call" assert contents[0].name == "get_weather" - assert isinstance(contents[1], TextContent) + assert contents[1].type == "text" assert chat_response.finish_reason == client._map_finish_reason("tool_use") @@ -131,5 +131,5 @@ def test_process_response_parses_tool_result() -> None: chat_response = client._process_converse_response(response) contents = chat_response.messages[0].contents - assert isinstance(contents[0], FunctionResultContent) + assert contents[0].type == "function_result" assert contents[0].result == {"answer": 42} diff --git a/python/packages/chatkit/agent_framework_chatkit/_converter.py b/python/packages/chatkit/agent_framework_chatkit/_converter.py index 252ac8a753..b83fd40812 100644 --- a/python/packages/chatkit/agent_framework_chatkit/_converter.py +++ b/python/packages/chatkit/agent_framework_chatkit/_converter.py @@ -8,12 +8,8 @@ from agent_framework import ( ChatMessage, - DataContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, - UriContent, ) from chatkit.types import ( AssistantMessageItem, @@ -91,8 +87,8 @@ async def user_message_to_input( if isinstance(content_part, UserMessageTextContent): text_content += content_part.text - # Convert attachments to DataContent or UriContent - data_contents: list[DataContent | UriContent] = [] + # Convert attachments to Content + data_contents: list[Content] = [] if item.attachments: for attachment in item.attachments: content = await self.attachment_to_message_content(attachment) @@ -108,9 +104,9 @@ async def user_message_to_input( user_message = ChatMessage(role=Role.USER, text=text_content.strip()) else: # Build contents list with both text and attachments - contents: list[TextContent | DataContent | UriContent] = [] + contents: list[Content] = [] if text_content.strip(): - contents.append(TextContent(text=text_content.strip())) + contents.append(Content.from_text(text=text_content.strip())) contents.extend(data_contents) user_message = ChatMessage(role=Role.USER, contents=contents) @@ -126,7 +122,7 @@ async def user_message_to_input( return messages - async def attachment_to_message_content(self, attachment: Attachment) -> DataContent | UriContent | None: + async def attachment_to_message_content(self, attachment: Attachment) -> Content | None: """Convert a ChatKit attachment to Agent Framework content. This method is called internally by `user_message_to_input()` to handle attachments. @@ -169,14 +165,14 @@ async def fetch_data(attachment_id: str) -> bytes: if self.attachment_data_fetcher is not None: try: data = await self.attachment_data_fetcher(attachment.id) - return DataContent(data=data, media_type=attachment.mime_type) + return Content.from_data(data=data, media_type=attachment.mime_type) except Exception as e: # If fetch fails, fall through to URL-based approach logger.debug(f"Failed to fetch attachment data for {attachment.id}: {e}") # For ImageAttachment, try to use preview_url if isinstance(attachment, ImageAttachment) and attachment.preview_url: - return UriContent(uri=str(attachment.preview_url), media_type=attachment.mime_type) + return Content.from_uri(uri=str(attachment.preview_url), media_type=attachment.mime_type) # For FileAttachment without data fetcher, skip the attachment # Subclasses can override this method to provide custom handling @@ -220,7 +216,7 @@ def hidden_context_to_input( """ return ChatMessage(role=Role.SYSTEM, text=f"{item.content}") - def tag_to_message_content(self, tag: UserMessageTagContent) -> TextContent: + def tag_to_message_content(self, tag: UserMessageTagContent) -> Content: """Convert a ChatKit tag (@-mention) to Agent Framework content. This method is called internally by `user_message_to_input()` to handle tags. @@ -248,10 +244,10 @@ def tag_to_message_content(self, tag: UserMessageTagContent) -> TextContent: type="input_tag", id="tag_1", text="john", data={"name": "John Doe"}, interactive=False ) content = converter.tag_to_message_content(tag) - # Returns: TextContent(text="Name:John Doe") + # Returns: Content.from_text(text="Name:John Doe") """ name = getattr(tag.data, "name", tag.text if hasattr(tag, "text") else "unknown") - return TextContent(text=f"Name:{name}") + return Content.from_text(text=f"Name:{name}") def task_to_input(self, item: TaskItem) -> ChatMessage | list[ChatMessage] | None: """Convert a ChatKit TaskItem to Agent Framework ChatMessage(s). @@ -448,7 +444,7 @@ async def client_tool_call_to_input(self, item: ClientToolCallItem) -> ChatMessa function_call_msg = ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id=item.call_id, name=item.name, arguments=json.dumps(item.arguments), @@ -460,7 +456,7 @@ async def client_tool_call_to_input(self, item: ClientToolCallItem) -> ChatMessa function_result_msg = ChatMessage( role=Role.TOOL, contents=[ - FunctionResultContent( + Content.from_function_result( call_id=item.call_id, result=json.dumps(item.output) if item.output is not None else "", ) diff --git a/python/packages/chatkit/agent_framework_chatkit/_streaming.py b/python/packages/chatkit/agent_framework_chatkit/_streaming.py index b0273c5944..df44fa005d 100644 --- a/python/packages/chatkit/agent_framework_chatkit/_streaming.py +++ b/python/packages/chatkit/agent_framework_chatkit/_streaming.py @@ -6,7 +6,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Callable from datetime import datetime -from agent_framework import AgentResponseUpdate, TextContent +from agent_framework import AgentResponseUpdate from chatkit.types import ( AssistantMessageContent, AssistantMessageContentPartTextDelta, @@ -77,7 +77,7 @@ def _default_id_generator(item_type: str) -> str: if update.contents: for content in update.contents: # Handle text content - only TextContent has a text attribute - if isinstance(content, TextContent) and content.text is not None: + if content.type == "text" and content.text is not None: # Yield incremental text delta for streaming display yield ThreadItemUpdated( type="thread.item.updated", diff --git a/python/packages/chatkit/tests/test_converter.py b/python/packages/chatkit/tests/test_converter.py index 457017f647..b75139bf58 100644 --- a/python/packages/chatkit/tests/test_converter.py +++ b/python/packages/chatkit/tests/test_converter.py @@ -5,7 +5,7 @@ from unittest.mock import Mock import pytest -from agent_framework import ChatMessage, Role, TextContent +from agent_framework import ChatMessage, Role from chatkit.types import UserMessageTextContent from agent_framework_chatkit import ThreadItemConverter, simple_to_agent_input @@ -133,7 +133,7 @@ def test_tag_to_message_content(self, converter): ) result = converter.tag_to_message_content(tag) - assert isinstance(result, TextContent) + assert result.type == "text" # Since data is a dict, getattr won't work, so it will fall back to text assert result.text == "Name:john" @@ -150,7 +150,7 @@ def test_tag_to_message_content_no_name(self, converter): ) result = converter.tag_to_message_content(tag) - assert isinstance(result, TextContent) + assert result.type == "text" assert result.text == "Name:jane" async def test_attachment_to_message_content_file_without_fetcher(self, converter): @@ -169,7 +169,6 @@ async def test_attachment_to_message_content_file_without_fetcher(self, converte async def test_attachment_to_message_content_image_with_preview_url(self, converter): """Test that ImageAttachment with preview_url creates UriContent.""" - from agent_framework import UriContent from chatkit.types import ImageAttachment attachment = ImageAttachment( @@ -181,13 +180,12 @@ async def test_attachment_to_message_content_image_with_preview_url(self, conver ) result = await converter.attachment_to_message_content(attachment) - assert isinstance(result, UriContent) + assert result.type == "uri" assert result.uri == "https://example.com/photo.jpg" assert result.media_type == "image/jpeg" async def test_attachment_to_message_content_with_data_fetcher(self): """Test attachment conversion with data fetcher.""" - from agent_framework import DataContent from chatkit.types import FileAttachment # Mock data fetcher @@ -204,14 +202,13 @@ async def fetch_data(attachment_id: str) -> bytes: ) result = await converter.attachment_to_message_content(attachment) - assert isinstance(result, DataContent) + assert result.type == "data" assert result.media_type == "application/pdf" async def test_to_agent_input_with_image_attachment(self): """Test converting user message with text and image attachment.""" from datetime import datetime - from agent_framework import UriContent from chatkit.types import ImageAttachment, UserMessageItem attachment = ImageAttachment( @@ -241,11 +238,11 @@ async def test_to_agent_input_with_image_attachment(self): assert len(message.contents) == 2 # First content should be text - assert isinstance(message.contents[0], TextContent) + assert message.contents[0].type == "text" assert message.contents[0].text == "Check out this photo!" # Second content should be UriContent for the image - assert isinstance(message.contents[1], UriContent) + assert message.contents[1].type == "uri" assert message.contents[1].uri == "https://example.com/photo.jpg" assert message.contents[1].media_type == "image/jpeg" @@ -253,7 +250,6 @@ async def test_to_agent_input_with_file_attachment_and_fetcher(self): """Test converting user message with file attachment using data fetcher.""" from datetime import datetime - from agent_framework import DataContent from chatkit.types import FileAttachment, UserMessageItem attachment = FileAttachment( @@ -285,10 +281,10 @@ async def fetch_data(attachment_id: str) -> bytes: assert len(message.contents) == 2 # First content should be text - assert isinstance(message.contents[0], TextContent) + assert message.contents[0].type == "text" # Second content should be DataContent for the file - assert isinstance(message.contents[1], DataContent) + assert message.contents[1].type == "data" assert message.contents[1].media_type == "application/pdf" def test_task_to_input(self, converter): diff --git a/python/packages/chatkit/tests/test_streaming.py b/python/packages/chatkit/tests/test_streaming.py index ead7c5f33e..ff552d79e8 100644 --- a/python/packages/chatkit/tests/test_streaming.py +++ b/python/packages/chatkit/tests/test_streaming.py @@ -4,7 +4,7 @@ from unittest.mock import Mock -from agent_framework import AgentResponseUpdate, Role, TextContent +from agent_framework import AgentResponseUpdate, Content, Role from chatkit.types import ( ThreadItemAddedEvent, ThreadItemDoneEvent, @@ -34,7 +34,7 @@ async def test_stream_single_text_update(self): """Test streaming single text update.""" async def single_update_stream(): - yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello world")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="Hello world")]) events = [] async for event in stream_agent_response(single_update_stream(), thread_id="test_thread"): @@ -59,8 +59,8 @@ async def test_stream_multiple_text_updates(self): """Test streaming multiple text updates.""" async def multiple_updates_stream(): - yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello ")]) - yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="world!")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="Hello ")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="world!")]) events = [] async for event in stream_agent_response(multiple_updates_stream(), thread_id="test_thread"): @@ -91,7 +91,7 @@ def custom_id_generator(item_type: str) -> str: return f"custom_{item_type}_123" async def single_update_stream(): - yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Test")]) + yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="Test")]) events = [] async for event in stream_agent_response( @@ -125,9 +125,10 @@ async def empty_content_stream(): async def test_stream_non_text_content(self): """Test streaming updates with non-text content.""" # Mock a content object without text attribute - non_text_content = Mock() + non_text_content = Mock(spec=Content) + non_text_content.type = "image" # Don't set text attribute - del non_text_content.text + non_text_content.text = None async def non_text_stream(): yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[non_text_content]) diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index 7dc676b06a..98d5a2b475 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -10,9 +10,9 @@ AgentThread, BaseAgent, ChatMessage, + Content, ContextProvider, Role, - TextContent, normalize_messages, ) from agent_framework._pydantic import AFBaseSettings @@ -332,7 +332,7 @@ async def _process_activities(self, activities: AsyncIterable[Any], streaming: b ): yield ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(activity.text)], + contents=[Content.from_text(activity.text)], author_name=activity.from_property.name if activity.from_property else None, message_id=activity.id, raw_representation=activity, diff --git a/python/packages/copilotstudio/tests/test_copilot_agent.py b/python/packages/copilotstudio/tests/test_copilot_agent.py index 4777557d32..c4e2ff3e08 100644 --- a/python/packages/copilotstudio/tests/test_copilot_agent.py +++ b/python/packages/copilotstudio/tests/test_copilot_agent.py @@ -4,14 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from agent_framework import ( - AgentResponse, - AgentResponseUpdate, - AgentThread, - ChatMessage, - Role, - TextContent, -) +from agent_framework import AgentResponse, AgentResponseUpdate, AgentThread, ChatMessage, Content, Role from agent_framework.exceptions import ServiceException, ServiceInitializationError from microsoft_agents.copilotstudio.client import CopilotClient @@ -136,7 +129,7 @@ async def test_run_with_string_message(self, mock_copilot_client: MagicMock, moc assert isinstance(response, AgentResponse) assert len(response.messages) == 1 content = response.messages[0].contents[0] - assert isinstance(content, TextContent) + assert content.type == "text" assert content.text == "Test response" assert response.messages[0].role == Role.ASSISTANT @@ -150,13 +143,13 @@ async def test_run_with_chat_message(self, mock_copilot_client: MagicMock, mock_ mock_copilot_client.start_conversation.return_value = create_async_generator([conversation_activity]) mock_copilot_client.ask_question.return_value = create_async_generator([mock_activity]) - chat_message = ChatMessage(role=Role.USER, contents=[TextContent("test message")]) + chat_message = ChatMessage(role=Role.USER, contents=[Content.from_text("test message")]) response = await agent.run(chat_message) assert isinstance(response, AgentResponse) assert len(response.messages) == 1 content = response.messages[0].contents[0] - assert isinstance(content, TextContent) + assert content.type == "text" assert content.text == "Test response" assert response.messages[0].role == Role.ASSISTANT @@ -206,7 +199,7 @@ async def test_run_stream_with_string_message(self, mock_copilot_client: MagicMo async for response in agent.run_stream("test message"): assert isinstance(response, AgentResponseUpdate) content = response.contents[0] - assert isinstance(content, TextContent) + assert content.type == "text" assert content.text == "Streaming response" response_count += 1 @@ -233,7 +226,7 @@ async def test_run_stream_with_thread(self, mock_copilot_client: MagicMock) -> N async for response in agent.run_stream("test message", thread=thread): assert isinstance(response, AgentResponseUpdate) content = response.contents[0] - assert isinstance(content, TextContent) + assert content.type == "text" assert content.text == "Streaming response" response_count += 1 diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index c748d7f2df..1fdd7d10f7 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1140,9 +1140,9 @@ async def _call_tool( # type: ignore # Convert result to MCP content if isinstance(result, str): - return [types.TextContent(type="text", text=result)] + return [types.Content.from_text(type="text", text=result)] - return [types.TextContent(type="text", text=str(result))] + return [types.Content.from_text(type="text", text=str(result))] @server.set_logging_level() # type: ignore async def _set_logging_level(level: types.LoggingLevel) -> None: # type: ignore diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 3a6d5b818c..908d7a0fd8 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import base64 import logging import re import sys @@ -29,13 +30,8 @@ ) from ._types import ( ChatMessage, - Contents, - DataContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, - UriContent, ) from .exceptions import ToolException, ToolExecutionException @@ -82,7 +78,7 @@ def _parse_message_from_mcp( def _parse_contents_from_mcp_tool_result( mcp_type: types.CallToolResult, -) -> list[Contents]: +) -> list[Content]: """Parse an MCP CallToolResult into Agent Framework content types. This function extracts the complete _meta field from CallToolResult objects @@ -147,25 +143,27 @@ def _parse_content_from_mcp( | types.ToolUseContent | types.ToolResultContent ], -) -> list[Contents]: +) -> list[Content]: """Parse an MCP type into an Agent Framework type.""" mcp_types = mcp_type if isinstance(mcp_type, Sequence) else [mcp_type] - return_types: list[Contents] = [] + return_types: list[Content] = [] for mcp_type in mcp_types: match mcp_type: case types.TextContent(): - return_types.append(TextContent(text=mcp_type.text, raw_representation=mcp_type)) + return_types.append(Content.from_text(text=mcp_type.text, raw_representation=mcp_type)) case types.ImageContent() | types.AudioContent(): + # MCP protocol uses base64-encoded strings, convert to bytes + data_bytes = base64.b64decode(mcp_type.data) if isinstance(mcp_type.data, str) else mcp_type.data return_types.append( - DataContent( - data=mcp_type.data, + Content.from_data( + data=data_bytes, media_type=mcp_type.mimeType, raw_representation=mcp_type, ) ) case types.ResourceLink(): return_types.append( - UriContent( + Content.from_uri( uri=str(mcp_type.uri), media_type=mcp_type.mimeType or "application/json", raw_representation=mcp_type, @@ -173,7 +171,7 @@ def _parse_content_from_mcp( ) case types.ToolUseContent(): return_types.append( - FunctionCallContent( + Content.from_function_call( call_id=mcp_type.id, name=mcp_type.name, arguments=mcp_type.input, @@ -182,7 +180,7 @@ def _parse_content_from_mcp( ) case types.ToolResultContent(): return_types.append( - FunctionResultContent( + Content.from_function_result( call_id=mcp_type.toolUseId, result=_parse_content_from_mcp(mcp_type.content) if mcp_type.content @@ -195,7 +193,7 @@ def _parse_content_from_mcp( match mcp_type.resource: case types.TextResourceContents(): return_types.append( - TextContent( + Content.from_text( text=mcp_type.resource.text, raw_representation=mcp_type, additional_properties=( @@ -205,7 +203,7 @@ def _parse_content_from_mcp( ) case types.BlobResourceContents(): return_types.append( - DataContent( + Content.from_uri( uri=mcp_type.resource.blob, media_type=mcp_type.resource.mimeType, raw_representation=mcp_type, @@ -218,45 +216,41 @@ def _parse_content_from_mcp( def _prepare_content_for_mcp( - content: Contents, + content: Content, ) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None: """Prepare an Agent Framework content type for MCP.""" - match content: - case TextContent(): - return types.TextContent(type="text", text=content.text) - case DataContent(): - if content.media_type and content.media_type.startswith("image/"): - return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type) - if content.media_type and content.media_type.startswith("audio/"): - return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type) - if content.media_type and content.media_type.startswith("application/"): - return types.EmbeddedResource( - type="resource", - resource=types.BlobResourceContents( - blob=content.uri, - mimeType=content.media_type, - # uri's are not limited in MCP but they have to be set. - # the uri of data content, contains the data uri, which - # is not the uri meant here, UriContent would match this. - uri=( - content.additional_properties.get("uri", "af://binary") - if content.additional_properties - else "af://binary" - ), # type: ignore[reportArgumentType] - ), - ) - return None - case UriContent(): - return types.ResourceLink( - type="resource_link", - uri=content.uri, # type: ignore[reportArgumentType] - mimeType=content.media_type, - name=( - content.additional_properties.get("name", "Unknown") if content.additional_properties else "Unknown" + if content.type == "text": + return types.TextContent(type="text", text=content.text) # type: ignore[attr-defined] + if content.type == "data": + if content.media_type and content.media_type.startswith("image/"): # type: ignore[attr-defined] + return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined] + if content.media_type and content.media_type.startswith("audio/"): # type: ignore[attr-defined] + return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined] + if content.media_type and content.media_type.startswith("application/"): # type: ignore[attr-defined] + return types.EmbeddedResource( + type="resource", + resource=types.BlobResourceContents( + blob=content.uri, # type: ignore[attr-defined] + mimeType=content.media_type, # type: ignore[attr-defined] + # uri's are not limited in MCP but they have to be set. + # the uri of data content, contains the data uri, which + # is not the uri meant here, UriContent would match this. + uri=( + content.additional_properties.get("uri", "af://binary") + if content.additional_properties + else "af://binary" + ), # type: ignore[reportArgumentType] ), ) - case _: - return None + return None + if content.type == "uri": + return types.ResourceLink( + type="resource_link", + uri=content.uri, # type: ignore[reportArgumentType,attr-defined] + mimeType=content.media_type, # type: ignore[attr-defined] + name=(content.additional_properties.get("name", "Unknown") if content.additional_properties else "Unknown"), + ) + return None def _prepare_message_for_mcp( @@ -650,7 +644,7 @@ async def load_tools(self) -> None: input_model = _get_input_model_from_mcp_tool(tool) approval_mode = self._determine_approval_mode(local_name) # Create AIFunctions out of each tool - func: AIFunction[BaseModel, list[Contents] | Any | types.CallToolResult] = AIFunction( + func: AIFunction[BaseModel, list[Content] | Any | types.CallToolResult] = AIFunction( func=partial(self.call_tool, tool.name), name=local_name, description=tool.description or "", @@ -704,7 +698,7 @@ async def _ensure_connected(self) -> None: inner_exception=ex, ) from ex - async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents] | Any | types.CallToolResult: + async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Content] | Any | types.CallToolResult: """Call a tool with the given arguments. Args: diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index a0d0a13dc2..4b6362ae13 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -54,9 +54,7 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - Contents, - FunctionApprovalResponseContent, - FunctionCallContent, + Content, ) from typing import overload @@ -104,15 +102,15 @@ def record(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover - trivi def _parse_inputs( - inputs: "Contents | dict[str, Any] | str | list[Contents | dict[str, Any] | str] | None", -) -> list["Contents"]: - """Parse the inputs for a tool, ensuring they are of type Contents. + inputs: "Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None", +) -> list["Content"]: + """Parse the inputs for a tool, ensuring they are of type Content. Args: - inputs: The inputs to parse. Can be a single item or list of Contents, dicts, or strings. + inputs: The inputs to parse. Can be a single item or list of Content, dicts, or strings. Returns: - A list of Contents objects. + A list of Content objects. Raises: ValueError: If an unsupported input type is encountered. @@ -122,43 +120,39 @@ def _parse_inputs( return [] from ._types import ( - BaseContent, - DataContent, - HostedFileContent, - HostedVectorStoreContent, - UriContent, + Content, ) - parsed_inputs: list["Contents"] = [] + parsed_inputs: list["Content"] = [] if not isinstance(inputs, list): inputs = [inputs] for input_item in inputs: if isinstance(input_item, str): # If it's a string, we assume it's a URI or similar identifier. # Convert it to a UriContent or similar type as needed. - parsed_inputs.append(UriContent(uri=input_item, media_type="text/plain")) + parsed_inputs.append(Content.from_uri(uri=input_item, media_type="text/plain")) elif isinstance(input_item, dict): # If it's a dict, we assume it contains properties for a specific content type. # we check if the required keys are present to determine the type. # for instance, if it has "uri" and "media_type", we treat it as UriContent. - # if is only has uri, then we treat it as DataContent. + # if it only has uri and media_type without a specific type indicator, we treat it as DataContent. # etc. if "uri" in input_item: - parsed_inputs.append( - UriContent(**input_item) if "media_type" in input_item else DataContent(**input_item) - ) + # Use Content.from_uri for proper URI content, DataContent for backwards compatibility + parsed_inputs.append(Content.from_uri(**input_item)) elif "file_id" in input_item: - parsed_inputs.append(HostedFileContent(**input_item)) + parsed_inputs.append(Content.from_hosted_file(**input_item)) elif "vector_store_id" in input_item: - parsed_inputs.append(HostedVectorStoreContent(**input_item)) + parsed_inputs.append(Content.from_hosted_vector_store(**input_item)) elif "data" in input_item: - parsed_inputs.append(DataContent(**input_item)) + # DataContent helper handles both uri and data parameters + parsed_inputs.append(Content.from_data(**input_item)) else: raise ValueError(f"Unsupported input type: {input_item}") - elif isinstance(input_item, BaseContent): + elif isinstance(input_item, Content): parsed_inputs.append(input_item) else: - raise TypeError(f"Unsupported input type: {type(input_item).__name__}. Expected Contents or dict.") + raise TypeError(f"Unsupported input type: {type(input_item).__name__}. Expected Content or dict.") return parsed_inputs @@ -254,7 +248,7 @@ class HostedCodeInterpreterTool(BaseTool): def __init__( self, *, - inputs: "Contents | dict[str, Any] | str | list[Contents | dict[str, Any] | str] | None" = None, + inputs: "Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None" = None, description: str | None = None, additional_properties: dict[str, Any] | None = None, **kwargs: Any, @@ -266,8 +260,8 @@ def __init__( This should mostly be HostedFileContent or HostedVectorStoreContent. Can also be DataContent, depending on the service used. When supplying a list, it can contain: - - Contents instances - - dicts with properties for Contents (e.g., {"uri": "http://example.com", "media_type": "text/html"}) + - Content instances + - dicts with properties for Content (e.g., {"uri": "http://example.com", "media_type": "text/html"}) - strings (which will be converted to UriContent with media_type "text/plain"). If None, defaults to an empty list. description: A description of the tool. @@ -503,7 +497,7 @@ class HostedFileSearchTool(BaseTool): def __init__( self, *, - inputs: "Contents | dict[str, Any] | str | list[Contents | dict[str, Any] | str] | None" = None, + inputs: "Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None" = None, max_results: int | None = None, description: str | None = None, additional_properties: dict[str, Any] | None = None, @@ -515,8 +509,8 @@ def __init__( inputs: A list of contents that the tool can accept as input. Defaults to None. This should be one or more HostedVectorStoreContents. When supplying a list, it can contain: - - Contents instances - - dicts with properties for Contents (e.g., {"uri": "http://example.com", "media_type": "text/html"}) + - Content instances + - dicts with properties for Content (e.g., {"uri": "http://example.com", "media_type": "text/html"}) - strings (which will be converted to UriContent with media_type "text/plain"). If None, defaults to an empty list. max_results: The maximum number of results to return from the file search. @@ -1480,7 +1474,7 @@ class FunctionExecutionResult: __slots__ = ("content", "terminate") - def __init__(self, content: "Contents", terminate: bool = False) -> None: + def __init__(self, content: "Content", terminate: bool = False) -> None: """Initialize FunctionExecutionResult. Args: @@ -1492,7 +1486,7 @@ def __init__(self, content: "Contents", terminate: bool = False) -> None: async def _auto_invoke_function( - function_call_content: "FunctionCallContent | FunctionApprovalResponseContent", + function_call_content: "Content", custom_args: dict[str, Any] | None = None, *, config: FunctionInvocationConfiguration, @@ -1500,7 +1494,7 @@ async def _auto_invoke_function( sequence_index: int | None = None, request_index: int | None = None, middleware_pipeline: Any = None, # Optional MiddlewarePipeline -) -> "FunctionExecutionResult | Contents": +) -> "FunctionExecutionResult | Content": """Invoke a function call requested by the agent, applying middleware that is defined. Args: @@ -1516,16 +1510,17 @@ async def _auto_invoke_function( Returns: A FunctionExecutionResult wrapping the content and terminate signal, - or a Contents object for approval/hosted tool scenarios. + or a Content object for approval/hosted tool scenarios. Raises: KeyError: If the requested function is not found in the tool map. """ + from ._types import Content + # Note: The scenarios for approval_mode="always_require", declaration_only, and # terminate_on_unknown_calls are all handled in _try_execute_function_calls before # this function is called. This function only handles the actual execution of approved, # non-declaration-only functions. - from ._types import FunctionCallContent, FunctionResultContent tool: AIFunction[BaseModel, Any] | None = None if function_call_content.type == "function_call": @@ -1534,7 +1529,7 @@ async def _auto_invoke_function( if tool is None: exc = KeyError(f'Function "{function_call_content.name}" not found.') return FunctionExecutionResult( - content=FunctionResultContent( + content=Content.from_function_result( call_id=function_call_content.call_id, result=f'Error: Requested function "{function_call_content.name}" not found.', exception=exc, @@ -1543,10 +1538,10 @@ async def _auto_invoke_function( else: # Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results # and never reach this function, so we only handle approved=True cases here. - inner_call = function_call_content.function_call - if not isinstance(inner_call, FunctionCallContent): + inner_call = function_call_content.function_call # type: ignore[attr-defined] + if inner_call.type != "function_call": return function_call_content - tool = tool_map.get(inner_call.name) + tool = tool_map.get(inner_call.name) # type: ignore[attr-defined] if tool is None: # we assume it is a hosted tool return function_call_content @@ -1567,7 +1562,7 @@ async def _auto_invoke_function( if config.include_detailed_errors: message = f"{message} Exception: {exc}" return FunctionExecutionResult( - content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + content=Content.from_function_result(call_id=function_call_content.call_id, result=message, exception=exc) ) if not middleware_pipeline or ( @@ -1581,7 +1576,7 @@ async def _auto_invoke_function( **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) return FunctionExecutionResult( - content=FunctionResultContent( + content=Content.from_function_result( call_id=function_call_content.call_id, result=function_result, ) @@ -1591,7 +1586,9 @@ async def _auto_invoke_function( if config.include_detailed_errors: message = f"{message} Exception: {exc}" return FunctionExecutionResult( - content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + content=Content.from_function_result( + call_id=function_call_content.call_id, result=message, exception=exc + ) ) # Execute through middleware pipeline if available from ._middleware import FunctionInvocationContext @@ -1617,7 +1614,7 @@ async def final_function_handler(context_obj: Any) -> Any: final_handler=final_function_handler, ) return FunctionExecutionResult( - content=FunctionResultContent( + content=Content.from_function_result( call_id=function_call_content.call_id, result=function_result, ), @@ -1628,7 +1625,7 @@ async def final_function_handler(context_obj: Any) -> Any: if config.include_detailed_errors: message = f"{message} Exception: {exc}" return FunctionExecutionResult( - content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + content=Content.from_function_result(call_id=function_call_content.call_id, result=message, exception=exc) ) @@ -1653,14 +1650,14 @@ def _get_tool_map( async def _try_execute_function_calls( custom_args: dict[str, Any], attempt_idx: int, - function_calls: Sequence["FunctionCallContent"] | Sequence["FunctionApprovalResponseContent"], + function_calls: Sequence["Content"], tools: "ToolProtocol \ | Callable[..., Any] \ | MutableMapping[str, Any] \ | Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]", config: FunctionInvocationConfiguration, middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports -) -> tuple[Sequence["Contents"], bool]: +) -> tuple[Sequence["Content"], bool]: """Execute multiple function calls concurrently. Args: @@ -1673,12 +1670,12 @@ async def _try_execute_function_calls( Returns: A tuple of: - - A list of Contents containing the results of each function call, + - A list of Content containing the results of each function call, or the approval requests if any function requires approval, or the original function calls if any are declaration only. - A boolean indicating whether to terminate the function calling loop. """ - from ._types import FunctionApprovalRequestContent, FunctionCallContent + from ._types import Content tool_map = _get_tool_map(tools) approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"] @@ -1689,27 +1686,27 @@ async def _try_execute_function_calls( approval_needed = False declaration_only_flag = False for fcc in function_calls: - if isinstance(fcc, FunctionCallContent) and fcc.name in approval_tools: + if fcc.type == "function_call" and fcc.name in approval_tools: # type: ignore[attr-defined] approval_needed = True break - if isinstance(fcc, FunctionCallContent) and (fcc.name in declaration_only or fcc.name in additional_tool_names): + if fcc.type == "function_call" and (fcc.name in declaration_only or fcc.name in additional_tool_names): # type: ignore[attr-defined] declaration_only_flag = True break - if config.terminate_on_unknown_calls and isinstance(fcc, FunctionCallContent) and fcc.name not in tool_map: - raise KeyError(f'Error: Requested function "{fcc.name}" not found.') + if config.terminate_on_unknown_calls and fcc.type == "function_call" and fcc.name not in tool_map: # type: ignore[attr-defined] + raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] if approval_needed: - # approval can only be needed for Function Call Contents, not Approval Responses. + # approval can only be needed for Function Call Content, not Approval Responses. return ( [ - FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc) + Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined] for fcc in function_calls - if isinstance(fcc, FunctionCallContent) + if fcc.type == "function_call" ], False, ) if declaration_only_flag: # return the declaration only tools to the user, since we cannot execute them. - return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False) + return ([fcc for fcc in function_calls if fcc.type == "function_call"], False) # Run all function calls concurrently execution_results = await asyncio.gather(*[ @@ -1726,7 +1723,7 @@ async def _try_execute_function_calls( ]) # Unpack FunctionExecutionResult wrappers and check for terminate signal - contents: list[Contents] = [] + contents: list[Content] = [] should_terminate = False for result in execution_results: if isinstance(result, FunctionExecutionResult): @@ -1734,7 +1731,7 @@ async def _try_execute_function_calls( if result.terminate: should_terminate = True else: - # Direct Contents (e.g., from hosted tools) + # Direct Content (e.g., from hosted tools) contents.append(result) return (contents, should_terminate) @@ -1772,30 +1769,27 @@ def _extract_tools(options: dict[str, Any] | None) -> Any: def _collect_approval_responses( messages: "list[ChatMessage]", -) -> dict[str, "FunctionApprovalResponseContent"]: +) -> dict[str, "Content"]: """Collect approval responses (both approved and rejected) from messages.""" - from ._types import ChatMessage, FunctionApprovalResponseContent + from ._types import ChatMessage, Content - fcc_todo: dict[str, FunctionApprovalResponseContent] = {} + fcc_todo: dict[str, Content] = {} for msg in messages: for content in msg.contents if isinstance(msg, ChatMessage) else []: # Collect BOTH approved and rejected responses - if isinstance(content, FunctionApprovalResponseContent): - fcc_todo[content.id] = content + if content.type == "function_approval_response": + fcc_todo[content.id] = content # type: ignore[attr-defined] return fcc_todo def _replace_approval_contents_with_results( messages: "list[ChatMessage]", - fcc_todo: dict[str, "FunctionApprovalResponseContent"], - approved_function_results: "list[Contents]", + fcc_todo: dict[str, "Content"], + approved_function_results: "list[Content]", ) -> None: """Replace approval request/response contents with function call/result contents in-place.""" from ._types import ( - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, ) @@ -1803,23 +1797,25 @@ def _replace_approval_contents_with_results( for msg in messages: # First pass - collect existing function call IDs to avoid duplicates existing_call_ids = { - content.call_id for content in msg.contents if isinstance(content, FunctionCallContent) and content.call_id + content.call_id + for content in msg.contents + if content.type == "function_call" and content.call_id # type: ignore[attr-defined] } # Track approval requests that should be removed (duplicates) contents_to_remove = [] for content_idx, content in enumerate(msg.contents): - if isinstance(content, FunctionApprovalRequestContent): + if content.type == "function_approval_request": # Don't add the function call if it already exists (would create duplicate) - if content.function_call.call_id in existing_call_ids: + if content.function_call.call_id in existing_call_ids: # type: ignore[attr-defined] # Just mark for removal - the function call already exists contents_to_remove.append(content_idx) else: # Put back the function call content only if it doesn't exist - msg.contents[content_idx] = content.function_call - elif isinstance(content, FunctionApprovalResponseContent): - if content.approved and content.id in fcc_todo: + msg.contents[content_idx] = content.function_call # type: ignore[attr-defined] + elif content.type == "function_approval_response": + if content.approved and content.id in fcc_todo: # type: ignore[attr-defined] # Replace with the corresponding result if result_idx < len(approved_function_results): msg.contents[content_idx] = approved_function_results[result_idx] @@ -1828,7 +1824,7 @@ def _replace_approval_contents_with_results( else: # Create a "not approved" result for rejected calls # Use function_call.call_id (the function's ID), not content.id (approval's ID) - msg.contents[content_idx] = FunctionResultContent( + msg.contents[content_idx] = Content.from_function_result( call_id=content.function_call.call_id, result="Error: Tool call invocation was rejected by user.", ) @@ -1867,9 +1863,6 @@ async def function_invocation_wrapper( from ._middleware import extract_and_merge_function_middleware from ._types import ( ChatMessage, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, prepare_messages, ) @@ -1893,7 +1886,7 @@ async def function_invocation_wrapper( tools = _extract_tools(options) # Only execute APPROVED function calls, not rejected ones approved_responses = [resp for resp in fcc_todo.values() if resp.approved] - approved_function_results: list[Contents] = [] + approved_function_results: list[Content] = [] if approved_responses: results, _ = await _try_execute_function_calls( custom_args=kwargs, @@ -1907,7 +1900,7 @@ async def function_invocation_wrapper( if any( fcr.exception is not None for fcr in approved_function_results - if isinstance(fcr, FunctionResultContent) + if fcr.type == "function_result" ): errors_in_a_row += 1 # no need to reset the counter here, since this is the start of a new attempt. @@ -1926,13 +1919,11 @@ async def function_invocation_wrapper( filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("thread", "tools", "tool_choice")} response = await func(self, messages=prepped_messages, options=options, **filtered_kwargs) # if there are function calls, we will handle them first - function_results = { - it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent) - } + function_results = {it.call_id for it in response.messages[0].contents if it.type == "function_result"} function_calls = [ it for it in response.messages[0].contents - if isinstance(it, FunctionCallContent) and it.call_id not in function_results + if it.type == "function_call" and it.call_id not in function_results ] if response.conversation_id is not None: @@ -1953,7 +1944,7 @@ async def function_invocation_wrapper( config=config, ) # Check if we have approval requests or function calls (not results) in the results - if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results): + if any(fccr.type == "function_approval_request" for fccr in function_call_results): # Add approval requests to the existing assistant message (with tool_calls) # instead of creating a separate tool message from ._types import Role @@ -1965,7 +1956,7 @@ async def function_invocation_wrapper( result_message = ChatMessage(role="assistant", contents=function_call_results) response.messages.append(result_message) return response - if any(isinstance(fccr, FunctionCallContent) for fccr in function_call_results): + if any(fccr.type == "function_call" for fccr in function_call_results): # the function calls are already in the response, so we just continue return response @@ -1980,11 +1971,7 @@ async def function_invocation_wrapper( response.messages.insert(0, msg) return response - if any( - fcr.exception is not None - for fcr in function_call_results - if isinstance(fcr, FunctionResultContent) - ): + if any(fcr.exception is not None for fcr in function_call_results if fcr.type == "function_result"): errors_in_a_row += 1 if errors_in_a_row >= config.max_consecutive_errors_per_request: logger.warning( @@ -2071,8 +2058,6 @@ async def streaming_function_invocation_wrapper( ChatMessage, ChatResponse, ChatResponseUpdate, - FunctionCallContent, - FunctionResultContent, prepare_messages, ) @@ -2094,7 +2079,7 @@ async def streaming_function_invocation_wrapper( tools = _extract_tools(options) # Only execute APPROVED function calls, not rejected ones approved_responses = [resp for resp in fcc_todo.values() if resp.approved] - approved_function_results: list[Contents] = [] + approved_function_results: list[Content] = [] if approved_responses: results, _ = await _try_execute_function_calls( custom_args=kwargs, @@ -2108,7 +2093,7 @@ async def streaming_function_invocation_wrapper( if any( fcr.exception is not None for fcr in approved_function_results - if isinstance(fcr, FunctionResultContent) + if fcr.type == "function_result" ): errors_in_a_row += 1 # no need to reset the counter here, since this is the start of a new attempt. @@ -2124,10 +2109,9 @@ async def streaming_function_invocation_wrapper( # efficient check for FunctionCallContent in the updates # if there is at least one, this stops and continuous # if there are no FCC's then it returns - from ._types import FunctionApprovalRequestContent if not any( - isinstance(item, (FunctionCallContent, FunctionApprovalRequestContent)) + item.type in ("function_call", "function_approval_request") for upd in all_updates for item in upd.contents ): @@ -2139,13 +2123,11 @@ async def streaming_function_invocation_wrapper( response: "ChatResponse" = ChatResponse.from_chat_response_updates(all_updates) # get the function calls (excluding ones that already have results) - function_results = { - it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent) - } + function_results = {it.call_id for it in response.messages[0].contents if it.type == "function_result"} function_calls = [ it for it in response.messages[0].contents - if isinstance(it, FunctionCallContent) and it.call_id not in function_results + if it.type == "function_call" and it.call_id not in function_results ] # When conversation id is present, it means that messages are hosted on the server. @@ -2169,7 +2151,7 @@ async def streaming_function_invocation_wrapper( ) # Check if we have approval requests or function calls (not results) in the results - if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results): + if any(fccr.type == "function_approval_request" for fccr in function_call_results): # Add approval requests to the existing assistant message (with tool_calls) # instead of creating a separate tool message from ._types import Role @@ -2184,7 +2166,7 @@ async def streaming_function_invocation_wrapper( yield ChatResponseUpdate(contents=function_call_results, role="assistant") response.messages.append(result_message) return - if any(isinstance(fccr, FunctionCallContent) for fccr in function_call_results): + if any(fccr.type == "function_call" for fccr in function_call_results): # the function calls were already yielded. return @@ -2195,11 +2177,7 @@ async def streaming_function_invocation_wrapper( yield ChatResponseUpdate(contents=function_call_results, role="tool") return - if any( - fcr.exception is not None - for fcr in function_call_results - if isinstance(fcr, FunctionResultContent) - ): + if any(fcr.exception is not None for fcr in function_call_results if fcr.type == "function_result"): errors_in_a_row += 1 if errors_in_a_row >= config.max_consecutive_errors_per_request: logger.warning( diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index fce99b3488..86f069edf9 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -3,7 +3,6 @@ import base64 import json import re -import sys from collections.abc import ( AsyncIterable, Callable, @@ -13,7 +12,7 @@ Sequence, ) from copy import deepcopy -from typing import Any, ClassVar, Literal, TypedDict, TypeVar, cast, overload +from typing import Any, ClassVar, Final, Literal, TypedDict, TypeVar, overload from pydantic import BaseModel, ValidationError @@ -22,49 +21,23 @@ from ._tools import ToolProtocol, ai_function from .exceptions import AdditionItemMismatch, ContentError -if sys.version_info >= (3, 11): - from typing import Self # pragma: no cover -else: - from typing_extensions import Self # pragma: no cover - - __all__ = [ "AgentResponse", "AgentResponseUpdate", - "AnnotatedRegions", - "Annotations", - "BaseAnnotation", - "BaseContent", + "Annotation", "ChatMessage", - "ChatOptions", # Backward compatibility alias "ChatOptions", "ChatResponse", "ChatResponseUpdate", - "CitationAnnotation", - "CodeInterpreterToolCallContent", - "CodeInterpreterToolResultContent", - "Contents", - "DataContent", - "ErrorContent", + "Content", "FinishReason", - "FunctionApprovalRequestContent", - "FunctionApprovalResponseContent", - "FunctionCallContent", - "FunctionResultContent", - "HostedFileContent", - "HostedVectorStoreContent", - "ImageGenerationToolCallContent", - "ImageGenerationToolResultContent", - "MCPServerToolCallContent", - "MCPServerToolResultContent", "Role", - "TextContent", - "TextReasoningContent", + "TextSpanRegion", "TextSpanRegion", "ToolMode", - "UriContent", - "UsageContent", "UsageDetails", + "add_usage_details", + "detect_media_type_from_base64", "merge_chat_options", "normalize_messages", "normalize_tools", @@ -103,84 +76,251 @@ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) return cls -def _parse_content(content_data: MutableMapping[str, Any]) -> "Contents": - """Parse a single content data dictionary into the appropriate Content object. +def _parse_content_list(contents_data: Sequence[Any]) -> list["Content"]: + """Parse a list of content data dictionaries into appropriate Content objects. Args: - content_data: Content data (dict) + contents_data: List of content data (dicts or already constructed objects) Returns: - Content object + List of Content objects with unknown types logged and ignored + """ + contents: list["Content"] = [] + for content_data in contents_data: + if isinstance(content_data, Content): + contents.append(content_data) + continue + try: + contents.append(Content.from_dict(content_data)) + except ContentError as exc: + logger.warning(f"Skipping unknown content type or invalid content: {exc}") - Raises: - ContentError if parsing fails + return contents + + +# region Internal Helper functions for unified Content + + +def detect_media_type_from_base64(data_base64: str) -> str | None: + """Detect media type from base64-encoded data by examining magic bytes. + + This function examines the binary signature (magic bytes) at the start of the data + to identify common media types. It's reliable for binary formats like images, audio, + video, and documents, but cannot detect text-based formats like JSON or plain text. + + Args: + data_base64: Base64-encoded data (with or without data URI prefix). + + Returns: + The detected media type (e.g., 'image/png', 'audio/wav', 'application/pdf') + or None if the format is not recognized. + + Examples: + .. code-block:: python + + from agent_framework import detect_media_type_from_base64 + + # Detect from base64 string + base64_data = "iVBORw0KGgo..." + media_type = detect_media_type_from_base64(base64_data) + # Returns: "image/png" + + # Works with data URIs too + data_uri = "data:image/png;base64,iVBORw0KGgo..." + media_type = detect_media_type_from_base64(data_uri) + # Returns: "image/png" """ - content_type: str | None = content_data.get("type", None) - match content_type: - case "text": - return TextContent.from_dict(content_data) - case "data": - return DataContent.from_dict(content_data) - case "uri": - return UriContent.from_dict(content_data) - case "error": - return ErrorContent.from_dict(content_data) - case "function_call": - return FunctionCallContent.from_dict(content_data) - case "function_result": - return FunctionResultContent.from_dict(content_data) - case "usage": - return UsageContent.from_dict(content_data) - case "hosted_file": - return HostedFileContent.from_dict(content_data) - case "hosted_vector_store": - return HostedVectorStoreContent.from_dict(content_data) - case "code_interpreter_tool_call": - return CodeInterpreterToolCallContent.from_dict(content_data) - case "code_interpreter_tool_result": - return CodeInterpreterToolResultContent.from_dict(content_data) - case "image_generation_tool_call": - return ImageGenerationToolCallContent.from_dict(content_data) - case "image_generation_tool_result": - return ImageGenerationToolResultContent.from_dict(content_data) - case "mcp_server_tool_call": - return MCPServerToolCallContent.from_dict(content_data) - case "mcp_server_tool_result": - return MCPServerToolResultContent.from_dict(content_data) - case "function_approval_request": - return FunctionApprovalRequestContent.from_dict(content_data) - case "function_approval_response": - return FunctionApprovalResponseContent.from_dict(content_data) - case "text_reasoning": - return TextReasoningContent.from_dict(content_data) - case None: - raise ContentError("Content type is missing") - case _: - raise ContentError(f"Unknown content type '{content_type}'") - - -def _parse_content_list(contents_data: Sequence[Any]) -> list["Contents"]: - """Parse a list of content data dictionaries into appropriate Content objects. + # Remove data URI prefix if present + if data_base64.startswith("data:"): + if ";base64," in data_base64: + data_base64 = data_base64.split(";base64,", 1)[1] + + try: + # Decode just the first few bytes to check magic numbers + decoded = base64.b64decode(data_base64[:50]) + except Exception: + return None + + # Check magic bytes for common formats + # Images + if decoded.startswith(b"\x89PNG\r\n\x1a\n"): + return "image/png" + if decoded.startswith(b"\xff\xd8\xff"): + return "image/jpeg" + if decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"): + return "image/gif" + if decoded.startswith(b"RIFF") and len(decoded) > 11 and decoded[8:12] == b"WEBP": + return "image/webp" + if decoded.startswith(b"BM"): + return "image/bmp" + if decoded.startswith(b" 11 and decoded[8:12] == b"WAVE": + return "audio/wav" + if decoded.startswith(b"ID3") or decoded.startswith(b"\xff\xfb") or decoded.startswith(b"\xff\xf3"): + return "audio/mpeg" + if decoded.startswith(b"OggS"): + return "audio/ogg" + if decoded.startswith(b"fLaC"): + return "audio/flac" + + return None + + +def _create_data_uri_from_base64(image_base64: str, media_type: str | None = None) -> tuple[str, str]: + """Create a data URI and media type from base64 data. Args: - contents_data: List of content data (dicts or already constructed objects) + image_base64: Base64-encoded image data (with or without data URI prefix). + media_type: Optional explicit media type. If not provided, will attempt to detect. Returns: - List of Content objects with unknown types logged and ignored + A tuple of (data_uri, media_type). + + Raises: + ContentError: If media type cannot be determined. """ - contents: list["Contents"] = [] - for content_data in contents_data: - if isinstance(content_data, dict): - try: - content = _parse_content(content_data) - contents.append(content) - except ContentError as exc: - logger.warning(f"Skipping unknown content type or invalid content: {exc}") + # If it's already a data URI, extract the parts + if image_base64.startswith("data:"): + if ";base64," in image_base64: + prefix, data = image_base64.split(";base64,", 1) + existing_media_type = prefix.split(":", 1)[1] if ":" in prefix else None + if media_type is None: + media_type = existing_media_type + image_base64 = data else: - # If it's already a content object, keep it as is - contents.append(content_data) + raise ContentError("Data URI must use base64 encoding") - return contents + # Detect format if media type not provided + if media_type is None: + detected_media_type = detect_media_type_from_base64(image_base64) + if detected_media_type: + media_type = detected_media_type + else: + raise ContentError("Could not detect media type from base64 data") + + # Construct data URI + data_uri = f"data:{media_type};base64,{image_base64}" + return data_uri, media_type + + +def _get_data_bytes_as_str(content: "Content") -> str | None: + """Extract base64 data string from data URI. + + Args: + content: The Content instance to extract data from. + + Returns: + The base64-encoded data as a string, or None if not a data content type. + + Raises: + ContentError: If the URI is not a valid data URI. + """ + if content.type not in ("data", "uri"): + return None + + uri = getattr(content, "uri", None) + if not uri: + return None + + if not uri.startswith("data:"): + return None + + if ";base64," not in uri: + raise ContentError("Data URI must use base64 encoding") + + _, data = uri.split(";base64,", 1) + return data + + +def _get_data_bytes(content: "Content") -> bytes | None: + """Extract and decode binary data from data URI. + + Args: + content: The Content instance to extract data from. + + Returns: + The decoded binary data, or None if not a data content type. + + Raises: + ContentError: If the URI is not a valid data URI or decoding fails. + """ + data_str = _get_data_bytes_as_str(content) + if data_str is None: + return None + + try: + return base64.b64decode(data_str) + except Exception as e: + raise ContentError(f"Failed to decode base64 data: {e}") from e + + +KNOWN_URI_SCHEMAS: Final[set[str]] = {"http", "https", "ftp", "ftps", "file", "s3", "gs", "azure", "blob"} + + +def _validate_uri(uri: str, media_type: str | None) -> dict[str, Any]: + """Validate URI format and return validation result. + + Args: + uri: The URI to validate. + media_type: Optional media type associated with the URI. + + Returns: + If valid, returns a dict, with "type" key indicating "data" or "uri", along with the uri and media_type. + """ + if not uri: + raise ContentError("URI cannot be empty") + + # Check for data URI + if uri.startswith("data:"): + if "," not in uri: + raise ContentError("Data URI must contain a comma separating metadata and data") + prefix, _ = uri.split(",", 1) + if ";" in prefix: + parts = prefix.split(";") + if len(parts) < 2: + raise ContentError("Invalid data URI format") + # Check encoding + encoding = parts[-1] + if encoding not in ("base64", ""): + raise ContentError(f"Unsupported data URI encoding: {encoding}") + if media_type is None: + # attempt to extract: + media_type = parts[0][5:] # Remove 'data:' + return {"type": "data", "uri": uri, "media_type": media_type} + + # Check for common URI schemes + if ":" in uri: + scheme = uri.split(":", 1)[0].lower() + if not media_type: + logger.warning("Using URI without media type is not recommended.") + if scheme not in KNOWN_URI_SCHEMAS: + logger.info(f"Unknown URI scheme: {scheme}, allowed schemes are {KNOWN_URI_SCHEMAS}.") + return {"type": "uri", "uri": uri, "media_type": media_type} + + # No scheme found + raise ContentError("URI must contain a scheme (e.g., http://, data:, file://)") + + +def _serialize_value(value: Any, exclude_none: bool) -> Any: + """Recursively serialize a value for to_dict.""" + if value is None: + return None + if isinstance(value, Content): + return value.to_dict(exclude_none=exclude_none) + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [_serialize_value(item, exclude_none) for item in value] + if isinstance(value, Mapping): + return {k: _serialize_value(v, exclude_none) for k, v in value.items()} + if hasattr(value, "to_dict"): + return value.to_dict() # type: ignore[call-arg] + return value # endregion @@ -223,1918 +363,1032 @@ def _parse_content_list(contents_data: Sequence[Any]) -> list["Contents"]: "text/xml", ] +# region Unified Content Types + +ContentType = Literal[ + "text", + "text_reasoning", + "data", + "uri", + "error", + "function_call", + "function_result", + "usage", + "hosted_file", + "hosted_vector_store", + "code_interpreter_tool_call", + "code_interpreter_tool_result", + "image_generation_tool_call", + "image_generation_tool_result", + "mcp_server_tool_call", + "mcp_server_tool_result", + "function_approval_request", + "function_approval_response", +] -class UsageDetails(SerializationMixin): - """Provides usage details about a request/response. - - Attributes: - input_token_count: The number of tokens in the input. - output_token_count: The number of tokens in the output. - total_token_count: The total number of tokens used to produce the response. - additional_counts: A dictionary of additional token counts, can be set by passing kwargs. - - Examples: - .. code-block:: python - - from agent_framework import UsageDetails - - # Create usage details - usage = UsageDetails( - input_token_count=100, - output_token_count=50, - total_token_count=150, - ) - print(usage.total_token_count) # 150 - - # With additional counts - usage = UsageDetails( - input_token_count=100, - output_token_count=50, - total_token_count=150, - reasoning_tokens=25, - ) - print(usage.additional_counts["reasoning_tokens"]) # 25 - # Combine usage details - usage1 = UsageDetails(input_token_count=100, output_token_count=50) - usage2 = UsageDetails(input_token_count=200, output_token_count=100) - combined = usage1 + usage2 - print(combined.input_token_count) # 300 - """ +class TextSpanRegion(TypedDict, total=False): + """TypedDict representation of a text span region annotation.""" - DEFAULT_EXCLUDE: ClassVar[set[str]] = {"_extra_counts"} + type: Literal["text_span"] + start_index: int + end_index: int - def __init__( - self, - input_token_count: int | None = None, - output_token_count: int | None = None, - total_token_count: int | None = None, - **kwargs: int, - ) -> None: - """Initializes the UsageDetails instance. - Args: - input_token_count: The number of tokens in the input. - output_token_count: The number of tokens in the output. - total_token_count: The total number of tokens used to produce the response. +class Annotation(TypedDict, total=False): + """TypedDict representation of an annotation.""" - Keyword Args: - **kwargs: Additional token counts, can be set by passing keyword arguments. - They can be retrieved through the `additional_counts` property. - """ - self.input_token_count = input_token_count - self.output_token_count = output_token_count - self.total_token_count = total_token_count + type: Literal["citation"] + title: str + url: str + file_id: str + tool_name: str + snippet: str + annotated_regions: Sequence[TextSpanRegion] + additional_properties: dict[str, Any] + raw_representation: Any - # Validate that all kwargs are integers (preserving Pydantic behavior) - self._extra_counts: dict[str, int] = {} - for key, value in kwargs.items(): - if not isinstance(value, int): - raise ValueError(f"Additional counts must be integers, got {type(value).__name__}") - self._extra_counts[key] = value - def to_dict(self, *, exclude_none: bool = True, exclude: set[str] | None = None) -> dict[str, Any]: - """Convert the UsageDetails instance to a dictionary. +TContent = TypeVar("TContent", bound="Content") - Keyword Args: - exclude_none: Whether to exclude None values from the output. - exclude: Set of field names to exclude from the output. +# endregion - Returns: - Dictionary representation of the UsageDetails instance. - """ - # Get the base dict from parent class - result = super().to_dict(exclude_none=exclude_none, exclude=exclude) - # Add additional counts (extra fields) - if exclude is None: - exclude = set() +class UsageDetails(TypedDict, total=False): + """A dictionary representing usage details. - for key, value in self._extra_counts.items(): - if key in exclude: - continue - if exclude_none and value is None: - continue - result[key] = value + This is a non-closed dictionary, so any specific provider fields can be added as needed. + Whenever they can be mapped to standard fields, they will be. + """ - return result + input_token_count: int | None + output_token_count: int | None + total_token_count: int | None - def __str__(self) -> str: - """Returns a string representation of the usage details.""" - return self.to_json() - @property - def additional_counts(self) -> dict[str, int]: - """Represents well-known additional counts for usage. This is not an exhaustive list. +def add_usage_details(usage1: UsageDetails | None, usage2: UsageDetails | None) -> UsageDetails: + """Add two UsageDetails dictionaries by summing all numeric values. - Remarks: - To make it possible to avoid collisions between similarly-named, but unrelated, additional counts - between different AI services, any keys not explicitly defined here should be prefixed with the - name of the AI service, e.g., "openai." or "azure.". The separator "." was chosen because it cannot - be a legal character in a JSON key. + Args: + usage1: First usage details dictionary. + usage2: Second usage details dictionary. - Over time additional counts may be added to the base class. - """ - return self._extra_counts - - def __setitem__(self, key: str, value: int) -> None: - """Sets an additional count for the usage details.""" - if not isinstance(value, int): - raise ValueError("Additional counts must be integers.") - self._extra_counts[key] = value - - def __add__(self, other: "UsageDetails | None") -> "UsageDetails": - """Combines two `UsageDetails` instances.""" - if not other: - return self - if not isinstance(other, UsageDetails): - raise ValueError("Can only add two usage details objects together.") - - additional_counts = self.additional_counts.copy() - if other.additional_counts: - for key, value in other.additional_counts.items(): - additional_counts[key] = additional_counts.get(key, 0) + (value or 0) - - return UsageDetails( - input_token_count=(self.input_token_count or 0) + (other.input_token_count or 0), - output_token_count=(self.output_token_count or 0) + (other.output_token_count or 0), - total_token_count=(self.total_token_count or 0) + (other.total_token_count or 0), - **additional_counts, - ) + Returns: + A new UsageDetails dictionary with summed values. - def __iadd__(self, other: "UsageDetails | None") -> Self: - if not other: - return self - if not isinstance(other, UsageDetails): - raise ValueError("Can only add usage details objects together.") + Examples: + .. code-block:: python - self.input_token_count = (self.input_token_count or 0) + (other.input_token_count or 0) - self.output_token_count = (self.output_token_count or 0) + (other.output_token_count or 0) - self.total_token_count = (self.total_token_count or 0) + (other.total_token_count or 0) + from agent_framework import UsageDetails, add_usage_details - for key, value in other.additional_counts.items(): - self.additional_counts[key] = self.additional_counts.get(key, 0) + (value or 0) + usage1 = UsageDetails(input_token_count=5, output_token_count=10) + usage2 = UsageDetails(input_token_count=3, output_token_count=6) + combined = add_usage_details(usage1, usage2) + # Result: {'input_token_count': 8, 'output_token_count': 16} + """ + if usage1 is None: + return usage2 or UsageDetails() + if usage2 is None: + return usage1 - return self + result = UsageDetails() - def __eq__(self, other: object) -> bool: - """Check if two UsageDetails instances are equal.""" - if not isinstance(other, UsageDetails): - return False + # Combine all keys from both dictionaries + all_keys = set(usage1.keys()) | set(usage2.keys()) - return ( - self.input_token_count == other.input_token_count - and self.output_token_count == other.output_token_count - and self.total_token_count == other.total_token_count - and self.additional_counts == other.additional_counts - ) + for key in all_keys: + val1 = usage1.get(key) + val2 = usage2.get(key) + # Sum if both present, otherwise use the non-None value + if val1 is not None and val2 is not None: + result[key] = val1 + val2 # type: ignore[literal-required] + elif val1 is not None: + result[key] = val1 # type: ignore[literal-required] + elif val2 is not None: + result[key] = val2 # type: ignore[literal-required] -# region BaseAnnotation + return result -class TextSpanRegion(SerializationMixin): - """Represents a region of text that has been annotated. +# region Content Class - Examples: - .. code-block:: python - from agent_framework import TextSpanRegion +class Content: + """Unified content container covering all content variants. - # Create a text span region - region = TextSpanRegion(start_index=0, end_index=10) - print(region.type) # "text_span" + This class provides a single unified type that handles all content variants. + Use the class methods like `Content.from_text()`, `Content.from_data()`, + `Content.from_uri()`, etc. to create instances. """ def __init__( self, + type: ContentType, *, - start_index: int | None = None, - end_index: int | None = None, - **kwargs: Any, + # Text content fields + text: str | None = None, + protected_data: str | None = None, + # Data/URI content fields + uri: str | None = None, + media_type: str | None = None, + # Error content fields + message: str | None = None, + error_code: str | None = None, + error_details: str | None = None, + # Usage content fields + usage_details: dict[str, Any] | UsageDetails | None = None, + # Function call/result fields + call_id: str | None = None, + name: str | None = None, + arguments: str | Mapping[str, Any] | None = None, + exception: str | None = None, + result: Any = None, + # Hosted file/vector store fields + file_id: str | None = None, + vector_store_id: str | None = None, + # Code interpreter tool fields + inputs: list["Content"] | None = None, + outputs: list["Content"] | Any | None = None, + # Image generation tool fields + image_id: str | None = None, + # MCP server tool fields + tool_name: str | None = None, + server_name: str | None = None, + output: Any = None, + # Function approval fields + id: str | None = None, + function_call: "Content | None" = None, + user_input_request: bool | None = None, + approved: bool | None = None, + # Common fields + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any | None = None, ) -> None: - """Initialize TextSpanRegion. + """Create a content instance. - Keyword Args: - start_index: The start index of the text span. - end_index: The end index of the text span. - **kwargs: Additional keyword arguments. + Prefer using the classmethod constructors like `Content.from_text()` instead of calling __init__ directly. """ - self.type: Literal["text_span"] = "text_span" - self.start_index = start_index - self.end_index = end_index - - # Handle any additional kwargs - for key, value in kwargs.items(): - if not hasattr(self, key): - setattr(self, key, value) - - -AnnotatedRegions = TextSpanRegion + self.type = type + self.annotations = annotations + self.additional_properties: dict[str, Any] = additional_properties or {} + self.raw_representation = raw_representation + # Set all content-specific attributes + self.text = text + self.protected_data = protected_data + self.uri = uri + self.media_type = media_type + self.message = message + self.error_code = error_code + self.error_details = error_details + self.usage_details = usage_details + self.call_id = call_id + self.name = name + self.arguments = arguments + self.exception = exception + self.result = result + self.file_id = file_id + self.vector_store_id = vector_store_id + self.inputs = inputs + self.outputs = outputs + self.image_id = image_id + self.tool_name = tool_name + self.server_name = server_name + self.output = output + self.id = id + self.function_call = function_call + self.user_input_request = user_input_request + self.approved = approved -class BaseAnnotation(SerializationMixin): - """Base class for all AI Annotation types.""" + @classmethod + def from_text( + cls: type[TContent], + text: str, + *, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create text content.""" + return cls( + "text", + text=text, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) - DEFAULT_EXCLUDE: ClassVar[set[str]] = {"raw_representation", "additional_properties"} + @classmethod + def from_text_reasoning( + cls: type[TContent], + *, + text: str | None = None, + protected_data: str | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create text reasoning content.""" + return cls( + "text_reasoning", + text=text, + protected_data=protected_data, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) - def __init__( - self, + @classmethod + def from_data( + cls: type[TContent], + data: bytes, + media_type: str, *, - annotated_regions: list[AnnotatedRegions] | list[MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initialize BaseAnnotation. + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + r"""Create data content from raw binary data. - Keyword Args: - annotated_regions: A list of regions that have been annotated. Can be region objects or dicts. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content from an underlying implementation. - **kwargs: Additional keyword arguments (merged into additional_properties). - """ - # Handle annotated_regions conversion from dict format (for SerializationMixin support) - self.annotated_regions: list[AnnotatedRegions] | None = None - if annotated_regions is not None: - converted_regions: list[AnnotatedRegions] = [] - for region_data in annotated_regions: - if isinstance(region_data, MutableMapping): - if region_data.get("type", "") == "text_span": - converted_regions.append(TextSpanRegion.from_dict(region_data)) - else: - logger.warning(f"Unknown region type: {region_data.get('type', '')} in {region_data}") - else: - # Already a region object, keep as is - converted_regions.append(region_data) - self.annotated_regions = converted_regions + Use this to create content from binary data (images, audio, documents, etc.). + The data will be automatically base64-encoded into a data URI. - # Merge kwargs into additional_properties - self.additional_properties = additional_properties or {} - self.additional_properties.update(kwargs) + Args: + data: Raw binary data as bytes. This should be the actual binary data, + not a base64-encoded string. If you have a base64 string, + decode it first: base64.b64decode(base64_string) + media_type: The MIME type of the data (e.g., "image/png", "application/pdf"). + If you don't know the media type and have base64 data, you can detect it in some cases: - self.raw_representation = raw_representation + .. code-block:: python - def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: - """Convert the instance to a dictionary. + from agent_framework import detect_media_type_from_base64 - Extracts additional_properties fields to the root level. + media_type = detect_media_type_from_base64(base64_string) + if media_type is None: + raise ValueError("Could not detect media type") + data_bytes = base64.b64decode(base64_string) + content = Content.from_data(data=data_bytes, media_type=media_type) Keyword Args: - exclude: Set of field names to exclude from serialization. - exclude_none: Whether to exclude None values from the output. Defaults to True. + annotations: Optional annotations associated with the content. + additional_properties: Optional additional properties. + raw_representation: Optional raw representation from an underlying implementation. Returns: - Dictionary representation of the instance. - """ - # Get the base dict from SerializationMixin - result = super().to_dict(exclude=exclude, exclude_none=exclude_none) - - # Extract additional_properties to root level - if self.additional_properties: - result.update(self.additional_properties) - - return result - + A Content instance with type="data". -class CitationAnnotation(BaseAnnotation): - """Represents a citation annotation. - - Attributes: - type: The type of content, which is always "citation" for this class. - title: The title of the cited content. - url: The URL of the cited content. - file_id: The file identifier of the cited content, if applicable. - tool_name: The name of the tool that generated the citation, if applicable. - snippet: A snippet of the cited content, if applicable. - annotated_regions: A list of regions that have been annotated with this citation. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content from an underlying implementation. - - Examples: - .. code-block:: python + Raises: + TypeError: If data is not bytes. - from agent_framework import CitationAnnotation, TextSpanRegion + Examples: + .. code-block:: python - # Create a citation annotation - citation = CitationAnnotation( - title="Agent Framework Documentation", - url="https://example.com/docs", - snippet="This is a relevant excerpt...", - annotated_regions=[TextSpanRegion(start_index=0, end_index=25)], - ) - print(citation.title) # "Agent Framework Documentation" - """ + from agent_framework import Content, detect_media_type_from_base64 + import base64 - def __init__( - self, - *, - title: str | None = None, - url: str | None = None, - file_id: str | None = None, - tool_name: str | None = None, - snippet: str | None = None, - annotated_regions: list[AnnotatedRegions] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initialize CitationAnnotation. + # Create from raw binary data with known media type + image_bytes = b"\x89PNG\r\n\x1a\n..." + content = Content.from_data(data=image_bytes, media_type="image/png") - Keyword Args: - title: The title of the cited content. - url: The URL of the cited content. - file_id: The file identifier of the cited content, if applicable. - tool_name: The name of the tool that generated the citation, if applicable. - snippet: A snippet of the cited content, if applicable. - annotated_regions: A list of regions that have been annotated with this citation. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content from an underlying implementation. - **kwargs: Additional keyword arguments. + # If you have a base64 string and need to detect media type + base64_string = "iVBORw0KGgo..." + media_type = detect_media_type_from_base64(base64_string) + if media_type is None: + raise ValueError("Unknown media type") + image_bytes = base64.b64decode(base64_string) + content = Content.from_data(data=image_bytes, media_type=media_type) """ - super().__init__( - annotated_regions=annotated_regions, + try: + encoded_data = base64.b64encode(data).decode("utf-8") + except TypeError as e: + raise TypeError( + "Could not encode data to base64. Ensure 'data' is of type bytes.Or another b64encode compatible type." + ) from e + return cls( + "data", + uri=f"data:{media_type};base64,{encoded_data}", + media_type=media_type, + annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.title = title - self.url = url - self.file_id = file_id - self.tool_name = tool_name - self.snippet = snippet - self.type: Literal["citation"] = "citation" + @classmethod + def from_uri( + cls: type[TContent], + uri: str, + *, + media_type: str | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create content from a URI, can be both data URI or external URI. -Annotations = CitationAnnotation - - -# region BaseContent - -TContents = TypeVar("TContents", bound="BaseContent") - - -class BaseContent(SerializationMixin): - """Represents content used by AI services. - - Attributes: - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content from an underlying implementation. - - """ - - DEFAULT_EXCLUDE: ClassVar[set[str]] = {"raw_representation", "additional_properties"} + Use this when you already have a properly formed data URI + (e.g., "data:image/png;base64,iVBORw0KGgo..."). + Or when you receive a link to a online resource (e.g., "https://example.com/image.png"). - def __init__( - self, - *, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initialize BaseContent. + Args: + uri: A URI string, + that either includes the media type and base64-encoded data, + or a valid URL to an external resource. Keyword Args: - annotations: Optional annotations associated with the content. Can be annotation objects or dicts. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content from an underlying implementation. - **kwargs: Additional keyword arguments (merged into additional_properties). - """ - self.annotations: list[Annotations] | None = None - # Handle annotations conversion from dict format (for SerializationMixin support) - if annotations is not None: - converted_annotations: list[Annotations] = [] - for annotation_data in annotations: - if isinstance(annotation_data, Annotations): - # If it's already an annotation object, keep it as is - converted_annotations.append(annotation_data) - elif isinstance(annotation_data, MutableMapping) and annotation_data.get("type", "") == "citation": - converted_annotations.append(CitationAnnotation.from_dict(annotation_data)) - else: - logger.debug( - f"Unknown annotation found: {annotation_data.get('type', 'no_type')}" - f" with data: {annotation_data}" - ) - self.annotations = converted_annotations + media_type: The MIME type of the data (e.g., "image/png", "application/pdf"). + This is optional but recommended for external URIs. + annotations: Optional annotations associated with the content. + additional_properties: Optional additional properties. + raw_representation: Optional raw representation from an underlying implementation. - # Merge kwargs into additional_properties - self.additional_properties = additional_properties or {} - self.additional_properties.update(kwargs) + Raises: + ContentError: If the URI is not valid. - self.raw_representation = raw_representation + Examples: + .. code-block:: python - def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: - """Convert the instance to a dictionary. + from agent_framework import Content - Extracts additional_properties fields to the root level. + # Create from a data URI + content = Content.from_uri(uri="data:image/png;base64,iVBORw0KGgo...", media_type="image/png") + assert content.type == "data" - Keyword Args: - exclude: Set of field names to exclude from serialization. - exclude_none: Whether to exclude None values from the output. Defaults to True. + # Create from an external URI + content = Content.from_uri(uri="https://example.com/image.png", media_type="image/png") + assert content.type == "uri" Returns: - Dictionary representation of the instance. + A Content instance with type="data" for data URIs or type="uri" for external URIs. """ - # Get the base dict from SerializationMixin - result = super().to_dict(exclude=exclude, exclude_none=exclude_none) - - # Extract additional_properties to root level - if self.additional_properties: - result.update(self.additional_properties) - - return result - - -class TextContent(BaseContent): - """Represents text content in a chat. - - Attributes: - text: The text content represented by this instance. - type: The type of content, which is always "text" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import TextContent - - # Create basic text content - text = TextContent(text="Hello, world!") - print(text.text) # "Hello, world!" - - # Concatenate text content - text1 = TextContent(text="Hello, ") - text2 = TextContent(text="world!") - combined = text1 + text2 - print(combined.text) # "Hello, world!" - """ - - def __init__( - self, - text: str, - *, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - **kwargs: Any, - ): - """Initializes a TextContent instance. - - Args: - text: The text content represented by this instance. - - Keyword Args: - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - annotations: Optional annotations associated with the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( - annotations=annotations, - additional_properties=additional_properties, - raw_representation=raw_representation, - **kwargs, - ) - self.text = text - self.type: Literal["text"] = "text" - - def __add__(self, other: "TextContent") -> "TextContent": - """Concatenate two TextContent instances. - - The following things happen: - The text is concatenated. - The annotations are combined. - The additional properties are merged, with the values of shared keys of the first instance taking precedence. - The raw_representations are combined into a list of them, if they both have one. - """ - if not isinstance(other, TextContent): - raise TypeError("Incompatible type") - - # Merge raw representations - if self.raw_representation is None: - raw_representation = other.raw_representation - elif other.raw_representation is None: - raw_representation = self.raw_representation - else: - raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) - - # Merge annotations - if self.annotations is None: - annotations = other.annotations - elif other.annotations is None: - annotations = self.annotations - else: - annotations = self.annotations + other.annotations - - # Create new instance using from_dict for proper deserialization - result_dict = { - "text": self.text + other.text, - "type": "text", - "annotations": [ann.to_dict(exclude_none=False) for ann in annotations] if annotations else None, - "additional_properties": { - **(other.additional_properties or {}), - **(self.additional_properties or {}), - }, - "raw_representation": raw_representation, - } - return TextContent.from_dict(result_dict) - - def __iadd__(self, other: "TextContent") -> Self: - """In-place concatenation of two TextContent instances. - - The following things happen: - The text is concatenated. - The annotations are combined. - The additional properties are merged, with the values of shared keys of the first instance taking precedence. - The raw_representations are combined into a list of them, if they both have one. - """ - if not isinstance(other, TextContent): - raise TypeError("Incompatible type") - - # Concatenate text - self.text += other.text - - # Merge additional properties (self takes precedence) - if self.additional_properties is None: - self.additional_properties = {} - if other.additional_properties: - # Update from other first, then restore self's values to maintain precedence - self_props = self.additional_properties.copy() - self.additional_properties.update(other.additional_properties) - self.additional_properties.update(self_props) - - # Merge raw representations - if self.raw_representation is None: - self.raw_representation = other.raw_representation - elif other.raw_representation is not None: - self.raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) - - # Merge annotations - if other.annotations: - if self.annotations is None: - self.annotations = [] - self.annotations.extend(other.annotations) - - return self - - -class TextReasoningContent(BaseContent): - """Represents text reasoning content in a chat. - - Remarks: - This class and `TextContent` are superficially similar, but distinct. - - Attributes: - text: The text content represented by this instance. - type: The type of content, which is always "text_reasoning" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import TextReasoningContent - - # Create reasoning content - reasoning = TextReasoningContent(text="Let me think step by step...") - print(reasoning.text) # "Let me think step by step..." - - # Concatenate reasoning content - reasoning1 = TextReasoningContent(text="First, ") - reasoning2 = TextReasoningContent(text="second, ") - combined = reasoning1 + reasoning2 - print(combined.text) # "First, second, " - """ - - def __init__( - self, - text: str | None, - *, - protected_data: str | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - **kwargs: Any, - ): - """Initializes a TextReasoningContent instance. - - Args: - text: The text content represented by this instance. - - Keyword Args: - protected_data: This property is used to store data from a provider that should be roundtripped back to the - provider but that is not intended for human consumption. It is often encrypted or otherwise redacted - information that is only intended to be sent back to the provider and not displayed to the user. It's - possible for a TextReasoningContent to contain only `protected_data` and have an empty `text` property. - This data also may be associated with the corresponding `text`, acting as a validation signature for it. - - Note that whereas `text` can be provider agnostic, `protected_data` is provider-specific, and is likely - to only be understood by the provider that created it. The data is often represented as a more complex - object, so it should be serialized to a string before storing so that the whole object is easily - serializable without loss. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - annotations: Optional annotations associated with the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( - annotations=annotations, - additional_properties=additional_properties, - raw_representation=raw_representation, - **kwargs, - ) - self.text = text - self.protected_data = protected_data - self.type: Literal["text_reasoning"] = "text_reasoning" - - def __add__(self, other: "TextReasoningContent") -> "TextReasoningContent": - """Concatenate two TextReasoningContent instances. - - The following things happen: - The text is concatenated. - The annotations are combined. - The additional properties are merged, with the values of shared keys of the first instance taking precedence. - The raw_representations are combined into a list of them, if they both have one. - """ - if not isinstance(other, TextReasoningContent): - raise TypeError("Incompatible type") - - # Merge raw representations - if self.raw_representation is None: - raw_representation = other.raw_representation - elif other.raw_representation is None: - raw_representation = self.raw_representation - else: - raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) - - # Merge annotations - if self.annotations is None: - annotations = other.annotations - elif other.annotations is None: - annotations = self.annotations - else: - annotations = self.annotations + other.annotations - - # Replace protected data. - # Discussion: https://github.com/microsoft/agent-framework/pull/2950#discussion_r2634345613 - protected_data = other.protected_data or self.protected_data - - # Create new instance using from_dict for proper deserialization - result_dict = { - "text": (self.text or "") + (other.text or "") if self.text is not None or other.text is not None else None, - "type": "text_reasoning", - "annotations": [ann.to_dict(exclude_none=False) for ann in annotations] if annotations else None, - "additional_properties": {**(self.additional_properties or {}), **(other.additional_properties or {})}, - "raw_representation": raw_representation, - "protected_data": protected_data, - } - return TextReasoningContent.from_dict(result_dict) - - def __iadd__(self, other: "TextReasoningContent") -> Self: - """In-place concatenation of two TextReasoningContent instances. - - The following things happen: - The text is concatenated. - The annotations are combined. - The additional properties are merged, with the values of shared keys of the first instance taking precedence. - The raw_representations are combined into a list of them, if they both have one. - """ - if not isinstance(other, TextReasoningContent): - raise TypeError("Incompatible type") - - # Concatenate text - if self.text is not None or other.text is not None: - self.text = (self.text or "") + (other.text or "") - # if both are None, should keep as None - - # Merge additional properties (self takes precedence) - if self.additional_properties is None: - self.additional_properties = {} - if other.additional_properties: - # Update from other first, then restore self's values to maintain precedence - self_props = self.additional_properties.copy() - self.additional_properties.update(other.additional_properties) - self.additional_properties.update(self_props) - - # Merge raw representations - if self.raw_representation is None: - self.raw_representation = other.raw_representation - elif other.raw_representation is not None: - self.raw_representation = ( - self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] - ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) - - # Replace protected data. - # Discussion: https://github.com/microsoft/agent-framework/pull/2950#discussion_r2634345613 - if other.protected_data is not None: - self.protected_data = other.protected_data - - # Merge annotations - if other.annotations: - if self.annotations is None: - self.annotations = [] - self.annotations.extend(other.annotations) - - return self - - -TDataContent = TypeVar("TDataContent", bound="DataContent") - - -class DataContent(BaseContent): - """Represents binary data content with an associated media type (also known as a MIME type). - - Important: - This is for binary data that is represented as a data URI, not for online resources. - Use ``UriContent`` for online resources. - - Attributes: - uri: The URI of the data represented by this instance, typically in the form of a data URI. - Should be in the form: "data:{media_type};base64,{base64_data}". - media_type: The media type of the data. - type: The type of content, which is always "data" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import DataContent - - # Create from binary data - image_data = b"raw image bytes" - data_content = DataContent(data=image_data, media_type="image/png") - - # Create from base64-encoded string - base64_string = "iVBORw0KGgoAAAANS..." - data_content = DataContent(data=base64_string, media_type="image/png") - - # Create from data URI - data_uri = "data:image/png;base64,iVBORw0KGgoAAAANS..." - data_content = DataContent(uri=data_uri) - - # Check media type - if data_content.has_top_level_media_type("image"): - print("This is an image") - """ - - @overload - def __init__( - self, - *, - uri: str, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a DataContent instance with a URI. - - Important: - This is for binary data that is represented as a data URI, not for online resources. - Use ``UriContent`` for online resources. - - Keyword Args: - uri: The URI of the data represented by this instance. - Should be in the form: "data:{media_type};base64,{base64_data}". - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - - @overload - def __init__( - self, - *, - data: bytes, - media_type: str, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a DataContent instance with binary data. - - Important: - This is for binary data that is represented as a data URI, not for online resources. - Use ``UriContent`` for online resources. - - Keyword Args: - data: The binary data represented by this instance. - The data is transformed into a base64-encoded data URI. - media_type: The media type of the data. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - - @overload - def __init__( - self, - *, - data: str, - media_type: str, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a DataContent instance with base64-encoded string data. - - Important: - This is for binary data that is represented as a data URI, not for online resources. - Use ``UriContent`` for online resources. - - Keyword Args: - data: The base64-encoded string data represented by this instance. - The data is used directly to construct a data URI. - media_type: The media type of the data. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - - def __init__( - self, - *, - uri: str | None = None, - data: bytes | str | None = None, - media_type: str | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a DataContent instance. - - Important: - This is for binary data that is represented as a data URI, not for online resources. - Use ``UriContent`` for online resources. - - Keyword Args: - uri: The URI of the data represented by this instance. - Should be in the form: "data:{media_type};base64,{base64_data}". - data: The binary data or base64-encoded string represented by this instance. - If bytes, the data is transformed into a base64-encoded data URI. - If str, it is assumed to be already base64-encoded and used directly. - media_type: The media type of the data. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - if uri is None: - if data is None or media_type is None: - raise ValueError("Either 'data' and 'media_type' or 'uri' must be provided.") - - base64_data: str = base64.b64encode(data).decode("utf-8") if isinstance(data, bytes) else data - uri = f"data:{media_type};base64,{base64_data}" - - # Validate URI format and extract media type if not provided - validated_uri = self._validate_uri(uri) - if media_type is None: - match = URI_PATTERN.match(validated_uri) - if match: - media_type = match.group("media_type") - - super().__init__( - annotations=annotations, - additional_properties=additional_properties, - raw_representation=raw_representation, - **kwargs, - ) - self.uri = validated_uri - self.media_type = media_type - self.type: Literal["data"] = "data" - - @classmethod - def _validate_uri(cls, uri: str) -> str: - """Validates the URI format and extracts the media type. - - Minimal data URI parser based on RFC 2397: https://datatracker.ietf.org/doc/html/rfc2397. - """ - match = URI_PATTERN.match(uri) - if not match: - raise ValueError(f"Invalid data URI format: {uri}") - media_type = match.group("media_type") - if media_type not in KNOWN_MEDIA_TYPES: - raise ValueError(f"Unknown media type: {media_type}") - return uri - - def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: - return _has_top_level_media_type(self.media_type, top_level_media_type) - - @staticmethod - def detect_image_format_from_base64(image_base64: str) -> str: - """Detect image format from base64 data by examining the binary header. - - Args: - image_base64: Base64 encoded image data - - Returns: - Image format as string (png, jpeg, webp, gif) with png as fallback - """ - try: - # Constants for image format detection - # ~75 bytes of binary data should be enough to detect most image formats - FORMAT_DETECTION_BASE64_CHARS = 100 - - # Decode a small portion to detect format - decoded_data = base64.b64decode(image_base64[:FORMAT_DETECTION_BASE64_CHARS]) - if decoded_data.startswith(b"\x89PNG"): - return "png" - if decoded_data.startswith(b"\xff\xd8\xff"): - return "jpeg" - if decoded_data.startswith(b"RIFF") and b"WEBP" in decoded_data[:12]: - return "webp" - if decoded_data.startswith(b"GIF87a") or decoded_data.startswith(b"GIF89a"): - return "gif" - return "png" # Default fallback - except Exception: - return "png" # Fallback if decoding fails - - @staticmethod - def create_data_uri_from_base64(image_base64: str) -> tuple[str, str]: - """Create a data URI and media type from base64 image data. - - Args: - image_base64: Base64 encoded image data - - Returns: - Tuple of (data_uri, media_type) - """ - format_type = DataContent.detect_image_format_from_base64(image_base64) - uri = f"data:image/{format_type};base64,{image_base64}" - media_type = f"image/{format_type}" - return uri, media_type - - def get_data_bytes_as_str(self) -> str: - """Extracts and returns the base64-encoded data from the data URI. - - Returns: - The binary data as str. - """ - match = URI_PATTERN.match(self.uri) - if not match: - raise ValueError(f"Invalid data URI format: {self.uri}") - return match.group("base64_data") - - def get_data_bytes(self) -> bytes: - """Extracts and returns the binary data from the data URI. - - Returns: - The binary data as bytes. - """ - base64_data = self.get_data_bytes_as_str() - return base64.b64decode(base64_data) - - -class UriContent(BaseContent): - """Represents a URI content. - - Important: - This is used for content that is identified by a URI, such as an image or a file. - For (binary) data URIs, use ``DataContent`` instead. - - Attributes: - uri: The URI of the content, e.g., 'https://example.com/image.png'. - media_type: The media type of the content, e.g., 'image/png', 'application/json', etc. - type: The type of content, which is always "uri" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import UriContent - - # Create URI content for an image - image_uri = UriContent( - uri="https://example.com/image.png", - media_type="image/png", - ) - - # Create URI content for a document - doc_uri = UriContent( - uri="https://example.com/document.pdf", - media_type="application/pdf", - ) - - # Check if it's an image - if image_uri.has_top_level_media_type("image"): - print("This is an image URI") - """ - - def __init__( - self, - uri: str, - media_type: str, - *, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a UriContent instance. - - Remarks: - This is used for content that is identified by a URI, such as an image or a file. - For (binary) data URIs, use `DataContent` instead. - - Args: - uri: The URI of the content. - media_type: The media type of the content. - - Keyword Args: - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( - annotations=annotations, - additional_properties=additional_properties, - raw_representation=raw_representation, - **kwargs, - ) - self.uri = uri - self.media_type = media_type - self.type: Literal["uri"] = "uri" - - def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: - """Returns a boolean indicating if the media type has the specified top-level media type. - - Args: - top_level_media_type: The top-level media type to check for, allowed values: - "image", "text", "application", "audio". - - """ - return _has_top_level_media_type(self.media_type, top_level_media_type) - - -def _has_top_level_media_type( - media_type: str | None, top_level_media_type: Literal["application", "audio", "image", "text"] -) -> bool: - if media_type is None: - return False - - slash_index = media_type.find("/") - span = media_type[:slash_index] if slash_index >= 0 else media_type - span = span.strip() - return span.lower() == top_level_media_type.lower() - - -class ErrorContent(BaseContent): - """Represents an error. - - Remarks: - Typically used for non-fatal errors, where something went wrong as part of the operation, - but the operation was still able to continue. - - Attributes: - error_code: The error code associated with the error. - details: Additional details about the error. - message: The error message. - type: The type of content, which is always "error" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import ErrorContent - - # Create an error content - error = ErrorContent( - message="Failed to process request", - error_code="PROCESSING_ERROR", - details="The input format was invalid", - ) - print(str(error)) # "Error PROCESSING_ERROR: Failed to process request" - - # Error without code - simple_error = ErrorContent(message="Something went wrong") - print(str(simple_error)) # "Something went wrong" - """ + return cls( + **_validate_uri(uri, media_type), + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) - def __init__( - self, + @classmethod + def from_error( + cls: type[TContent], *, message: str | None = None, error_code: str | None = None, - details: str | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes an ErrorContent instance. - - Keyword Args: - message: The error message. - error_code: The error code associated with the error. - details: Additional details about the error. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( + error_details: str | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create error content.""" + return cls( + "error", + message=message, + error_code=error_code, + error_details=error_details, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.message = message - self.error_code = error_code - self.details = details - self.type: Literal["error"] = "error" - - def __str__(self) -> str: - """Returns a string representation of the error.""" - return f"Error {self.error_code}: {self.message}" if self.error_code else self.message or "Unknown error" - - -class FunctionCallContent(BaseContent): - """Represents a function call request. - - Attributes: - call_id: The function call identifier. - name: The name of the function requested. - arguments: The arguments requested to be provided to the function. - exception: Any exception that occurred while mapping the original function call data to this representation. - type: The type of content, which is always "function_call" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import FunctionCallContent - - # Create a function call - func_call = FunctionCallContent( - call_id="call_123", - name="get_weather", - arguments={"location": "Seattle", "unit": "celsius"}, - ) - - # Parse arguments - args = func_call.parse_arguments() - print(args["location"]) # "Seattle" - - # Create with string arguments (gradual completion) - func_call_partial_1 = FunctionCallContent( - call_id="call_124", - name="search", - arguments='{"query": ', - ) - func_call_partial_2 = FunctionCallContent( - call_id="call_124", - name="search", - arguments='"latest news"}', - ) - full_call = func_call_partial_1 + func_call_partial_2 - args = full_call.parse_arguments() - print(args["query"]) # "latest news" - """ - def __init__( - self, - *, + @classmethod + def from_function_call( + cls: type[TContent], call_id: str, name: str, - arguments: str | dict[str, Any | None] | None = None, - exception: Exception | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a FunctionCallContent instance. - - Keyword Args: - call_id: The function call identifier. - name: The name of the function requested. - arguments: The arguments requested to be provided to the function, - can be a string to allow gradual completion of the args. - exception: Any exception that occurred while mapping the original function call data to this representation. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( + *, + arguments: str | Mapping[str, Any] | None = None, + exception: str | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create function call content.""" + return cls( + "function_call", + call_id=call_id, + name=name, + arguments=arguments, + exception=exception, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, - ) - self.call_id = call_id - self.name = name - self.arguments = arguments - self.exception = exception - self.type: Literal["function_call"] = "function_call" - - def parse_arguments(self) -> dict[str, Any | None] | None: - """Parse the arguments into a dictionary. - - If they cannot be parsed as json or if the resulting json is not a dict, - they are returned as a dictionary with a single key "raw". - """ - if isinstance(self.arguments, str): - # If arguments are a string, try to parse it as JSON - try: - loaded = json.loads(self.arguments) - if isinstance(loaded, dict): - return loaded # type:ignore - return {"raw": loaded} - except (json.JSONDecodeError, TypeError): - return {"raw": self.arguments} - return self.arguments - - def __add__(self, other: "FunctionCallContent") -> "FunctionCallContent": - if not isinstance(other, FunctionCallContent): - raise TypeError("Incompatible type") - if other.call_id and self.call_id != other.call_id: - raise AdditionItemMismatch("", log_level=None) - if not self.arguments: - arguments = other.arguments - elif not other.arguments: - arguments = self.arguments - elif isinstance(self.arguments, str) and isinstance(other.arguments, str): - arguments = self.arguments + other.arguments - elif isinstance(self.arguments, dict) and isinstance(other.arguments, dict): - arguments = {**self.arguments, **other.arguments} - else: - raise TypeError("Incompatible argument types") - return FunctionCallContent( - call_id=self.call_id, - name=self.name, - arguments=arguments, - exception=self.exception or other.exception, - additional_properties={**(self.additional_properties or {}), **(other.additional_properties or {})}, - raw_representation=self.raw_representation or other.raw_representation, ) - -class FunctionResultContent(BaseContent): - """Represents the result of a function call. - - Attributes: - call_id: The identifier of the function call for which this is the result. - result: The result of the function call, or a generic error message if the function call failed. - exception: An exception that occurred if the function call failed. - type: The type of content, which is always "function_result" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import FunctionResultContent - - # Create a successful function result - result = FunctionResultContent( - call_id="call_123", - result={"temperature": 22, "condition": "sunny"}, - ) - - # Create a failed function result - failed_result = FunctionResultContent( - call_id="call_124", - result="Function execution failed", - exception=ValueError("Invalid location"), - ) - """ - - def __init__( - self, - *, + @classmethod + def from_function_result( + cls: type[TContent], call_id: str, - result: Any | None = None, - exception: Exception | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a FunctionResultContent instance. - - Keyword Args: - call_id: The identifier of the function call for which this is the result. - result: The result of the function call, or a generic error message if the function call failed. - exception: An exception that occurred if the function call failed. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( + *, + result: Any = None, + exception: str | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create function result content.""" + return cls( + "function_result", + call_id=call_id, + result=result, + exception=exception, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.call_id = call_id - self.result = result - self.exception = exception - self.type: Literal["function_result"] = "function_result" - - -class UsageContent(BaseContent): - """Represents usage information associated with a chat request and response. - Attributes: - details: The usage information, including input and output token counts, and any additional counts. - type: The type of content, which is always "usage" for this class. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import UsageContent, UsageDetails - - # Create usage content - usage = UsageContent( - details=UsageDetails( - input_token_count=100, - output_token_count=50, - total_token_count=150, - ), - ) - print(usage.details.total_token_count) # 150 - """ - - def __init__( - self, - details: UsageDetails | MutableMapping[str, Any], + @classmethod + def from_usage( + cls: type[TContent], + usage_details: UsageDetails, *, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a UsageContent instance.""" - super().__init__( + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create usage content.""" + return cls( + "usage", + usage_details=usage_details, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - # Convert dict to UsageDetails if needed - if isinstance(details, MutableMapping): - details = UsageDetails.from_dict(details) - self.details = details - self.type: Literal["usage"] = "usage" - - -class HostedFileContent(BaseContent): - """Represents a hosted file content. - - Attributes: - file_id: The identifier of the hosted file. - type: The type of content, which is always "hosted_file" for this class. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - from agent_framework import HostedFileContent - - # Create hosted file content - file_content = HostedFileContent(file_id="file-abc123") - print(file_content.file_id) # "file-abc123" - """ - - def __init__( - self, + @classmethod + def from_hosted_file( + cls: type[TContent], file_id: str, *, media_type: str | None = None, name: str | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a HostedFileContent instance. - - Args: - file_id: The identifier of the hosted file. - media_type: Optional media type of the hosted file. - name: Optional display name of the hosted file. - - Keyword Args: - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create hosted file content.""" + return cls( + "hosted_file", + file_id=file_id, + media_type=media_type, + name=name, + annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.file_id = file_id - self.media_type = media_type - self.name = name - self.type: Literal["hosted_file"] = "hosted_file" - - def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: - """Returns a boolean indicating if the media type has the specified top-level media type.""" - return _has_top_level_media_type(self.media_type, top_level_media_type) - -class HostedVectorStoreContent(BaseContent): - """Represents a hosted vector store content. - - Attributes: - vector_store_id: The identifier of the hosted vector store. - type: The type of content, which is always "hosted_vector_store" for this class. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - - Examples: - .. code-block:: python - - from agent_framework import HostedVectorStoreContent - - # Create hosted vector store content - vs_content = HostedVectorStoreContent(vector_store_id="vs-xyz789") - print(vs_content.vector_store_id) # "vs-xyz789" - """ - - def __init__( - self, + @classmethod + def from_hosted_vector_store( + cls: type[TContent], vector_store_id: str, *, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a HostedVectorStoreContent instance. - - Args: - vector_store_id: The identifier of the hosted vector store. + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create hosted vector store content.""" + return cls( + "hosted_vector_store", + vector_store_id=vector_store_id, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) - Keyword Args: - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - super().__init__( + @classmethod + def from_code_interpreter_tool_call( + cls: type[TContent], + *, + call_id: str | None = None, + inputs: Sequence["Content"] | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create code interpreter tool call content.""" + return cls( + "code_interpreter_tool_call", + call_id=call_id, + inputs=list(inputs) if inputs is not None else None, + annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.vector_store_id = vector_store_id - self.type: Literal["hosted_vector_store"] = "hosted_vector_store" + @classmethod + def from_code_interpreter_tool_result( + cls: type[TContent], + *, + call_id: str | None = None, + outputs: Sequence["Content"] | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create code interpreter tool result content.""" + return cls( + "code_interpreter_tool_result", + call_id=call_id, + outputs=list(outputs) if outputs is not None else None, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) -class CodeInterpreterToolCallContent(BaseContent): - """Represents a code interpreter tool call invocation by a hosted service.""" + @classmethod + def from_image_generation_tool_call( + cls: type[TContent], + *, + image_id: str | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create image generation tool call content.""" + return cls( + "image_generation_tool_call", + image_id=image_id, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) - def __init__( - self, + @classmethod + def from_image_generation_tool_result( + cls: type[TContent], *, - call_id: str | None = None, - inputs: Sequence["Contents | MutableMapping[str, Any]"] | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - super().__init__( + image_id: str | None = None, + outputs: Any = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create image generation tool result content.""" + return cls( + "image_generation_tool_result", + image_id=image_id, + outputs=outputs, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.call_id = call_id - self.inputs: list["Contents"] | None = None - if inputs: - normalized_inputs: Sequence["Contents | MutableMapping[str, Any]"] = ( - inputs - if isinstance(inputs, Sequence) and not isinstance(inputs, (str, bytes, MutableMapping)) - else [inputs] - ) - self.inputs = _parse_content_list(list(normalized_inputs)) - self.type: Literal["code_interpreter_tool_call"] = "code_interpreter_tool_call" + @classmethod + def from_mcp_server_tool_call( + cls: type[TContent], + call_id: str, + tool_name: str, + *, + server_name: str | None = None, + arguments: str | Mapping[str, Any] | None = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create MCP server tool call content.""" + return cls( + "mcp_server_tool_call", + call_id=call_id, + tool_name=tool_name, + server_name=server_name, + arguments=arguments, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) -class CodeInterpreterToolResultContent(BaseContent): - """Represents the result of a code interpreter tool invocation by a hosted service.""" + @classmethod + def from_mcp_server_tool_result( + cls: type[TContent], + call_id: str, + *, + output: Any = None, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create MCP server tool result content.""" + return cls( + "mcp_server_tool_result", + call_id=call_id, + output=output, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) - def __init__( - self, + @classmethod + def from_function_approval_request( + cls: type[TContent], + id: str, + function_call: "Content", *, - call_id: str | None = None, - outputs: Sequence["Contents | MutableMapping[str, Any]"] | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - super().__init__( + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create function approval request content.""" + return cls( + "function_approval_request", + id=id, + function_call=function_call, + user_input_request=True, + annotations=annotations, + additional_properties=additional_properties, + raw_representation=raw_representation, + ) + + @classmethod + def from_function_approval_response( + cls: type[TContent], + approved: bool, + id: str, + function_call: "Content", + *, + annotations: Sequence[Annotation] | None = None, + additional_properties: MutableMapping[str, Any] | None = None, + raw_representation: Any = None, + ) -> TContent: + """Create function approval response content.""" + return cls( + "function_approval_response", + approved=approved, + id=id, + function_call=function_call, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, ) - self.call_id = call_id - self.outputs: list["Contents"] | None = None - if outputs: - normalized_outputs: Sequence["Contents | MutableMapping[str, Any]"] = ( - outputs - if isinstance(outputs, Sequence) and not isinstance(outputs, (str, bytes, MutableMapping)) - else [outputs] + + def to_function_approval_response( + self, + approved: bool, + ) -> "Content": + """Convert a function approval request content to a function approval response content.""" + if self.type != "function_approval_request": + raise ContentError( + "Can only convert 'function_approval_request' content to 'function_approval_response' content." ) - self.outputs = _parse_content_list(list(normalized_outputs)) - self.type: Literal["code_interpreter_tool_result"] = "code_interpreter_tool_result" + return Content.from_function_approval_response( + approved=approved, + id=self.id, # type: ignore[attr-defined] + function_call=self.function_call, # type: ignore[attr-defined] + annotations=self.annotations, + additional_properties=self.additional_properties, + raw_representation=self.raw_representation, + ) + def to_dict(self, *, exclude_none: bool = True, exclude: set[str] | None = None) -> dict[str, Any]: + """Serialize the content to a dictionary.""" + fields_to_capture = ( + "text", + "protected_data", + "uri", + "media_type", + "message", + "error_code", + "error_details", + "usage_details", + "call_id", + "name", + "arguments", + "exception", + "result", + "file_id", + "vector_store_id", + "inputs", + "outputs", + "image_id", + "tool_name", + "server_name", + "output", + "function_call", + "user_input_request", + "approved", + "id", + "additional_properties", + ) -class ImageGenerationToolCallContent(BaseContent): - """Represents the invocation of an image generation tool call by a hosted service.""" + exclude = exclude or set() + result: dict[str, Any] = {"type": self.type} - def __init__( - self, - *, - image_id: str | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes an ImageGenerationToolCallContent instance. + for field in fields_to_capture: + value = getattr(self, field, None) + if field in exclude: + continue + if exclude_none and value is None: + continue + result[field] = _serialize_value(value, exclude_none) - Keyword Args: - image_id: The identifier of the image to be generated. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. + if "annotations" not in exclude and self.annotations is not None: + result["annotations"] = [dict(annotation) for annotation in self.annotations] - """ - super().__init__( + return result + + def __eq__(self, other: object) -> bool: + """Check if two Content instances are equal by comparing their dict representations.""" + if not isinstance(other, Content): + return False + return self.to_dict(exclude_none=False) == other.to_dict(exclude_none=False) + + def __str__(self) -> str: + """Return a string representation of the Content.""" + if self.type == "error": + if self.error_code: + return f"Error {self.error_code}: {self.message or ''}" + return self.message or "Unknown error" + if self.type == "text": + return self.text or "" + return f"Content(type={self.type})" + + @classmethod + def from_dict(cls: type[TContent], data: Mapping[str, Any]) -> TContent: + """Create a Content instance from a mapping.""" + if not (content_type := data.get("type")): + raise ValueError("Content mapping requires 'type'") + remaining = dict(data) + remaining.pop("type", None) + annotations = remaining.pop("annotations", None) + additional_properties = remaining.pop("additional_properties", None) + raw_representation = remaining.pop("raw_representation", None) + + # Special handling for DataContent with data and media_type + if content_type == "data" and "data" in remaining and "media_type" in remaining: + # Use from_data() to properly create the DataContent with URI + return cls.from_data(remaining["data"], remaining["media_type"]) + + # Handle nested Content objects (e.g., function_call in function_approval_request) + if "function_call" in remaining and isinstance(remaining["function_call"], dict): + remaining["function_call"] = cls.from_dict(remaining["function_call"]) + + # Handle list of Content objects (e.g., inputs in code_interpreter_tool_call) + if "inputs" in remaining and isinstance(remaining["inputs"], list): + remaining["inputs"] = [ + cls.from_dict(item) if isinstance(item, dict) else item for item in remaining["inputs"] + ] + + if "outputs" in remaining and isinstance(remaining["outputs"], list): + remaining["outputs"] = [ + cls.from_dict(item) if isinstance(item, dict) else item for item in remaining["outputs"] + ] + + return cls( + type=content_type, annotations=annotations, additional_properties=additional_properties, raw_representation=raw_representation, - **kwargs, + **remaining, ) - self.image_id = image_id - self.type: Literal["image_generation_tool_call"] = "image_generation_tool_call" + def __add__(self, other: "Content") -> "Content": + """Concatenate or merge two Content instances.""" + if not isinstance(other, Content): + raise TypeError(f"Incompatible type: Cannot add Content with {type(other).__name__}") + + if self.type != other.type: + raise TypeError(f"Cannot add Content of type '{self.type}' with type '{other.type}'") + + if self.type == "text": + return self._add_text_content(other) + if self.type == "text_reasoning": + return self._add_text_reasoning_content(other) + if self.type == "function_call": + return self._add_function_call_content(other) + if self.type == "usage": + return self._add_usage_content(other) + raise ContentError(f"Addition not supported for content type: {self.type}") + + def _add_text_content(self, other: "Content") -> "Content": + """Add two TextContent instances.""" + # Merge raw representations + if self.raw_representation is None: + raw_representation = other.raw_representation + elif other.raw_representation is None: + raw_representation = self.raw_representation + else: + raw_representation = ( + self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] + ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) -class ImageGenerationToolResultContent(BaseContent): - """Represents the result of an image generation tool call invocation by a hosted service.""" - - def __init__( - self, - *, - image_id: str | None = None, - outputs: DataContent | UriContent | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes an ImageGenerationToolResultContent instance. - - Keyword Args: - image_id: The identifier of the generated image. - outputs: The outputs of the image generation tool call. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. + # Merge annotations + if self.annotations is None: + annotations = other.annotations + elif other.annotations is None: + annotations = self.annotations + else: + annotations = self.annotations + other.annotations - """ - super().__init__( + return Content( + "text", + text=self.text + other.text, # type: ignore[attr-defined] annotations=annotations, - additional_properties=additional_properties, + additional_properties={ + **(other.additional_properties or {}), + **(self.additional_properties or {}), + }, raw_representation=raw_representation, - **kwargs, ) - self.image_id = image_id - self.outputs: DataContent | UriContent | None = outputs - self.type: Literal["image_generation_tool_result"] = "image_generation_tool_result" + def _add_text_reasoning_content(self, other: "Content") -> "Content": + """Add two TextReasoningContent instances.""" + # Merge raw representations + if self.raw_representation is None: + raw_representation = other.raw_representation + elif other.raw_representation is None: + raw_representation = self.raw_representation + else: + raw_representation = ( + self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] + ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) -class MCPServerToolCallContent(BaseContent): - """Represents a tool call request to a MCP server.""" + # Merge annotations + if self.annotations is None: + annotations = other.annotations + elif other.annotations is None: + annotations = self.annotations + else: + annotations = self.annotations + other.annotations - def __init__( - self, - call_id: str, - tool_name: str, - server_name: str | None = None, - *, - arguments: str | Mapping[str, Any] | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a MCPServerToolCallContent instance. + # Concatenate text, handling None values + self_text = self.text or "" # type: ignore[attr-defined] + other_text = other.text or "" # type: ignore[attr-defined] + combined_text = self_text + other_text if (self_text or other_text) else None - Args: - call_id: The tool call identifier. - tool_name: The name of the tool requested. - server_name: The name of the MCP server where the tool is hosted. + # Handle protected_data replacement + protected_data = other.protected_data if other.protected_data is not None else self.protected_data # type: ignore[attr-defined] - Keyword Args: - arguments: The arguments requested to be provided to the tool, - can be a string to allow gradual completion of the args. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - if not call_id: - raise ValueError("call_id must be a non-empty string.") - if not tool_name: - raise ValueError("tool_name must be a non-empty string.") - super().__init__( + return Content( + "text_reasoning", + text=combined_text, + protected_data=protected_data, annotations=annotations, - additional_properties=additional_properties, + additional_properties={ + **(other.additional_properties or {}), + **(self.additional_properties or {}), + }, raw_representation=raw_representation, - **kwargs, ) - self.call_id = call_id - self.tool_name = tool_name - self.name = tool_name - self.server_name = server_name - self.arguments = arguments - self.type: Literal["mcp_server_tool_call"] = "mcp_server_tool_call" - - def parse_arguments(self) -> dict[str, Any] | None: - """Returns the parsed arguments for the MCP server tool call, if any.""" - if isinstance(self.arguments, str): - # If arguments are a string, try to parse it as JSON - try: - loaded = json.loads(self.arguments) - if isinstance(loaded, dict): - return loaded # type:ignore - return {"raw": loaded} - except (json.JSONDecodeError, TypeError): - return {"raw": self.arguments} - return cast(dict[str, Any] | None, self.arguments) + def _add_function_call_content(self, other: "Content") -> "Content": + """Add two FunctionCallContent instances.""" + other_call_id = getattr(other, "call_id", None) + self_call_id = getattr(self, "call_id", None) + if other_call_id and self_call_id != other_call_id: + raise ContentError("Cannot add function calls with different call_ids") + + self_arguments = getattr(self, "arguments", None) + other_arguments = getattr(other, "arguments", None) + + if not self_arguments: + arguments: str | Mapping[str, Any] | None = other_arguments + elif not other_arguments: + arguments = self_arguments + elif isinstance(self_arguments, str) and isinstance(other_arguments, str): + arguments = self_arguments + other_arguments + elif isinstance(self_arguments, dict) and isinstance(other_arguments, dict): + arguments = {**self_arguments, **other_arguments} + else: + raise TypeError("Incompatible argument types") -class MCPServerToolResultContent(BaseContent): - """Represents the result of a MCP server tool call.""" - - def __init__( - self, - call_id: str, - *, - output: Any | None = None, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a MCPServerToolResultContent instance. - - Args: - call_id: The identifier of the tool call for which this is the result. + # Merge raw representations + if self.raw_representation is None: + raw_representation: Any = other.raw_representation + elif other.raw_representation is None: + raw_representation = self.raw_representation + else: + raw_representation = ( + self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] + ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) - Keyword Args: - output: The output of the MCP server tool call. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - if not call_id: - raise ValueError("call_id must be a non-empty string.") - super().__init__( - annotations=annotations, - additional_properties=additional_properties, + return Content( + "function_call", + call_id=self_call_id, + name=getattr(self, "name", getattr(other, "name", None)), + arguments=arguments, + exception=getattr(self, "exception", None) or getattr(other, "exception", None), + additional_properties={ + **(self.additional_properties or {}), + **(other.additional_properties or {}), + }, raw_representation=raw_representation, - **kwargs, ) - self.call_id = call_id - self.output: Any | None = output - self.type: Literal["mcp_server_tool_result"] = "mcp_server_tool_result" + def _add_usage_content(self, other: "Content") -> "Content": + """Add two UsageContent instances by combining their usage details.""" + self_details = getattr(self, "usage_details", {}) + other_details = getattr(other, "usage_details", {}) + + # Combine token counts + combined_details: dict[str, Any] = {} + for key in set(list(self_details.keys()) + list(other_details.keys())): + self_val = self_details.get(key) + other_val = other_details.get(key) + if isinstance(self_val, int) and isinstance(other_val, int): + combined_details[key] = self_val + other_val + elif self_val is not None: + combined_details[key] = self_val + elif other_val is not None: + combined_details[key] = other_val -class BaseUserInputRequest(BaseContent): - """Base class for all user requests.""" - - def __init__( - self, - *, - id: str, - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initialize BaseUserInputRequest. + # Merge raw representations + if self.raw_representation is None: + raw_representation = other.raw_representation + elif other.raw_representation is None: + raw_representation = self.raw_representation + else: + raw_representation = ( + self.raw_representation if isinstance(self.raw_representation, list) else [self.raw_representation] + ) + (other.raw_representation if isinstance(other.raw_representation, list) else [other.raw_representation]) - Keyword Args: - id: The unique identifier for the request. - annotations: Optional annotations associated with the content. - additional_properties: Optional additional properties associated with the content. - raw_representation: Optional raw representation of the content. - **kwargs: Any additional keyword arguments. - """ - if not id or len(id) < 1: - raise ValueError("id must be at least 1 character long") - super().__init__( - annotations=annotations, - additional_properties=additional_properties, + return Content( + "usage", + usage_details=combined_details, + additional_properties={ + **(self.additional_properties or {}), + **(other.additional_properties or {}), + }, raw_representation=raw_representation, - **kwargs, ) - self.id = id - self.type: Literal["user_input_request"] = "user_input_request" + def has_top_level_media_type(self, top_level_media_type: Literal["application", "audio", "image", "text"]) -> bool: + """Check if content has a specific top-level media type. -class FunctionApprovalResponseContent(BaseContent): - """Represents a response for user approval of a function call. + Works with data, uri, and hosted_file content types. - Examples: - .. code-block:: python + Args: + top_level_media_type: The top-level media type to check for. - from agent_framework import FunctionApprovalResponseContent, FunctionCallContent + Returns: + True if the content's media type matches the specified top-level type. - # Create a function approval response - func_call = FunctionCallContent( - call_id="call_123", - name="send_email", - arguments={"to": "user@example.com"}, - ) - response = FunctionApprovalResponseContent( - approved=False, - id="approval_001", - function_call=func_call, - ) - print(response.approved) # False - """ + Raises: + ContentError: If the content type doesn't support media types. - def __init__( - self, - approved: bool, - *, - id: str, - function_call: FunctionCallContent | MCPServerToolCallContent | MutableMapping[str, Any], - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a FunctionApprovalResponseContent instance. + Examples: + .. code-block:: python - Args: - approved: Whether the function call was approved. + from agent_framework import Content - Keyword Args: - id: The unique identifier for the request. - function_call: The function call content to be approved. Can be a FunctionCallContent object or dict. - annotations: Optional list of annotations for the request. - additional_properties: Optional additional properties for the request. - raw_representation: Optional raw representation of the request. - **kwargs: Additional keyword arguments. + image = Content.from_uri(uri="data:image/png;base64,abc123", media_type="image/png") + print(image.has_top_level_media_type("image")) # True + print(image.has_top_level_media_type("audio")) # False """ - super().__init__( - annotations=annotations, - additional_properties=additional_properties, - raw_representation=raw_representation, - **kwargs, - ) - self.id = id - self.approved = approved - # Convert dict to FunctionCallContent if needed (for SerializationMixin support) - self.function_call: FunctionCallContent | MCPServerToolCallContent - if isinstance(function_call, MutableMapping): - if function_call.get("type") == "mcp_server_tool_call": - self.function_call = MCPServerToolCallContent.from_dict(function_call) - else: - self.function_call = FunctionCallContent.from_dict(function_call) - else: - self.function_call = function_call - # Override the type for this specific subclass - self.type: Literal["function_approval_response"] = "function_approval_response" + if self.media_type is None: + raise ContentError("no media_type found") + slash_index = self.media_type.find("/") + span = self.media_type[:slash_index] if slash_index >= 0 else self.media_type + span = span.strip() + return span.lower() == top_level_media_type.lower() -class FunctionApprovalRequestContent(BaseContent): - """Represents a request for user approval of a function call. + def parse_arguments(self) -> dict[str, Any | None] | None: + """Parse arguments from function_call or mcp_server_tool_call content. - Examples: - .. code-block:: python + If arguments cannot be parsed as JSON or the result is not a dict, + they are returned as a dictionary with a single key "raw". - from agent_framework import FunctionApprovalRequestContent, FunctionCallContent + Returns: + Parsed arguments as a dictionary, or None if no arguments. - # Create a function approval request - func_call = FunctionCallContent( - call_id="call_123", - name="send_email", - arguments={"to": "user@example.com", "subject": "Hello"}, - ) - approval_request = FunctionApprovalRequestContent( - id="approval_001", - function_call=func_call, - ) + Raises: + ContentError: If the content type doesn't support arguments. - # Create response - approval_response = approval_request.create_response(approved=True) - print(approval_response.approved) # True - """ + Examples: + .. code-block:: python - def __init__( - self, - *, - id: str, - function_call: FunctionCallContent | MutableMapping[str, Any], - annotations: Sequence[Annotations | MutableMapping[str, Any]] | None = None, - additional_properties: dict[str, Any] | None = None, - raw_representation: Any | None = None, - **kwargs: Any, - ) -> None: - """Initializes a FunctionApprovalRequestContent instance. + from agent_framework import Content - Keyword Args: - id: The unique identifier for the request. - function_call: The function call content to be approved. Can be a FunctionCallContent object or dict. - annotations: Optional list of annotations for the request. - additional_properties: Optional additional properties for the request. - raw_representation: Optional raw representation of the request. - **kwargs: Additional keyword arguments. + func_call = Content.from_function_call( + call_id="call_123", + name="send_email", + arguments='{"to": "user@example.com"}', + ) + args = func_call.parse_arguments() + print(args) # {"to": "user@example.com"} """ - super().__init__( - annotations=annotations, - additional_properties=additional_properties, - raw_representation=raw_representation, - **kwargs, - ) - self.id = id - self.function_call: FunctionCallContent - # Convert dict to FunctionCallContent if needed (for SerializationMixin support) - if isinstance(function_call, MutableMapping): - self.function_call = FunctionCallContent.from_dict(function_call) - else: - self.function_call = function_call - # Override the type for this specific subclass - self.type: Literal["function_approval_request"] = "function_approval_request" - - def create_response(self, approved: bool) -> "FunctionApprovalResponseContent": - """Create a response for the function approval request.""" - return FunctionApprovalResponseContent( - approved, - id=self.id, - function_call=self.function_call, - additional_properties=self.additional_properties, - ) + if self.arguments is None: + return None + if not self.arguments: + return {} + + if isinstance(self.arguments, str): + # If arguments are a string, try to parse it as JSON + try: + loaded = json.loads(self.arguments) + if isinstance(loaded, dict): + return loaded # type: ignore[return-value] + return {"raw": loaded} + except (json.JSONDecodeError, TypeError): + return {"raw": self.arguments} + return self.arguments -UserInputRequestContents = FunctionApprovalRequestContent - -Contents = ( - TextContent - | DataContent - | TextReasoningContent - | UriContent - | FunctionCallContent - | FunctionResultContent - | ErrorContent - | UsageContent - | HostedFileContent - | HostedVectorStoreContent - | CodeInterpreterToolCallContent - | CodeInterpreterToolResultContent - | ImageGenerationToolCallContent - | ImageGenerationToolResultContent - | MCPServerToolCallContent - | MCPServerToolResultContent - | FunctionApprovalRequestContent - | FunctionApprovalResponseContent -) + +# endregion -def _prepare_function_call_results_as_dumpable(content: Contents | Any | list[Contents | Any]) -> Any: +def _prepare_function_call_results_as_dumpable(content: "Content | Any | list[Content | Any]") -> Any: if isinstance(content, list): # Particularly deal with lists of Content return [_prepare_function_call_results_as_dumpable(item) for item in content] @@ -2150,9 +1404,9 @@ def _prepare_function_call_results_as_dumpable(content: Contents | Any | list[Co return content -def prepare_function_call_results(content: Contents | Any | list[Contents | Any]) -> str: +def prepare_function_call_results(content: "Content | Any | list[Content | Any]") -> str: """Prepare the values of the function call results.""" - if isinstance(content, Contents): + if isinstance(content, Content): # For BaseContent objects, use to_dict and serialize to JSON # Use default=str to handle datetime and other non-JSON-serializable objects return json.dumps(content.to_dict(exclude={"raw_representation", "additional_properties"}), default=str) @@ -2332,7 +1586,7 @@ class ChatMessage(SerializationMixin): # Create a message with contents assistant_msg = ChatMessage( role="assistant", - contents=[TextContent(text="The weather is sunny!")], + contents=[Content.from_text(text="The weather is sunny!")], ) print(assistant_msg.text) # "The weather is sunny!" @@ -2385,7 +1639,7 @@ def __init__( self, role: Role | Literal["system", "user", "assistant", "tool"], *, - contents: Sequence[Contents | Mapping[str, Any]], + contents: "Sequence[Content | Mapping[str, Any]]", author_name: str | None = None, message_id: str | None = None, additional_properties: MutableMapping[str, Any] | None = None, @@ -2412,7 +1666,7 @@ def __init__( role: Role | Literal["system", "user", "assistant", "tool"] | dict[str, Any], *, text: str | None = None, - contents: Sequence[Contents | Mapping[str, Any]] | None = None, + contents: "Sequence[Content | Mapping[str, Any]] | None" = None, author_name: str | None = None, message_id: str | None = None, additional_properties: MutableMapping[str, Any] | None = None, @@ -2444,7 +1698,7 @@ def __init__( parsed_contents = [] if contents is None else _parse_content_list(contents) if text is not None: - parsed_contents.append(TextContent(text=text)) + parsed_contents.append(Content.from_text(text=text)) self.role = role self.contents = parsed_contents @@ -2459,9 +1713,9 @@ def text(self) -> str: """Returns the text content of the message. Remarks: - This property concatenates the text of all TextContent objects in Contents. + This property concatenates the text of all TextContent objects in Content. """ - return " ".join(content.text for content in self.contents if isinstance(content, TextContent)) + return " ".join(content.text for content in self.contents if content.type == "text") def prepare_messages( @@ -2592,7 +1846,7 @@ def _process_update( # Slow path: only check for dict if type is None if content_type is None and isinstance(content, (dict, MutableMapping)): try: - content = _parse_content(content) + content = Content.from_dict(content) content_type = content.type except ContentError as exc: logger.warning(f"Skipping unknown content type or invalid content: {exc}") @@ -2602,13 +1856,13 @@ def _process_update( case "function_call" if message.contents and message.contents[-1].type == "function_call": try: message.contents[-1] += content # type: ignore[operator] - except AdditionItemMismatch: + except (AdditionItemMismatch, ContentError): message.contents.append(content) case "usage": if response.usage_details is None: response.usage_details = UsageDetails() # mypy doesn't narrow type based on match/case, but we know this is UsageContent - response.usage_details += content.details # type: ignore[union-attr, arg-type] + response.usage_details = add_usage_details(response.usage_details, content.usage_details) # type: ignore[arg-type] case _: message.contents.append(content) # Incorporate the update's properties into the response. @@ -2634,16 +1888,14 @@ def _process_update( response.model_id = update.model_id -def _coalesce_text_content( - contents: list["Contents"], type_: type["TextContent"] | type["TextReasoningContent"] -) -> None: +def _coalesce_text_content(contents: list["Content"], type_str: Literal["text", "text_reasoning"]) -> None: """Take any subsequence Text or TextReasoningContent items and coalesce them into a single item.""" if not contents: return - coalesced_contents: list["Contents"] = [] + coalesced_contents: list["Content"] = [] first_new_content: Any | None = None for content in contents: - if isinstance(content, type_): + if content.type == type_str: if first_new_content is None: first_new_content = deepcopy(content) else: @@ -2666,8 +1918,8 @@ def _coalesce_text_content( def _finalize_response(response: "ChatResponse | AgentResponse") -> None: """Finalizes the response by performing any necessary post-processing.""" for msg in response.messages: - _coalesce_text_content(msg.contents, TextContent) - _coalesce_text_content(msg.contents, TextReasoningContent) + _coalesce_text_content(msg.contents, "text") + _coalesce_text_content(msg.contents, "text_reasoning") class ChatResponse(SerializationMixin): @@ -2761,7 +2013,7 @@ def __init__( def __init__( self, *, - text: TextContent | str, + text: Content | str, response_id: str | None = None, conversation_id: str | None = None, model_id: str | None = None, @@ -2796,7 +2048,7 @@ def __init__( self, *, messages: ChatMessage | MutableSequence[ChatMessage] | list[dict[str, Any]] | None = None, - text: TextContent | str | None = None, + text: Content | str | None = None, response_id: str | None = None, conversation_id: str | None = None, model_id: str | None = None, @@ -2843,16 +2095,15 @@ def __init__( if text is not None: if isinstance(text, str): - text = TextContent(text=text) + text = Content.from_text(text=text) messages.append(ChatMessage(role=Role.ASSISTANT, contents=[text])) # Handle finish_reason conversion if isinstance(finish_reason, dict): finish_reason = FinishReason.from_dict(finish_reason) - # Handle usage_details conversion - if isinstance(usage_details, dict): - usage_details = UsageDetails.from_dict(usage_details) + # Handle usage_details - UsageDetails is now a TypedDict, so dict is already the right type + # No conversion needed self.messages = list(messages) self.response_id = response_id @@ -3032,7 +2283,7 @@ class ChatResponseUpdate(SerializationMixin): # Create a response update update = ChatResponseUpdate( - contents=[TextContent(text="Hello")], + contents=[Content.from_text(text="Hello")], role="assistant", message_id="msg_123", ) @@ -3061,8 +2312,8 @@ class ChatResponseUpdate(SerializationMixin): def __init__( self, *, - contents: Sequence[Contents | dict[str, Any]] | None = None, - text: TextContent | str | None = None, + contents: Sequence[Content | dict[str, Any]] | None = None, + text: Content | str | None = None, role: Role | Literal["system", "user", "assistant", "tool"] | dict[str, Any] | None = None, author_name: str | None = None, response_id: str | None = None, @@ -3099,7 +2350,7 @@ def __init__( if text is not None: if isinstance(text, str): - text = TextContent(text=text) + text = Content.from_text(text=text) contents.append(text) # Handle role conversion @@ -3127,7 +2378,7 @@ def __init__( @property def text(self) -> str: """Returns the concatenated text of all contents in the update.""" - return "".join(content.text for content in self.contents if isinstance(content, TextContent)) + return "".join(content.text for content in self.contents if content.type == "text") def __str__(self) -> str: return self.text @@ -3223,8 +2474,7 @@ def __init__( processed_messages.append(ChatMessage.from_dict(messages)) # Convert usage_details from dict if needed (for SerializationMixin support) - if isinstance(usage_details, MutableMapping): - usage_details = UsageDetails.from_dict(usage_details) + # UsageDetails is now a TypedDict, so dict is already the right type self.messages = processed_messages self.response_id = response_id @@ -3264,13 +2514,13 @@ def value(self) -> Any | None: return self._value @property - def user_input_requests(self) -> list[UserInputRequestContents]: + def user_input_requests(self) -> list[Content]: """Get all BaseUserInputRequest messages from the response.""" return [ content for msg in self.messages for content in msg.contents - if isinstance(content, UserInputRequestContents) + if isinstance(content, Content) and content.user_input_request ] @classmethod @@ -3367,11 +2617,11 @@ class AgentResponseUpdate(SerializationMixin): Examples: .. code-block:: python - from agent_framework import AgentResponseUpdate, TextContent + from agent_framework import AgentResponseUpdate, Content # Create an agent run update update = AgentResponseUpdate( - contents=[TextContent(text="Processing...")], + contents=[Content.from_text(text="Processing...")], role="assistant", response_id="run_123", ) @@ -3399,8 +2649,8 @@ class AgentResponseUpdate(SerializationMixin): def __init__( self, *, - contents: Sequence[Contents | MutableMapping[str, Any]] | None = None, - text: TextContent | str | None = None, + contents: Sequence[Content | MutableMapping[str, Any]] | None = None, + text: Content | str | None = None, role: Role | MutableMapping[str, Any] | str | None = None, author_name: str | None = None, response_id: str | None = None, @@ -3425,11 +2675,11 @@ def __init__( kwargs: will be combined with additional_properties if provided. """ - parsed_contents: list[Contents] = [] if contents is None else _parse_content_list(contents) + parsed_contents: list[Content] = [] if contents is None else _parse_content_list(contents) if text is not None: if isinstance(text, str): - text = TextContent(text=text) + text = Content.from_text(text=text) parsed_contents.append(text) # Convert role from dict if needed (for SerializationMixin support) @@ -3450,16 +2700,12 @@ def __init__( @property def text(self) -> str: """Get the concatenated text of all TextContent objects in contents.""" - return ( - "".join(content.text for content in self.contents if isinstance(content, TextContent)) - if self.contents - else "" - ) + return "".join(content.text for content in self.contents if content.type == "text") if self.contents else "" @property - def user_input_requests(self) -> list[UserInputRequestContents]: + def user_input_requests(self) -> list[Content]: """Get all BaseUserInputRequest messages from the response.""" - return [content for content in self.contents if isinstance(content, UserInputRequestContents)] + return [content for content in self.contents if isinstance(content, Content) and content.user_input_request] def __str__(self) -> str: return self.text diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index cd768ffc4d..341bc6239b 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -13,18 +13,13 @@ AgentResponseUpdate, AgentThread, BaseAgent, - BaseContent, ChatMessage, - Contents, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, UsageDetails, ) +from .._types import add_usage_details from ..exceptions import AgentExecutionException from ._agent_executor import AgentExecutor from ._checkpoint import CheckpointStorage @@ -357,12 +352,12 @@ def _convert_workflow_event_to_agent_update( args = self.RequestInfoFunctionArgs(request_id=request_id, data=event.data).to_dict() - function_call = FunctionCallContent( + function_call = Content.from_function_call( call_id=request_id, name=self.REQUEST_INFO_FUNCTION_NAME, arguments=args, ) - approval_request = FunctionApprovalRequestContent( + approval_request = Content.from_function_approval_request( id=request_id, function_call=function_call, additional_properties={"request_id": request_id}, @@ -385,9 +380,9 @@ def _extract_function_responses(self, input_messages: list[ChatMessage]) -> dict function_responses: dict[str, Any] = {} for message in input_messages: for content in message.contents: - if isinstance(content, FunctionApprovalResponseContent): + if content.type == "function_approval_response": # Parse the function arguments to recover request payload - arguments_payload = content.function_call.arguments + arguments_payload = content.function_call.arguments # type: ignore[attr-defined] if isinstance(arguments_payload, str): try: parsed_args = self.RequestInfoFunctionArgs.from_json(arguments_payload) @@ -402,8 +397,8 @@ def _extract_function_responses(self, input_messages: list[ChatMessage]) -> dict "FunctionApprovalResponseContent arguments must be a mapping or JSON string." ) - request_id = parsed_args.request_id or content.id - if not content.approved: + request_id = parsed_args.request_id or content.id # type: ignore[attr-defined] + if not content.approved: # type: ignore[attr-defined] raise AgentExecutionException(f"Request '{request_id}' was not approved by the caller.") if request_id in self.pending_requests: @@ -412,10 +407,10 @@ def _extract_function_responses(self, input_messages: list[ChatMessage]) -> dict raise AgentExecutionException( "Only responses for pending requests are allowed when there are outstanding approvals." ) - elif isinstance(content, FunctionResultContent): - request_id = content.call_id + elif content.type == "function_result": + request_id = content.call_id # type: ignore[attr-defined] if request_id in self.pending_requests: - response_data = content.result if hasattr(content, "result") else str(content) + response_data = content.result if hasattr(content, "result") else str(content) # type: ignore[attr-defined] function_responses[request_id] = response_data elif bool(self.pending_requests): raise AgentExecutionException( @@ -426,17 +421,17 @@ def _extract_function_responses(self, input_messages: list[ChatMessage]) -> dict raise AgentExecutionException("Unexpected content type while awaiting request info responses.") return function_responses - def _extract_contents(self, data: Any) -> list[Contents]: - """Recursively extract Contents from workflow output data.""" + def _extract_contents(self, data: Any) -> list[Content]: + """Recursively extract Content from workflow output data.""" if isinstance(data, ChatMessage): return list(data.contents) if isinstance(data, list): return [c for item in data for c in self._extract_contents(item)] - if isinstance(data, BaseContent): - return [cast(Contents, data)] + if isinstance(data, Content): + return [cast(Content, data)] if isinstance(data, str): - return [TextContent(text=data)] - return [TextContent(text=str(data))] + return [Content.from_text(text=data)] + return [Content.from_text(text=str(data))] class _ResponseState(TypedDict): """State for grouping response updates by message_id.""" @@ -508,13 +503,6 @@ def _parse_dt(value: str | None) -> tuple[int, datetime | str | None]: except Exception: return (0, v) - def _sum_usage(a: UsageDetails | None, b: UsageDetails | None) -> UsageDetails | None: - if a is None: - return b - if b is None: - return a - return a + b - def _merge_responses(current: AgentResponse | None, incoming: AgentResponse) -> AgentResponse: if current is None: return incoming @@ -534,7 +522,7 @@ def _add_raw(value: object) -> None: messages=(current.messages or []) + (incoming.messages or []), response_id=current.response_id or incoming.response_id, created_at=incoming.created_at or current.created_at, - usage_details=_sum_usage(current.usage_details, incoming.usage_details), + usage_details=add_usage_details(current.usage_details, incoming.usage_details), raw_representation=raw_list if raw_list else None, additional_properties=incoming.additional_properties or current.additional_properties, ) @@ -569,7 +557,7 @@ def _add_raw(value: object) -> None: if aggregated: final_messages.extend(aggregated.messages) if aggregated.usage_details: - merged_usage = _sum_usage(merged_usage, aggregated.usage_details) + merged_usage = add_usage_details(merged_usage, aggregated.usage_details) if aggregated.created_at and ( not latest_created_at or _parse_dt(aggregated.created_at) > _parse_dt(latest_created_at) ): @@ -593,7 +581,7 @@ def _add_raw(value: object) -> None: flattened = AgentResponse.from_agent_run_response_updates(global_dangling) final_messages.extend(flattened.messages) if flattened.usage_details: - merged_usage = _sum_usage(merged_usage, flattened.usage_details) + merged_usage = add_usage_details(merged_usage, flattened.usage_details) if flattened.created_at and ( not latest_created_at or _parse_dt(flattened.created_at) > _parse_dt(latest_created_at) ): diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index bcd47caca2..3d665bd56c 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from typing import Any, cast -from agent_framework import FunctionApprovalRequestContent, FunctionApprovalResponseContent +from agent_framework import Content from .._agents import AgentProtocol, ChatAgent from .._threads import AgentThread @@ -95,8 +95,8 @@ def __init__( super().__init__(exec_id) self._agent = agent self._agent_thread = agent_thread or self._agent.get_new_thread() - self._pending_agent_requests: dict[str, FunctionApprovalRequestContent] = {} - self._pending_responses_to_agent: list[FunctionApprovalResponseContent] = [] + self._pending_agent_requests: dict[str, Content] = {} + self._pending_responses_to_agent: list[Content] = [] self._output_response = output_response # AgentExecutor maintains an internal cache of messages in between runs @@ -179,8 +179,8 @@ async def from_messages( @response_handler async def handle_user_input_response( self, - original_request: FunctionApprovalRequestContent, - response: FunctionApprovalResponseContent, + original_request: Content, + response: Content, ctx: WorkflowContext[AgentExecutorResponse, AgentResponse], ) -> None: """Handle user input responses for function approvals during agent execution. @@ -345,7 +345,7 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentResponse | None: if response.user_input_requests: for user_input_request in response.user_input_requests: self._pending_agent_requests[user_input_request.id] = user_input_request - await ctx.request_info(user_input_request, FunctionApprovalResponseContent) + await ctx.request_info(user_input_request, Content) return None return response @@ -362,7 +362,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentResponse | No run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) updates: list[AgentResponseUpdate] = [] - user_input_requests: list[FunctionApprovalRequestContent] = [] + user_input_requests: list[Content] = [] async for update in self._agent.run_stream( self._cache, thread=self._agent_thread, @@ -388,7 +388,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentResponse | No if user_input_requests: for user_input_request in user_input_requests: self._pending_agent_requests[user_input_request.id] = user_input_request - await ctx.request_info(user_input_request, FunctionApprovalResponseContent) + await ctx.request_info(user_input_request, Content) return None return response diff --git a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py index edcffaa530..09f118a6c6 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py +++ b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py @@ -37,8 +37,6 @@ def clean_conversation_for_handoff(conversation: list[ChatMessage]) -> list[Chat Returns: Cleaned conversation safe for handoff routing """ - from agent_framework import FunctionApprovalRequestContent, FunctionCallContent - cleaned: list[ChatMessage] = [] for msg in conversation: # Skip tool response messages entirely @@ -49,7 +47,7 @@ def clean_conversation_for_handoff(conversation: list[ChatMessage]) -> list[Chat has_tool_content = False if msg.contents: has_tool_content = any( - isinstance(content, (FunctionApprovalRequestContent, FunctionCallContent)) for content in msg.contents + content.type in ("function_approval_request", "function_call") for content in msg.contents ) # If no tool content, keep original diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index 164ae9040c..b60054165f 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -13,10 +13,10 @@ from pydantic import ValidationError from agent_framework import ( + Annotation, ChatResponse, ChatResponseUpdate, - CitationAnnotation, - TextContent, + Content, use_chat_middleware, use_function_invocation, ) @@ -267,8 +267,8 @@ class MyOptions(AzureOpenAIChatOptions, total=False): ) @override - def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> TextContent | None: - """Parse the choice into a TextContent object. + def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> Content | None: + """Parse the choice into a Content object with type='text'. Overwritten from OpenAIBaseChatClient to deal with Azure On Your Data function. For docs see: @@ -279,10 +279,10 @@ def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> TextContent | if message is None: # type: ignore return None if hasattr(message, "refusal") and message.refusal: - return TextContent(text=message.refusal, raw_representation=choice) + return Content.from_text(text=message.refusal, raw_representation=choice) if not message.content: return None - text_content = TextContent(text=message.content, raw_representation=choice) + text_content = Content.from_text(text=message.content, raw_representation=choice) if not message.model_extra or "context" not in message.model_extra: return text_content @@ -304,7 +304,8 @@ def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> TextContent | text_content.annotations = [] for citation in citations: text_content.annotations.append( - CitationAnnotation( + Annotation( + type="citation", title=citation.get("title", ""), url=citation.get("url", ""), snippet=citation.get("content", ""), diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 70564c9354..452e9fc7b7 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -40,7 +40,7 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - Contents, + Content, FinishReason, ) @@ -1744,8 +1744,10 @@ def _to_otel_message(message: "ChatMessage") -> dict[str, Any]: return {"role": message.role.value, "parts": [_to_otel_part(content) for content in message.contents]} -def _to_otel_part(content: "Contents") -> dict[str, Any] | None: +def _to_otel_part(content: "Content") -> dict[str, Any] | None: """Create a otel representation of a Content.""" + from ._types import _get_data_bytes_as_str + match content.type: case "text": return {"type": "text", "content": content.text} @@ -1761,7 +1763,7 @@ def _to_otel_part(content: "Contents") -> dict[str, Any] | None: case "data": return { "type": "blob", - "content": content.get_data_bytes_as_str(), + "content": _get_data_bytes_as_str(content), "mime_type": content.media_type, "modality": content.media_type.split("/")[0] if content.media_type else None, } @@ -1802,10 +1804,10 @@ def _get_response_attributes( if model_id := getattr(response, "model_id", None): attributes[SpanAttributes.LLM_RESPONSE_MODEL] = model_id if capture_usage and (usage := response.usage_details): - if usage.input_token_count: - attributes[OtelAttr.INPUT_TOKENS] = usage.input_token_count - if usage.output_token_count: - attributes[OtelAttr.OUTPUT_TOKENS] = usage.output_token_count + if usage.get("input_token_count"): + attributes[OtelAttr.INPUT_TOKENS] = usage["input_token_count"] + if usage.get("output_token_count"): + attributes[OtelAttr.OUTPUT_TOKENS] = usage["output_token_count"] if duration: attributes[Meters.LLM_OPERATION_DURATION] = duration return attributes diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index 2b226f1d22..640d0b49d7 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -47,15 +47,8 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - CodeInterpreterToolCallContent, - Contents, - FunctionCallContent, - FunctionResultContent, - MCPServerToolCallContent, + Content, Role, - TextContent, - UriContent, - UsageContent, UsageDetails, prepare_function_call_results, ) @@ -416,7 +409,7 @@ async def _create_assistant_stream( thread_id: str | None, assistant_id: str, run_options: dict[str, Any], - tool_results: list[FunctionResultContent] | None, + tool_results: list[Content] | None, ) -> tuple[Any, str]: """Create the assistant stream for processing. @@ -526,7 +519,7 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter and response.data.usage is not None ): usage = response.data.usage - usage_content = UsageContent( + usage_content = Content.from_usage( UsageDetails( input_token_count=usage.prompt_tokens, output_token_count=usage.completion_tokens, @@ -551,9 +544,9 @@ async def _process_stream_events(self, stream: Any, thread_id: str) -> AsyncIter role=Role.ASSISTANT, ) - def _parse_function_calls_from_assistants(self, event_data: Run, response_id: str | None) -> list[Contents]: + def _parse_function_calls_from_assistants(self, event_data: Run, response_id: str | None) -> list[Content]: """Parse function call contents from an assistants tool action event.""" - contents: list[Contents] = [] + contents: list[Content] = [] if event_data.required_action is not None: for tool_call in event_data.required_action.submit_tool_outputs.tool_calls: @@ -563,10 +556,12 @@ def _parse_function_calls_from_assistants(self, event_data: Run, response_id: st if tool_type == "code_interpreter" and getattr(tool_call_any, "code_interpreter", None): code_input = getattr(tool_call_any.code_interpreter, "input", None) inputs = ( - [TextContent(text=code_input, raw_representation=tool_call)] if code_input is not None else None + [Content.from_text(text=code_input, raw_representation=tool_call)] + if code_input is not None + else None ) contents.append( - CodeInterpreterToolCallContent( + Content.from_code_interpreter_tool_call( call_id=call_id, inputs=inputs, raw_representation=tool_call, @@ -574,7 +569,7 @@ def _parse_function_calls_from_assistants(self, event_data: Run, response_id: st ) elif tool_type == "mcp": contents.append( - MCPServerToolCallContent( + Content.from_mcp_server_tool_call( call_id=call_id, tool_name=getattr(tool_call, "name", "") or "", server_name=getattr(tool_call, "server_label", None), @@ -586,7 +581,7 @@ def _parse_function_calls_from_assistants(self, event_data: Run, response_id: st function_name = tool_call.function.name function_arguments = json.loads(tool_call.function.arguments) contents.append( - FunctionCallContent( + Content.from_function_call( call_id=call_id, name=function_name, arguments=function_arguments, @@ -600,7 +595,7 @@ def _prepare_options( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any, - ) -> tuple[dict[str, Any], list[FunctionResultContent] | None]: + ) -> tuple[dict[str, Any], list[Content] | None]: from .._types import validate_tool_mode run_options: dict[str, Any] = {**kwargs} @@ -672,7 +667,7 @@ def _prepare_options( } instructions: list[str] = [] - tool_results: list[FunctionResultContent] | None = None + tool_results: list[Content] | None = None additional_messages: list[AdditionalMessage] | None = None @@ -681,21 +676,23 @@ def _prepare_options( # All other messages are added 1:1. for chat_message in messages: if chat_message.role.value in ["system", "developer"]: - for text_content in [content for content in chat_message.contents if isinstance(content, TextContent)]: - instructions.append(text_content.text) + for text_content in [content for content in chat_message.contents if content.type == "text"]: + text = getattr(text_content, "text", None) + if text: + instructions.append(text) continue message_contents: list[MessageContentPartParam] = [] for content in chat_message.contents: - if isinstance(content, TextContent): - message_contents.append(TextContentBlockParam(type="text", text=content.text)) - elif isinstance(content, UriContent) and content.has_top_level_media_type("image"): + if content.type == "text": + message_contents.append(TextContentBlockParam(type="text", text=content.text)) # type: ignore[attr-defined] + elif content.type == "uri" and content.has_top_level_media_type("image"): message_contents.append( - ImageURLContentBlockParam(type="image_url", image_url=ImageURLParam(url=content.uri)) + ImageURLContentBlockParam(type="image_url", image_url=ImageURLParam(url=content.uri)) # type: ignore[attr-defined] ) - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": if tool_results is None: tool_results = [] tool_results.append(content) @@ -720,7 +717,7 @@ def _prepare_options( def _prepare_tool_outputs_for_assistants( self, - tool_results: list[FunctionResultContent] | None, + tool_results: list[Content] | None, ) -> tuple[str | None, list[ToolOutput] | None]: """Prepare function results for submission to the assistants API.""" run_id: str | None = None diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 2d1ef8b463..7a9d9f7fac 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -25,18 +25,9 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - Contents, - DataContent, + Content, FinishReason, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, Role, - TextContent, - TextReasoningContent, - UriContent, - UsageContent, UsageDetails, prepare_function_call_results, ) @@ -294,13 +285,13 @@ def _parse_response_from_openai(self, response: ChatCompletion, options: dict[st response_metadata.update(self._get_metadata_from_chat_choice(choice)) if choice.finish_reason: finish_reason = FinishReason(value=choice.finish_reason) - contents: list[Contents] = [] + contents: list[Content] = [] if text_content := self._parse_text_from_openai(choice): contents.append(text_content) if parsed_tool_calls := [tool for tool in self._parse_tool_calls_from_openai(choice)]: contents.extend(parsed_tool_calls) if reasoning_details := getattr(choice.message, "reasoning_details", None): - contents.append(TextReasoningContent(None, protected_data=json.dumps(reasoning_details))) + contents.append(Content.from_text_reasoning(None, protected_data=json.dumps(reasoning_details))) messages.append(ChatMessage(role="assistant", contents=contents)) return ChatResponse( response_id=response.id, @@ -322,13 +313,15 @@ def _parse_response_update_from_openai( if chunk.usage: return ChatResponseUpdate( role=Role.ASSISTANT, - contents=[UsageContent(details=self._parse_usage_from_openai(chunk.usage), raw_representation=chunk)], + contents=[ + Content.from_usage(details=self._parse_usage_from_openai(chunk.usage), raw_representation=chunk) + ], model_id=chunk.model, additional_properties=chunk_metadata, response_id=chunk.id, message_id=chunk.id, ) - contents: list[Contents] = [] + contents: list[Content] = [] finish_reason: FinishReason | None = None for choice in chunk.choices: chunk_metadata.update(self._get_metadata_from_chat_choice(choice)) @@ -339,7 +332,7 @@ def _parse_response_update_from_openai( if text_content := self._parse_text_from_openai(choice): contents.append(text_content) if reasoning_details := getattr(choice.delta, "reasoning_details", None): - contents.append(TextReasoningContent(None, protected_data=json.dumps(reasoning_details))) + contents.append(Content.from_text_reasoning(None, protected_data=json.dumps(reasoning_details))) return ChatResponseUpdate( created_at=datetime.fromtimestamp(chunk.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), contents=contents, @@ -374,13 +367,13 @@ def _parse_usage_from_openai(self, usage: CompletionUsage) -> UsageDetails: details["prompt/cached_tokens"] = tokens return details - def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> TextContent | None: - """Parse the choice into a TextContent object.""" + def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> Content | None: + """Parse the choice into a Content object with type='text'.""" message = choice.message if isinstance(choice, Choice) else choice.delta if message.content: - return TextContent(text=message.content, raw_representation=choice) + return Content.from_text(text=message.content, raw_representation=choice) if hasattr(message, "refusal") and message.refusal: - return TextContent(text=message.refusal, raw_representation=choice) + return Content.from_text(text=message.refusal, raw_representation=choice) return None def _get_metadata_from_chat_response(self, response: ChatCompletion) -> dict[str, Any]: @@ -401,15 +394,15 @@ def _get_metadata_from_chat_choice(self, choice: Choice | ChunkChoice) -> dict[s "logprobs": getattr(choice, "logprobs", None), } - def _parse_tool_calls_from_openai(self, choice: Choice | ChunkChoice) -> list[Contents]: + def _parse_tool_calls_from_openai(self, choice: Choice | ChunkChoice) -> list[Content]: """Parse tool calls from an OpenAI response choice.""" - resp: list[Contents] = [] + resp: list[Content] = [] content = choice.message if isinstance(choice, Choice) else choice.delta if content and content.tool_calls: for tool in content.tool_calls: if not isinstance(tool, ChatCompletionMessageCustomToolCall) and tool.function: # ignoring tool.custom - fcc = FunctionCallContent( + fcc = Content.from_function_call( call_id=tool.id if tool.id else "", name=tool.function.name if tool.function.name else "", arguments=tool.function.arguments if tool.function.arguments else "", @@ -455,7 +448,7 @@ def _prepare_message_for_openai(self, message: ChatMessage) -> list[dict[str, An all_messages: list[dict[str, Any]] = [] for content in message.contents: # Skip approval content - it's internal framework state, not for the LLM - if isinstance(content, (FunctionApprovalRequestContent, FunctionApprovalResponseContent)): + if content.type in ("function_approval_request", "function_approval_response"): continue args: dict[str, Any] = { @@ -467,21 +460,21 @@ def _prepare_message_for_openai(self, message: ChatMessage) -> list[dict[str, An details := message.additional_properties["reasoning_details"] ): args["reasoning_details"] = details - match content: - case FunctionCallContent(): + match content.type: + case "function_call": if all_messages and "tool_calls" in all_messages[-1]: # If the last message already has tool calls, append to it all_messages[-1]["tool_calls"].append(self._prepare_content_for_openai(content)) else: args["tool_calls"] = [self._prepare_content_for_openai(content)] # type: ignore - case FunctionResultContent(): + case "function_result": args["tool_call_id"] = content.call_id # Always include content for tool results - API requires it even if empty # Functions returning None should still have a tool result message args["content"] = ( prepare_function_call_results(content.result) if content.result is not None else "" ) - case TextReasoningContent(protected_data=protected_data) if protected_data is not None: + case "text_reasoning" if (protected_data := content.protected_data) is not None: all_messages[-1]["reasoning_details"] = json.loads(protected_data) case _: if "content" not in args: @@ -492,27 +485,27 @@ def _prepare_message_for_openai(self, message: ChatMessage) -> list[dict[str, An all_messages.append(args) return all_messages - def _prepare_content_for_openai(self, content: Contents) -> dict[str, Any]: + def _prepare_content_for_openai(self, content: Content) -> dict[str, Any]: """Prepare content for OpenAI.""" - match content: - case FunctionCallContent(): + match content.type: + case "function_call": args = json.dumps(content.arguments) if isinstance(content.arguments, Mapping) else content.arguments return { "id": content.call_id, "type": "function", "function": {"name": content.name, "arguments": args}, } - case FunctionResultContent(): + case "function_result": return { "tool_call_id": content.call_id, "content": content.result, } - case DataContent() | UriContent() if content.has_top_level_media_type("image"): + case "data" | "uri" if content.has_top_level_media_type("image"): return { "type": "image_url", "image_url": {"url": content.uri}, } - case DataContent() | UriContent() if content.has_top_level_media_type("audio"): + case "data" | "uri" if content.has_top_level_media_type("audio"): if content.media_type and "wav" in content.media_type: audio_format = "wav" elif content.media_type and "mp3" in content.media_type: @@ -534,9 +527,7 @@ def _prepare_content_for_openai(self, content: Contents) -> dict[str, Any]: "format": audio_format, }, } - case DataContent() | UriContent() if content.has_top_level_media_type( - "application" - ) and content.uri.startswith("data:"): + case "data" | "uri" if content.has_top_level_media_type("application") and content.uri.startswith("data:"): # All application/* media types should be treated as files for OpenAI filename = getattr(content, "filename", None) or ( content.additional_properties.get("filename") diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 37a35ae9bc..e5858499c4 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import base64 import sys from collections.abc import ( AsyncIterable, @@ -48,33 +49,15 @@ use_function_invocation, ) from .._types import ( + Annotation, ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, - CitationAnnotation, - CodeInterpreterToolCallContent, - CodeInterpreterToolResultContent, - Contents, - DataContent, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, - HostedFileContent, - HostedVectorStoreContent, - ImageGenerationToolCallContent, - ImageGenerationToolResultContent, - MCPServerToolCallContent, - MCPServerToolResultContent, + Content, Role, - TextContent, - TextReasoningContent, TextSpanRegion, - UriContent, - UsageContent, UsageDetails, - _parse_content, prepare_function_call_results, prepend_instructions_to_messages, validate_tool_mode, @@ -391,7 +374,7 @@ def _prepare_tools_for_openai( if tool.inputs: tool_args["file_ids"] = [] for tool_input in tool.inputs: - if isinstance(tool_input, HostedFileContent): + if tool_input.type == "hosted_file": tool_args["file_ids"].append(tool_input.file_id) # type: ignore[attr-defined] if not tool_args["file_ids"]: tool_args.pop("file_ids") @@ -417,7 +400,9 @@ def _prepare_tools_for_openai( if not tool.inputs: raise ValueError("HostedFileSearchTool requires inputs to be specified.") inputs: list[str] = [ - inp.vector_store_id for inp in tool.inputs if isinstance(inp, HostedVectorStoreContent) + inp.vector_store_id + for inp in tool.inputs + if inp.type == "hosted_vector_store" # type: ignore[attr-defined] ] if not inputs: raise ValueError( @@ -629,11 +614,11 @@ def _prepare_messages_for_openai(self, chat_messages: Sequence[ChatMessage]) -> for message in chat_messages: for content in message.contents: if ( - isinstance(content, FunctionCallContent) + content.type == "function_call" and content.additional_properties and "fc_id" in content.additional_properties ): - call_id_to_id[content.call_id] = content.additional_properties["fc_id"] + call_id_to_id[content.call_id] = content.additional_properties["fc_id"] # type: ignore[attr-defined] list_of_list = [self._prepare_message_for_openai(message, call_id_to_id) for message in chat_messages] # Flatten the list of lists into a single list return list(chain.from_iterable(list_of_list)) @@ -649,18 +634,18 @@ def _prepare_message_for_openai( "role": message.role.value if isinstance(message.role, Role) else message.role, } for content in message.contents: - match content: - case TextReasoningContent(): + match content.type: + case "text_reasoning": # Don't send reasoning content back to model continue - case FunctionResultContent(): + case "function_result": new_args: dict[str, Any] = {} new_args.update(self._prepare_content_for_openai(message.role, content, call_id_to_id)) all_messages.append(new_args) - case FunctionCallContent(): + case "function_call": function_call = self._prepare_content_for_openai(message.role, content, call_id_to_id) all_messages.append(function_call) # type: ignore - case FunctionApprovalResponseContent() | FunctionApprovalRequestContent(): + case "function_approval_response" | "function_approval_request": all_messages.append(self._prepare_content_for_openai(message.role, content, call_id_to_id)) # type: ignore case _: if "content" not in args: @@ -673,17 +658,17 @@ def _prepare_message_for_openai( def _prepare_content_for_openai( self, role: Role, - content: Contents, + content: Content, call_id_to_id: dict[str, str], ) -> dict[str, Any]: """Prepare content for the OpenAI Responses API format.""" - match content: - case TextContent(): + match content.type: + case "text": return { "type": "output_text" if role == Role.ASSISTANT else "input_text", "text": content.text, } - case TextReasoningContent(): + case "text_reasoning": ret: dict[str, Any] = { "type": "reasoning", "summary": { @@ -703,7 +688,7 @@ def _prepare_content_for_openai( if encrypted_content := props.get("encrypted_content"): ret["encrypted_content"] = encrypted_content return ret - case DataContent() | UriContent(): + case "data" | "uri": if content.has_top_level_media_type("image"): return { "type": "input_image", @@ -744,7 +729,7 @@ def _prepare_content_for_openai( file_obj["filename"] = filename return file_obj return {} - case FunctionCallContent(): + case "function_call": if not content.call_id: logger.warning(f"FunctionCallContent missing call_id for function '{content.name}'") return {} @@ -761,7 +746,7 @@ def _prepare_content_for_openai( "arguments": content.arguments, "status": None, } - case FunctionResultContent(): + case "function_result": # call_id for the result needs to be the same as the call_id for the function call args: dict[str, Any] = { "call_id": content.call_id, @@ -769,7 +754,7 @@ def _prepare_content_for_openai( "output": prepare_function_call_results(content.result), } return args - case FunctionApprovalRequestContent(): + case "function_approval_request": return { "type": "mcp_approval_request", "id": content.id, @@ -779,19 +764,19 @@ def _prepare_content_for_openai( if content.function_call.additional_properties else None, } - case FunctionApprovalResponseContent(): + case "function_approval_response": return { "type": "mcp_approval_response", "approval_request_id": content.id, "approve": content.approved, } - case HostedFileContent(): + case "hosted_file": return { "type": "input_file", "file_id": content.file_id, } case _: # should catch UsageDetails and ErrorContent and HostedVectorStoreContent - logger.debug("Unsupported content type passed (type: %s)", type(content)) + logger.debug("Unsupported content type passed (type: %s)", content.type) return {} # region Parse methods @@ -804,7 +789,7 @@ def _parse_response_from_openai( structured_response: BaseModel | None = response.output_parsed if isinstance(response, ParsedResponse) else None # type: ignore[reportUnknownMemberType] metadata: dict[str, Any] = response.metadata or {} - contents: list[Contents] = [] + contents: list[Content] = [] for item in response.output: # type: ignore[reportUnknownMemberType] match item.type: # types: @@ -829,7 +814,7 @@ def _parse_response_from_openai( for message_content in item.content: # type: ignore[reportMissingTypeArgument] match message_content.type: case "output_text": - text_content = TextContent( + text_content = Content.from_text( text=message_content.text, raw_representation=message_content, # type: ignore[reportUnknownArgumentType] ) @@ -840,7 +825,8 @@ def _parse_response_from_openai( match annotation.type: case "file_path": text_content.annotations.append( - CitationAnnotation( + Annotation( + type="citation", file_id=annotation.file_id, additional_properties={ "index": annotation.index, @@ -850,7 +836,8 @@ def _parse_response_from_openai( ) case "file_citation": text_content.annotations.append( - CitationAnnotation( + Annotation( + type="citation", url=annotation.filename, file_id=annotation.file_id, raw_representation=annotation, @@ -861,11 +848,13 @@ def _parse_response_from_openai( ) case "url_citation": text_content.annotations.append( - CitationAnnotation( + Annotation( + type="citation", title=annotation.title, url=annotation.url, annotated_regions=[ TextSpanRegion( + type="text_span", start_index=annotation.start_index, end_index=annotation.end_index, ) @@ -875,7 +864,8 @@ def _parse_response_from_openai( ) case "container_file_citation": text_content.annotations.append( - CitationAnnotation( + Annotation( + type="citation", file_id=annotation.file_id, url=annotation.filename, additional_properties={ @@ -883,6 +873,7 @@ def _parse_response_from_openai( }, annotated_regions=[ TextSpanRegion( + type="text_span", start_index=annotation.start_index, end_index=annotation.end_index, ) @@ -898,7 +889,7 @@ def _parse_response_from_openai( contents.append(text_content) case "refusal": contents.append( - TextContent( + Content.from_text( text=message_content.refusal, raw_representation=message_content, ) @@ -910,7 +901,7 @@ def _parse_response_from_openai( if hasattr(item, "summary") and item.summary and index < len(item.summary): additional_properties = {"summary": item.summary[index]} contents.append( - TextReasoningContent( + Content.from_text_reasoning( text=reasoning_content.text, raw_representation=reasoning_content, additional_properties=additional_properties, @@ -919,23 +910,23 @@ def _parse_response_from_openai( if hasattr(item, "summary") and item.summary: for summary in item.summary: contents.append( - TextReasoningContent(text=summary.text, raw_representation=summary) # type: ignore[arg-type] + Content.from_text_reasoning(text=summary.text, raw_representation=summary) # type: ignore[arg-type] ) case "code_interpreter_call": # ResponseOutputCodeInterpreterCall call_id = getattr(item, "call_id", None) or getattr(item, "id", None) - outputs: list["Contents"] = [] + outputs: list["Content"] = [] if item_outputs := getattr(item, "outputs", None): for code_output in item_outputs: if getattr(code_output, "type", None) == "logs": outputs.append( - TextContent( + Content.from_text( text=code_output.logs, raw_representation=code_output, ) ) elif getattr(code_output, "type", None) == "image": outputs.append( - UriContent( + Content.from_uri( uri=code_output.url, raw_representation=code_output, media_type="image", @@ -943,14 +934,14 @@ def _parse_response_from_openai( ) if code := getattr(item, "code", None): contents.append( - CodeInterpreterToolCallContent( + Content.from_code_interpreter_tool_call( call_id=call_id, - inputs=[TextContent(text=code, raw_representation=item)], + inputs=[Content.from_text(text=code, raw_representation=item)], raw_representation=item, ) ) contents.append( - CodeInterpreterToolResultContent( + Content.from_code_interpreter_tool_result( call_id=call_id, outputs=outputs, raw_representation=item, @@ -958,7 +949,7 @@ def _parse_response_from_openai( ) case "function_call": # ResponseOutputFunctionCall contents.append( - FunctionCallContent( + Content.from_function_call( call_id=item.call_id if hasattr(item, "call_id") and item.call_id else "", name=item.name if hasattr(item, "name") else "", arguments=item.arguments if hasattr(item, "arguments") else "", @@ -968,9 +959,9 @@ def _parse_response_from_openai( ) case "mcp_approval_request": # ResponseOutputMcpApprovalRequest contents.append( - FunctionApprovalRequestContent( + Content.from_function_approval_request( id=item.id, - function_call=FunctionCallContent( + function_call=Content.from_function_call( call_id=item.id, name=item.name, arguments=item.arguments, @@ -982,7 +973,7 @@ def _parse_response_from_openai( case "mcp_call": call_id = item.id contents.append( - MCPServerToolCallContent( + Content.from_mcp_server_tool_call( call_id=call_id, tool_name=item.name, server_name=item.server_label, @@ -992,31 +983,38 @@ def _parse_response_from_openai( ) if item.output is not None: contents.append( - MCPServerToolResultContent( + Content.from_mcp_server_tool_result( call_id=call_id, - output=[TextContent(text=item.output)], + output=[Content.from_text(text=item.output)], raw_representation=item, ) ) case "image_generation_call": # ResponseOutputImageGenerationCall - image_output: DataContent | None = None + image_output: Content | None = None if item.result: - base64_data = item.result - image_format = DataContent.detect_image_format_from_base64(base64_data) - image_output = DataContent( - data=base64_data, - media_type=f"image/{image_format}" if image_format else "image/png", + base64_data = item.result # OpenAI returns base64-encoded string + # Detect media type from base64 data + from agent_framework._types import detect_media_type_from_base64 + + media_type = detect_media_type_from_base64(base64_data) + if media_type is None: + media_type = "image/png" # Default fallback + # Convert base64 string to bytes for Content.from_data + data_bytes = base64.b64decode(base64_data) + image_output = Content.from_data( + data=data_bytes, + media_type=media_type, raw_representation=item.result, ) image_id = item.id contents.append( - ImageGenerationToolCallContent( + Content.from_image_generation_tool_call( image_id=image_id, raw_representation=item, ) ) contents.append( - ImageGenerationToolResultContent( + Content.from_image_generation_tool_result( image_id=image_id, outputs=image_output, raw_representation=item, @@ -1056,11 +1054,10 @@ def _parse_chunk_from_openai( ) -> ChatResponseUpdate: """Parse an OpenAI Responses API streaming event into a ChatResponseUpdate.""" metadata: dict[str, Any] = {} - contents: list[Contents] = [] + contents: list[Content] = [] conversation_id: str | None = None response_id: str | None = None model = self.model_id - # TODO(peterychang): Add support for other content types match event.type: # types: # ResponseAudioDeltaEvent, @@ -1120,26 +1117,26 @@ def _parse_chunk_from_openai( event_part = event.part match event_part.type: case "output_text": - contents.append(TextContent(text=event_part.text, raw_representation=event)) + contents.append(Content.from_text(text=event_part.text, raw_representation=event)) metadata.update(self._get_metadata_from_response(event_part)) case "refusal": - contents.append(TextContent(text=event_part.refusal, raw_representation=event)) + contents.append(Content.from_text(text=event_part.refusal, raw_representation=event)) case _: pass case "response.output_text.delta": - contents.append(TextContent(text=event.delta, raw_representation=event)) + contents.append(Content.from_text(text=event.delta, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) case "response.reasoning_text.delta": - contents.append(TextReasoningContent(text=event.delta, raw_representation=event)) + contents.append(Content.from_text_reasoning(text=event.delta, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) case "response.reasoning_text.done": - contents.append(TextReasoningContent(text=event.text, raw_representation=event)) + contents.append(Content.from_text_reasoning(text=event.text, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) case "response.reasoning_summary_text.delta": - contents.append(TextReasoningContent(text=event.delta, raw_representation=event)) + contents.append(Content.from_text_reasoning(text=event.delta, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) case "response.reasoning_summary_text.done": - contents.append(TextReasoningContent(text=event.text, raw_representation=event)) + contents.append(Content.from_text_reasoning(text=event.text, raw_representation=event)) metadata.update(self._get_metadata_from_response(event)) case "response.created": response_id = event.response.id @@ -1154,7 +1151,7 @@ def _parse_chunk_from_openai( if event.response.usage: usage = self._parse_usage_from_openai(event.response.usage) if usage: - contents.append(UsageContent(details=usage, raw_representation=event)) + contents.append(Content.from_usage(details=usage, raw_representation=event)) case "response.output_item.added": event_item = event.item match event_item.type: @@ -1179,9 +1176,9 @@ def _parse_chunk_from_openai( ) case "mcp_approval_request": contents.append( - FunctionApprovalRequestContent( + Content.from_function_approval_request( id=event_item.id, - function_call=FunctionCallContent( + function_call=Content.from_function_call( call_id=event_item.id, name=event_item.name, arguments=event_item.arguments, @@ -1193,7 +1190,7 @@ def _parse_chunk_from_openai( case "mcp_call": call_id = getattr(event_item, "id", None) or getattr(event_item, "call_id", None) or "" contents.append( - MCPServerToolCallContent( + Content.from_mcp_server_tool_call( call_id=call_id, tool_name=getattr(event_item, "name", "") or "", server_name=getattr(event_item, "server_label", None), @@ -1206,7 +1203,7 @@ def _parse_chunk_from_openai( or getattr(event_item, "output", None) or getattr(event_item, "outputs", None) ) - parsed_output: list[Contents] | None = None + parsed_output: list[Content] | None = None if result_output: normalized = ( result_output @@ -1214,9 +1211,9 @@ def _parse_chunk_from_openai( and not isinstance(result_output, (str, bytes, MutableMapping)) else [result_output] ) - parsed_output = [_parse_content(output_item) for output_item in normalized] + parsed_output = [Content.from_dict(output_item) for output_item in normalized] contents.append( - MCPServerToolResultContent( + Content.from_mcp_server_tool_result( call_id=call_id, output=parsed_output, raw_representation=event_item, @@ -1224,19 +1221,19 @@ def _parse_chunk_from_openai( ) case "code_interpreter_call": # ResponseOutputCodeInterpreterCall call_id = getattr(event_item, "call_id", None) or getattr(event_item, "id", None) - outputs: list[Contents] = [] + outputs: list[Content] = [] if hasattr(event_item, "outputs") and event_item.outputs: for code_output in event_item.outputs: if getattr(code_output, "type", None) == "logs": outputs.append( - TextContent( + Content.from_text( text=cast(Any, code_output).logs, raw_representation=code_output, ) ) elif getattr(code_output, "type", None) == "image": outputs.append( - UriContent( + Content.from_uri( uri=cast(Any, code_output).url, raw_representation=code_output, media_type="image", @@ -1244,10 +1241,10 @@ def _parse_chunk_from_openai( ) if hasattr(event_item, "code") and event_item.code: contents.append( - CodeInterpreterToolCallContent( + Content.from_code_interpreter_tool_call( call_id=call_id, inputs=[ - TextContent( + Content.from_text( text=event_item.code, raw_representation=event_item, ) @@ -1256,7 +1253,7 @@ def _parse_chunk_from_openai( ) ) contents.append( - CodeInterpreterToolResultContent( + Content.from_code_interpreter_tool_result( call_id=call_id, outputs=outputs, raw_representation=event_item, @@ -1273,7 +1270,7 @@ def _parse_chunk_from_openai( ): additional_properties = {"summary": event_item.summary[index]} contents.append( - TextReasoningContent( + Content.from_text_reasoning( text=reasoning_content.text, raw_representation=reasoning_content, additional_properties=additional_properties, @@ -1285,7 +1282,7 @@ def _parse_chunk_from_openai( call_id, name = function_call_ids.get(event.output_index, (None, None)) if call_id and name: contents.append( - FunctionCallContent( + Content.from_function_call( call_id=call_id, name=name, arguments=event.delta, @@ -1301,11 +1298,17 @@ def _parse_chunk_from_openai( image_base64 = event.partial_image_b64 partial_index = event.partial_image_index - # Use helper function to create data URI from base64 - uri, media_type = DataContent.create_data_uri_from_base64(image_base64) + # Detect media type from base64 data + from agent_framework._types import detect_media_type_from_base64 - image_output = DataContent( - uri=uri, + media_type = detect_media_type_from_base64(image_base64) + if media_type is None: + media_type = "image/png" # Default fallback + + # Decode base64 and use Content.from_data + data_bytes = base64.b64decode(image_base64) + image_output = Content.from_data( + data=data_bytes, media_type=media_type, additional_properties={ "partial_image_index": partial_index, @@ -1316,13 +1319,13 @@ def _parse_chunk_from_openai( image_id = getattr(event, "item_id", None) contents.append( - ImageGenerationToolCallContent( + Content.from_image_generation_tool_call( image_id=image_id, raw_representation=event, ) ) contents.append( - ImageGenerationToolResultContent( + Content.from_image_generation_tool_result( image_id=image_id, outputs=image_output, raw_representation=event, @@ -1343,7 +1346,7 @@ def _get_ann_value(key: str) -> Any: if ann_type == "file_path": if ann_file_id: contents.append( - HostedFileContent( + Content.from_hosted_file( file_id=str(ann_file_id), additional_properties={ "annotation_index": event.annotation_index, @@ -1355,7 +1358,7 @@ def _get_ann_value(key: str) -> Any: elif ann_type == "file_citation": if ann_file_id: contents.append( - HostedFileContent( + Content.from_hosted_file( file_id=str(ann_file_id), additional_properties={ "annotation_index": event.annotation_index, @@ -1368,7 +1371,7 @@ def _get_ann_value(key: str) -> Any: elif ann_type == "container_file_citation": if ann_file_id: contents.append( - HostedFileContent( + Content.from_hosted_file( file_id=str(ann_file_id), additional_properties={ "annotation_index": event.annotation_index, diff --git a/python/packages/core/tests/azure/test_azure_assistants_client.py b/python/packages/core/tests/azure/test_azure_assistants_client.py index 6c65dac7c1..776951a9ea 100644 --- a/python/packages/core/tests/azure/test_azure_assistants_client.py +++ b/python/packages/core/tests/azure/test_azure_assistants_client.py @@ -18,7 +18,6 @@ ChatResponse, ChatResponseUpdate, HostedCodeInterpreterTool, - TextContent, ) from agent_framework.azure import AzureOpenAIAssistantsClient from agent_framework.exceptions import ServiceInitializationError @@ -332,7 +331,7 @@ async def test_azure_assistants_client_streaming() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert any(word in full_message.lower() for word in ["sunny", "25", "weather", "seattle"]) @@ -358,7 +357,7 @@ async def test_azure_assistants_client_streaming_tools() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert any(word in full_message.lower() for word in ["sunny", "25", "weather"]) diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index 3e9a3a5042..d91b58c646 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -25,7 +25,6 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - TextContent, ai_function, ) from agent_framework._telemetry import USER_AGENT_KEY @@ -304,9 +303,9 @@ async def test_azure_on_your_data( ) assert len(content.messages) == 1 assert len(content.messages[0].contents) == 1 - assert isinstance(content.messages[0].contents[0], TextContent) + assert content.messages[0].contents[0].type == "text" assert len(content.messages[0].contents[0].annotations) == 1 - assert content.messages[0].contents[0].annotations[0].title == "test title" + assert content.messages[0].contents[0].annotations[0]["title"] == "test title" assert content.messages[0].contents[0].text == "test" mock_create.assert_awaited_once_with( @@ -374,9 +373,9 @@ async def test_azure_on_your_data_string( ) assert len(content.messages) == 1 assert len(content.messages[0].contents) == 1 - assert isinstance(content.messages[0].contents[0], TextContent) + assert content.messages[0].contents[0].type == "text" assert len(content.messages[0].contents[0].annotations) == 1 - assert content.messages[0].contents[0].annotations[0].title == "test title" + assert content.messages[0].contents[0].annotations[0]["title"] == "test title" assert content.messages[0].contents[0].text == "test" mock_create.assert_awaited_once_with( @@ -433,7 +432,7 @@ async def test_azure_on_your_data_fail( ) assert len(content.messages) == 1 assert len(content.messages[0].contents) == 1 - assert isinstance(content.messages[0].contents[0], TextContent) + assert content.messages[0].contents[0].type == "text" assert content.messages[0].contents[0].text == "test" mock_create.assert_awaited_once_with( @@ -731,7 +730,7 @@ async def test_azure_openai_chat_client_streaming() -> None: assert chunk.message_id is not None assert chunk.response_id is not None for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert "Emily" in full_message or "David" in full_message @@ -757,7 +756,7 @@ async def test_azure_openai_chat_client_streaming_tools() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert "Emily" in full_message or "David" in full_message diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index 0e1c17f9a8..b2d4a59ab7 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -15,10 +15,10 @@ ChatClientProtocol, ChatMessage, ChatResponse, + Content, HostedCodeInterpreterTool, HostedFileSearchTool, HostedMCPTool, - HostedVectorStoreContent, HostedWebSearchTool, ai_function, ) @@ -48,7 +48,7 @@ async def get_weather(location: Annotated[str, "The location as a city name"]) - return f"The weather in {location} is sunny and 72°F." -async def create_vector_store(client: AzureOpenAIResponsesClient) -> tuple[str, HostedVectorStoreContent]: +async def create_vector_store(client: AzureOpenAIResponsesClient) -> tuple[str, Content]: """Create a vector store with sample documents for testing.""" file = await client.client.files.create( file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), purpose="assistants" @@ -61,7 +61,7 @@ async def create_vector_store(client: AzureOpenAIResponsesClient) -> tuple[str, if result.last_error is not None: raise Exception(f"Vector store file processing failed with status: {result.last_error.message}") - return file.id, HostedVectorStoreContent(vector_store_id=vector_store.id) + return file.id, Content.from_hosted_vector_store(vector_store_id=vector_store.id) async def delete_vector_store(client: AzureOpenAIResponsesClient, file_id: str, vector_store_id: str) -> None: diff --git a/python/packages/core/tests/core/conftest.py b/python/packages/core/tests/core/conftest.py index 1561392214..f2f6059b91 100644 --- a/python/packages/core/tests/core/conftest.py +++ b/python/packages/core/tests/core/conftest.py @@ -20,8 +20,8 @@ ChatMessage, ChatResponse, ChatResponseUpdate, + Content, Role, - TextContent, ToolProtocol, ai_function, use_chat_middleware, @@ -108,8 +108,8 @@ async def get_streaming_response( for update in self.streaming_responses.pop(0): yield update else: - yield ChatResponseUpdate(text=TextContent(text="test streaming response "), role="assistant") - yield ChatResponseUpdate(contents=[TextContent(text="another update")], role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="test streaming response "), role="assistant") + yield ChatResponseUpdate(contents=[Content.from_text(text="another update")], role="assistant") @use_chat_middleware @@ -233,7 +233,7 @@ async def run( **kwargs: Any, ) -> AgentResponse: logger.debug(f"Running mock agent, with: {messages=}, {thread=}, {kwargs=}") - return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Response")])]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text("Response")])]) async def run_stream( self, @@ -243,7 +243,7 @@ async def run_stream( **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: logger.debug(f"Running mock agent stream, with: {messages=}, {thread=}, {kwargs=}") - yield AgentResponseUpdate(contents=[TextContent("Response")]) + yield AgentResponseUpdate(contents=[Content.from_text("Response")]) def get_new_thread(self) -> AgentThread: return MockAgentThread() diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index ee9054c143..a331f6f75c 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -18,12 +18,11 @@ ChatMessage, ChatMessageStore, ChatResponse, + Content, Context, ContextProvider, - FunctionCallContent, HostedCodeInterpreterTool, Role, - TextContent, ai_function, ) from agent_framework._mcp import MCPTool @@ -136,7 +135,7 @@ async def test_prepare_thread_does_not_mutate_agent_chat_options(chat_client: Ch async def test_chat_client_agent_update_thread_id(chat_client_base: ChatClientProtocol) -> None: mock_response = ChatResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])], + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text("test response")])], conversation_id="123", ) chat_client_base.run_responses = [mock_response] @@ -200,7 +199,9 @@ async def test_chat_client_agent_author_name_is_used_from_response(chat_client_b chat_client_base.run_responses = [ ChatResponse( messages=[ - ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")], author_name="TestAuthor") + ChatMessage( + role=Role.ASSISTANT, contents=[Content.from_text("test response")], author_name="TestAuthor" + ) ] ) ] @@ -264,7 +265,7 @@ async def test_chat_agent_context_providers_thread_created(chat_client_base: Cha mock_provider = MockContextProvider() chat_client_base.run_responses = [ ChatResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])], + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text("test response")])], conversation_id="test-thread-id", ) ] @@ -345,7 +346,7 @@ async def test_chat_agent_context_providers_with_thread_service_id(chat_client_b mock_provider = MockContextProvider() chat_client_base.run_responses = [ ChatResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("test response")])], + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text("test response")])], conversation_id="service-thread-123", ) ] @@ -575,7 +576,9 @@ def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnk ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}')], + contents=[ + Content.from_function_call(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index 6bd52975c2..39f441eb49 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable from typing import Any -from agent_framework import ChatAgent, ChatMessage, ChatResponse, FunctionCallContent, agent_middleware +from agent_framework import ChatAgent, ChatMessage, ChatResponse, Content, agent_middleware from agent_framework._middleware import AgentRunContext from .conftest import MockChatClient @@ -113,7 +113,7 @@ async def capture_middleware( ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_c_1", name="call_c", arguments='{"task": "Please execute agent_c"}', @@ -170,10 +170,10 @@ async def capture_middleware( await next(context) # Setup mock streaming responses - from agent_framework import ChatResponseUpdate, TextContent + from agent_framework import ChatResponseUpdate chat_client.streaming_responses = [ - [ChatResponseUpdate(text=TextContent(text="Streaming response"), role="assistant")], + [ChatResponseUpdate(text=Content.from_text(text="Streaming response"), role="assistant")], ] sub_agent = ChatAgent( diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 3aa8586a69..b2de663d03 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -11,11 +11,8 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, ai_function, ) from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware @@ -34,7 +31,9 @@ def ai_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -43,12 +42,12 @@ def ai_func(arg1: str) -> str: assert exec_counter == 1 assert len(response.messages) == 3 assert response.messages[0].role == Role.ASSISTANT - assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[0].contents[0].type == "function_call" assert response.messages[0].contents[0].name == "test_function" assert response.messages[0].contents[0].arguments == '{"arg1": "value1"}' assert response.messages[0].contents[0].call_id == "1" assert response.messages[1].role == Role.TOOL - assert isinstance(response.messages[1].contents[0], FunctionResultContent) + assert response.messages[1].contents[0].type == "function_result" assert response.messages[1].contents[0].call_id == "1" assert response.messages[1].contents[0].result == "Processed value1" assert response.messages[2].role == Role.ASSISTANT @@ -68,13 +67,17 @@ def ai_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="2", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="2", name="test_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -87,10 +90,10 @@ def ai_func(arg1: str) -> str: assert response.messages[2].role == Role.ASSISTANT assert response.messages[3].role == Role.TOOL assert response.messages[4].role == Role.ASSISTANT - assert isinstance(response.messages[0].contents[0], FunctionCallContent) - assert isinstance(response.messages[1].contents[0], FunctionResultContent) - assert isinstance(response.messages[2].contents[0], FunctionCallContent) - assert isinstance(response.messages[3].contents[0], FunctionResultContent) + assert response.messages[0].contents[0].type == "function_call" + assert response.messages[1].contents[0].type == "function_result" + assert response.messages[2].contents[0].type == "function_call" + assert response.messages[3].contents[0].type == "function_result" async def test_base_client_with_streaming_function_calling(chat_client_base: ChatClientProtocol): @@ -105,17 +108,17 @@ def ai_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1":')], + contents=[Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1":')], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='"value1"}')], + contents=[Content.from_function_call(call_id="1", name="test_function", arguments='"value1"}')], role="assistant", ), ], [ ChatResponseUpdate( - contents=[TextContent(text="Processed value1")], + contents=[Content.from_text(text="Processed value1")], role="assistant", ) ], @@ -150,7 +153,7 @@ def ai_func(user_query: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="1", name="start_todo_investigation", arguments='{"user_query": "issue"}', @@ -207,7 +210,7 @@ def ai_func(user_query: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="thread-1", name="start_threaded_investigation", arguments='{"user_query": "issue"}', @@ -334,7 +337,7 @@ def func_with_approval(arg1: str) -> str: function_name = "approval_func" if approval_required else "no_approval_func" # Single function call content - func_call = FunctionCallContent(call_id="1", name=function_name, arguments='{"arg1": "value1"}') + func_call = Content.from_function_call(call_id="1", name=function_name, arguments='{"arg1": "value1"}') completion = ChatMessage(role="assistant", text="done") chat_client_base.run_responses = [ @@ -344,23 +347,27 @@ def func_with_approval(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name=function_name, arguments='{"arg1":')], + contents=[Content.from_function_call(call_id="1", name=function_name, arguments='{"arg1":')], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name=function_name, arguments='"value1"}')], + contents=[Content.from_function_call(call_id="1", name=function_name, arguments='"value1"}')], role="assistant", ), ] - ] + ([] if approval_required else [[ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")]]) + ] + ( + [] + if approval_required + else [[ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")]] + ) else: # num_functions == 2 tools = [func_no_approval, func_with_approval] # Two function calls content func_calls = [ - FunctionCallContent(call_id="1", name="no_approval_func", arguments='{"arg1": "value1"}'), - FunctionCallContent(call_id="2", name="approval_func", arguments='{"arg1": "value2"}'), + Content.from_function_call(call_id="1", name="no_approval_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="2", name="approval_func", arguments='{"arg1": "value2"}'), ] chat_client_base.run_responses = [ChatResponse(messages=ChatMessage(role="assistant", contents=func_calls))] @@ -405,24 +412,24 @@ def func_with_approval(arg1: str) -> str: assert len(messages) == 1 # Assistant message should have FunctionCallContent + FunctionApprovalRequestContent assert len(messages[0].contents) == 2 - assert isinstance(messages[0].contents[0], FunctionCallContent) - assert isinstance(messages[0].contents[1], FunctionApprovalRequestContent) + assert messages[0].contents[0].type == "function_call" + assert messages[0].contents[1].type == "function_approval_request" assert messages[0].contents[1].function_call.name == "approval_func" assert exec_counter == 0 # Function not executed yet else: # Streaming: 2 function call chunks + 1 approval request update (same assistant message) assert len(messages) == 3 - assert isinstance(messages[0].contents[0], FunctionCallContent) - assert isinstance(messages[1].contents[0], FunctionCallContent) - assert isinstance(messages[2].contents[0], FunctionApprovalRequestContent) + assert messages[0].contents[0].type == "function_call" + assert messages[1].contents[0].type == "function_call" + assert messages[2].contents[0].type == "function_approval_request" assert messages[2].contents[0].function_call.name == "approval_func" assert exec_counter == 0 # Function not executed yet else: # Single function without approval: call + result + final if not streaming: assert len(messages) == 3 - assert isinstance(messages[0].contents[0], FunctionCallContent) - assert isinstance(messages[1].contents[0], FunctionResultContent) + assert messages[0].contents[0].type == "function_call" + assert messages[1].contents[0].type == "function_result" assert messages[1].contents[0].result == "Processed value1" assert messages[2].role == Role.ASSISTANT assert messages[2].text == "done" @@ -430,9 +437,9 @@ def func_with_approval(arg1: str) -> str: else: # Streaming has: 2 function call updates + 1 result update + 1 final update assert len(messages) == 4 - assert isinstance(messages[0].contents[0], FunctionCallContent) - assert isinstance(messages[1].contents[0], FunctionCallContent) - assert isinstance(messages[2].contents[0], FunctionResultContent) + assert messages[0].contents[0].type == "function_call" + assert messages[1].contents[0].type == "function_call" + assert messages[2].contents[0].type == "function_result" assert messages[3].text == "done" assert exec_counter == 1 else: # num_functions == 2 @@ -443,26 +450,25 @@ def func_with_approval(arg1: str) -> str: assert len(messages) == 1 # Should have: 2 FunctionCallContent + 2 FunctionApprovalRequestContent assert len(messages[0].contents) == 4 - assert isinstance(messages[0].contents[0], FunctionCallContent) - assert isinstance(messages[0].contents[1], FunctionCallContent) + assert messages[0].contents[0].type == "function_call" + assert messages[0].contents[1].type == "function_call" # Both should result in approval requests - approval_requests = [c for c in messages[0].contents if isinstance(c, FunctionApprovalRequestContent)] + approval_requests = [c for c in messages[0].contents if c.type == "function_approval_request"] assert len(approval_requests) == 2 assert exec_counter == 0 # Neither function executed yet else: # Streaming: 2 function call updates + 1 approval request with 2 contents assert len(messages) == 3 - assert isinstance(messages[0].contents[0], FunctionCallContent) - assert isinstance(messages[1].contents[0], FunctionCallContent) + assert messages[0].contents[0].type == "function_call" + assert messages[1].contents[0].type == "function_call" # The approval request message contains both approval requests assert len(messages[2].contents) == 2 - assert all(isinstance(c, FunctionApprovalRequestContent) for c in messages[2].contents) + assert all(c.type == "function_approval_request" for c in messages[2].contents) assert exec_counter == 0 # Neither function executed yet async def test_rejected_approval(chat_client_base: ChatClientProtocol): """Test that rejecting an approval alongside an approved one is handled correctly.""" - from agent_framework import FunctionApprovalResponseContent exec_counter_approved = 0 exec_counter_rejected = 0 @@ -485,8 +491,8 @@ def func_rejected(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="approved_func", arguments='{"arg1": "value1"}'), - FunctionCallContent(call_id="2", name="rejected_func", arguments='{"arg1": "value2"}'), + Content.from_function_call(call_id="1", name="approved_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="2", name="rejected_func", arguments='{"arg1": "value2"}'), ], ) ), @@ -501,19 +507,19 @@ def func_rejected(arg1: str) -> str: assert len(response.messages) == 1 # Assistant message should have: 2 FunctionCallContent + 2 FunctionApprovalRequestContent assert len(response.messages[0].contents) == 4 - approval_requests = [c for c in response.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)] + approval_requests = [c for c in response.messages[0].contents if c.type == "function_approval_request"] assert len(approval_requests) == 2 # Approve one and reject the other approval_req_1 = approval_requests[0] approval_req_2 = approval_requests[1] - approved_response = FunctionApprovalResponseContent( + approved_response = Content.from_function_approval_response( id=approval_req_1.id, function_call=approval_req_1.function_call, approved=True, ) - rejected_response = FunctionApprovalResponseContent( + rejected_response = Content.from_function_approval_response( id=approval_req_2.id, function_call=approval_req_2.function_call, approved=False, @@ -533,7 +539,7 @@ def func_rejected(arg1: str) -> str: rejected_result = None for msg in all_messages: for content in msg.contents: - if isinstance(content, FunctionResultContent): + if content.type == "function_result": if content.call_id == "1": approved_result = content elif content.call_id == "2": @@ -553,7 +559,7 @@ def func_rejected(arg1: str) -> str: # This ensures the message format is correct for OpenAI's API for msg in all_messages: for content in msg.contents: - if isinstance(content, FunctionResultContent): + if content.type == "function_result": assert msg.role == Role.TOOL, ( f"Message with FunctionResultContent must have role='tool', got '{msg.role}'" ) @@ -574,7 +580,7 @@ def func_with_approval(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), ], ) ), @@ -588,14 +594,13 @@ def func_with_approval(arg1: str) -> str: assert len(response.messages) == 1 assert response.messages[0].role == Role.ASSISTANT assert len(response.messages[0].contents) == 2 - assert isinstance(response.messages[0].contents[0], FunctionCallContent) - assert isinstance(response.messages[0].contents[1], FunctionApprovalRequestContent) + assert response.messages[0].contents[0].type == "function_call" + assert response.messages[0].contents[1].type == "function_approval_request" assert exec_counter == 0 async def test_persisted_approval_messages_replay_correctly(chat_client_base: ChatClientProtocol): """Approval flow should work when messages are persisted and sent back (thread scenario).""" - from agent_framework import FunctionApprovalResponseContent exec_counter = 0 @@ -610,7 +615,7 @@ def func_with_approval(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), ], ) ), @@ -624,13 +629,13 @@ def func_with_approval(arg1: str) -> str: # Store messages (like a thread would) persisted_messages = [ - ChatMessage(role="user", contents=[TextContent(text="hello")]), + ChatMessage(role="user", contents=[Content.from_text(text="hello")]), *response1.messages, ] # Send approval - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] - approval_response = FunctionApprovalResponseContent( + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] + approval_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=True, @@ -650,7 +655,6 @@ def func_with_approval(arg1: str) -> str: async def test_no_duplicate_function_calls_after_approval_processing(chat_client_base: ChatClientProtocol): """Processing approval should not create duplicate function calls in messages.""" - from agent_framework import FunctionApprovalResponseContent @ai_function(name="test_func", approval_mode="always_require") def func_with_approval(arg1: str) -> str: @@ -661,7 +665,7 @@ def func_with_approval(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), ], ) ), @@ -672,8 +676,8 @@ def func_with_approval(arg1: str) -> str: "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} ) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] - approval_response = FunctionApprovalResponseContent( + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] + approval_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=True, @@ -687,7 +691,7 @@ def func_with_approval(arg1: str) -> str: 1 for msg in all_messages for content in msg.contents - if isinstance(content, FunctionCallContent) and content.call_id == "1" + if content.type == "function_call" and content.call_id == "1" ) assert function_call_count == 1 @@ -695,7 +699,6 @@ def func_with_approval(arg1: str) -> str: async def test_rejection_result_uses_function_call_id(chat_client_base: ChatClientProtocol): """Rejection error result should use the function call's call_id, not the approval's id.""" - from agent_framework import FunctionApprovalResponseContent @ai_function(name="test_func", approval_mode="always_require") def func_with_approval(arg1: str) -> str: @@ -706,7 +709,7 @@ def func_with_approval(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="call_123", name="test_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="call_123", name="test_func", arguments='{"arg1": "value1"}'), ], ) ), @@ -717,8 +720,8 @@ def func_with_approval(arg1: str) -> str: "hello", options={"tool_choice": "auto", "tools": [func_with_approval]} ) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] - rejection_response = FunctionApprovalResponseContent( + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] + rejection_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=False, @@ -729,7 +732,7 @@ def func_with_approval(arg1: str) -> str: # Find the rejection result rejection_result = next( - (content for msg in all_messages for content in msg.contents if isinstance(content, FunctionResultContent)), + (content for msg in all_messages for content in msg.contents if content.type == "function_result"), None, ) @@ -753,13 +756,17 @@ def ai_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="2", name="test_function", arguments='{"arg1": "value2"}')], + contents=[ + Content.from_function_call(call_id="2", name="test_function", arguments='{"arg1": "value2"}') + ], ) ), # Failsafe response when tool_choice is set to "none" @@ -816,25 +823,33 @@ def error_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="error_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="error_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="2", name="error_function", arguments='{"arg1": "value2"}')], + contents=[ + Content.from_function_call(call_id="2", name="error_function", arguments='{"arg1": "value2"}') + ], ) ), ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="3", name="error_function", arguments='{"arg1": "value3"}')], + contents=[ + Content.from_function_call(call_id="3", name="error_function", arguments='{"arg1": "value3"}') + ], ) ), ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="4", name="error_function", arguments='{"arg1": "value4"}')], + contents=[ + Content.from_function_call(call_id="4", name="error_function", arguments='{"arg1": "value4"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="final response")), @@ -850,14 +865,14 @@ def error_func(arg1: str) -> str: content for msg in response.messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.exception + if content.type == "function_result" and content.exception ] # The first call errors, then the second call errors, hitting the limit # So we get 2 function calls with errors, but the responses show the behavior stopped assert len(error_results) >= 1 # At least one error occurred # Should have stopped making new function calls after hitting the error limit function_calls = [ - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionCallContent) + content for msg in response.messages for content in msg.contents if content.type == "function_call" ] # Should have made at most 2 function calls before stopping assert len(function_calls) <= 2 @@ -877,7 +892,9 @@ def known_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -890,7 +907,7 @@ def known_func(arg1: str) -> str: # Should have a result message indicating the tool wasn't found assert len(response.messages) == 3 - assert isinstance(response.messages[1].contents[0], FunctionResultContent) + assert response.messages[1].contents[0].type == "function_result" result_str = response.messages[1].contents[0].result or response.messages[1].contents[0].exception or "" assert "not found" in result_str.lower() assert exec_counter == 0 # Known function not executed @@ -910,7 +927,9 @@ def known_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}') + ], ) ), ] @@ -946,7 +965,9 @@ def hidden_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="hidden_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="hidden_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -967,7 +988,7 @@ def hidden_func(arg1: str) -> str: content for msg in response.messages for content in msg.contents - if isinstance(content, FunctionCallContent) and content.name == "hidden_function" + if content.type == "function_call" and content.name == "hidden_function" ] assert len(function_calls) >= 1 @@ -983,7 +1004,9 @@ def error_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="error_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="error_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -996,7 +1019,7 @@ def error_func(arg1: str) -> str: # Should have a generic error message error_result = next( - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) + content for msg in response.messages for content in msg.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -1015,7 +1038,9 @@ def error_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="error_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="error_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1028,7 +1053,7 @@ def error_func(arg1: str) -> str: # Should have detailed error message error_result = next( - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) + content for msg in response.messages for content in msg.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -1083,7 +1108,9 @@ def typed_func(arg1: int) -> str: # Expects int, not str ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}')], + contents=[ + Content.from_function_call(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1096,7 +1123,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Should have detailed validation error error_result = next( - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) + content for msg in response.messages for content in msg.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -1115,7 +1142,9 @@ def typed_func(arg1: int) -> str: # Expects int, not str ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}')], + contents=[ + Content.from_function_call(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1128,7 +1157,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Should have generic validation error error_result = next( - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) + content for msg in response.messages for content in msg.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -1138,17 +1167,16 @@ def typed_func(arg1: int) -> str: # Expects int, not str async def test_hosted_tool_approval_response(chat_client_base: ChatClientProtocol): """Test handling of approval responses for hosted tools (tools not in tool_map).""" - from agent_framework import FunctionApprovalResponseContent @ai_function(name="local_function") def local_func(arg1: str) -> str: return f"Local {arg1}" # Create an approval response for a hosted tool that's not in our tool_map - hosted_function_call = FunctionCallContent( + hosted_function_call = Content.from_function_call( call_id="hosted_1", name="hosted_function", arguments='{"arg1": "value"}' ) - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id="approval_1", function_call=hosted_function_call, approved=True, @@ -1172,7 +1200,6 @@ def local_func(arg1: str) -> str: async def test_unapproved_tool_execution_raises_exception(chat_client_base: ChatClientProtocol): """Test that attempting to execute an unapproved tool raises ToolException.""" - from agent_framework import FunctionApprovalResponseContent @ai_function(name="test_function", approval_mode="always_require") def test_func(arg1: str) -> str: @@ -1183,7 +1210,7 @@ def test_func(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}'), ], ) ), @@ -1193,10 +1220,10 @@ def test_func(arg1: str) -> str: # Get approval request response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]}) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] # Create a rejection response (approved=False) - rejection_response = FunctionApprovalResponseContent( + rejection_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=False, @@ -1214,8 +1241,7 @@ def test_func(arg1: str) -> str: content for msg in all_messages for content in msg.contents - if isinstance(content, FunctionResultContent) - and "rejected" in (content.result or content.exception or "").lower() + if content.type == "function_result" and "rejected" in (content.result or content.exception or "").lower() ), None, ) @@ -1227,7 +1253,6 @@ async def test_approved_function_call_with_error_without_detailed_errors(chat_cl When include_detailed_errors=False. """ - from agent_framework import FunctionApprovalResponseContent exec_counter = 0 @@ -1241,7 +1266,7 @@ def error_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="error_func", arguments='{"arg1": "value1"}')], + contents=[Content.from_function_call(call_id="1", name="error_func", arguments='{"arg1": "value1"}')], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1253,10 +1278,10 @@ def error_func(arg1: str) -> str: # Get approval request response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] # Approve the function - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=True, @@ -1276,7 +1301,7 @@ def error_func(arg1: str) -> str: content for msg in all_messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.exception is not None + if content.type == "function_result" and content.exception is not None ), None, ) @@ -1291,7 +1316,6 @@ async def test_approved_function_call_with_error_with_detailed_errors(chat_clien When include_detailed_errors=True. """ - from agent_framework import FunctionApprovalResponseContent exec_counter = 0 @@ -1305,7 +1329,7 @@ def error_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="error_func", arguments='{"arg1": "value1"}')], + contents=[Content.from_function_call(call_id="1", name="error_func", arguments='{"arg1": "value1"}')], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1317,10 +1341,10 @@ def error_func(arg1: str) -> str: # Get approval request response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]}) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] # Approve the function - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=True, @@ -1340,7 +1364,7 @@ def error_func(arg1: str) -> str: content for msg in all_messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.exception is not None + if content.type == "function_result" and content.exception is not None ), None, ) @@ -1353,7 +1377,6 @@ def error_func(arg1: str) -> str: async def test_approved_function_call_with_validation_error(chat_client_base: ChatClientProtocol): """Test that approved functions with validation errors are handled correctly.""" - from agent_framework import FunctionApprovalResponseContent exec_counter = 0 @@ -1367,7 +1390,9 @@ def typed_func(arg1: int) -> str: # Expects int, not str ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="typed_func", arguments='{"arg1": "not_an_int"}')], + contents=[ + Content.from_function_call(call_id="1", name="typed_func", arguments='{"arg1": "not_an_int"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1379,10 +1404,10 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Get approval request response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]}) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] # Approve the function (even though it will fail validation) - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=True, @@ -1402,7 +1427,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str content for msg in all_messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.exception is not None + if content.type == "function_result" and content.exception is not None ), None, ) @@ -1413,7 +1438,6 @@ def typed_func(arg1: int) -> str: # Expects int, not str async def test_approved_function_call_successful_execution(chat_client_base: ChatClientProtocol): """Test that approved functions execute successfully when no errors occur.""" - from agent_framework import FunctionApprovalResponseContent exec_counter = 0 @@ -1427,7 +1451,7 @@ def success_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="success_func", arguments='{"arg1": "value1"}')], + contents=[Content.from_function_call(call_id="1", name="success_func", arguments='{"arg1": "value1"}')], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1436,10 +1460,10 @@ def success_func(arg1: str) -> str: # Get approval request response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [success_func]}) - approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0] + approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0] # Approve the function - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( id=approval_req.id, function_call=approval_req.function_call, approved=True, @@ -1459,7 +1483,7 @@ def success_func(arg1: str) -> str: content for msg in all_messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.exception is None + if content.type == "function_result" and content.exception is None ), None, ) @@ -1486,7 +1510,9 @@ async def test_declaration_only_tool(chat_client_base: ChatClientProtocol): ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="declaration_func", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="declaration_func", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1501,7 +1527,7 @@ async def test_declaration_only_tool(chat_client_base: ChatClientProtocol): content for msg in response.messages for content in msg.contents - if isinstance(content, FunctionCallContent) and content.name == "declaration_func" + if content.type == "function_call" and content.name == "declaration_func" ] assert len(function_calls) >= 1 @@ -1510,7 +1536,7 @@ async def test_declaration_only_tool(chat_client_base: ChatClientProtocol): content for msg in response.messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.call_id == "1" + if content.type == "function_result" and content.call_id == "1" ] assert len(function_results) == 0 @@ -1540,8 +1566,8 @@ async def func2(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="func1", arguments='{"arg1": "value1"}'), - FunctionCallContent(call_id="2", name="func2", arguments='{"arg1": "value2"}'), + Content.from_function_call(call_id="1", name="func1", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="2", name="func2", arguments='{"arg1": "value2"}'), ], ) ), @@ -1557,9 +1583,7 @@ async def func2(arg1: str) -> str: assert "func2_end" in exec_order # Should have results for both - results = [ - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) - ] + results = [content for msg in response.messages for content in msg.contents if content.type == "function_result"] assert len(results) == 2 @@ -1577,7 +1601,9 @@ def plain_function(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="plain_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="plain_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1588,9 +1614,7 @@ def plain_function(arg1: str) -> str: # Function should be executed assert exec_counter == 1 - result = next( - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) - ) + result = next(content for msg in response.messages for content in msg.contents if content.type == "function_result") assert result.result == "Plain value1" @@ -1606,7 +1630,9 @@ def test_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], ), conversation_id="conv_123", # Simulate service-side thread ), @@ -1619,9 +1645,7 @@ def test_func(arg1: str) -> str: response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]}) # Should have executed the function - results = [ - content for msg in response.messages for content in msg.contents if isinstance(content, FunctionResultContent) - ] + results = [content for msg in response.messages for content in msg.contents if content.type == "function_result"] assert len(results) >= 1 assert response.conversation_id == "conv_123" @@ -1637,7 +1661,9 @@ def test_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1648,10 +1674,8 @@ def test_func(arg1: str) -> str: # Should have messages with both function call and function result assert len(response.messages) >= 2 # Check that we have both a function call and a function result - has_call = any(isinstance(content, FunctionCallContent) for msg in response.messages for content in msg.contents) - has_result = any( - isinstance(content, FunctionResultContent) for msg in response.messages for content in msg.contents - ) + has_call = any(content.type == "function_call" for msg in response.messages for content in msg.contents) + has_result = any(content.type == "function_result" for msg in response.messages for content in msg.contents) assert has_call assert has_result @@ -1673,13 +1697,17 @@ def sometimes_fails(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="sometimes_fails", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="sometimes_fails", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="2", name="sometimes_fails", arguments='{"arg1": "value2"}')], + contents=[ + Content.from_function_call(call_id="2", name="sometimes_fails", arguments='{"arg1": "value2"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -1692,13 +1720,13 @@ def sometimes_fails(arg1: str) -> str: content for msg in response.messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.exception + if content.type == "function_result" and content.exception ] success_results = [ content for msg in response.messages for content in msg.contents - if isinstance(content, FunctionResultContent) and content.result + if content.type == "function_result" and content.result ] assert len(error_results) >= 1 @@ -1723,7 +1751,7 @@ def func_with_approval(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}')], + contents=[Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}')], role="assistant", ), ], @@ -1738,10 +1766,7 @@ def func_with_approval(arg1: str) -> str: # Should have function call update and approval request approval_requests = [ - content - for update in updates - for content in update.contents - if isinstance(content, FunctionApprovalRequestContent) + content for update in updates for content in update.contents if content.type == "function_approval_request" ] assert len(approval_requests) == 1 assert approval_requests[0].function_call.name == "test_func" @@ -1762,26 +1787,26 @@ def ai_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1":')], + contents=[Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1":')], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='"value1"}')], + contents=[Content.from_function_call(call_id="1", name="test_function", arguments='"value1"}')], role="assistant", ), ], [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="2", name="test_function", arguments='{"arg1":')], + contents=[Content.from_function_call(call_id="2", name="test_function", arguments='{"arg1":')], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="2", name="test_function", arguments='"value2"}')], + contents=[Content.from_function_call(call_id="2", name="test_function", arguments='"value2"}')], role="assistant", ), ], # Failsafe response when tool_choice is set to "none" - [ChatResponseUpdate(contents=[TextContent(text="giving up on tools")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="giving up on tools")], role="assistant")], ] # Set max_iterations to 1 in additional_properties @@ -1811,7 +1836,7 @@ def ai_func(arg1: str) -> str: return f"Processed {arg1}" chat_client_base.streaming_responses = [ - [ChatResponseUpdate(contents=[TextContent(text="response without function calling")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="response without function calling")], role="assistant")], ] # Disable function invocation @@ -1840,23 +1865,29 @@ def error_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="error_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="error_function", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="2", name="error_function", arguments='{"arg1": "value2"}')], + contents=[ + Content.from_function_call(call_id="2", name="error_function", arguments='{"arg1": "value2"}') + ], role="assistant", ), ], [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="3", name="error_function", arguments='{"arg1": "value3"}')], + contents=[ + Content.from_function_call(call_id="3", name="error_function", arguments='{"arg1": "value3"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="final response")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="final response")], role="assistant")], ] # Set max_consecutive_errors to 2 @@ -1873,14 +1904,12 @@ def error_func(arg1: str) -> str: content for update in updates for content in update.contents - if isinstance(content, FunctionResultContent) and content.exception + if content.type == "function_result" and content.exception ] # At least one error occurred assert len(error_results) >= 1 # Should have stopped making new function calls after hitting the error limit - function_calls = [ - content for update in updates for content in update.contents if isinstance(content, FunctionCallContent) - ] + function_calls = [content for update in updates for content in update.contents if content.type == "function_call"] # Should have made at most 2 function calls before stopping assert len(function_calls) <= 2 @@ -1900,11 +1929,13 @@ def known_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] # Set terminate_on_unknown_calls to False (default) @@ -1918,7 +1949,7 @@ def known_func(arg1: str) -> str: # Should have a result message indicating the tool wasn't found result_contents = [ - content for update in updates for content in update.contents if isinstance(content, FunctionResultContent) + content for update in updates for content in update.contents if content.type == "function_result" ] assert len(result_contents) >= 1 result_str = result_contents[0].result or result_contents[0].exception or "" @@ -1941,7 +1972,9 @@ def known_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="unknown_function", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], @@ -1970,11 +2003,13 @@ def error_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="error_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="error_function", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] # Set include_detailed_errors to True @@ -1988,7 +2023,7 @@ def error_func(arg1: str) -> str: # Should have detailed error message error_result = next( - content for update in updates for content in update.contents if isinstance(content, FunctionResultContent) + content for update in updates for content in update.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -2008,11 +2043,13 @@ def error_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="error_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="error_function", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] # Set include_detailed_errors to False (default) @@ -2026,7 +2063,7 @@ def error_func(arg1: str) -> str: # Should have a generic error message error_result = next( - content for update in updates for content in update.contents if isinstance(content, FunctionResultContent) + content for update in updates for content in update.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -2044,11 +2081,13 @@ def typed_func(arg1: int) -> str: # Expects int, not str chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}')], + contents=[ + Content.from_function_call(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] # Set include_detailed_errors to True @@ -2062,7 +2101,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Should have detailed validation error error_result = next( - content for update in updates for content in update.contents if isinstance(content, FunctionResultContent) + content for update in updates for content in update.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -2080,11 +2119,13 @@ def typed_func(arg1: int) -> str: # Expects int, not str chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}')], + contents=[ + Content.from_function_call(call_id="1", name="typed_function", arguments='{"arg1": "not_an_int"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] # Set include_detailed_errors to False (default) @@ -2098,7 +2139,7 @@ def typed_func(arg1: int) -> str: # Expects int, not str # Should have generic validation error error_result = next( - content for update in updates for content in update.contents if isinstance(content, FunctionResultContent) + content for update in updates for content in update.contents if content.type == "function_result" ) assert error_result.result is not None assert error_result.exception is not None @@ -2129,15 +2170,15 @@ async def func2(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="func1", arguments='{"arg1": "value1"}')], + contents=[Content.from_function_call(call_id="1", name="func1", arguments='{"arg1": "value1"}')], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="2", name="func2", arguments='{"arg1": "value2"}')], + contents=[Content.from_function_call(call_id="2", name="func2", arguments='{"arg1": "value2"}')], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] updates = [] @@ -2153,9 +2194,7 @@ async def func2(arg1: str) -> str: assert "func2_end" in exec_order # Should have results for both - results = [ - content for update in updates for content in update.contents if isinstance(content, FunctionResultContent) - ] + results = [content for update in updates for content in update.contents if content.type == "function_result"] assert len(results) == 2 @@ -2173,7 +2212,7 @@ def func_with_approval(arg1: str) -> str: [ ChatResponseUpdate( contents=[ - FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), + Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}'), ], role="assistant", ), @@ -2188,10 +2227,7 @@ def func_with_approval(arg1: str) -> str: # Should have updates containing both the call and approval request approval_requests = [ - content - for update in updates - for content in update.contents - if isinstance(content, FunctionApprovalRequestContent) + content for update in updates for content in update.contents if content.type == "function_approval_request" ] assert len(approval_requests) == 1 assert exec_counter == 0 @@ -2213,17 +2249,21 @@ def sometimes_fails(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="sometimes_fails", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="sometimes_fails", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="2", name="sometimes_fails", arguments='{"arg1": "value2"}')], + contents=[ + Content.from_function_call(call_id="2", name="sometimes_fails", arguments='{"arg1": "value2"}') + ], role="assistant", ), ], - [ChatResponseUpdate(contents=[TextContent(text="done")], role="assistant")], + [ChatResponseUpdate(contents=[Content.from_text(text="done")], role="assistant")], ] updates = [] @@ -2237,13 +2277,13 @@ def sometimes_fails(arg1: str) -> str: content for update in updates for content in update.contents - if isinstance(content, FunctionResultContent) and content.exception + if content.type == "function_result" and content.exception ] success_results = [ content for update in updates for content in update.contents - if isinstance(content, FunctionResultContent) and content.result + if content.type == "function_result" and content.result ] assert len(error_results) >= 1 @@ -2278,7 +2318,9 @@ def ai_func(arg1: str) -> str: ChatResponse( messages=ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], ) ), ChatResponse(messages=ChatMessage(role="assistant", text="done")), @@ -2297,9 +2339,9 @@ def ai_func(arg1: str) -> str: # The loop should NOT have continued to call the LLM again assert len(response.messages) == 2 assert response.messages[0].role == Role.ASSISTANT - assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[0].contents[0].type == "function_call" assert response.messages[1].role == Role.TOOL - assert isinstance(response.messages[1].contents[0], FunctionResultContent) + assert response.messages[1].contents[0].type == "function_result" assert response.messages[1].contents[0].result == "terminated by middleware" # Verify the second response is still in the queue (wasn't consumed) @@ -2343,8 +2385,10 @@ def terminating_func(arg1: str) -> str: messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'), - FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'), + Content.from_function_call(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'), + Content.from_function_call( + call_id="2", name="terminating_function", arguments='{"arg1": "value2"}' + ), ], ) ), @@ -2389,13 +2433,15 @@ def ai_func(arg1: str) -> str: chat_client_base.streaming_responses = [ [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')], + contents=[ + Content.from_function_call(call_id="1", name="test_function", arguments='{"arg1": "value1"}') + ], role="assistant", ), ], [ ChatResponseUpdate( - contents=[TextContent(text="done")], + contents=[Content.from_text(text="done")], role="assistant", ) ], diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py index fc6acb435d..1a206d9646 100644 --- a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -8,8 +8,7 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - FunctionCallContent, - TextContent, + Content, ai_function, ) from agent_framework._tools import _handle_function_calls_response, _handle_function_calls_streaming_response @@ -42,7 +41,9 @@ async def mock_get_response(self, messages, **kwargs): ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="call_1", name="capture_kwargs_tool", arguments='{"x": 42}') + Content.from_function_call( + call_id="call_1", name="capture_kwargs_tool", arguments='{"x": 42}' + ) ], ) ] @@ -94,7 +95,9 @@ async def mock_get_response(self, messages, **kwargs): messages=[ ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="call_1", name="simple_tool", arguments='{"x": 99}')], + contents=[ + Content.from_function_call(call_id="call_1", name="simple_tool", arguments='{"x": 99}') + ], ) ] ) @@ -136,10 +139,10 @@ async def mock_get_response(self, messages, **kwargs): ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_1", name="tracking_tool", arguments='{"name": "first"}' ), - FunctionCallContent( + Content.from_function_call( call_id="call_2", name="tracking_tool", arguments='{"name": "second"}' ), ], @@ -187,7 +190,7 @@ async def mock_get_streaming_response(self, messages, **kwargs): yield ChatResponseUpdate( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="stream_call_1", name="streaming_capture_tool", arguments='{"value": "streaming-test"}', @@ -197,7 +200,9 @@ async def mock_get_streaming_response(self, messages, **kwargs): ) else: # Second call: return final response - yield ChatResponseUpdate(text=TextContent(text="Stream complete!"), role="assistant", is_finished=True) + yield ChatResponseUpdate( + text=Content.from_text(text="Stream complete!"), role="assistant", is_finished=True + ) wrapped = _handle_function_calls_streaming_response(mock_get_streaming_response) diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index c4e8cb09df..5780113976 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -13,14 +13,12 @@ from agent_framework import ( ChatMessage, - DataContent, + Content, MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool, Role, - TextContent, ToolProtocol, - UriContent, ) from agent_framework._mcp import ( MCPTool, @@ -65,7 +63,7 @@ def test_mcp_prompt_message_to_ai_content(): assert isinstance(ai_content, ChatMessage) assert ai_content.role.value == "user" assert len(ai_content.contents) == 1 - assert isinstance(ai_content.contents[0], TextContent) + assert ai_content.contents[0].type == "text" assert ai_content.contents[0].text == "Hello, world!" assert ai_content.raw_representation == mcp_message @@ -75,20 +73,20 @@ def test_parse_contents_from_mcp_tool_result(): mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Result text"), - types.ImageContent(type="image", data="xyz", mimeType="image/png"), - types.ImageContent(type="image", data=b"abc", mimeType="image/webp"), + types.ImageContent(type="image", data="eHl6", mimeType="image/png"), # base64 for "xyz" + types.ImageContent(type="image", data="YWJj", mimeType="image/webp"), # base64 for "abc" ] ) ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 3 - assert isinstance(ai_contents[0], TextContent) + assert ai_contents[0].type == "text" assert ai_contents[0].text == "Result text" - assert isinstance(ai_contents[1], DataContent) - assert ai_contents[1].uri == "data:image/png;base64,xyz" + assert ai_contents[1].type == "data" + assert ai_contents[1].uri == "data:image/png;base64,eHl6" assert ai_contents[1].media_type == "image/png" - assert isinstance(ai_contents[2], DataContent) - assert ai_contents[2].uri == "data:image/webp;base64,abc" + assert ai_contents[2].type == "data" + assert ai_contents[2].uri == "data:image/webp;base64,YWJj" assert ai_contents[2].media_type == "image/webp" @@ -103,7 +101,7 @@ def test_mcp_call_tool_result_with_meta_error(): ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 - assert isinstance(ai_contents[0], TextContent) + assert ai_contents[0].type == "text" assert ai_contents[0].text == "Error occurred" # Check that _meta data is merged into additional_properties @@ -134,7 +132,7 @@ def test_mcp_call_tool_result_with_meta_arbitrary_data(): ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 - assert isinstance(ai_contents[0], TextContent) + assert ai_contents[0].type == "text" assert ai_contents[0].text == "Success result" # Check that _meta data is preserved in additional_properties @@ -172,7 +170,7 @@ def test_mcp_call_tool_result_with_meta_none(): ai_contents = _parse_contents_from_mcp_tool_result(mcp_result) assert len(ai_contents) == 1 - assert isinstance(ai_contents[0], TextContent) + assert ai_contents[0].type == "text" assert ai_contents[0].text == "No meta test" # Should handle gracefully when no _meta field exists @@ -187,7 +185,7 @@ def test_mcp_call_tool_result_regression_successful_workflow(): mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Success message"), - types.ImageContent(type="image", data="abc123", mimeType="image/jpeg"), + types.ImageContent(type="image", data="YWJjMTIz", mimeType="image/jpeg"), # base64 for "abc123" ] ) @@ -197,12 +195,12 @@ def test_mcp_call_tool_result_regression_successful_workflow(): assert len(ai_contents) == 2 text_content = ai_contents[0] - assert isinstance(text_content, TextContent) + assert text_content.type == "text" assert text_content.text == "Success message" image_content = ai_contents[1] - assert isinstance(image_content, DataContent) - assert image_content.uri == "data:image/jpeg;base64,abc123" + assert image_content.type == "data" + assert image_content.uri == "data:image/jpeg;base64,YWJjMTIz" assert image_content.media_type == "image/jpeg" # Should have no additional_properties when no _meta field @@ -215,30 +213,31 @@ def test_mcp_content_types_to_ai_content_text(): mcp_content = types.TextContent(type="text", text="Sample text") ai_content = _parse_content_from_mcp(mcp_content)[0] - assert isinstance(ai_content, TextContent) + assert ai_content.type == "text" assert ai_content.text == "Sample text" assert ai_content.raw_representation == mcp_content def test_mcp_content_types_to_ai_content_image(): """Test conversion of MCP image content to AI content.""" - mcp_content = types.ImageContent(type="image", data="abc", mimeType="image/jpeg") - mcp_content = types.ImageContent(type="image", data=b"abc", mimeType="image/jpeg") + # MCP can send data as base64 string or as bytes + mcp_content = types.ImageContent(type="image", data="YWJj", mimeType="image/jpeg") # base64 for b"abc" ai_content = _parse_content_from_mcp(mcp_content)[0] - assert isinstance(ai_content, DataContent) - assert ai_content.uri == "data:image/jpeg;base64,abc" + assert ai_content.type == "data" + assert ai_content.uri == "data:image/jpeg;base64,YWJj" assert ai_content.media_type == "image/jpeg" assert ai_content.raw_representation == mcp_content def test_mcp_content_types_to_ai_content_audio(): """Test conversion of MCP audio content to AI content.""" - mcp_content = types.AudioContent(type="audio", data="def", mimeType="audio/wav") + # Use properly padded base64 + mcp_content = types.AudioContent(type="audio", data="ZGVm", mimeType="audio/wav") # base64 for b"def" ai_content = _parse_content_from_mcp(mcp_content)[0] - assert isinstance(ai_content, DataContent) - assert ai_content.uri == "data:audio/wav;base64,def" + assert ai_content.type == "data" + assert ai_content.uri == "data:audio/wav;base64,ZGVm" assert ai_content.media_type == "audio/wav" assert ai_content.raw_representation == mcp_content @@ -253,7 +252,7 @@ def test_mcp_content_types_to_ai_content_resource_link(): ) ai_content = _parse_content_from_mcp(mcp_content)[0] - assert isinstance(ai_content, UriContent) + assert ai_content.type == "uri" assert ai_content.uri == "https://example.com/resource" assert ai_content.media_type == "application/json" assert ai_content.raw_representation == mcp_content @@ -269,7 +268,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_text(): mcp_content = types.EmbeddedResource(type="resource", resource=text_resource) ai_content = _parse_content_from_mcp(mcp_content)[0] - assert isinstance(ai_content, TextContent) + assert ai_content.type == "text" assert ai_content.text == "Embedded text content" assert ai_content.raw_representation == mcp_content @@ -285,7 +284,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_blob(): mcp_content = types.EmbeddedResource(type="resource", resource=blob_resource) ai_content = _parse_content_from_mcp(mcp_content)[0] - assert isinstance(ai_content, DataContent) + assert ai_content.type == "data" assert ai_content.uri == "data:application/octet-stream;base64,dGVzdCBkYXRh" assert ai_content.media_type == "application/octet-stream" assert ai_content.raw_representation == mcp_content @@ -293,7 +292,7 @@ def test_mcp_content_types_to_ai_content_embedded_resource_blob(): def test_ai_content_to_mcp_content_types_text(): """Test conversion of AI text content to MCP content.""" - ai_content = TextContent(text="Sample text") + ai_content = Content.from_text(text="Sample text") mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.TextContent) @@ -303,7 +302,7 @@ def test_ai_content_to_mcp_content_types_text(): def test_ai_content_to_mcp_content_types_data_image(): """Test conversion of AI data content to MCP content.""" - ai_content = DataContent(uri="data:image/png;base64,xyz", media_type="image/png") + ai_content = Content.from_uri(uri="data:image/png;base64,xyz", media_type="image/png") mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.ImageContent) @@ -314,7 +313,7 @@ def test_ai_content_to_mcp_content_types_data_image(): def test_ai_content_to_mcp_content_types_data_audio(): """Test conversion of AI data content to MCP content.""" - ai_content = DataContent(uri="data:audio/mpeg;base64,xyz", media_type="audio/mpeg") + ai_content = Content.from_uri(uri="data:audio/mpeg;base64,xyz", media_type="audio/mpeg") mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.AudioContent) @@ -325,7 +324,7 @@ def test_ai_content_to_mcp_content_types_data_audio(): def test_ai_content_to_mcp_content_types_data_binary(): """Test conversion of AI data content to MCP content.""" - ai_content = DataContent( + ai_content = Content.from_uri( uri="data:application/octet-stream;base64,xyz", media_type="application/octet-stream", ) @@ -339,7 +338,7 @@ def test_ai_content_to_mcp_content_types_data_binary(): def test_ai_content_to_mcp_content_types_uri(): """Test conversion of AI URI content to MCP content.""" - ai_content = UriContent(uri="https://example.com/resource", media_type="application/json") + ai_content = Content.from_uri(uri="https://example.com/resource", media_type="application/json") mcp_content = _prepare_content_for_mcp(ai_content) assert isinstance(mcp_content, types.ResourceLink) @@ -352,8 +351,8 @@ def test_prepare_message_for_mcp(): message = ChatMessage( role="user", contents=[ - TextContent(text="test"), - DataContent(uri="data:image/png;base64,xyz", media_type="image/png"), + Content.from_text(text="test"), + Content.from_uri(uri="data:image/png;base64,xyz", media_type="image/png"), ], ) mcp_contents = _prepare_message_for_mcp(message) @@ -871,7 +870,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: result = await func.invoke(param="test_value") assert len(result) == 1 - assert isinstance(result[0], TextContent) + assert result[0].type == "text" assert result[0].text == "Tool executed with metadata" # Verify that _meta data is present in additional_properties @@ -920,7 +919,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: result = await func.invoke(param="test_value") assert len(result) == 1 - assert isinstance(result[0], TextContent) + assert result[0].type == "text" assert result[0].text == "Tool executed successfully" @@ -969,7 +968,7 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: result = await func.invoke(params={"customer_id": 251}) assert len(result) == 1 - assert isinstance(result[0], TextContent) + assert result[0].type == "text" # Verify the session.call_tool was called with the correct nested structure server.session.call_tool.assert_called_once() @@ -1413,7 +1412,7 @@ async def test_mcp_tool_sampling_callback_chat_client_exception(): async def test_mcp_tool_sampling_callback_no_valid_content(): """Test sampling callback when response has no valid content types.""" - from agent_framework import ChatMessage, DataContent, Role + from agent_framework import ChatMessage, Role tool = MCPStdioTool(name="test_tool", command="python") @@ -1424,7 +1423,7 @@ async def test_mcp_tool_sampling_callback_no_valid_content(): ChatMessage( role=Role.ASSISTANT, contents=[ - DataContent( + Content.from_uri( uri="data:application/json;base64,e30K", media_type="application/json", ) diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index ebb833f2b4..441896f92b 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -14,8 +14,8 @@ ChatMessage, ChatResponse, ChatResponseUpdate, + Content, Role, - TextContent, ) from agent_framework._middleware import ( AgentMiddleware, @@ -217,8 +217,8 @@ async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> context = AgentRunContext(agent=mock_agent, messages=messages) async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): @@ -250,8 +250,8 @@ async def process( async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") - yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) execution_order.append("handler_end") updates: list[AgentResponseUpdate] = [] @@ -313,8 +313,8 @@ async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentP async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: # Handler should not be executed when terminated before next() execution_order.append("handler_start") - yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) execution_order.append("handler_end") updates: list[AgentResponseUpdate] = [] @@ -336,8 +336,8 @@ async def test_execute_stream_with_post_next_termination(self, mock_agent: Agent async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") - yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) - yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) execution_order.append("handler_end") updates: list[AgentResponseUpdate] = [] @@ -609,8 +609,8 @@ async def test_execute_stream_no_middleware(self, mock_chat_client: Any) -> None context = ChatContext(chat_client=mock_chat_client, messages=messages, options=chat_options) async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) - yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk2")]) updates: list[ChatResponseUpdate] = [] async for update in pipeline.execute_stream(mock_chat_client, messages, chat_options, context, final_handler): @@ -641,8 +641,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: execution_order.append("handler_start") - yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) - yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk2")]) execution_order.append("handler_end") updates: list[ChatResponseUpdate] = [] @@ -706,8 +706,8 @@ async def test_execute_stream_with_pre_next_termination(self, mock_chat_client: async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: # Handler should not be executed when terminated before next() execution_order.append("handler_start") - yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) - yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk2")]) execution_order.append("handler_end") updates: list[ChatResponseUpdate] = [] @@ -730,8 +730,8 @@ async def test_execute_stream_with_post_next_termination(self, mock_chat_client: async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: execution_order.append("handler_start") - yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) - yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk1")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk2")]) execution_order.append("handler_end") updates: list[ChatResponseUpdate] = [] @@ -1264,7 +1264,7 @@ async def final_handler(ctx: AgentRunContext) -> AgentResponse: async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: streaming_flags.append(ctx.is_streaming) - yield AgentResponseUpdate(contents=[TextContent(text="chunk")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk")]) updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context_stream, final_stream_handler): @@ -1292,9 +1292,9 @@ async def process( async def final_stream_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: chunks_processed.append("stream_start") - yield AgentResponseUpdate(contents=[TextContent(text="chunk1")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) chunks_processed.append("chunk1_yielded") - yield AgentResponseUpdate(contents=[TextContent(text="chunk2")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) chunks_processed.append("chunk2_yielded") chunks_processed.append("stream_end") @@ -1342,7 +1342,7 @@ async def final_handler(ctx: ChatContext) -> ChatResponse: async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: streaming_flags.append(ctx.is_streaming) - yield ChatResponseUpdate(contents=[TextContent(text="chunk")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk")]) updates: list[ChatResponseUpdate] = [] async for update in pipeline.execute_stream( @@ -1371,9 +1371,9 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai async def final_stream_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: chunks_processed.append("stream_start") - yield ChatResponseUpdate(contents=[TextContent(text="chunk1")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk1")]) chunks_processed.append("chunk1_yielded") - yield ChatResponseUpdate(contents=[TextContent(text="chunk2")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="chunk2")]) chunks_processed.append("chunk2_yielded") chunks_processed.append("stream_end") @@ -1486,7 +1486,7 @@ async def process( async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: nonlocal handler_called handler_called = True - yield AgentResponseUpdate(contents=[TextContent(text="should not execute")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="should not execute")]) # When middleware doesn't call next(), streaming should yield no updates updates: list[AgentResponseUpdate] = [] @@ -1617,7 +1617,7 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai async def final_handler(ctx: ChatContext) -> AsyncIterable[ChatResponseUpdate]: nonlocal handler_called handler_called = True - yield ChatResponseUpdate(contents=[TextContent(text="should not execute")]) + yield ChatResponseUpdate(contents=[Content.from_text(text="should not execute")]) # When middleware doesn't call next(), streaming should yield no updates updates: list[ChatResponseUpdate] = [] diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index bfcfb48e5f..f939a0f409 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -13,8 +13,8 @@ AgentResponseUpdate, ChatAgent, ChatMessage, + Content, Role, - TextContent, ) from agent_framework._middleware import ( AgentMiddleware, @@ -75,8 +75,8 @@ async def test_agent_middleware_response_override_streaming(self, mock_agent: Ag """Test that agent middleware can override response for streaming execution.""" async def override_stream() -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[TextContent(text="overridden")]) - yield AgentResponseUpdate(contents=[TextContent(text=" stream")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="overridden")]) + yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")]) class StreamResponseOverrideMiddleware(AgentMiddleware): async def process( @@ -92,7 +92,7 @@ async def process( context = AgentRunContext(agent=mock_agent, messages=messages) async def final_handler(ctx: AgentRunContext) -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[TextContent(text="original")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="original")]) updates: list[AgentResponseUpdate] = [] async for update in pipeline.execute_stream(mock_agent, messages, context, final_handler): @@ -175,9 +175,9 @@ async def test_chat_agent_middleware_streaming_override(self) -> None: mock_chat_client = MockChatClient() async def custom_stream() -> AsyncIterable[AgentResponseUpdate]: - yield AgentResponseUpdate(contents=[TextContent(text="Custom")]) - yield AgentResponseUpdate(contents=[TextContent(text=" streaming")]) - yield AgentResponseUpdate(contents=[TextContent(text=" response!")]) + yield AgentResponseUpdate(contents=[Content.from_text(text="Custom")]) + yield AgentResponseUpdate(contents=[Content.from_text(text=" streaming")]) + yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")]) class ChatAgentStreamOverrideMiddleware(AgentMiddleware): async def process( diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 4994cf4981..445f13596a 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -13,10 +13,8 @@ ChatMiddleware, ChatResponse, ChatResponseUpdate, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, agent_middleware, chat_middleware, function_middleware, @@ -201,7 +199,9 @@ async def process( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"}) + Content.from_function_call( + call_id="test_call", name="test_function", arguments={"text": "test"} + ) ], ) ] @@ -256,7 +256,9 @@ async def process( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent(call_id="test_call", name="test_function", arguments={"text": "test"}) + Content.from_function_call( + call_id="test_call", name="test_function", arguments={"text": "test"} + ) ], ) ] @@ -365,8 +367,8 @@ async def process( # Set up mock streaming responses chat_client.streaming_responses = [ [ - ChatResponseUpdate(contents=[TextContent(text="Streaming")], role=Role.ASSISTANT), - ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[Content.from_text(text="Streaming")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[Content.from_text(text=" response")], role=Role.ASSISTANT), ] ] @@ -550,7 +552,7 @@ async def process( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_123", name="sample_tool_function", arguments='{"location": "Seattle"}', @@ -585,8 +587,8 @@ async def process( # Verify function call and result are in the response all_contents = [content for message in response.messages for content in message.contents] - function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] - function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + function_calls = [c for c in all_contents if c.type == "function_call"] + function_results = [c for c in all_contents if c.type == "function_result"] assert len(function_calls) == 1 assert len(function_results) == 1 @@ -610,7 +612,7 @@ async def tracking_function_middleware( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_456", name="sample_tool_function", arguments='{"location": "San Francisco"}', @@ -644,8 +646,8 @@ async def tracking_function_middleware( # Verify function call and result are in the response all_contents = [content for message in response.messages for content in message.contents] - function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] - function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + function_calls = [c for c in all_contents if c.type == "function_call"] + function_results = [c for c in all_contents if c.type == "function_result"] assert len(function_calls) == 1 assert len(function_results) == 1 @@ -682,7 +684,7 @@ async def process( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_789", name="sample_tool_function", arguments='{"location": "New York"}', @@ -723,8 +725,8 @@ async def process( # Verify function call and result are in the response all_contents = [content for message in response.messages for content in message.contents] - function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] - function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + function_calls = [c for c in all_contents if c.type == "function_call"] + function_results = [c for c in all_contents if c.type == "function_result"] assert len(function_calls) == 1 assert len(function_results) == 1 @@ -769,14 +771,16 @@ async def kwargs_middleware( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="test_call", name="sample_tool_function", arguments={"location": "Seattle"} ) ], ) ] ), - ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent("Function completed")])]), + ChatResponse( + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text("Function completed")])] + ), ] # Create ChatAgent with function middleware @@ -1076,8 +1080,8 @@ async def process( # Set up mock streaming responses chat_client.streaming_responses = [ [ - ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT), - ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[Content.from_text(text="Stream")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[Content.from_text(text=" response")], role=Role.ASSISTANT), ] ] @@ -1159,7 +1163,7 @@ def custom_tool(message: str) -> str: ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="test_call", name="custom_tool", arguments='{"message": "test"}', @@ -1204,8 +1208,8 @@ def custom_tool(message: str) -> str: # Verify function call and result are in the response all_contents = [content for message in response.messages for content in message.contents] - function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] - function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + function_calls = [c for c in all_contents if c.type == "function_call"] + function_results = [c for c in all_contents if c.type == "function_result"] assert len(function_calls) == 1 assert len(function_results) == 1 @@ -1248,7 +1252,7 @@ def custom_tool(message: str) -> str: ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="test_call", name="custom_tool", arguments='{"message": "test"}', @@ -1315,7 +1319,7 @@ def custom_tool(message: str) -> str: ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="test_call", name="custom_tool", arguments='{"message": "test"}', @@ -1365,7 +1369,7 @@ def custom_tool(message: str) -> str: ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="test_call", name="custom_tool", arguments='{"message": "test"}', @@ -1704,8 +1708,8 @@ async def process(self, context: ChatContext, next: Callable[[ChatContext], Awai # Set up mock streaming responses chat_client.streaming_responses = [ [ - ChatResponseUpdate(contents=[TextContent(text="Stream")], role=Role.ASSISTANT), - ChatResponseUpdate(contents=[TextContent(text=" response")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[Content.from_text(text="Stream")], role=Role.ASSISTANT), + ChatResponseUpdate(contents=[Content.from_text(text=" response")], role=Role.ASSISTANT), ] ] @@ -1806,7 +1810,7 @@ async def function_middleware( ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_456", name="sample_tool_function", arguments='{"location": "San Francisco"}', @@ -1850,8 +1854,8 @@ async def function_middleware( # Verify function call and result are in the response all_contents = [content for message in response.messages for content in message.contents] - function_calls = [c for c in all_contents if isinstance(c, FunctionCallContent)] - function_results = [c for c in all_contents if isinstance(c, FunctionResultContent)] + function_calls = [c for c in all_contents if c.type == "function_call"] + function_results = [c for c in all_contents if c.type == "function_result"] assert len(function_calls) == 1 assert len(function_results) == 1 diff --git a/python/packages/core/tests/core/test_middleware_with_chat.py b/python/packages/core/tests/core/test_middleware_with_chat.py index 9d395284ea..a24d0e8037 100644 --- a/python/packages/core/tests/core/test_middleware_with_chat.py +++ b/python/packages/core/tests/core/test_middleware_with_chat.py @@ -9,7 +9,7 @@ ChatMessage, ChatMiddleware, ChatResponse, - FunctionCallContent, + Content, FunctionInvocationContext, Role, chat_middleware, @@ -349,7 +349,7 @@ def sample_tool(location: str) -> str: ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_1", name="sample_tool", arguments={"location": "San Francisco"}, @@ -405,7 +405,7 @@ def sample_tool(location: str) -> str: ChatMessage( role=Role.ASSISTANT, contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_2", name="sample_tool", arguments={"location": "New York"}, diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 77442be322..ee6103a613 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -9,6 +9,7 @@ from agent_framework import ( AIFunction, + Content, HostedCodeInterpreterTool, HostedImageGenerationTool, HostedMCPTool, @@ -639,24 +640,22 @@ def test_parse_inputs_none(): def test_parse_inputs_string(): """Test _parse_inputs with string input.""" - from agent_framework import UriContent result = _parse_inputs("http://example.com") assert len(result) == 1 - assert isinstance(result[0], UriContent) + assert result[0].type == "uri" assert result[0].uri == "http://example.com" assert result[0].media_type == "text/plain" def test_parse_inputs_list_of_strings(): """Test _parse_inputs with list of strings.""" - from agent_framework import UriContent inputs = ["http://example.com", "https://test.org"] result = _parse_inputs(inputs) assert len(result) == 2 - assert all(isinstance(item, UriContent) for item in result) + assert all(item.type == "uri" for item in result) assert result[0].uri == "http://example.com" assert result[1].uri == "https://test.org" assert all(item.media_type == "text/plain" for item in result) @@ -664,88 +663,84 @@ def test_parse_inputs_list_of_strings(): def test_parse_inputs_uri_dict(): """Test _parse_inputs with URI dictionary.""" - from agent_framework import UriContent input_dict = {"uri": "http://example.com", "media_type": "application/json"} result = _parse_inputs(input_dict) assert len(result) == 1 - assert isinstance(result[0], UriContent) + assert result[0].type == "uri" assert result[0].uri == "http://example.com" assert result[0].media_type == "application/json" def test_parse_inputs_hosted_file_dict(): """Test _parse_inputs with hosted file dictionary.""" - from agent_framework import HostedFileContent input_dict = {"file_id": "file-123"} result = _parse_inputs(input_dict) assert len(result) == 1 - assert isinstance(result[0], HostedFileContent) + assert result[0].type == "hosted_file" assert result[0].file_id == "file-123" def test_parse_inputs_hosted_vector_store_dict(): """Test _parse_inputs with hosted vector store dictionary.""" - from agent_framework import HostedVectorStoreContent + from agent_framework import Content input_dict = {"vector_store_id": "vs-789"} result = _parse_inputs(input_dict) assert len(result) == 1 - assert isinstance(result[0], HostedVectorStoreContent) + assert isinstance(result[0], Content) + assert result[0].type == "hosted_vector_store" assert result[0].vector_store_id == "vs-789" def test_parse_inputs_data_dict(): """Test _parse_inputs with data dictionary.""" - from agent_framework import DataContent input_dict = {"data": b"test data", "media_type": "application/octet-stream"} result = _parse_inputs(input_dict) assert len(result) == 1 - assert isinstance(result[0], DataContent) + assert result[0].type == "data" assert result[0].uri == "data:application/octet-stream;base64,dGVzdCBkYXRh" assert result[0].media_type == "application/octet-stream" def test_parse_inputs_ai_contents_instance(): - """Test _parse_inputs with Contents instance.""" - from agent_framework import TextContent + """Test _parse_inputs with Content instance.""" - text_content = TextContent(text="Hello, world!") + text_content = Content.from_text(text="Hello, world!") result = _parse_inputs(text_content) assert len(result) == 1 - assert isinstance(result[0], TextContent) + assert result[0].type == "text" assert result[0].text == "Hello, world!" def test_parse_inputs_mixed_list(): """Test _parse_inputs with mixed input types.""" - from agent_framework import HostedFileContent, TextContent, UriContent inputs = [ "http://example.com", # string {"uri": "https://test.org", "media_type": "text/html"}, # URI dict {"file_id": "file-456"}, # hosted file dict - TextContent(text="Hello"), # Contents instance + Content.from_text(text="Hello"), # Content instance ] result = _parse_inputs(inputs) assert len(result) == 4 - assert isinstance(result[0], UriContent) + assert result[0].type == "uri" assert result[0].uri == "http://example.com" - assert isinstance(result[1], UriContent) + assert result[1].type == "uri" assert result[1].uri == "https://test.org" assert result[1].media_type == "text/html" - assert isinstance(result[2], HostedFileContent) + assert result[2].type == "hosted_file" assert result[2].file_id == "file-456" - assert isinstance(result[3], TextContent) + assert result[3].type == "text" assert result[3].text == "Hello" @@ -765,55 +760,51 @@ def test_parse_inputs_unsupported_type(): def test_hosted_code_interpreter_tool_with_string_input(): """Test HostedCodeInterpreterTool with string input.""" - from agent_framework import UriContent tool = HostedCodeInterpreterTool(inputs="http://example.com") assert len(tool.inputs) == 1 - assert isinstance(tool.inputs[0], UriContent) + assert tool.inputs[0].type == "uri" assert tool.inputs[0].uri == "http://example.com" def test_hosted_code_interpreter_tool_with_dict_inputs(): """Test HostedCodeInterpreterTool with dictionary inputs.""" - from agent_framework import HostedFileContent, UriContent inputs = [{"uri": "http://example.com", "media_type": "text/html"}, {"file_id": "file-123"}] tool = HostedCodeInterpreterTool(inputs=inputs) assert len(tool.inputs) == 2 - assert isinstance(tool.inputs[0], UriContent) + assert tool.inputs[0].type == "uri" assert tool.inputs[0].uri == "http://example.com" assert tool.inputs[0].media_type == "text/html" - assert isinstance(tool.inputs[1], HostedFileContent) + assert tool.inputs[1].type == "hosted_file" assert tool.inputs[1].file_id == "file-123" def test_hosted_code_interpreter_tool_with_ai_contents(): - """Test HostedCodeInterpreterTool with Contents instances.""" - from agent_framework import DataContent, TextContent + """Test HostedCodeInterpreterTool with Content instances.""" - inputs = [TextContent(text="Hello, world!"), DataContent(data=b"test", media_type="text/plain")] + inputs = [Content.from_text(text="Hello, world!"), Content.from_data(data=b"test", media_type="text/plain")] tool = HostedCodeInterpreterTool(inputs=inputs) assert len(tool.inputs) == 2 - assert isinstance(tool.inputs[0], TextContent) + assert tool.inputs[0].type == "text" assert tool.inputs[0].text == "Hello, world!" - assert isinstance(tool.inputs[1], DataContent) + assert tool.inputs[1].type == "data" assert tool.inputs[1].media_type == "text/plain" def test_hosted_code_interpreter_tool_with_single_input(): """Test HostedCodeInterpreterTool with single input (not in list).""" - from agent_framework import HostedFileContent input_dict = {"file_id": "file-single"} tool = HostedCodeInterpreterTool(inputs=input_dict) assert len(tool.inputs) == 1 - assert isinstance(tool.inputs[0], HostedFileContent) + assert tool.inputs[0].type == "hosted_file" assert tool.inputs[0].file_id == "file-single" @@ -983,7 +974,7 @@ async def get_streaming_response(self, messages, **kwargs): yield ChatResponseUpdate(contents=[content], role=msg.role) else: # Default response - yield ChatResponseUpdate(contents=["Default response"], role="assistant") + yield ChatResponseUpdate(text="Default response", role="assistant") return MockChatClient() @@ -1006,7 +997,7 @@ def requires_approval_tool(x: int) -> int: async def test_non_streaming_single_function_no_approval(): """Test non-streaming handler with single function call that doesn't require approval.""" - from agent_framework import ChatMessage, ChatResponse, FunctionCallContent + from agent_framework import ChatMessage, ChatResponse from agent_framework._tools import _handle_function_calls_response # Create mock client @@ -1017,11 +1008,11 @@ async def test_non_streaming_single_function_no_approval(): messages=[ ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], + contents=[Content.from_function_call(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], ) ] ) - final_response = ChatResponse(messages=[ChatMessage(role="assistant", contents=["The result is 10"])]) + final_response = ChatResponse(messages=[ChatMessage(role="assistant", text="The result is 10")]) call_count = [0] responses = [initial_response, final_response] @@ -1039,17 +1030,16 @@ async def mock_get_response(self, messages, **kwargs): # Verify: should have 3 messages: function call, function result, final answer assert len(result.messages) == 3 - assert isinstance(result.messages[0].contents[0], FunctionCallContent) - from agent_framework import FunctionResultContent + assert result.messages[0].contents[0].type == "function_call" - assert isinstance(result.messages[1].contents[0], FunctionResultContent) + assert result.messages[1].contents[0].type == "function_result" assert result.messages[1].contents[0].result == 10 # 5 * 2 - assert result.messages[2].contents[0] == "The result is 10" + assert result.messages[2].text == "The result is 10" async def test_non_streaming_single_function_requires_approval(): """Test non-streaming handler with single function call that requires approval.""" - from agent_framework import ChatMessage, ChatResponse, FunctionCallContent + from agent_framework import ChatMessage, ChatResponse from agent_framework._tools import _handle_function_calls_response mock_client = type("MockClient", (), {})() @@ -1059,7 +1049,9 @@ async def test_non_streaming_single_function_requires_approval(): messages=[ ChatMessage( role="assistant", - contents=[FunctionCallContent(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}')], + contents=[ + Content.from_function_call(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}') + ], ) ] ) @@ -1078,18 +1070,17 @@ async def mock_get_response(self, messages, **kwargs): result = await wrapped(mock_client, messages=[], options={"tools": [requires_approval_tool]}) # Verify: should return 1 message with function call and approval request - from agent_framework import FunctionApprovalRequestContent assert len(result.messages) == 1 assert len(result.messages[0].contents) == 2 - assert isinstance(result.messages[0].contents[0], FunctionCallContent) - assert isinstance(result.messages[0].contents[1], FunctionApprovalRequestContent) + assert result.messages[0].contents[0].type == "function_call" + assert result.messages[0].contents[1].type == "function_approval_request" assert result.messages[0].contents[1].function_call.name == "requires_approval_tool" async def test_non_streaming_two_functions_both_no_approval(): """Test non-streaming handler with two function calls, neither requiring approval.""" - from agent_framework import ChatMessage, ChatResponse, FunctionCallContent + from agent_framework import ChatMessage, ChatResponse from agent_framework._tools import _handle_function_calls_response mock_client = type("MockClient", (), {})() @@ -1100,15 +1091,13 @@ async def test_non_streaming_two_functions_both_no_approval(): ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}'), - FunctionCallContent(call_id="call_2", name="no_approval_tool", arguments='{"x": 3}'), + Content.from_function_call(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}'), + Content.from_function_call(call_id="call_2", name="no_approval_tool", arguments='{"x": 3}'), ], ) ] ) - final_response = ChatResponse( - messages=[ChatMessage(role="assistant", contents=["Both tools executed successfully"])] - ) + final_response = ChatResponse(messages=[ChatMessage(role="assistant", text="Both tools executed successfully")]) call_count = [0] responses = [initial_response, final_response] @@ -1124,21 +1113,20 @@ async def mock_get_response(self, messages, **kwargs): result = await wrapped(mock_client, messages=[], options={"tools": [no_approval_tool]}) # Verify: should have function calls, results, and final answer - from agent_framework import FunctionResultContent assert len(result.messages) == 3 # First message has both function calls assert len(result.messages[0].contents) == 2 # Second message has both results assert len(result.messages[1].contents) == 2 - assert all(isinstance(c, FunctionResultContent) for c in result.messages[1].contents) + assert all(c.type == "function_result" for c in result.messages[1].contents) assert result.messages[1].contents[0].result == 10 # 5 * 2 assert result.messages[1].contents[1].result == 6 # 3 * 2 async def test_non_streaming_two_functions_both_require_approval(): """Test non-streaming handler with two function calls, both requiring approval.""" - from agent_framework import ChatMessage, ChatResponse, FunctionCallContent + from agent_framework import ChatMessage, ChatResponse from agent_framework._tools import _handle_function_calls_response mock_client = type("MockClient", (), {})() @@ -1149,8 +1137,8 @@ async def test_non_streaming_two_functions_both_require_approval(): ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}'), - FunctionCallContent(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}'), + Content.from_function_call(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}'), + Content.from_function_call(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}'), ], ) ] @@ -1170,12 +1158,11 @@ async def mock_get_response(self, messages, **kwargs): result = await wrapped(mock_client, messages=[], options={"tools": [requires_approval_tool]}) # Verify: should return 1 message with function calls and approval requests - from agent_framework import FunctionApprovalRequestContent assert len(result.messages) == 1 assert len(result.messages[0].contents) == 4 # 2 function calls + 2 approval requests - function_calls = [c for c in result.messages[0].contents if isinstance(c, FunctionCallContent)] - approval_requests = [c for c in result.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)] + function_calls = [c for c in result.messages[0].contents if c.type == "function_call"] + approval_requests = [c for c in result.messages[0].contents if c.type == "function_approval_request"] assert len(function_calls) == 2 assert len(approval_requests) == 2 assert approval_requests[0].function_call.name == "requires_approval_tool" @@ -1184,7 +1171,7 @@ async def mock_get_response(self, messages, **kwargs): async def test_non_streaming_two_functions_mixed_approval(): """Test non-streaming handler with two function calls, one requiring approval.""" - from agent_framework import ChatMessage, ChatResponse, FunctionCallContent + from agent_framework import ChatMessage, ChatResponse from agent_framework._tools import _handle_function_calls_response mock_client = type("MockClient", (), {})() @@ -1195,8 +1182,8 @@ async def test_non_streaming_two_functions_mixed_approval(): ChatMessage( role="assistant", contents=[ - FunctionCallContent(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}'), - FunctionCallContent(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}'), + Content.from_function_call(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}'), + Content.from_function_call(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}'), ], ) ] @@ -1216,17 +1203,16 @@ async def mock_get_response(self, messages, **kwargs): result = await wrapped(mock_client, messages=[], options={"tools": [no_approval_tool, requires_approval_tool]}) # Verify: should return approval requests for both (when one needs approval, all are sent for approval) - from agent_framework import FunctionApprovalRequestContent assert len(result.messages) == 1 assert len(result.messages[0].contents) == 4 # 2 function calls + 2 approval requests - approval_requests = [c for c in result.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)] + approval_requests = [c for c in result.messages[0].contents if c.type == "function_approval_request"] assert len(approval_requests) == 2 async def test_streaming_single_function_no_approval(): """Test streaming handler with single function call that doesn't require approval.""" - from agent_framework import ChatResponseUpdate, FunctionCallContent + from agent_framework import ChatResponseUpdate from agent_framework._tools import _handle_function_calls_streaming_response mock_client = type("MockClient", (), {})() @@ -1234,11 +1220,11 @@ async def test_streaming_single_function_no_approval(): # Initial response with function call, then final response after function execution initial_updates = [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], + contents=[Content.from_function_call(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], role="assistant", ) ] - final_updates = [ChatResponseUpdate(contents=["The result is 10"], role="assistant")] + final_updates = [ChatResponseUpdate(text="The result is 10", role="assistant")] call_count = [0] updates_list = [initial_updates, final_updates] @@ -1257,22 +1243,23 @@ async def mock_get_streaming_response(self, messages, **kwargs): updates.append(update) # Verify: should have function call update, tool result update (injected), and final update - from agent_framework import FunctionResultContent, Role + from agent_framework import Role assert len(updates) >= 3 # First update is the function call - assert isinstance(updates[0].contents[0], FunctionCallContent) + assert updates[0].contents[0].type == "function_call" # Second update should be the tool result (injected by the wrapper) assert updates[1].role == Role.TOOL - assert isinstance(updates[1].contents[0], FunctionResultContent) + assert updates[1].contents[0].type == "function_result" assert updates[1].contents[0].result == 10 # 5 * 2 # Last update is the final message - assert updates[-1].contents[0] == "The result is 10" + assert updates[-1].contents[0].type == "text" + assert updates[-1].contents[0].text == "The result is 10" async def test_streaming_single_function_requires_approval(): """Test streaming handler with single function call that requires approval.""" - from agent_framework import ChatResponseUpdate, FunctionCallContent + from agent_framework import ChatResponseUpdate from agent_framework._tools import _handle_function_calls_streaming_response mock_client = type("MockClient", (), {})() @@ -1280,7 +1267,9 @@ async def test_streaming_single_function_requires_approval(): # Initial response with function call initial_updates = [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}')], + contents=[ + Content.from_function_call(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}') + ], role="assistant", ) ] @@ -1302,17 +1291,17 @@ async def mock_get_streaming_response(self, messages, **kwargs): updates.append(update) # Verify: should yield function call and then approval request - from agent_framework import FunctionApprovalRequestContent, Role + from agent_framework import Role assert len(updates) == 2 - assert isinstance(updates[0].contents[0], FunctionCallContent) + assert updates[0].contents[0].type == "function_call" assert updates[1].role == Role.ASSISTANT - assert isinstance(updates[1].contents[0], FunctionApprovalRequestContent) + assert updates[1].contents[0].type == "function_approval_request" async def test_streaming_two_functions_both_no_approval(): """Test streaming handler with two function calls, neither requiring approval.""" - from agent_framework import ChatResponseUpdate, FunctionCallContent + from agent_framework import ChatResponseUpdate from agent_framework._tools import _handle_function_calls_streaming_response mock_client = type("MockClient", (), {})() @@ -1320,15 +1309,14 @@ async def test_streaming_two_functions_both_no_approval(): # Initial response with two function calls to the same tool initial_updates = [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], - role="assistant", - ), - ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_2", name="no_approval_tool", arguments='{"x": 3}')], + contents=[ + Content.from_function_call(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}'), + Content.from_function_call(call_id="call_2", name="no_approval_tool", arguments='{"x": 3}'), + ], role="assistant", ), ] - final_updates = [ChatResponseUpdate(contents=["Both tools executed successfully"], role="assistant")] + final_updates = [ChatResponseUpdate(text="Both tools executed successfully", role="assistant")] call_count = [0] updates_list = [initial_updates, final_updates] @@ -1347,22 +1335,23 @@ async def mock_get_streaming_response(self, messages, **kwargs): updates.append(update) # Verify: should have both function calls, one tool result update with both results, and final message - from agent_framework import FunctionResultContent, Role + from agent_framework import Role - assert len(updates) >= 3 - # First two updates are function calls - assert isinstance(updates[0].contents[0], FunctionCallContent) - assert isinstance(updates[1].contents[0], FunctionCallContent) + assert len(updates) >= 2 + # First update has both function calls + assert len(updates[0].contents) == 2 + assert updates[0].contents[0].type == "function_call" + assert updates[0].contents[1].type == "function_call" # Should have a tool result update with both results tool_updates = [u for u in updates if u.role == Role.TOOL] assert len(tool_updates) == 1 assert len(tool_updates[0].contents) == 2 - assert all(isinstance(c, FunctionResultContent) for c in tool_updates[0].contents) + assert all(c.type == "function_result" for c in tool_updates[0].contents) async def test_streaming_two_functions_both_require_approval(): """Test streaming handler with two function calls, both requiring approval.""" - from agent_framework import ChatResponseUpdate, FunctionCallContent + from agent_framework import ChatResponseUpdate from agent_framework._tools import _handle_function_calls_streaming_response mock_client = type("MockClient", (), {})() @@ -1370,11 +1359,15 @@ async def test_streaming_two_functions_both_require_approval(): # Initial response with two function calls to the same tool initial_updates = [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}')], + contents=[ + Content.from_function_call(call_id="call_1", name="requires_approval_tool", arguments='{"x": 5}') + ], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}')], + contents=[ + Content.from_function_call(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}') + ], role="assistant", ), ] @@ -1396,20 +1389,20 @@ async def mock_get_streaming_response(self, messages, **kwargs): updates.append(update) # Verify: should yield both function calls and then approval requests - from agent_framework import FunctionApprovalRequestContent, Role + from agent_framework import Role assert len(updates) == 3 - assert isinstance(updates[0].contents[0], FunctionCallContent) - assert isinstance(updates[1].contents[0], FunctionCallContent) + assert updates[0].contents[0].type == "function_call" + assert updates[1].contents[0].type == "function_call" # Assistant update with both approval requests assert updates[2].role == Role.ASSISTANT assert len(updates[2].contents) == 2 - assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents) + assert all(c.type == "function_approval_request" for c in updates[2].contents) async def test_streaming_two_functions_mixed_approval(): """Test streaming handler with two function calls, one requiring approval.""" - from agent_framework import ChatResponseUpdate, FunctionCallContent + from agent_framework import ChatResponseUpdate from agent_framework._tools import _handle_function_calls_streaming_response mock_client = type("MockClient", (), {})() @@ -1417,11 +1410,13 @@ async def test_streaming_two_functions_mixed_approval(): # Initial response with two function calls initial_updates = [ ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], + contents=[Content.from_function_call(call_id="call_1", name="no_approval_tool", arguments='{"x": 5}')], role="assistant", ), ChatResponseUpdate( - contents=[FunctionCallContent(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}')], + contents=[ + Content.from_function_call(call_id="call_2", name="requires_approval_tool", arguments='{"x": 3}') + ], role="assistant", ), ] @@ -1445,15 +1440,15 @@ async def mock_get_streaming_response(self, messages, **kwargs): updates.append(update) # Verify: should yield both function calls and then approval requests (when one needs approval, all wait) - from agent_framework import FunctionApprovalRequestContent, Role + from agent_framework import Role assert len(updates) == 3 - assert isinstance(updates[0].contents[0], FunctionCallContent) - assert isinstance(updates[1].contents[0], FunctionCallContent) + assert updates[0].contents[0].type == "function_call" + assert updates[1].contents[0].type == "function_call" # Assistant update with both approval requests assert updates[2].role == Role.ASSISTANT assert len(updates[2].contents) == 2 - assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents) + assert all(c.type == "function_approval_request" for c in updates[2].contents) async def test_ai_function_with_kwargs_injection(): diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index fe719967db..15f67fb44f 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import base64 from collections.abc import AsyncIterable from datetime import datetime, timezone from typing import Any @@ -12,41 +11,23 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, - BaseContent, + Annotation, ChatMessage, ChatOptions, ChatResponse, ChatResponseUpdate, - CitationAnnotation, - CodeInterpreterToolCallContent, - CodeInterpreterToolResultContent, - DataContent, - ErrorContent, + Content, FinishReason, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, - HostedFileContent, - HostedVectorStoreContent, - ImageGenerationToolCallContent, - ImageGenerationToolResultContent, - MCPServerToolCallContent, - MCPServerToolResultContent, Role, - TextContent, - TextReasoningContent, TextSpanRegion, ToolMode, ToolProtocol, - UriContent, - UsageContent, UsageDetails, ai_function, merge_chat_options, prepare_function_call_results, ) -from agent_framework.exceptions import AdditionItemMismatch, ContentError +from agent_framework.exceptions import ContentError @fixture @@ -83,9 +64,11 @@ def simple_function(x: int, y: int) -> int: def test_text_content_positional(): - """Test the TextContent class to ensure it initializes correctly and inherits from BaseContent.""" + """Test the TextContent class to ensure it initializes correctly and inherits from Content.""" # Create an instance of TextContent - content = TextContent("Hello, world!", raw_representation="Hello, world!", additional_properties={"version": 1}) + content = Content.from_text( + "Hello, world!", raw_representation="Hello, world!", additional_properties={"version": 1} + ) # Check the type and content assert content.type == "text" @@ -93,15 +76,15 @@ def test_text_content_positional(): assert content.raw_representation == "Hello, world!" assert content.additional_properties["version"] == 1 # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) # Note: No longer using Pydantic validation, so type assignment should work content.type = "text" # This should work fine now def test_text_content_keyword(): - """Test the TextContent class to ensure it initializes correctly and inherits from BaseContent.""" + """Test the TextContent class to ensure it initializes correctly and inherits from Content.""" # Create an instance of TextContent - content = TextContent( + content = Content.from_text( text="Hello, world!", raw_representation="Hello, world!", additional_properties={"version": 1} ) @@ -111,7 +94,7 @@ def test_text_content_keyword(): assert content.raw_representation == "Hello, world!" assert content.additional_properties["version"] == 1 # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) # Note: No longer using Pydantic validation, so type assignment should work content.type = "text" # This should work fine now @@ -122,109 +105,104 @@ def test_text_content_keyword(): def test_data_content_bytes(): """Test the DataContent class to ensure it initializes correctly.""" # Create an instance of DataContent - content = DataContent(data=b"test", media_type="application/octet-stream", additional_properties={"version": 1}) + content = Content.from_data( + data=b"test", media_type="application/octet-stream", additional_properties={"version": 1} + ) # Check the type and content assert content.type == "data" assert content.uri == "data:application/octet-stream;base64,dGVzdA==" - assert content.has_top_level_media_type("application") is True - assert content.has_top_level_media_type("image") is False + assert content.media_type.startswith("application/") is True + assert content.media_type.startswith("image/") is False assert content.additional_properties["version"] == 1 # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) def test_data_content_uri(): - """Test the DataContent class to ensure it initializes correctly with a URI.""" - # Create an instance of DataContent with a URI - content = DataContent(uri="data:application/octet-stream;base64,dGVzdA==", additional_properties={"version": 1}) + """Test the Content.from_uri class to ensure it initializes correctly with a URI.""" + # Create an instance of Content.from_uri with a URI and explicit media_type + content = Content.from_uri( + uri="data:application/octet-stream;base64,dGVzdA==", + media_type="application/octet-stream", + additional_properties={"version": 1}, + ) # Check the type and content assert content.type == "data" assert content.uri == "data:application/octet-stream;base64,dGVzdA==" - # media_type is extracted from URI now + # media_type must be explicitly provided assert content.media_type == "application/octet-stream" - assert content.has_top_level_media_type("application") is True + assert content.media_type.startswith("application/") is True assert content.additional_properties["version"] == 1 # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) def test_data_content_invalid(): """Test the DataContent class to ensure it raises an error for invalid initialization.""" - # Attempt to create an instance of DataContent with invalid data - # not a proper uri - with raises(ValueError): - DataContent(uri="invalid_uri") - # unknown media type - with raises(ValueError): - DataContent(uri="data:application/random;base64,dGVzdA==") - # not valid base64 data would still be accepted by our basic validation - # but it's not a critical issue for now + with pytest.raises(ContentError): + Content.from_uri(uri="invalid_uri", media_type="text/plain") def test_data_content_empty(): """Test the DataContent class to ensure it raises an error for empty data.""" - # Attempt to create an instance of DataContent with empty data - with raises(ValueError): - DataContent(data=b"", media_type="application/octet-stream") + data = Content.from_data(data=b"", media_type="application/octet-stream") + assert data.uri == "data:application/octet-stream;base64," + assert data.media_type == "application/octet-stream" - # Attempt to create an instance of DataContent with empty URI - with raises(ValueError): - DataContent(uri="") +# def test_data_content_detect_image_format_from_base64(): +# """Test the detect_image_format_from_base64 static method.""" +# # Test each supported format +# png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" +# assert detect_image_format_from_base64(base64.b64encode(png_data).decode()) == "png" -def test_data_content_detect_image_format_from_base64(): - """Test the detect_image_format_from_base64 static method.""" - # Test each supported format - png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" - assert DataContent.detect_image_format_from_base64(base64.b64encode(png_data).decode()) == "png" +# jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" +# assert DataContent.detect_image_format_from_base64(base64.b64encode(jpeg_data).decode()) == "jpeg" - jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" - assert DataContent.detect_image_format_from_base64(base64.b64encode(jpeg_data).decode()) == "jpeg" +# webp_data = b"RIFF" + b"1234" + b"WEBP" + b"fake_data" +# assert DataContent.detect_image_format_from_base64(base64.b64encode(webp_data).decode()) == "webp" - webp_data = b"RIFF" + b"1234" + b"WEBP" + b"fake_data" - assert DataContent.detect_image_format_from_base64(base64.b64encode(webp_data).decode()) == "webp" +# gif_data = b"GIF89a" + b"fake_data" +# assert DataContent.detect_image_format_from_base64(base64.b64encode(gif_data).decode()) == "gif" - gif_data = b"GIF89a" + b"fake_data" - assert DataContent.detect_image_format_from_base64(base64.b64encode(gif_data).decode()) == "gif" +# # Test fallback behavior +# unknown_data = b"UNKNOWN_FORMAT" +# assert DataContent.detect_image_format_from_base64(base64.b64encode(unknown_data).decode()) == "png" - # Test fallback behavior - unknown_data = b"UNKNOWN_FORMAT" - assert DataContent.detect_image_format_from_base64(base64.b64encode(unknown_data).decode()) == "png" +# # Test error handling +# assert DataContent.detect_image_format_from_base64("invalid_base64!") == "png" +# assert DataContent.detect_image_format_from_base64("") == "png" - # Test error handling - assert DataContent.detect_image_format_from_base64("invalid_base64!") == "png" - assert DataContent.detect_image_format_from_base64("") == "png" +# def test_data_content_create_data_uri_from_base64(): +# """Test the create_data_uri_from_base64 class method.""" +# # Test with PNG data +# png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" +# png_base64 = base64.b64encode(png_data).decode() +# uri, media_type = Content.create_data_uri_from_base64(png_base64) -def test_data_content_create_data_uri_from_base64(): - """Test the create_data_uri_from_base64 class method.""" - # Test with PNG data - png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" - png_base64 = base64.b64encode(png_data).decode() - uri, media_type = DataContent.create_data_uri_from_base64(png_base64) +# assert uri == f"data:image/png;base64,{png_base64}" +# assert media_type == "image/png" - assert uri == f"data:image/png;base64,{png_base64}" - assert media_type == "image/png" +# # Test with different format +# jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" +# jpeg_base64 = base64.b64encode(jpeg_data).decode() +# uri, media_type = DataContent.create_data_uri_from_base64(jpeg_base64) - # Test with different format - jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" - jpeg_base64 = base64.b64encode(jpeg_data).decode() - uri, media_type = DataContent.create_data_uri_from_base64(jpeg_base64) +# assert uri == f"data:image/jpeg;base64,{jpeg_base64}" +# assert media_type == "image/jpeg" - assert uri == f"data:image/jpeg;base64,{jpeg_base64}" - assert media_type == "image/jpeg" +# # Test fallback for unknown format +# unknown_data = b"UNKNOWN_FORMAT" +# unknown_base64 = base64.b64encode(unknown_data).decode() +# uri, media_type = DataContent.create_data_uri_from_base64(unknown_base64) - # Test fallback for unknown format - unknown_data = b"UNKNOWN_FORMAT" - unknown_base64 = base64.b64encode(unknown_data).decode() - uri, media_type = DataContent.create_data_uri_from_base64(unknown_base64) - - assert uri == f"data:image/png;base64,{unknown_base64}" - assert media_type == "image/png" +# assert uri == f"data:image/png;base64,{unknown_base64}" +# assert media_type == "image/png" # region UriContent @@ -232,18 +210,16 @@ def test_data_content_create_data_uri_from_base64(): def test_uri_content(): """Test the UriContent class to ensure it initializes correctly.""" - content = UriContent(uri="http://example.com", media_type="image/jpg", additional_properties={"version": 1}) + content = Content.from_uri(uri="http://example.com", media_type="image/jpg", additional_properties={"version": 1}) # Check the type and content assert content.type == "uri" assert content.uri == "http://example.com" assert content.media_type == "image/jpg" - assert content.has_top_level_media_type("image") is True - assert content.has_top_level_media_type("application") is False + assert content.media_type.startswith("image/") is True + assert content.media_type.startswith("application/") is False assert content.additional_properties["version"] == 1 - - # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) # region: HostedFileContent @@ -251,101 +227,98 @@ def test_uri_content(): def test_hosted_file_content(): """Test the HostedFileContent class to ensure it initializes correctly.""" - content = HostedFileContent(file_id="file-123", additional_properties={"version": 1}) + content = Content.from_hosted_file(file_id="file-123", additional_properties={"version": 1}) # Check the type and content assert content.type == "hosted_file" assert content.file_id == "file-123" assert content.additional_properties["version"] == 1 - - # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) def test_hosted_file_content_minimal(): """Test the HostedFileContent class with minimal parameters.""" - content = HostedFileContent(file_id="file-456") + content = Content.from_hosted_file(file_id="file-456") # Check the type and content assert content.type == "hosted_file" assert content.file_id == "file-456" assert content.additional_properties == {} assert content.raw_representation is None - - # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) def test_hosted_file_content_optional_fields(): """HostedFileContent should capture optional media type and name.""" - content = HostedFileContent(file_id="file-789", media_type="image/png", name="plot.png") + content = Content.from_hosted_file(file_id="file-789", media_type="image/png", name="plot.png") assert content.media_type == "image/png" assert content.name == "plot.png" - assert content.has_top_level_media_type("image") - assert content.has_top_level_media_type("application") is False + assert content.media_type.startswith("image/") + assert content.media_type.startswith("application/") is False # region: CodeInterpreter content def test_code_interpreter_tool_call_content_parses_inputs(): - call = CodeInterpreterToolCallContent( + call = Content.from_code_interpreter_tool_call( call_id="call-1", - inputs=[{"type": "text", "text": "print('hi')"}], + inputs=[Content.from_text(text="print('hi')")], ) assert call.type == "code_interpreter_tool_call" assert call.call_id == "call-1" - assert call.inputs and isinstance(call.inputs[0], TextContent) + assert call.inputs and call.inputs[0].type == "text" assert call.inputs[0].text == "print('hi')" def test_code_interpreter_tool_result_content_outputs(): - result = CodeInterpreterToolResultContent( + result = Content.from_code_interpreter_tool_result( call_id="call-2", outputs=[ - {"type": "text", "text": "log output"}, - {"type": "uri", "uri": "https://example.com/file.png", "media_type": "image/png"}, + Content.from_text(text="log output"), + Content.from_uri(uri="https://example.com/file.png", media_type="image/png"), ], ) assert result.type == "code_interpreter_tool_result" assert result.call_id == "call-2" assert result.outputs is not None - assert isinstance(result.outputs[0], TextContent) - assert isinstance(result.outputs[1], UriContent) + assert result.outputs[0].type == "text" + assert result.outputs[1].type == "uri" # region: Image generation content def test_image_generation_tool_contents(): - call = ImageGenerationToolCallContent(image_id="img-1") - outputs = [DataContent(data=b"1234", media_type="image/png")] - result = ImageGenerationToolResultContent(image_id="img-1", outputs=outputs) + call = Content.from_image_generation_tool_call(image_id="img-1") + outputs = [Content.from_data(data=b"1234", media_type="image/png")] + result = Content.from_image_generation_tool_result(image_id="img-1", outputs=outputs) assert call.type == "image_generation_tool_call" assert call.image_id == "img-1" assert result.type == "image_generation_tool_result" assert result.image_id == "img-1" - assert result.outputs and isinstance(result.outputs[0], DataContent) + assert result.outputs and result.outputs[0].type == "data" # region: MCP server tool content def test_mcp_server_tool_call_and_result(): - call = MCPServerToolCallContent(call_id="c-1", tool_name="tool", server_name="server", arguments={"x": 1}) + call = Content.from_mcp_server_tool_call(call_id="c-1", tool_name="tool", server_name="server", arguments={"x": 1}) assert call.type == "mcp_server_tool_call" assert call.arguments == {"x": 1} - result = MCPServerToolResultContent(call_id="c-1", output=[{"type": "text", "text": "done"}]) + result = Content.from_mcp_server_tool_result(call_id="c-1", output=[{"type": "text", "text": "done"}]) assert result.type == "mcp_server_tool_result" assert result.output - with raises(ValueError): - MCPServerToolCallContent(call_id="", tool_name="tool") + # Empty call_id is allowed, validation happens elsewhere + call2 = Content.from_mcp_server_tool_call(call_id="", tool_name="tool", server_name="server") + assert call2.call_id == "" # region: HostedVectorStoreContent @@ -353,7 +326,7 @@ def test_mcp_server_tool_call_and_result(): def test_hosted_vector_store_content(): """Test the HostedVectorStoreContent class to ensure it initializes correctly.""" - content = HostedVectorStoreContent(vector_store_id="vs-789", additional_properties={"version": 1}) + content = Content.from_hosted_vector_store(vector_store_id="vs-789", additional_properties={"version": 1}) # Check the type and content assert content.type == "hosted_vector_store" @@ -361,13 +334,14 @@ def test_hosted_vector_store_content(): assert content.additional_properties["version"] == 1 # Ensure the instance is of type BaseContent - assert isinstance(content, HostedVectorStoreContent) - assert isinstance(content, BaseContent) + assert isinstance(content, Content) + assert content.type == "hosted_vector_store" + assert isinstance(content, Content) def test_hosted_vector_store_content_minimal(): """Test the HostedVectorStoreContent class with minimal parameters.""" - content = HostedVectorStoreContent(vector_store_id="vs-101112") + content = Content.from_hosted_vector_store(vector_store_id="vs-101112") # Check the type and content assert content.type == "hosted_vector_store" @@ -375,17 +349,13 @@ def test_hosted_vector_store_content_minimal(): assert content.additional_properties == {} assert content.raw_representation is None - # Ensure the instance is of type BaseContent - assert isinstance(content, HostedVectorStoreContent) - assert isinstance(content, BaseContent) - # region FunctionCallContent def test_function_call_content(): """Test the FunctionCallContent class to ensure it initializes correctly.""" - content = FunctionCallContent(call_id="1", name="example_function", arguments={"param1": "value1"}) + content = Content.from_function_call(call_id="1", name="example_function", arguments={"param1": "value1"}) # Check the type and content assert content.type == "function_call" @@ -393,42 +363,42 @@ def test_function_call_content(): assert content.arguments == {"param1": "value1"} # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) def test_function_call_content_parse_arguments(): - c1 = FunctionCallContent(call_id="1", name="f", arguments='{"a": 1, "b": 2}') + c1 = Content.from_function_call(call_id="1", name="f", arguments='{"a": 1, "b": 2}') assert c1.parse_arguments() == {"a": 1, "b": 2} - c2 = FunctionCallContent(call_id="1", name="f", arguments="not json") + c2 = Content.from_function_call(call_id="1", name="f", arguments="not json") assert c2.parse_arguments() == {"raw": "not json"} - c3 = FunctionCallContent(call_id="1", name="f", arguments={"x": None}) + c3 = Content.from_function_call(call_id="1", name="f", arguments={"x": None}) assert c3.parse_arguments() == {"x": None} def test_function_call_content_add_merging_and_errors(): # str + str concatenation - a = FunctionCallContent(call_id="1", name="f", arguments="abc") - b = FunctionCallContent(call_id="1", name="f", arguments="def") + a = Content.from_function_call(call_id="1", name="f", arguments="abc") + b = Content.from_function_call(call_id="1", name="f", arguments="def") c = a + b assert isinstance(c.arguments, str) and c.arguments == "abcdef" # dict + dict merge - a = FunctionCallContent(call_id="1", name="f", arguments={"x": 1}) - b = FunctionCallContent(call_id="1", name="f", arguments={"y": 2}) + a = Content.from_function_call(call_id="1", name="f", arguments={"x": 1}) + b = Content.from_function_call(call_id="1", name="f", arguments={"y": 2}) c = a + b assert c.arguments == {"x": 1, "y": 2} # incompatible argument types - a = FunctionCallContent(call_id="1", name="f", arguments="abc") - b = FunctionCallContent(call_id="1", name="f", arguments={"y": 2}) + a = Content.from_function_call(call_id="1", name="f", arguments="abc") + b = Content.from_function_call(call_id="1", name="f", arguments={"y": 2}) with raises(TypeError): _ = a + b # incompatible call ids - a = FunctionCallContent(call_id="1", name="f", arguments="abc") - b = FunctionCallContent(call_id="2", name="f", arguments="def") + a = Content.from_function_call(call_id="1", name="f", arguments="abc") + b = Content.from_function_call(call_id="2", name="f", arguments="def") - with raises(AdditionItemMismatch): + with raises(ContentError): _ = a + b @@ -437,14 +407,14 @@ def test_function_call_content_add_merging_and_errors(): def test_function_result_content(): """Test the FunctionResultContent class to ensure it initializes correctly.""" - content = FunctionResultContent(call_id="1", result={"param1": "value1"}) + content = Content.from_function_result(call_id="1", result={"param1": "value1"}) # Check the type and content assert content.type == "function_result" assert content.result == {"param1": "value1"} # Ensure the instance is of type BaseContent - assert isinstance(content, BaseContent) + assert isinstance(content, Content) # region UsageDetails @@ -452,13 +422,15 @@ def test_function_result_content(): def test_usage_details(): usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15) - assert usage.input_token_count == 5 - assert usage.output_token_count == 10 - assert usage.total_token_count == 15 - assert usage.additional_counts == {} + assert usage["input_token_count"] == 5 + assert usage["output_token_count"] == 10 + assert usage["total_token_count"] == 15 + assert usage.get("additional_counts", {}) == {} def test_usage_details_addition(): + from agent_framework._types import add_usage_details + usage1 = UsageDetails( input_token_count=5, output_token_count=10, @@ -474,39 +446,38 @@ def test_usage_details_addition(): test3=30, ) - combined_usage = usage1 + usage2 - assert combined_usage.input_token_count == 8 - assert combined_usage.output_token_count == 16 - assert combined_usage.total_token_count == 24 - assert combined_usage.additional_counts["test1"] == 20 - assert combined_usage.additional_counts["test2"] == 20 - assert combined_usage.additional_counts["test3"] == 30 + combined_usage = add_usage_details(usage1, usage2) + assert combined_usage["input_token_count"] == 8 + assert combined_usage["output_token_count"] == 16 + assert combined_usage["total_token_count"] == 24 + assert combined_usage["test1"] == 20 + assert combined_usage["test2"] == 20 + assert combined_usage["test3"] == 30 def test_usage_details_fail(): - with raises(ValueError): - UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, wrong_type="42.923") + # TypedDict doesn't validate types at runtime, so this test no longer applies + # Creating UsageDetails with wrong types won't raise ValueError + usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, wrong_type="42.923") # type: ignore[typeddict-item] + assert usage["wrong_type"] == "42.923" # type: ignore[typeddict-item] def test_usage_details_additional_counts(): usage = UsageDetails(input_token_count=5, output_token_count=10, total_token_count=15, **{"test": 1}) - assert usage.additional_counts["test"] == 1 + assert usage.get("test") == 1 def test_usage_details_add_with_none_and_type_errors(): + from agent_framework._types import add_usage_details + u = UsageDetails(input_token_count=1) - # __add__ with None returns self (no change) - v = u + None - assert v is u - # __iadd__ with None leaves unchanged - u2 = UsageDetails(input_token_count=2) - u2 += None - assert u2.input_token_count == 2 - # wrong type raises - with raises(ValueError): - _ = u + 42 # type: ignore[arg-type] - with raises(ValueError): - u += 42 # type: ignore[arg-type] + # add_usage_details with None returns the non-None value + v = add_usage_details(u, None) + assert v == u + # add_usage_details with None on left + v2 = add_usage_details(None, u) + assert v2 == u + # TypedDict doesn't support + operator, use add_usage_details # region UserInputRequest and Response @@ -514,28 +485,29 @@ def test_usage_details_add_with_none_and_type_errors(): def test_function_approval_request_and_response_creation(): """Test creating a FunctionApprovalRequestContent and producing a response.""" - fc = FunctionCallContent(call_id="call-1", name="do_something", arguments={"a": 1}) - req = FunctionApprovalRequestContent(id="req-1", function_call=fc) + fc = Content.from_function_call(call_id="call-1", name="do_something", arguments={"a": 1}) + req = Content.from_function_approval_request(id="req-1", function_call=fc) assert req.type == "function_approval_request" assert req.function_call == fc assert req.id == "req-1" - assert isinstance(req, BaseContent) + assert isinstance(req, Content) - resp = req.create_response(True) + resp = req.to_function_approval_response(True) - assert isinstance(resp, FunctionApprovalResponseContent) + assert isinstance(resp, Content) + assert resp.type == "function_approval_response" assert resp.approved is True assert resp.function_call == fc assert resp.id == "req-1" def test_function_approval_serialization_roundtrip(): - fc = FunctionCallContent(call_id="c2", name="f", arguments='{"x":1}') - req = FunctionApprovalRequestContent(id="id-2", function_call=fc, additional_properties={"meta": 1}) + fc = Content.from_function_call(call_id="c2", name="f", arguments='{"x":1}') + req = Content.from_function_approval_request(id="id-2", function_call=fc, additional_properties={"meta": 1}) dumped = req.to_dict() - loaded = FunctionApprovalRequestContent.from_dict(dumped) + loaded = Content.from_dict(dumped) # Test that the basic properties match assert loaded.id == req.id @@ -545,15 +517,17 @@ def test_function_approval_serialization_roundtrip(): assert loaded.function_call.arguments == req.function_call.arguments # Skip the BaseModel validation test since we're no longer using Pydantic - # The Contents union will need to be handled differently when we fully migrate + # The Content union will need to be handled differently when we fully migrate def test_function_approval_accepts_mcp_call(): """Ensure FunctionApprovalRequestContent supports MCP server tool calls.""" - mcp_call = MCPServerToolCallContent(call_id="c-mcp", tool_name="tool", server_name="srv", arguments={"x": 1}) - req = FunctionApprovalRequestContent(id="req-mcp", function_call=mcp_call) + mcp_call = Content.from_mcp_server_tool_call( + call_id="c-mcp", tool_name="tool", server_name="srv", arguments={"x": 1} + ) + req = Content.from_function_approval_request(id="req-mcp", function_call=mcp_call) - assert isinstance(req.function_call, MCPServerToolCallContent) + assert isinstance(req.function_call, Content) assert req.function_call.call_id == "c-mcp" @@ -561,47 +535,21 @@ def test_function_approval_accepts_mcp_call(): @mark.parametrize( - "content_type, args", + "args", [ - (TextContent, {"text": "Hello, world!"}), - (DataContent, {"data": b"Hello, world!", "media_type": "text/plain"}), - (UriContent, {"uri": "http://example.com", "media_type": "text/html"}), - (FunctionCallContent, {"call_id": "1", "name": "example_function", "arguments": {}}), - (FunctionResultContent, {"call_id": "1", "result": {}}), - (HostedFileContent, {"file_id": "file-123"}), - (HostedVectorStoreContent, {"vector_store_id": "vs-789"}), + {"type": "text", "text": "Hello, world!"}, + {"type": "uri", "uri": "http://example.com", "media_type": "text/html"}, + {"type": "function_call", "call_id": "1", "name": "example_function", "arguments": {}}, + {"type": "function_result", "call_id": "1", "result": {}}, + {"type": "file", "file_id": "file-123"}, + {"type": "vector_store", "vector_store_id": "vs-789"}, ], ) -def test_ai_content_serialization(content_type: type[BaseContent], args: dict): - content = content_type(**args) +def test_ai_content_serialization(args: dict): + content = Content(**args) serialized = content.to_dict() - deserialized = content_type.from_dict(serialized) - # Note: Since we're no longer using Pydantic, we can't do direct equality comparison - # Instead, let's check that the deserialized object has the same attributes - - # Special handling for DataContent which doesn't expose the original 'data' parameter - if content_type == DataContent and "data" in args: - # For DataContent created with data, check uri and media_type instead - assert hasattr(deserialized, "uri") - assert hasattr(deserialized, "media_type") - assert deserialized.media_type == args["media_type"] # type: ignore - # Skip checking the 'data' attribute since it's converted to uri - for key, value in args.items(): - if key != "data": # Skip the 'data' key for DataContent - assert getattr(deserialized, key) == value - else: - # Normal attribute checking for other content types - for key, value in args.items(): - if value: - assert getattr(deserialized, key) == value - - # For now, skip the TestModel validation since it still uses Pydantic - # This would need to be updated when we migrate more classes - # class TestModel(BaseModel): - # content: Contents - # - # test_item = TestModel.model_validate({"content": serialized}) - # assert isinstance(test_item.content, content_type) + deserialized = Content.from_dict(serialized) + assert content == deserialized # region ChatMessage @@ -615,26 +563,26 @@ def test_chat_message_text(): # Check the type and content assert message.role == Role.USER assert len(message.contents) == 1 - assert isinstance(message.contents[0], TextContent) + assert message.contents[0].type == "text" assert message.contents[0].text == "Hello, how are you?" assert message.text == "Hello, how are you?" # Ensure the instance is of type BaseContent - assert isinstance(message.contents[0], BaseContent) + assert isinstance(message.contents[0], Content) def test_chat_message_contents(): """Test the ChatMessage class to ensure it initializes correctly with contents.""" # Create a ChatMessage with a role and multiple contents - content1 = TextContent("Hello, how are you?") - content2 = TextContent("I'm fine, thank you!") + content1 = Content.from_text("Hello, how are you?") + content2 = Content.from_text("I'm fine, thank you!") message = ChatMessage(role="user", contents=[content1, content2]) # Check the type and content assert message.role == Role.USER assert len(message.contents) == 2 - assert isinstance(message.contents[0], TextContent) - assert isinstance(message.contents[1], TextContent) + assert message.contents[0].type == "text" + assert message.contents[1].type == "text" assert message.contents[0].text == "Hello, how are you?" assert message.contents[1].text == "I'm fine, thank you!" assert message.text == "Hello, how are you? I'm fine, thank you!" @@ -829,22 +777,22 @@ class MySchema(BaseModel): def test_chat_response_update(): """Test the ChatResponseUpdate class to ensure it initializes correctly with a message.""" # Create a ChatMessage - message = TextContent(text="I'm doing well, thank you!") + message = Content.from_text(text="I'm doing well, thank you!") # Create a ChatResponseUpdate with the message response_update = ChatResponseUpdate(contents=[message]) # Check the type and content assert response_update.contents[0].text == "I'm doing well, thank you!" - assert isinstance(response_update.contents[0], TextContent) + assert response_update.contents[0].type == "text" assert response_update.text == "I'm doing well, thank you!" def test_chat_response_updates_to_chat_response_one(): """Test converting ChatResponseUpdate to ChatResponse.""" # Create a ChatMessage - message1 = TextContent("I'm doing well, ") - message2 = TextContent("thank you!") + message1 = Content.from_text("I'm doing well, ") + message2 = Content.from_text("thank you!") # Create a ChatResponseUpdate with the message response_updates = [ @@ -866,8 +814,8 @@ def test_chat_response_updates_to_chat_response_one(): def test_chat_response_updates_to_chat_response_two(): """Test converting ChatResponseUpdate to ChatResponse.""" # Create a ChatMessage - message1 = TextContent("I'm doing well, ") - message2 = TextContent("thank you!") + message1 = Content.from_text("I'm doing well, ") + message2 = Content.from_text("thank you!") # Create a ChatResponseUpdate with the message response_updates = [ @@ -890,13 +838,13 @@ def test_chat_response_updates_to_chat_response_two(): def test_chat_response_updates_to_chat_response_multiple(): """Test converting ChatResponseUpdate to ChatResponse.""" # Create a ChatMessage - message1 = TextContent("I'm doing well, ") - message2 = TextContent("thank you!") + message1 = Content.from_text("I'm doing well, ") + message2 = Content.from_text("thank you!") # Create a ChatResponseUpdate with the message response_updates = [ ChatResponseUpdate(text=message1, message_id="1"), - ChatResponseUpdate(contents=[TextReasoningContent(text="Additional context")], message_id="1"), + ChatResponseUpdate(contents=[Content.from_text_reasoning(text="Additional context")], message_id="1"), ChatResponseUpdate(text=message2, message_id="1"), ] @@ -914,15 +862,15 @@ def test_chat_response_updates_to_chat_response_multiple(): def test_chat_response_updates_to_chat_response_multiple_multiple(): """Test converting ChatResponseUpdate to ChatResponse.""" # Create a ChatMessage - message1 = TextContent("I'm doing well, ", raw_representation="I'm doing well, ") - message2 = TextContent("thank you!") + message1 = Content.from_text("I'm doing well, ", raw_representation="I'm doing well, ") + message2 = Content.from_text("thank you!") # Create a ChatResponseUpdate with the message response_updates = [ ChatResponseUpdate(text=message1, message_id="1"), ChatResponseUpdate(text=message2, message_id="1"), - ChatResponseUpdate(contents=[TextReasoningContent(text="Additional context")], message_id="1"), - ChatResponseUpdate(contents=[TextContent(text="More context")], message_id="1"), + ChatResponseUpdate(contents=[Content.from_text_reasoning(text="Additional context")], message_id="1"), + ChatResponseUpdate(contents=[Content.from_text(text="More context")], message_id="1"), ChatResponseUpdate(text="Final part", message_id="1"), ] @@ -936,11 +884,11 @@ def test_chat_response_updates_to_chat_response_multiple_multiple(): assert chat_response.messages[0].contents[0].raw_representation is not None assert len(chat_response.messages[0].contents) == 3 - assert isinstance(chat_response.messages[0].contents[0], TextContent) + assert chat_response.messages[0].contents[0].type == "text" assert chat_response.messages[0].contents[0].text == "I'm doing well, thank you!" - assert isinstance(chat_response.messages[0].contents[1], TextReasoningContent) + assert chat_response.messages[0].contents[1].type == "text_reasoning" assert chat_response.messages[0].contents[1].text == "Additional context" - assert isinstance(chat_response.messages[0].contents[2], TextContent) + assert chat_response.messages[0].contents[2].type == "text" assert chat_response.messages[0].contents[2].text == "More contextFinal part" assert chat_response.text == "I'm doing well, thank you! More contextFinal part" @@ -1139,8 +1087,8 @@ def chat_message() -> ChatMessage: @fixture -def text_content() -> TextContent: - return TextContent(text="Test content") +def text_content() -> Content: + return Content.from_text(text="Test content") @fixture @@ -1149,7 +1097,7 @@ def agent_response(chat_message: ChatMessage) -> AgentResponse: @fixture -def agent_response_update(text_content: TextContent) -> AgentResponseUpdate: +def agent_response_update(text_content: Content) -> AgentResponseUpdate: return AgentResponseUpdate(role=Role.ASSISTANT, contents=[text_content]) @@ -1197,7 +1145,7 @@ def test_agent_run_response_str_method(chat_message: ChatMessage) -> None: # region AgentResponseUpdate -def test_agent_run_response_update_init_content_list(text_content: TextContent) -> None: +def test_agent_run_response_update_init_content_list(text_content: Content) -> None: update = AgentResponseUpdate(contents=[text_content, text_content]) assert len(update.contents) == 2 assert update.contents[0] == text_content @@ -1208,7 +1156,7 @@ def test_agent_run_response_update_init_none_content() -> None: assert update.contents == [] -def test_agent_run_response_update_text_property(text_content: TextContent) -> None: +def test_agent_run_response_update_text_property(text_content: Content) -> None: update = AgentResponseUpdate(contents=[text_content, text_content]) assert update.text == "Test contentTest content" @@ -1218,7 +1166,7 @@ def test_agent_run_response_update_text_property_empty() -> None: assert update.text == "" -def test_agent_run_response_update_str_method(text_content: TextContent) -> None: +def test_agent_run_response_update_str_method(text_content: Content) -> None: update = AgentResponseUpdate(contents=[text_content]) assert str(update) == "Test content" @@ -1228,7 +1176,7 @@ def test_agent_run_response_update_created_at() -> None: # Test with a properly formatted UTC timestamp utc_timestamp = "2024-12-01T00:31:30.000000Z" update = AgentResponseUpdate( - contents=[TextContent(text="test")], + contents=[Content.from_text(text="test")], role=Role.ASSISTANT, created_at=utc_timestamp, ) @@ -1239,7 +1187,7 @@ def test_agent_run_response_update_created_at() -> None: now_utc = datetime.now(tz=timezone.utc) formatted_utc = now_utc.strftime("%Y-%m-%dT%H:%M:%S.%fZ") update_with_now = AgentResponseUpdate( - contents=[TextContent(text="test")], + contents=[Content.from_text(text="test")], role=Role.ASSISTANT, created_at=formatted_utc, ) @@ -1273,71 +1221,77 @@ def test_agent_run_response_created_at() -> None: def test_error_content_str(): - e1 = ErrorContent(message="Oops", error_code="E1") + e1 = Content.from_error(message="Oops", error_code="E1") assert str(e1) == "Error E1: Oops" - e2 = ErrorContent(message="Oops") + e2 = Content.from_error(message="Oops") assert str(e2) == "Oops" - e3 = ErrorContent() + e3 = Content.from_error() assert str(e3) == "Unknown error" -# region Annotations +# region Annotation def test_annotations_models_and_roundtrip(): - span = TextSpanRegion(start_index=0, end_index=5) - cit = CitationAnnotation(title="Doc", url="http://example.com", snippet="Snippet", annotated_regions=[span]) + span = TextSpanRegion(type="text_span", start_index=0, end_index=5) + cit = Annotation( + type="citation", title="Doc", url="http://example.com", snippet="Snippet", annotated_regions=[span] + ) # Attach to content - content = TextContent(text="hello", additional_properties={"v": 1}) + content = Content.from_text(text="hello", additional_properties={"v": 1}) content.annotations = [cit] dumped = content.to_dict() - loaded = TextContent.from_dict(dumped) + loaded = Content.from_dict(dumped) assert isinstance(loaded.annotations, list) assert len(loaded.annotations) == 1 - # After migration from Pydantic, annotations should be properly reconstructed as objects - assert isinstance(loaded.annotations[0], CitationAnnotation) + # After migration from Pydantic, annotations are now TypedDicts (dicts at runtime) + assert isinstance(loaded.annotations[0], dict) # Check the annotation properties loaded_cit = loaded.annotations[0] - assert loaded_cit.type == "citation" - assert loaded_cit.title == "Doc" - assert loaded_cit.url == "http://example.com" - assert loaded_cit.snippet == "Snippet" + assert loaded_cit["type"] == "citation" + assert loaded_cit["title"] == "Doc" + assert loaded_cit["url"] == "http://example.com" + assert loaded_cit["snippet"] == "Snippet" # Check the annotated_regions - assert isinstance(loaded_cit.annotated_regions, list) - assert len(loaded_cit.annotated_regions) == 1 - assert isinstance(loaded_cit.annotated_regions[0], TextSpanRegion) - assert loaded_cit.annotated_regions[0].type == "text_span" - assert loaded_cit.annotated_regions[0].start_index == 0 - assert loaded_cit.annotated_regions[0].end_index == 5 + assert isinstance(loaded_cit["annotated_regions"], list) + assert len(loaded_cit["annotated_regions"]) == 1 + assert isinstance(loaded_cit["annotated_regions"][0], dict) + assert loaded_cit["annotated_regions"][0]["type"] == "text_span" + assert loaded_cit["annotated_regions"][0]["start_index"] == 0 + assert loaded_cit["annotated_regions"][0]["end_index"] == 5 def test_function_call_merge_in_process_update_and_usage_aggregation(): # Two function call chunks with same call_id should merge - u1 = ChatResponseUpdate(contents=[FunctionCallContent(call_id="c1", name="f", arguments="{")], message_id="m") - u2 = ChatResponseUpdate(contents=[FunctionCallContent(call_id="c1", name="f", arguments="}")], message_id="m") + u1 = ChatResponseUpdate( + contents=[Content.from_function_call(call_id="c1", name="f", arguments="{")], message_id="m" + ) + u2 = ChatResponseUpdate( + contents=[Content.from_function_call(call_id="c1", name="f", arguments="}")], message_id="m" + ) # plus usage - u3 = ChatResponseUpdate(contents=[UsageContent(UsageDetails(input_token_count=1, output_token_count=2))]) + u3 = ChatResponseUpdate(contents=[Content.from_usage(UsageDetails(input_token_count=1, output_token_count=2))]) resp = ChatResponse.from_chat_response_updates([u1, u2, u3]) assert len(resp.messages) == 1 last_contents = resp.messages[0].contents - assert any(isinstance(c, FunctionCallContent) for c in last_contents) - fcs = [c for c in last_contents if isinstance(c, FunctionCallContent)] + assert any(c.type == "function_call" for c in last_contents) + fcs = [c for c in last_contents if c.type == "function_call"] assert len(fcs) == 1 assert fcs[0].arguments == "{}" assert resp.usage_details is not None - assert resp.usage_details.input_token_count == 1 - assert resp.usage_details.output_token_count == 2 + assert resp.usage_details["input_token_count"] == 1 + assert resp.usage_details["output_token_count"] == 2 def test_function_call_incompatible_ids_are_not_merged(): - u1 = ChatResponseUpdate(contents=[FunctionCallContent(call_id="a", name="f", arguments="x")], message_id="m") - u2 = ChatResponseUpdate(contents=[FunctionCallContent(call_id="b", name="f", arguments="y")], message_id="m") + u1 = ChatResponseUpdate(contents=[Content.from_function_call(call_id="a", name="f", arguments="x")], message_id="m") + u2 = ChatResponseUpdate(contents=[Content.from_function_call(call_id="b", name="f", arguments="y")], message_id="m") resp = ChatResponse.from_chat_response_updates([u1, u2]) - fcs = [c for c in resp.messages[0].contents if isinstance(c, FunctionCallContent)] + fcs = [c for c in resp.messages[0].contents if c.type == "function_call"] assert len(fcs) == 2 @@ -1379,13 +1333,13 @@ def test_response_update_propagates_fields_and_metadata(): def test_text_coalescing_preserves_first_properties(): - t1 = TextContent("A", raw_representation={"r": 1}, additional_properties={"p": 1}) - t2 = TextContent("B") + t1 = Content.from_text("A", raw_representation={"r": 1}, additional_properties={"p": 1}) + t2 = Content.from_text("B") upd1 = ChatResponseUpdate(text=t1, message_id="x") upd2 = ChatResponseUpdate(text=t2, message_id="x") resp = ChatResponse.from_chat_response_updates([upd1, upd2]) # After coalescing there should be a single TextContent with merged text and preserved props from first - items = [c for c in resp.messages[0].contents if isinstance(c, TextContent)] + items = [c for c in resp.messages[0].contents if c.type == "text"] assert len(items) >= 1 assert items[0].text == "AB" assert items[0].raw_representation == {"r": 1} @@ -1393,9 +1347,9 @@ def test_text_coalescing_preserves_first_properties(): def test_function_call_content_parse_numeric_or_list(): - c_num = FunctionCallContent(call_id="1", name="f", arguments="123") + c_num = Content.from_function_call(call_id="1", name="f", arguments="123") assert c_num.parse_arguments() == {"raw": 123} - c_list = FunctionCallContent(call_id="1", name="f", arguments="[1,2]") + c_list = Content.from_function_call(call_id="1", name="f", arguments="[1,2]") assert c_list.parse_arguments() == {"raw": [1, 2]} @@ -1413,8 +1367,8 @@ def agent_run_response_async() -> AgentResponse: async def test_agent_run_response_from_async_generator(): async def gen(): - yield AgentResponseUpdate(contents=[TextContent("A")]) - yield AgentResponseUpdate(contents=[TextContent("B")]) + yield AgentResponseUpdate(contents=[Content.from_text("A")]) + yield AgentResponseUpdate(contents=[Content.from_text("B")]) r = await AgentResponse.from_agent_response_generator(gen()) assert r.text == "AB" @@ -1427,67 +1381,65 @@ def test_text_content_add_comprehensive_coverage(): """Test TextContent __add__ method with various combinations to improve coverage.""" # Test with None raw_representation - t1 = TextContent("Hello", raw_representation=None, annotations=None) - t2 = TextContent(" World", raw_representation=None, annotations=None) + t1 = Content.from_text("Hello", raw_representation=None, annotations=None) + t2 = Content.from_text(" World", raw_representation=None, annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation is None assert result.annotations is None # Test first has raw_representation, second has None - t1 = TextContent("Hello", raw_representation="raw1", annotations=None) - t2 = TextContent(" World", raw_representation=None, annotations=None) + t1 = Content.from_text("Hello", raw_representation="raw1", annotations=None) + t2 = Content.from_text(" World", raw_representation=None, annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation == "raw1" # Test first has None, second has raw_representation - t1 = TextContent("Hello", raw_representation=None, annotations=None) - t2 = TextContent(" World", raw_representation="raw2", annotations=None) + t1 = Content.from_text("Hello", raw_representation=None, annotations=None) + t2 = Content.from_text(" World", raw_representation="raw2", annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation == "raw2" # Test both have raw_representation (non-list) - t1 = TextContent("Hello", raw_representation="raw1", annotations=None) - t2 = TextContent(" World", raw_representation="raw2", annotations=None) + t1 = Content.from_text("Hello", raw_representation="raw1", annotations=None) + t2 = Content.from_text(" World", raw_representation="raw2", annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation == ["raw1", "raw2"] # Test first has list raw_representation, second has single - t1 = TextContent("Hello", raw_representation=["raw1", "raw2"], annotations=None) - t2 = TextContent(" World", raw_representation="raw3", annotations=None) + t1 = Content.from_text("Hello", raw_representation=["raw1", "raw2"], annotations=None) + t2 = Content.from_text(" World", raw_representation="raw3", annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation == ["raw1", "raw2", "raw3"] # Test both have list raw_representation - t1 = TextContent("Hello", raw_representation=["raw1", "raw2"], annotations=None) - t2 = TextContent(" World", raw_representation=["raw3", "raw4"], annotations=None) + t1 = Content.from_text("Hello", raw_representation=["raw1", "raw2"], annotations=None) + t2 = Content.from_text(" World", raw_representation=["raw3", "raw4"], annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation == ["raw1", "raw2", "raw3", "raw4"] # Test first has single raw_representation, second has list - t1 = TextContent("Hello", raw_representation="raw1", annotations=None) - t2 = TextContent(" World", raw_representation=["raw2", "raw3"], annotations=None) + t1 = Content.from_text("Hello", raw_representation="raw1", annotations=None) + t2 = Content.from_text(" World", raw_representation=["raw2", "raw3"], annotations=None) result = t1 + t2 assert result.text == "Hello World" assert result.raw_representation == ["raw1", "raw2", "raw3"] def test_text_content_iadd_coverage(): - """Test TextContent __iadd__ method for better coverage.""" + """Test TextContent += operator for better coverage.""" - t1 = TextContent("Hello", raw_representation="raw1", additional_properties={"key1": "val1"}) - t2 = TextContent(" World", raw_representation="raw2", additional_properties={"key2": "val2"}) + t1 = Content.from_text("Hello", raw_representation="raw1", additional_properties={"key1": "val1"}) + t2 = Content.from_text(" World", raw_representation="raw2", additional_properties={"key2": "val2"}) - original_id = id(t1) t1 += t2 - # Should modify in place - assert id(t1) == original_id + # Content doesn't implement __iadd__, so += creates a new object via __add__ assert t1.text == "Hello World" assert t1.raw_representation == ["raw1", "raw2"] assert t1.additional_properties == {"key1": "val1", "key2": "val2"} @@ -1496,23 +1448,22 @@ def test_text_content_iadd_coverage(): def test_text_reasoning_content_add_coverage(): """Test TextReasoningContent __add__ method for better coverage.""" - t1 = TextReasoningContent("Thinking 1") - t2 = TextReasoningContent(" Thinking 2") + t1 = Content.from_text_reasoning(text="Thinking 1") + t2 = Content.from_text_reasoning(text=" Thinking 2") result = t1 + t2 assert result.text == "Thinking 1 Thinking 2" def test_text_reasoning_content_iadd_coverage(): - """Test TextReasoningContent __iadd__ method for better coverage.""" + """Test TextReasoningContent += operator for better coverage.""" - t1 = TextReasoningContent("Thinking 1") - t2 = TextReasoningContent(" Thinking 2") + t1 = Content.from_text_reasoning(text="Thinking 1") + t2 = Content.from_text_reasoning(text=" Thinking 2") - original_id = id(t1) t1 += t2 - assert id(t1) == original_id + # Content doesn't implement __iadd__, so += creates a new object via __add__ assert t1.text == "Thinking 1 Thinking 2" @@ -1520,49 +1471,46 @@ def test_comprehensive_to_dict_exclude_options(): """Test to_dict methods with various exclude options for better coverage.""" # Test TextContent with exclude_none - text_content = TextContent("Hello", raw_representation=None, additional_properties={"prop": "val"}) + text_content = Content.from_text("Hello", raw_representation=None, additional_properties={"prop": "val"}) text_dict = text_content.to_dict(exclude_none=True) assert "raw_representation" not in text_dict - assert text_dict["prop"] == "val" + assert text_dict["additional_properties"]["prop"] == "val" # Test with custom exclude set text_dict_exclude = text_content.to_dict(exclude={"additional_properties"}) assert "additional_properties" not in text_dict_exclude assert "text" in text_dict_exclude - # Test UsageDetails with additional counts + # Test UsageDetails - it's a TypedDict now, not a class with to_dict usage = UsageDetails(input_token_count=5, custom_count=10) - usage_dict = usage.to_dict() - assert usage_dict["input_token_count"] == 5 - assert usage_dict["custom_count"] == 10 + assert usage["input_token_count"] == 5 + assert usage["custom_count"] == 10 - # Test UsageDetails exclude_none - usage_none = UsageDetails(input_token_count=5, output_token_count=None) - usage_dict_no_none = usage_none.to_dict(exclude_none=True) - assert "output_token_count" not in usage_dict_no_none - assert usage_dict_no_none["input_token_count"] == 5 + # Test UsageDetails exclude_none behavior isn't applicable to TypedDict + # TypedDict doesn't have a to_dict method def test_usage_details_iadd_edge_cases(): - """Test UsageDetails __iadd__ with edge cases for better coverage.""" + """Test UsageDetails addition with edge cases for better coverage.""" + from agent_framework._types import add_usage_details # Test with None values u1 = UsageDetails(input_token_count=None, output_token_count=5, custom1=10) u2 = UsageDetails(input_token_count=3, output_token_count=None, custom2=20) - u1 += u2 - assert u1.input_token_count == 3 - assert u1.output_token_count == 5 - assert u1.additional_counts["custom1"] == 10 - assert u1.additional_counts["custom2"] == 20 + result = add_usage_details(u1, u2) + assert result["input_token_count"] == 3 + assert result["output_token_count"] == 5 + assert result.get("custom1") == 10 + assert result.get("custom2") == 20 # Test merging additional counts u3 = UsageDetails(input_token_count=1, shared_count=5) u4 = UsageDetails(input_token_count=2, shared_count=15) - u3 += u4 - assert u3.input_token_count == 3 - assert u3.additional_counts["shared_count"] == 20 + result2 = add_usage_details(u3, u4) + assert result2["input_token_count"] == 3 + assert result2.get("shared_count") == 20 def test_chat_message_from_dict_with_mixed_content(): @@ -1579,9 +1527,9 @@ def test_chat_message_from_dict_with_mixed_content(): message = ChatMessage.from_dict(message_data) assert len(message.contents) == 3 # Unknown type is ignored - assert isinstance(message.contents[0], TextContent) - assert isinstance(message.contents[1], FunctionCallContent) - assert isinstance(message.contents[2], FunctionResultContent) + assert message.contents[0].type == "text" + assert message.contents[1].type == "function_call" + assert message.contents[2].type == "function_result" # Test round-trip message_dict = message.to_dict() @@ -1590,7 +1538,7 @@ def test_chat_message_from_dict_with_mixed_content(): def test_text_content_add_type_error(): """Test TextContent __add__ raises TypeError for incompatible types.""" - t1 = TextContent("Hello") + t1 = Content.from_text("Hello") with raises(TypeError, match="Incompatible type"): t1 + "not a TextContent" @@ -1601,12 +1549,13 @@ def test_comprehensive_serialization_methods(): # Test TextContent with all fields text_data = { + "type": "text", "text": "Hello world", "raw_representation": {"key": "value"}, - "prop": "val", + "additional_properties": {"prop": "val"}, "annotations": None, } - text_content = TextContent.from_dict(text_data) + text_content = Content.from_dict(text_data) assert text_content.text == "Hello world" assert text_content.raw_representation == {"key": "value"} assert text_content.additional_properties == {"prop": "val"} @@ -1614,7 +1563,7 @@ def test_comprehensive_serialization_methods(): # Test round-trip text_dict = text_content.to_dict() assert text_dict["text"] == "Hello world" - assert text_dict["prop"] == "val" + assert text_dict["additional_properties"] == {"prop": "val"} # Note: raw_representation is always excluded from to_dict() output # Test with exclude_none @@ -1622,8 +1571,13 @@ def test_comprehensive_serialization_methods(): assert "annotations" not in text_dict_no_none # Test FunctionResultContent - result_data = {"call_id": "call123", "result": "success", "additional_properties": {"meta": "data"}} - result_content = FunctionResultContent.from_dict(result_data) + result_data = { + "type": "function_result", + "call_id": "call123", + "result": "success", + "additional_properties": {"meta": "data"}, + } + result_content = Content.from_dict(result_data) assert result_content.call_id == "call123" assert result_content.result == "success" @@ -1633,9 +1587,9 @@ def test_chat_message_complex_content_serialization(): # Create a message with multiple content types contents = [ - TextContent("Hello"), - FunctionCallContent(call_id="call1", name="func", arguments={"arg": "val"}), - FunctionResultContent(call_id="call1", result="success"), + Content.from_text("Hello"), + Content.from_function_call(call_id="call1", name="func", arguments={"arg": "val"}), + Content.from_function_result(call_id="call1", result="success"), ] message = ChatMessage(role=Role.ASSISTANT, contents=contents) @@ -1650,9 +1604,9 @@ def test_chat_message_complex_content_serialization(): # Test from_dict round-trip reconstructed = ChatMessage.from_dict(message_dict) assert len(reconstructed.contents) == 3 - assert isinstance(reconstructed.contents[0], TextContent) - assert isinstance(reconstructed.contents[1], FunctionCallContent) - assert isinstance(reconstructed.contents[2], FunctionResultContent) + assert reconstructed.contents[0].type == "text" + assert reconstructed.contents[1].type == "function_call" + assert reconstructed.contents[2].type == "function_result" def test_usage_content_serialization_with_details(): @@ -1661,7 +1615,7 @@ def test_usage_content_serialization_with_details(): # Test from_dict with details as dict usage_data = { "type": "usage", - "details": { + "usage_details": { "type": "usage_details", "input_token_count": 10, "output_token_count": 20, @@ -1669,15 +1623,15 @@ def test_usage_content_serialization_with_details(): "custom_count": 5, }, } - usage_content = UsageContent.from_dict(usage_data) - assert isinstance(usage_content.details, UsageDetails) - assert usage_content.details.input_token_count == 10 - assert usage_content.details.additional_counts["custom_count"] == 5 + usage_content = Content(**usage_data) + assert isinstance(usage_content.usage_details, dict) + assert usage_content.usage_details["input_token_count"] == 10 + assert usage_content.usage_details["custom_count"] == 5 # Custom fields go directly in UsageDetails # Test to_dict with UsageDetails object usage_dict = usage_content.to_dict() - assert isinstance(usage_dict["details"], dict) - assert usage_dict["details"]["input_token_count"] == 10 + assert isinstance(usage_dict["usage_details"], dict) + assert usage_dict["usage_details"]["input_token_count"] == 10 def test_function_approval_response_content_serialization(): @@ -1695,8 +1649,8 @@ def test_function_approval_response_content_serialization(): "arguments": {"param": "value"}, }, } - response_content = FunctionApprovalResponseContent.from_dict(response_data) - assert isinstance(response_content.function_call, FunctionCallContent) + response_content = Content.from_dict(response_data) + assert response_content.function_call.type == "function_call" assert response_content.function_call.call_id == "call123" # Test to_dict with FunctionCallContent object @@ -1728,7 +1682,7 @@ def test_chat_response_complex_serialization(): assert len(response.messages) == 2 assert isinstance(response.messages[0], ChatMessage) assert isinstance(response.finish_reason, FinishReason) - assert isinstance(response.usage_details, UsageDetails) + assert isinstance(response.usage_details, dict) assert response.model_id == "gpt-4" # Should be stored as model_id # Test to_dict with complex objects @@ -1748,10 +1702,10 @@ def test_chat_response_update_all_content_types(): {"type": "text", "text": "Hello"}, {"type": "data", "data": b"base64data", "media_type": "text/plain"}, {"type": "uri", "uri": "http://example.com", "media_type": "text/html"}, - {"type": "error", "error": "An error occurred"}, + {"type": "error", "message": "An error occurred"}, {"type": "function_call", "call_id": "call1", "name": "func", "arguments": {}}, {"type": "function_result", "call_id": "call1", "result": "success"}, - {"type": "usage", "details": {"type": "usage_details", "input_token_count": 1}}, + {"type": "usage", "usage_details": {"input_token_count": 1}}, {"type": "hosted_file", "file_id": "file123"}, {"type": "hosted_vector_store", "vector_store_id": "vs123"}, { @@ -1771,18 +1725,18 @@ def test_chat_response_update_all_content_types(): update = ChatResponseUpdate.from_dict(update_data) assert len(update.contents) == 12 # unknown_type is skipped with warning - assert isinstance(update.contents[0], TextContent) - assert isinstance(update.contents[1], DataContent) - assert isinstance(update.contents[2], UriContent) - assert isinstance(update.contents[3], ErrorContent) - assert isinstance(update.contents[4], FunctionCallContent) - assert isinstance(update.contents[5], FunctionResultContent) - assert isinstance(update.contents[6], UsageContent) - assert isinstance(update.contents[7], HostedFileContent) - assert isinstance(update.contents[8], HostedVectorStoreContent) - assert isinstance(update.contents[9], FunctionApprovalRequestContent) - assert isinstance(update.contents[10], FunctionApprovalResponseContent) - assert isinstance(update.contents[11], TextReasoningContent) + assert update.contents[0].type == "text" + assert update.contents[1].type == "data" + assert update.contents[2].type == "uri" + assert update.contents[3].type == "error" + assert update.contents[4].type == "function_call" + assert update.contents[5].type == "function_result" + assert update.contents[6].type == "usage" + assert update.contents[7].type == "hosted_file" + assert update.contents[8].type == "hosted_vector_store" + assert update.contents[9].type == "function_approval_request" + assert update.contents[10].type == "function_approval_response" + assert update.contents[11].type == "text_reasoning" def test_agent_run_response_complex_serialization(): @@ -1804,7 +1758,7 @@ def test_agent_run_response_complex_serialization(): response = AgentResponse.from_dict(response_data) assert len(response.messages) == 2 assert isinstance(response.messages[0], ChatMessage) - assert isinstance(response.usage_details, UsageDetails) + assert isinstance(response.usage_details, dict) # Test to_dict response_dict = response.to_dict() @@ -1821,10 +1775,10 @@ def test_agent_run_response_update_all_content_types(): {"type": "text", "text": "Hello"}, {"type": "data", "data": b"base64data", "media_type": "text/plain"}, {"type": "uri", "uri": "http://example.com", "media_type": "text/html"}, - {"type": "error", "error": "An error occurred"}, + {"type": "error", "message": "An error occurred"}, {"type": "function_call", "call_id": "call1", "name": "func", "arguments": {}}, {"type": "function_result", "call_id": "call1", "result": "success"}, - {"type": "usage", "details": {"type": "usage_details", "input_token_count": 1}}, + {"type": "usage", "usage_details": {"input_token_count": 1}}, {"type": "hosted_file", "file_id": "file123"}, {"type": "hosted_vector_store", "vector_store_id": "vs123"}, { @@ -1868,7 +1822,7 @@ def test_agent_run_response_update_all_content_types(): "content_class,init_kwargs", [ pytest.param( - TextContent, + Content, { "type": "text", "text": "Hello world", @@ -1877,7 +1831,7 @@ def test_agent_run_response_update_all_content_types(): id="text_content", ), pytest.param( - TextReasoningContent, + Content, { "type": "text_reasoning", "text": "Reasoning text", @@ -1886,7 +1840,7 @@ def test_agent_run_response_update_all_content_types(): id="text_reasoning_content", ), pytest.param( - DataContent, + Content, { "type": "data", "uri": "data:text/plain;base64,dGVzdCBkYXRh", @@ -1894,7 +1848,7 @@ def test_agent_run_response_update_all_content_types(): id="data_content_with_uri", ), pytest.param( - DataContent, + Content, { "type": "data", "data": b"test data", @@ -1903,7 +1857,7 @@ def test_agent_run_response_update_all_content_types(): id="data_content_with_bytes", ), pytest.param( - UriContent, + Content, { "type": "uri", "uri": "http://example.com", @@ -1912,12 +1866,12 @@ def test_agent_run_response_update_all_content_types(): id="uri_content", ), pytest.param( - HostedFileContent, + Content, {"type": "hosted_file", "file_id": "file-123"}, id="hosted_file_content", ), pytest.param( - HostedVectorStoreContent, + Content, { "type": "hosted_vector_store", "vector_store_id": "vs-789", @@ -1925,7 +1879,7 @@ def test_agent_run_response_update_all_content_types(): id="hosted_vector_store_content", ), pytest.param( - FunctionCallContent, + Content, { "type": "function_call", "call_id": "call-1", @@ -1935,7 +1889,7 @@ def test_agent_run_response_update_all_content_types(): id="function_call_content", ), pytest.param( - FunctionResultContent, + Content, { "type": "function_result", "call_id": "call-1", @@ -1944,7 +1898,7 @@ def test_agent_run_response_update_all_content_types(): id="function_result_content", ), pytest.param( - ErrorContent, + Content, { "type": "error", "message": "Error occurred", @@ -1953,10 +1907,10 @@ def test_agent_run_response_update_all_content_types(): id="error_content", ), pytest.param( - UsageContent, + Content, { "type": "usage", - "details": { + "usage_details": { "type": "usage_details", "input_token_count": 10, "output_token_count": 20, @@ -1966,7 +1920,7 @@ def test_agent_run_response_update_all_content_types(): id="usage_content", ), pytest.param( - FunctionApprovalRequestContent, + Content, { "type": "function_approval_request", "id": "req-1", @@ -1975,7 +1929,7 @@ def test_agent_run_response_update_all_content_types(): id="function_approval_request", ), pytest.param( - FunctionApprovalResponseContent, + Content, { "type": "function_approval_response", "id": "resp-1", @@ -2078,10 +2032,10 @@ def test_agent_run_response_update_all_content_types(): ), ], ) -def test_content_roundtrip_serialization(content_class: type[BaseContent], init_kwargs: dict[str, Any]): +def test_content_roundtrip_serialization(content_class: type[Content], init_kwargs: dict[str, Any]): """Test to_dict/from_dict roundtrip for all content types.""" - # Create instance - content = content_class(**init_kwargs) + # Create instance using from_dict to handle nested dict-to-object conversions + content = content_class.from_dict(init_kwargs) # Serialize to dict content_dict = content.to_dict() @@ -2109,7 +2063,7 @@ def test_content_roundtrip_serialization(content_class: type[BaseContent], init_ continue # Special handling for DataContent created with 'data' parameter - if content_class == DataContent and key == "data": + if hasattr(content, "type") and content.type == "data" and key == "data": # DataContent converts 'data' to 'uri', so we skip checking 'data' attribute # Instead we verify that uri and media_type are set correctly assert hasattr(reconstructed, "uri") @@ -2131,108 +2085,48 @@ def test_content_roundtrip_serialization(content_class: type[BaseContent], init_ if isinstance(value[0], dict) and hasattr(reconstructed_value[0], "to_dict"): # Compare each item by serializing the reconstructed object assert len(reconstructed_value) == len(value) - + for orig_dict, recon_obj in zip(value, reconstructed_value): + recon_dict = recon_obj.to_dict() + # Compare all keys from original dict (reconstructed may have extra default fields) + for k, v in orig_dict.items(): + assert k in recon_dict, f"Key '{k}' missing from reconstructed dict" + # For nested lists, recursively compare + if isinstance(v, list) and v and isinstance(v[0], dict): + assert len(recon_dict[k]) == len(v) + for orig_item, recon_item in zip(v, recon_dict[k]): + # Compare essential keys, ignoring fields like additional_properties + for item_key, item_val in orig_item.items(): + assert item_key in recon_item + assert recon_item[item_key] == item_val + else: + assert recon_dict[k] == v, f"Value mismatch for key '{k}'" else: assert reconstructed_value == value # Special handling for dicts that get converted to objects (like UsageDetails, FunctionCallContent) elif isinstance(value, dict) and hasattr(reconstructed_value, "to_dict"): - # Compare the dict with the serialized form of the object, excluding 'type' key + # Compare the dict with the serialized form of the object reconstructed_dict = reconstructed_value.to_dict() - if value: - assert len(reconstructed_dict) == len(value) + # Verify all keys from the original dict are in the reconstructed dict + for k, v in value.items(): + assert k in reconstructed_dict, f"Key '{k}' missing from reconstructed dict" + assert reconstructed_dict[k] == v, f"Value mismatch for key '{k}'" else: assert reconstructed_value == value def test_text_content_with_annotations_serialization(): - """Test TextContent with CitationAnnotation and TextSpanRegion roundtrip serialization.""" - # Create TextSpanRegion - region = TextSpanRegion(start_index=0, end_index=5) - - # Create CitationAnnotation with region - citation = CitationAnnotation( - title="Test Citation", - url="http://example.com/citation", - file_id="file-123", - tool_name="test_tool", - snippet="This is a test snippet", - annotated_regions=[region], - additional_properties={"custom": "value"}, - ) - - # Create TextContent with annotation - content = TextContent( - text="Hello world", annotations=[citation], additional_properties={"content_key": "content_val"} - ) - - # Serialize to dict - content_dict = content.to_dict() - - # Verify structure - assert content_dict["type"] == "text" - assert content_dict["text"] == "Hello world" - assert content_dict["content_key"] == "content_val" - assert len(content_dict["annotations"]) == 1 - - # Verify annotation structure - annotation_dict = content_dict["annotations"][0] - assert annotation_dict["type"] == "citation" - assert annotation_dict["title"] == "Test Citation" - assert annotation_dict["url"] == "http://example.com/citation" - assert annotation_dict["file_id"] == "file-123" - assert annotation_dict["tool_name"] == "test_tool" - assert annotation_dict["snippet"] == "This is a test snippet" - assert annotation_dict["custom"] == "value" - - # Verify region structure - assert len(annotation_dict["annotated_regions"]) == 1 - region_dict = annotation_dict["annotated_regions"][0] - assert region_dict["type"] == "text_span" - assert region_dict["start_index"] == 0 - assert region_dict["end_index"] == 5 - - # Deserialize from dict - reconstructed = TextContent.from_dict(content_dict) - - # Verify reconstructed content - assert isinstance(reconstructed, TextContent) - assert reconstructed.text == "Hello world" - assert reconstructed.type == "text" - assert reconstructed.additional_properties == {"content_key": "content_val"} - - # Verify reconstructed annotation - assert len(reconstructed.annotations) == 1 # type: ignore[arg-type] - recon_annotation = reconstructed.annotations[0] # type: ignore[index] - assert isinstance(recon_annotation, CitationAnnotation) - assert recon_annotation.title == "Test Citation" - assert recon_annotation.url == "http://example.com/citation" - assert recon_annotation.file_id == "file-123" - assert recon_annotation.tool_name == "test_tool" - assert recon_annotation.snippet == "This is a test snippet" - assert recon_annotation.additional_properties == {"custom": "value"} - - # Verify reconstructed region - assert len(recon_annotation.annotated_regions) == 1 # type: ignore[arg-type] - recon_region = recon_annotation.annotated_regions[0] # type: ignore[index] - assert isinstance(recon_region, TextSpanRegion) - assert recon_region.start_index == 0 - assert recon_region.end_index == 5 - assert recon_region.type == "text_span" - - -def test_text_content_with_multiple_annotations_serialization(): """Test TextContent with multiple annotations roundtrip serialization.""" # Create multiple regions - region1 = TextSpanRegion(start_index=0, end_index=5) - region2 = TextSpanRegion(start_index=6, end_index=11) + region1 = TextSpanRegion(type="text_span", start_index=0, end_index=5) + region2 = TextSpanRegion(type="text_span", start_index=6, end_index=11) # Create multiple citations - citation1 = CitationAnnotation(title="Citation 1", url="http://example.com/1", annotated_regions=[region1]) + citation1 = Annotation(type="citation", title="Citation 1", url="http://example.com/1", annotated_regions=[region1]) - citation2 = CitationAnnotation(title="Citation 2", url="http://example.com/2", annotated_regions=[region2]) + citation2 = Annotation(type="citation", title="Citation 2", url="http://example.com/2", annotated_regions=[region2]) # Create TextContent with multiple annotations - content = TextContent(text="Hello world", annotations=[citation1, citation2]) + content = Content.from_text(text="Hello world", annotations=[citation1, citation2]) # Serialize content_dict = content.to_dict() @@ -2243,14 +2137,15 @@ def test_text_content_with_multiple_annotations_serialization(): assert content_dict["annotations"][1]["title"] == "Citation 2" # Deserialize - reconstructed = TextContent.from_dict(content_dict) + reconstructed = Content.from_dict(content_dict) # Verify reconstruction assert len(reconstructed.annotations) == 2 - assert all(isinstance(ann, CitationAnnotation) for ann in reconstructed.annotations) - assert reconstructed.annotations[0].title == "Citation 1" - assert reconstructed.annotations[1].title == "Citation 2" - assert all(isinstance(ann.annotated_regions[0], TextSpanRegion) for ann in reconstructed.annotations) + # Annotation are TypedDicts (dicts at runtime) + assert all(isinstance(ann, dict) for ann in reconstructed.annotations) + assert reconstructed.annotations[0]["title"] == "Citation 1" + assert reconstructed.annotations[1]["title"] == "Citation 2" + assert all(isinstance(ann["annotated_regions"][0], dict) for ann in reconstructed.annotations) # region prepare_function_call_results with Pydantic models diff --git a/python/packages/core/tests/openai/test_openai_assistants_client.py b/python/packages/core/tests/openai/test_openai_assistants_client.py index 424c1cc044..81ac7395e4 100644 --- a/python/packages/core/tests/openai/test_openai_assistants_client.py +++ b/python/packages/core/tests/openai/test_openai_assistants_client.py @@ -19,15 +19,10 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - FunctionCallContent, - FunctionResultContent, + Content, HostedCodeInterpreterTool, HostedFileSearchTool, - HostedVectorStoreContent, Role, - TextContent, - UriContent, - UsageContent, ai_function, ) from agent_framework.exceptions import ServiceInitializationError @@ -68,7 +63,7 @@ def create_test_openai_assistants_client( return client -async def create_vector_store(client: OpenAIAssistantsClient) -> tuple[str, HostedVectorStoreContent]: +async def create_vector_store(client: OpenAIAssistantsClient) -> tuple[str, Content]: """Create a vector store with sample documents for testing.""" file = await client.client.files.create( file=("todays_weather.txt", b"The weather today is sunny with a high of 25C."), purpose="user_data" @@ -81,7 +76,7 @@ async def create_vector_store(client: OpenAIAssistantsClient) -> tuple[str, Host if result.last_error is not None: raise Exception(f"Vector store file processing failed with status: {result.last_error.message}") - return file.id, HostedVectorStoreContent(vector_store_id=vector_store.id) + return file.id, Content.from_hosted_vector_store(vector_store_id=vector_store.id) async def delete_vector_store(client: OpenAIAssistantsClient, file_id: str, vector_store_id: str) -> None: @@ -464,7 +459,7 @@ async def test_process_stream_events_requires_action(mock_async_openai: MagicMoc chat_client = create_test_openai_assistants_client(mock_async_openai) # Mock the _parse_function_calls_from_assistants method to return test content - test_function_content = FunctionCallContent(call_id="call-123", name="test_func", arguments={"arg": "value"}) + test_function_content = Content.from_function_call(call_id="call-123", name="test_func", arguments={"arg": "value"}) chat_client._parse_function_calls_from_assistants = MagicMock(return_value=[test_function_content]) # type: ignore # Create a mock Run object @@ -578,10 +573,10 @@ async def async_iterator() -> Any: # Check the usage content usage_content = update.contents[0] - assert isinstance(usage_content, UsageContent) - assert usage_content.details.input_token_count == 100 - assert usage_content.details.output_token_count == 50 - assert usage_content.details.total_token_count == 150 + assert usage_content.type == "usage" + assert usage_content.usage_details["input_token_count"] == 100 + assert usage_content.usage_details["output_token_count"] == 50 + assert usage_content.usage_details["total_token_count"] == 150 assert update.raw_representation == mock_run @@ -609,7 +604,7 @@ def test_parse_function_calls_from_assistants_basic(mock_async_openai: MagicMock # Test that one function call content was created assert len(contents) == 1 - assert isinstance(contents[0], FunctionCallContent) + assert contents[0].type == "function_call" assert contents[0].name == "get_weather" assert contents[0].arguments == {"location": "Seattle"} @@ -830,7 +825,7 @@ def test_prepare_options_with_image_content(mock_async_openai: MagicMock) -> Non chat_client = create_test_openai_assistants_client(mock_async_openai) # Create message with image content - image_content = UriContent(uri="https://example.com/image.jpg", media_type="image/jpeg") + image_content = Content.from_uri(uri="https://example.com/image.jpg", media_type="image/jpeg") messages = [ChatMessage(role=Role.USER, contents=[image_content])] # Call the method @@ -861,7 +856,7 @@ def test_prepare_tool_outputs_for_assistants_valid(mock_async_openai: MagicMock) chat_client = create_test_openai_assistants_client(mock_async_openai) call_id = json.dumps(["run-123", "call-456"]) - function_result = FunctionResultContent(call_id=call_id, result="Function executed successfully") + function_result = Content.from_function_result(call_id=call_id, result="Function executed successfully") run_id, tool_outputs = chat_client._prepare_tool_outputs_for_assistants([function_result]) # type: ignore @@ -881,8 +876,8 @@ def test_prepare_tool_outputs_for_assistants_mismatched_run_ids( # Create function results with different run IDs call_id1 = json.dumps(["run-123", "call-456"]) call_id2 = json.dumps(["run-789", "call-xyz"]) # Different run ID - function_result1 = FunctionResultContent(call_id=call_id1, result="Result 1") - function_result2 = FunctionResultContent(call_id=call_id2, result="Result 2") + function_result1 = Content.from_function_result(call_id=call_id1, result="Result 1") + function_result2 = Content.from_function_result(call_id=call_id2, result="Result 2") run_id, tool_outputs = chat_client._prepare_tool_outputs_for_assistants([function_result1, function_result2]) # type: ignore @@ -1006,7 +1001,7 @@ async def test_streaming() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert any(word in full_message.lower() for word in ["sunny", "25", "weather", "seattle"]) @@ -1035,7 +1030,7 @@ async def test_streaming_tools() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text assert any(word in full_message.lower() for word in ["sunny", "25", "weather"]) @@ -1121,7 +1116,7 @@ async def test_file_search_streaming() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text await delete_vector_store(openai_assistants_client, file_id, vector_store.vector_store_id) diff --git a/python/packages/core/tests/openai/test_openai_chat_client.py b/python/packages/core/tests/openai/test_openai_chat_client.py index 1f1d624345..32886de805 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client.py +++ b/python/packages/core/tests/openai/test_openai_chat_client.py @@ -14,8 +14,7 @@ ChatClientProtocol, ChatMessage, ChatResponse, - DataContent, - FunctionResultContent, + Content, HostedWebSearchTool, ToolProtocol, ai_function, @@ -282,7 +281,9 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s client = OpenAIChatClient() # Test with empty list (falsy but not None) - message_with_empty_list = ChatMessage(role="tool", contents=[FunctionResultContent(call_id="call-123", result=[])]) + message_with_empty_list = ChatMessage( + role="tool", contents=[Content.from_function_result(call_id="call-123", result=[])] + ) openai_messages = client._prepare_message_for_openai(message_with_empty_list) assert len(openai_messages) == 1 @@ -290,7 +291,7 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s # Test with empty string (falsy but not None) message_with_empty_string = ChatMessage( - role="tool", contents=[FunctionResultContent(call_id="call-456", result="")] + role="tool", contents=[Content.from_function_result(call_id="call-456", result="")] ) openai_messages = client._prepare_message_for_openai(message_with_empty_string) @@ -298,7 +299,9 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s assert openai_messages[0]["content"] == "" # Empty string should be preserved # Test with False (falsy but not None) - message_with_false = ChatMessage(role="tool", contents=[FunctionResultContent(call_id="call-789", result=False)]) + message_with_false = ChatMessage( + role="tool", contents=[Content.from_function_result(call_id="call-789", result=False)] + ) openai_messages = client._prepare_message_for_openai(message_with_false) assert len(openai_messages) == 1 @@ -317,7 +320,7 @@ def test_function_result_exception_handling(openai_unit_test_env: dict[str, str] message_with_exception = ChatMessage( role="tool", contents=[ - FunctionResultContent(call_id="call-123", result="Error: Function failed.", exception=test_exception) + Content.from_function_result(call_id="call-123", result="Error: Function failed.", exception=test_exception) ], ) @@ -339,7 +342,7 @@ def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dic client = OpenAIChatClient() # Test DataContent with image media type - image_data_content = DataContent( + image_data_content = Content.from_uri( uri="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg==", media_type="image/png", ) @@ -351,7 +354,7 @@ def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dic assert result["image_url"]["url"] == image_data_content.uri # Test DataContent with non-image media type should use default model_dump - text_data_content = DataContent(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain") + text_data_content = Content.from_uri(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain") result = client._prepare_content_for_openai(text_data_content) # type: ignore @@ -361,7 +364,7 @@ def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dic assert result["media_type"] == "text/plain" # Test DataContent with audio media type - audio_data_content = DataContent( + audio_data_content = Content.from_uri( uri="data:audio/wav;base64,UklGRjBEAABXQVZFZm10IBAAAAABAAEAQB8AAEAfAAABAAgAZGF0YQwEAAAAAAAAAAAA", media_type="audio/wav", ) @@ -375,7 +378,9 @@ def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dic assert result["input_audio"]["format"] == "wav" # Test DataContent with MP3 audio - mp3_data_content = DataContent(uri="data:audio/mp3;base64,//uQAAAAWGluZwAAAA8AAAACAAACcQ==", media_type="audio/mp3") + mp3_data_content = Content.from_uri( + uri="data:audio/mp3;base64,//uQAAAAWGluZwAAAA8AAAACAAACcQ==", media_type="audio/mp3" + ) result = client._prepare_content_for_openai(mp3_data_content) # type: ignore @@ -391,7 +396,7 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: client = OpenAIChatClient() # Test PDF without filename - should omit filename in OpenAI payload - pdf_data_content = DataContent( + pdf_data_content = Content.from_uri( uri="data:application/pdf;base64,JVBERi0xLjQKJcfsj6IKNSAwIG9iago8PC9UeXBlL0NhdGFsb2cvUGFnZXMgMiAwIFI+PgplbmRvYmoKMiAwIG9iago8PC9UeXBlL1BhZ2VzL0tpZHNbMyAwIFJdL0NvdW50IDE+PgplbmRvYmoKMyAwIG9iago8PC9UeXBlL1BhZ2UvTWVkaWFCb3ggWzAgMCA2MTIgNzkyXS9QYXJlbnQgMiAwIFIvUmVzb3VyY2VzPDwvRm9udDw8L0YxIDQgMCBSPj4+Pi9Db250ZW50cyA1IDAgUj4+CmVuZG9iago0IDAgb2JqCjw8L1R5cGUvRm9udC9TdWJ0eXBlL1R5cGUxL0Jhc2VGb250L0hlbHZldGljYT4+CmVuZG9iago1IDAgb2JqCjw8L0xlbmd0aCA0ND4+CnN0cmVhbQpCVApxCjcwIDUwIFRECi9GMSA4IFRmCihIZWxsbyBXb3JsZCEpIFRqCkVUCmVuZHN0cmVhbQplbmRvYmoKeHJlZgowIDYKMDAwMDAwMDAwMCA2NTUzNSBmIAowMDAwMDAwMDA5IDAwMDAwIG4gCjAwMDAwMDAwNTggMDAwMDAgbiAKMDAwMDAwMDExNSAwMDAwMCBuIAowMDAwMDAwMjQ1IDAwMDAwIG4gCjAwMDAwMDAzMDcgMDAwMDAgbiAKdHJhaWxlcgo8PC9TaXplIDYvUm9vdCAxIDAgUj4+CnN0YXJ0eHJlZgo0MDUKJSVFT0Y=", media_type="application/pdf", ) @@ -407,7 +412,7 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert result["file"]["file_data"] == pdf_data_content.uri # Test PDF with custom filename via additional_properties - pdf_with_filename = DataContent( + pdf_with_filename = Content.from_uri( uri="data:application/pdf;base64,JVBERi0xLjQ=", media_type="application/pdf", additional_properties={"filename": "report.pdf"}, @@ -441,7 +446,7 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: for case in test_cases: # Test without filename - doc_content = DataContent( + doc_content = Content.from_uri( uri=f"data:{case['media_type']};base64,{case['base64']}", media_type=case["media_type"], ) @@ -454,7 +459,7 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert result["file"]["file_data"] == doc_content.uri # Test with filename - should now use file format with filename - doc_with_filename = DataContent( + doc_with_filename = Content.from_uri( uri=f"data:{case['media_type']};base64,{case['base64']}", media_type=case["media_type"], additional_properties={"filename": case["filename"]}, @@ -468,7 +473,7 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert result["file"]["file_data"] == doc_with_filename.uri # Test edge case: empty additional_properties dict - pdf_empty_props = DataContent( + pdf_empty_props = Content.from_uri( uri="data:application/pdf;base64,JVBERi0xLjQ=", media_type="application/pdf", additional_properties={}, @@ -480,7 +485,7 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert "filename" not in result["file"] # Test edge case: None filename in additional_properties - pdf_none_filename = DataContent( + pdf_none_filename = Content.from_uri( uri="data:application/pdf;base64,JVBERi0xLjQ=", media_type="application/pdf", additional_properties={"filename": None}, diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index c91297d7df..644aeccd7c 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -33,26 +33,13 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - CodeInterpreterToolCallContent, - CodeInterpreterToolResultContent, - DataContent, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, + Content, HostedCodeInterpreterTool, - HostedFileContent, HostedFileSearchTool, HostedImageGenerationTool, HostedMCPTool, - HostedVectorStoreContent, HostedWebSearchTool, - ImageGenerationToolCallContent, - ImageGenerationToolResultContent, Role, - TextContent, - TextReasoningContent, - UriContent, ai_function, ) from agent_framework.exceptions import ( @@ -81,7 +68,7 @@ class OutputStruct(BaseModel): async def create_vector_store( client: OpenAIResponsesClient, -) -> tuple[str, HostedVectorStoreContent]: +) -> tuple[str, Content]: """Create a vector store with sample documents for testing.""" file = await client.client.files.create( file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), @@ -99,7 +86,7 @@ async def create_vector_store( if result.last_error is not None: raise Exception(f"Vector store file processing failed with status: {result.last_error.message}") - return file.id, HostedVectorStoreContent(vector_store_id=vector_store.id) + return file.id, Content.from_hosted_vector_store(vector_store_id=vector_store.id) async def delete_vector_store(client: OpenAIResponsesClient, file_id: str, vector_store_id: str) -> None: @@ -285,7 +272,7 @@ def test_file_search_tool_with_invalid_inputs() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test with invalid inputs type (should trigger ValueError) - file_search_tool = HostedFileSearchTool(inputs=[HostedFileContent(file_id="invalid")]) + file_search_tool = HostedFileSearchTool(inputs=[Content.from_hosted_file(file_id="invalid")]) # Should raise an error due to invalid inputs with pytest.raises(ValueError, match="HostedFileSearchTool requires inputs to be of type"): @@ -314,7 +301,7 @@ def test_code_interpreter_tool_variations() -> None: # Test code interpreter with files code_tool_with_files = HostedCodeInterpreterTool( - inputs=[HostedFileContent(file_id="file1"), HostedFileContent(file_id="file2")] + inputs=[Content.from_hosted_file(file_id="file1"), Content.from_hosted_file(file_id="file2")] ) with pytest.raises(ServiceResponseException): @@ -367,14 +354,14 @@ def test_chat_message_parsing_with_function_calls() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Create messages with function call and result content - function_call = FunctionCallContent( + function_call = Content.from_function_call( call_id="test-call-id", name="test_function", arguments='{"param": "value"}', additional_properties={"fc_id": "test-fc-id"}, ) - function_result = FunctionResultContent(call_id="test-call-id", result="Function executed successfully") + function_result = Content.from_function_result(call_id="test-call-id", result="Function executed successfully") messages = [ ChatMessage(role="user", text="Call a function"), @@ -516,7 +503,7 @@ def test_response_content_creation_with_annotations() -> None: response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) >= 1 - assert isinstance(response.messages[0].contents[0], TextContent) + assert response.messages[0].contents[0].type == "text" assert response.messages[0].contents[0].text == "Text with annotations." assert response.messages[0].contents[0].annotations is not None @@ -547,7 +534,7 @@ def test_response_content_creation_with_refusal() -> None: response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 1 - assert isinstance(response.messages[0].contents[0], TextContent) + assert response.messages[0].contents[0].type == "text" assert response.messages[0].contents[0].text == "I cannot provide that information." @@ -577,7 +564,7 @@ def test_response_content_creation_with_reasoning() -> None: response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 2 - assert isinstance(response.messages[0].contents[0], TextReasoningContent) + assert response.messages[0].contents[0].type == "text_reasoning" assert response.messages[0].contents[0].text == "Reasoning step" @@ -614,13 +601,13 @@ def test_response_content_creation_with_code_interpreter() -> None: assert len(response.messages[0].contents) == 2 call_content, result_content = response.messages[0].contents - assert isinstance(call_content, CodeInterpreterToolCallContent) + assert call_content.type == "code_interpreter_tool_call" assert call_content.inputs is not None - assert isinstance(call_content.inputs[0], TextContent) - assert isinstance(result_content, CodeInterpreterToolResultContent) + assert call_content.inputs[0].type == "text" + assert result_content.type == "code_interpreter_tool_result" assert result_content.outputs is not None - assert any(isinstance(out, TextContent) for out in result_content.outputs) - assert any(isinstance(out, UriContent) for out in result_content.outputs) + assert any(out.type == "text" for out in result_content.outputs) + assert any(out.type == "uri" for out in result_content.outputs) def test_response_content_creation_with_function_call() -> None: @@ -648,7 +635,7 @@ def test_response_content_creation_with_function_call() -> None: response = client._parse_response_from_openai(mock_response, options={}) # type: ignore assert len(response.messages[0].contents) == 1 - assert isinstance(response.messages[0].contents[0], FunctionCallContent) + assert response.messages[0].contents[0].type == "function_call" function_call = response.messages[0].contents[0] assert function_call.call_id == "call_123" assert function_call.name == "get_weather" @@ -708,7 +695,7 @@ def test_parse_response_from_openai_with_mcp_approval_request() -> None: response = client._parse_response_from_openai(mock_response, options={}) # type: ignore - assert isinstance(response.messages[0].contents[0], FunctionApprovalRequestContent) + assert response.messages[0].contents[0].type == "function_approval_request" req = response.messages[0].contents[0] assert req.id == "approval-1" assert req.function_call.name == "do_sensitive_action" @@ -874,8 +861,8 @@ def test_parse_chunk_from_openai_with_mcp_approval_request() -> None: mock_event.item = mock_item update = client._parse_chunk_from_openai(mock_event, chat_options, function_call_ids) - assert any(isinstance(c, FunctionApprovalRequestContent) for c in update.contents) - fa = next(c for c in update.contents if isinstance(c, FunctionApprovalRequestContent)) + assert any(c.type == "function_approval_request" for c in update.contents) + fa = next(c for c in update.contents if c.type == "function_approval_request") assert fa.id == "approval-stream-1" assert fa.function_call.name == "do_stream_action" @@ -925,12 +912,12 @@ async def test_end_to_end_mcp_approval_flow(span_exporter) -> None: with patch.object(client.client.responses, "create", side_effect=[mock_response1, mock_response2]) as mock_create: # First call: get the approval request response = await client.get_response(messages=[ChatMessage(role="user", text="Trigger approval")]) - assert isinstance(response.messages[0].contents[0], FunctionApprovalRequestContent) + assert response.messages[0].contents[0].type == "function_approval_request" req = response.messages[0].contents[0] assert req.id == "approval-1" # Build a user approval and send it (include required function_call) - approval = FunctionApprovalResponseContent(approved=True, id=req.id, function_call=req.function_call) + approval = Content.from_function_approval_response(approved=True, id=req.id, function_call=req.function_call) approval_message = ChatMessage(role="user", contents=[approval]) _ = await client.get_response(messages=[approval_message]) @@ -961,9 +948,9 @@ def test_usage_details_basic() -> None: details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None - assert details.input_token_count == 100 - assert details.output_token_count == 50 - assert details.total_token_count == 150 + assert details["input_token_count"] == 100 + assert details["output_token_count"] == 50 + assert details["total_token_count"] == 150 def test_usage_details_with_cached_tokens() -> None: @@ -980,8 +967,8 @@ def test_usage_details_with_cached_tokens() -> None: details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None - assert details.input_token_count == 200 - assert details.additional_counts["openai.cached_input_tokens"] == 25 + assert details["input_token_count"] == 200 + assert details["openai.cached_input_tokens"] == 25 def test_usage_details_with_reasoning_tokens() -> None: @@ -998,8 +985,8 @@ def test_usage_details_with_reasoning_tokens() -> None: details = client._parse_usage_from_openai(mock_usage) # type: ignore assert details is not None - assert details.output_token_count == 80 - assert details.additional_counts["openai.reasoning_tokens"] == 30 + assert details["output_token_count"] == 80 + assert details["openai.reasoning_tokens"] == 30 def test_get_metadata_from_response() -> None: @@ -1098,7 +1085,7 @@ def test_streaming_annotation_added_with_file_path() -> None: assert len(response.contents) == 1 content = response.contents[0] - assert isinstance(content, HostedFileContent) + assert content.type == "hosted_file" assert content.file_id == "file-abc123" assert content.additional_properties is not None assert content.additional_properties.get("annotation_index") == 0 @@ -1125,7 +1112,7 @@ def test_streaming_annotation_added_with_file_citation() -> None: assert len(response.contents) == 1 content = response.contents[0] - assert isinstance(content, HostedFileContent) + assert content.type == "hosted_file" assert content.file_id == "file-xyz789" assert content.additional_properties is not None assert content.additional_properties.get("filename") == "sample.txt" @@ -1154,7 +1141,7 @@ def test_streaming_annotation_added_with_container_file_citation() -> None: assert len(response.contents) == 1 content = response.contents[0] - assert isinstance(content, HostedFileContent) + assert content.type == "hosted_file" assert content.file_id == "file-container123" assert content.additional_properties is not None assert content.additional_properties.get("container_id") == "container-456" @@ -1228,7 +1215,7 @@ def test_prepare_content_for_openai_image_content() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test image content with detail parameter and file_id - image_content_with_detail = UriContent( + image_content_with_detail = Content.from_uri( uri="https://example.com/image.jpg", media_type="image/jpeg", additional_properties={"detail": "high", "file_id": "file_123"}, @@ -1240,7 +1227,7 @@ def test_prepare_content_for_openai_image_content() -> None: assert result["file_id"] == "file_123" # Test image content without additional properties (defaults) - image_content_basic = UriContent(uri="https://example.com/basic.png", media_type="image/png") + image_content_basic = Content.from_uri(uri="https://example.com/basic.png", media_type="image/png") result = client._prepare_content_for_openai(Role.USER, image_content_basic, {}) # type: ignore assert result["type"] == "input_image" assert result["detail"] == "auto" @@ -1252,14 +1239,14 @@ def test_prepare_content_for_openai_audio_content() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test WAV audio content - wav_content = UriContent(uri="data:audio/wav;base64,abc123", media_type="audio/wav") + wav_content = Content.from_uri(uri="data:audio/wav;base64,abc123", media_type="audio/wav") result = client._prepare_content_for_openai(Role.USER, wav_content, {}) # type: ignore assert result["type"] == "input_audio" assert result["input_audio"]["data"] == "data:audio/wav;base64,abc123" assert result["input_audio"]["format"] == "wav" # Test MP3 audio content - mp3_content = UriContent(uri="data:audio/mp3;base64,def456", media_type="audio/mp3") + mp3_content = Content.from_uri(uri="data:audio/mp3;base64,def456", media_type="audio/mp3") result = client._prepare_content_for_openai(Role.USER, mp3_content, {}) # type: ignore assert result["type"] == "input_audio" assert result["input_audio"]["format"] == "mp3" @@ -1270,12 +1257,12 @@ def test_prepare_content_for_openai_unsupported_content() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test unsupported audio format - unsupported_audio = UriContent(uri="data:audio/ogg;base64,ghi789", media_type="audio/ogg") + unsupported_audio = Content.from_uri(uri="data:audio/ogg;base64,ghi789", media_type="audio/ogg") result = client._prepare_content_for_openai(Role.USER, unsupported_audio, {}) # type: ignore assert result == {} # Test non-media content - text_uri_content = UriContent(uri="https://example.com/document.txt", media_type="text/plain") + text_uri_content = Content.from_uri(uri="https://example.com/document.txt", media_type="text/plain") result = client._prepare_content_for_openai(Role.USER, text_uri_content, {}) # type: ignore assert result == {} @@ -1299,11 +1286,9 @@ def test_parse_chunk_from_openai_code_interpreter() -> None: result = client._parse_chunk_from_openai(mock_event_image, chat_options, function_call_ids) # type: ignore assert len(result.contents) == 1 - assert isinstance(result.contents[0], CodeInterpreterToolResultContent) + assert result.contents[0].type == "code_interpreter_tool_result" assert result.contents[0].outputs - assert any( - isinstance(out, UriContent) and out.uri == "https://example.com/plot.png" for out in result.contents[0].outputs - ) + assert any(out.type == "uri" and out.uri == "https://example.com/plot.png" for out in result.contents[0].outputs) def test_parse_chunk_from_openai_reasoning() -> None: @@ -1324,7 +1309,7 @@ def test_parse_chunk_from_openai_reasoning() -> None: result = client._parse_chunk_from_openai(mock_event_reasoning, chat_options, function_call_ids) # type: ignore assert len(result.contents) == 1 - assert isinstance(result.contents[0], TextReasoningContent) + assert result.contents[0].type == "text_reasoning" assert result.contents[0].text == "Analyzing the problem step by step..." if result.contents[0].additional_properties: assert result.contents[0].additional_properties["summary"] == "Problem analysis summary" @@ -1335,7 +1320,7 @@ def test_prepare_content_for_openai_text_reasoning_comprehensive() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Test TextReasoningContent with all additional properties - comprehensive_reasoning = TextReasoningContent( + comprehensive_reasoning = Content.from_text_reasoning( text="Comprehensive reasoning summary", additional_properties={ "status": "in_progress", @@ -1371,7 +1356,7 @@ def test_streaming_reasoning_text_delta_event() -> None: response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 - assert isinstance(response.contents[0], TextReasoningContent) + assert response.contents[0].type == "text_reasoning" assert response.contents[0].text == "reasoning delta" assert response.contents[0].raw_representation == event mock_metadata.assert_called_once_with(event) @@ -1396,7 +1381,7 @@ def test_streaming_reasoning_text_done_event() -> None: response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 - assert isinstance(response.contents[0], TextReasoningContent) + assert response.contents[0].type == "text_reasoning" assert response.contents[0].text == "complete reasoning" assert response.contents[0].raw_representation == event mock_metadata.assert_called_once_with(event) @@ -1422,7 +1407,7 @@ def test_streaming_reasoning_summary_text_delta_event() -> None: response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 - assert isinstance(response.contents[0], TextReasoningContent) + assert response.contents[0].type == "text_reasoning" assert response.contents[0].text == "summary delta" assert response.contents[0].raw_representation == event mock_metadata.assert_called_once_with(event) @@ -1447,7 +1432,7 @@ def test_streaming_reasoning_summary_text_done_event() -> None: response = client._parse_chunk_from_openai(event, chat_options, function_call_ids) # type: ignore assert len(response.contents) == 1 - assert isinstance(response.contents[0], TextReasoningContent) + assert response.contents[0].type == "text_reasoning" assert response.contents[0].text == "complete summary" assert response.contents[0].raw_representation == event mock_metadata.assert_called_once_with(event) @@ -1488,8 +1473,8 @@ def test_streaming_reasoning_events_preserve_metadata() -> None: assert reasoning_response.additional_properties == {"test": "metadata"} # Content types should be different - assert isinstance(text_response.contents[0], TextContent) - assert isinstance(reasoning_response.contents[0], TextReasoningContent) + assert text_response.contents[0].type == "text" + assert reasoning_response.contents[0].type == "text_reasoning" def test_parse_response_from_openai_image_generation_raw_base64(): @@ -1521,11 +1506,11 @@ def test_parse_response_from_openai_image_generation_raw_base64(): # Verify the response contains call + result with DataContent output assert len(response.messages[0].contents) == 2 call_content, result_content = response.messages[0].contents - assert isinstance(call_content, ImageGenerationToolCallContent) - assert isinstance(result_content, ImageGenerationToolResultContent) + assert call_content.type == "image_generation_tool_call" + assert result_content.type == "image_generation_tool_result" assert result_content.outputs data_out = result_content.outputs - assert isinstance(data_out, DataContent) + assert data_out.type == "data" assert data_out.uri.startswith("data:image/png;base64,") assert data_out.media_type == "image/png" @@ -1558,11 +1543,11 @@ def test_parse_response_from_openai_image_generation_existing_data_uri(): # Verify the response contains call + result with DataContent output assert len(response.messages[0].contents) == 2 call_content, result_content = response.messages[0].contents - assert isinstance(call_content, ImageGenerationToolCallContent) - assert isinstance(result_content, ImageGenerationToolResultContent) + assert call_content.type == "image_generation_tool_call" + assert result_content.type == "image_generation_tool_result" assert result_content.outputs data_out = result_content.outputs - assert isinstance(data_out, DataContent) + assert data_out.type == "data" assert data_out.uri == f"data:image/webp;base64,{valid_webp_base64}" assert data_out.media_type == "image/webp" @@ -1591,9 +1576,9 @@ def test_parse_response_from_openai_image_generation_format_detection(): with patch.object(client, "_get_metadata_from_response", return_value={}): response_jpeg = client._parse_response_from_openai(mock_response_jpeg, options={}) # type: ignore result_contents = response_jpeg.messages[0].contents - assert isinstance(result_contents[1], ImageGenerationToolResultContent) + assert result_contents[1].type == "image_generation_tool_result" outputs = result_contents[1].outputs - assert outputs and isinstance(outputs, DataContent) + assert outputs and outputs.type == "data" assert outputs.media_type == "image/jpeg" assert "data:image/jpeg;base64," in outputs.uri @@ -1617,7 +1602,7 @@ def test_parse_response_from_openai_image_generation_format_detection(): with patch.object(client, "_get_metadata_from_response", return_value={}): response_webp = client._parse_response_from_openai(mock_response_webp, options={}) # type: ignore outputs_webp = response_webp.messages[0].contents[1].outputs - assert outputs_webp and isinstance(outputs_webp, DataContent) + assert outputs_webp and outputs_webp.type == "data" assert outputs_webp.media_type == "image/webp" assert "data:image/webp;base64," in outputs_webp.uri @@ -1650,7 +1635,7 @@ def test_parse_response_from_openai_image_generation_fallback(): # Verify it falls back to PNG format for unrecognized binary data assert len(response.messages[0].contents) == 2 result_content = response.messages[0].contents[1] - assert isinstance(result_content, ImageGenerationToolResultContent) + assert result_content.type == "image_generation_tool_result" assert result_content.outputs content = result_content.outputs assert content.media_type == "image/png" @@ -1944,7 +1929,7 @@ async def test_integration_streaming_file_search() -> None: assert chunk is not None assert isinstance(chunk, ChatResponseUpdate) for content in chunk.contents: - if isinstance(content, TextContent) and content.text: + if content.type == "text" and content.text: full_message += content.text await delete_vector_store(openai_responses_client, file_id, vector_store.vector_store_id) diff --git a/python/packages/core/tests/test_observability_datetime.py b/python/packages/core/tests/test_observability_datetime.py index 6ad3d77e1a..2510a5b355 100644 --- a/python/packages/core/tests/test_observability_datetime.py +++ b/python/packages/core/tests/test_observability_datetime.py @@ -5,7 +5,7 @@ import json from datetime import datetime -from agent_framework._types import FunctionResultContent +from agent_framework import Content from agent_framework.observability import _to_otel_part @@ -14,7 +14,7 @@ def test_datetime_in_tool_results() -> None: Reproduces issue #2219 where datetime objects caused TypeError. """ - content = FunctionResultContent( + content = Content.from_function_result( call_id="test-call", result={"timestamp": datetime(2025, 11, 16, 10, 30, 0)}, ) diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 95225cb4a3..0fa2bfd952 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -11,9 +11,9 @@ BaseAgent, ChatMessage, ChatMessageStore, + Content, Role, SequentialBuilder, - TextContent, WorkflowOutputEvent, WorkflowRunState, WorkflowStatusEvent, @@ -49,7 +49,7 @@ async def run_stream( # type: ignore[override] **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: self.call_count += 1 - yield AgentResponseUpdate(contents=[TextContent(text=f"Response #{self.call_count}: {self.name}")]) + yield AgentResponseUpdate(contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")]) async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: diff --git a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py index ecf8b3d635..ac861d34b2 100644 --- a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py +++ b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py @@ -19,12 +19,9 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, + Content, RequestInfoEvent, Role, - TextContent, WorkflowBuilder, WorkflowContext, WorkflowOutputEvent, @@ -60,14 +57,14 @@ async def run_stream( """Simulate streaming with tool calls and results.""" # First update: some text yield AgentResponseUpdate( - contents=[TextContent(text="Let me search for that...")], + contents=[Content.from_text(text="Let me search for that...")], role=Role.ASSISTANT, ) # Second update: tool call (no text!) yield AgentResponseUpdate( contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_123", name="search", arguments={"query": "weather"}, @@ -79,7 +76,7 @@ async def run_stream( # Third update: tool result (no text!) yield AgentResponseUpdate( contents=[ - FunctionResultContent( + Content.from_function_result( call_id="call_123", result={"temperature": 72, "condition": "sunny"}, ) @@ -89,7 +86,7 @@ async def run_stream( # Fourth update: final text response yield AgentResponseUpdate( - contents=[TextContent(text="The weather is sunny, 72°F.")], + contents=[Content.from_text(text="The weather is sunny, 72°F.")], role=Role.ASSISTANT, ) @@ -113,25 +110,25 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None: # First event: text update assert events[0].data is not None - assert isinstance(events[0].data.contents[0], TextContent) + assert events[0].data.contents[0].type == "text" assert "Let me search" in events[0].data.contents[0].text # Second event: function call assert events[1].data is not None - assert isinstance(events[1].data.contents[0], FunctionCallContent) + assert events[1].data.contents[0].type == "function_call" func_call = events[1].data.contents[0] assert func_call.call_id == "call_123" assert func_call.name == "search" # Third event: function result assert events[2].data is not None - assert isinstance(events[2].data.contents[0], FunctionResultContent) + assert events[2].data.contents[0].type == "function_result" func_result = events[2].data.contents[0] assert func_result.call_id == "call_123" # Fourth event: final text assert events[3].data is not None - assert isinstance(events[3].data.contents[0], TextContent) + assert events[3].data.contents[0].type == "text" assert "sunny" in events[3].data.contents[0].text @@ -161,10 +158,10 @@ async def get_response( messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}' ), - FunctionCallContent( + Content.from_function_call( call_id="2", name="mock_tool_requiring_approval", arguments='{"query": "test"}' ), ], @@ -175,7 +172,7 @@ async def get_response( messages=ChatMessage( role="assistant", contents=[ - FunctionCallContent( + Content.from_function_call( call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}' ) ], @@ -196,10 +193,10 @@ async def get_streaming_response( if self._parallel_request: yield ChatResponseUpdate( contents=[ - FunctionCallContent( + Content.from_function_call( call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}' ), - FunctionCallContent( + Content.from_function_call( call_id="2", name="mock_tool_requiring_approval", arguments='{"query": "test"}' ), ], @@ -208,15 +205,15 @@ async def get_streaming_response( else: yield ChatResponseUpdate( contents=[ - FunctionCallContent( + Content.from_function_call( call_id="1", name="mock_tool_requiring_approval", arguments='{"query": "test"}' ) ], role="assistant", ) else: - yield ChatResponseUpdate(text=TextContent(text="Tool executed "), role="assistant") - yield ChatResponseUpdate(contents=[TextContent(text="successfully.")], role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="Tool executed "), role="assistant") + yield ChatResponseUpdate(contents=[Content.from_text(text="successfully.")], role="assistant") self._iteration += 1 @@ -243,12 +240,14 @@ async def test_agent_executor_tool_call_with_approval() -> None: # Assert assert len(events.get_request_info_events()) == 1 approval_request = events.get_request_info_events()[0] - assert isinstance(approval_request.data, FunctionApprovalRequestContent) + assert approval_request.data.type == "function_approval_request" assert approval_request.data.function_call.name == "mock_tool_requiring_approval" assert approval_request.data.function_call.arguments == '{"query": "test"}' # Act - events = await workflow.send_responses({approval_request.request_id: approval_request.data.create_response(True)}) + events = await workflow.send_responses({ + approval_request.request_id: approval_request.data.to_function_approval_response(True) + }) # Assert final_response = events.get_outputs() @@ -276,14 +275,14 @@ async def test_agent_executor_tool_call_with_approval_streaming() -> None: # Assert assert len(request_info_events) == 1 approval_request = request_info_events[0] - assert isinstance(approval_request.data, FunctionApprovalRequestContent) + assert approval_request.data.type == "function_approval_request" assert approval_request.data.function_call.name == "mock_tool_requiring_approval" assert approval_request.data.function_call.arguments == '{"query": "test"}' # Act output: str | None = None async for event in workflow.send_responses_streaming({ - approval_request.request_id: approval_request.data.create_response(True) + approval_request.request_id: approval_request.data.to_function_approval_response(True) }): if isinstance(event, WorkflowOutputEvent): output = event.data @@ -310,13 +309,13 @@ async def test_agent_executor_parallel_tool_call_with_approval() -> None: # Assert assert len(events.get_request_info_events()) == 2 for approval_request in events.get_request_info_events(): - assert isinstance(approval_request.data, FunctionApprovalRequestContent) + assert approval_request.data.type == "function_approval_request" assert approval_request.data.function_call.name == "mock_tool_requiring_approval" assert approval_request.data.function_call.arguments == '{"query": "test"}' # Act responses = { - approval_request.request_id: approval_request.data.create_response(True) # type: ignore + approval_request.request_id: approval_request.data.to_function_approval_response(True) # type: ignore for approval_request in events.get_request_info_events() } events = await workflow.send_responses(responses) @@ -347,13 +346,13 @@ async def test_agent_executor_parallel_tool_call_with_approval_streaming() -> No # Assert assert len(request_info_events) == 2 for approval_request in request_info_events: - assert isinstance(approval_request.data, FunctionApprovalRequestContent) + assert approval_request.data.type == "function_approval_request" assert approval_request.data.function_call.name == "mock_tool_requiring_approval" assert approval_request.data.function_call.arguments == '{"query": "test"}' # Act responses = { - approval_request.request_id: approval_request.data.create_response(True) # type: ignore + approval_request.request_id: approval_request.data.to_function_approval_response(True) # type: ignore for approval_request in request_info_events } diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index b1a3194468..9a8f4bd9c9 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -14,10 +14,10 @@ AgentThread, BaseAgent, ChatMessage, + Content, Executor, Role, SequentialBuilder, - TextContent, WorkflowBuilder, WorkflowContext, WorkflowRunState, @@ -50,7 +50,7 @@ async def run_stream( # type: ignore[override] **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: # This agent does not support streaming; yield a single complete response - yield AgentResponseUpdate(contents=[TextContent(text=self._reply_text)]) + yield AgentResponseUpdate(contents=[Content.from_text(text=self._reply_text)]) class _CaptureFullConversation(Executor): @@ -136,7 +136,7 @@ async def run_stream( # type: ignore[override] elif isinstance(m, str): norm.append(ChatMessage(role=Role.USER, text=m)) self._last_messages = norm - yield AgentResponseUpdate(contents=[TextContent(text=self._reply_text)]) + yield AgentResponseUpdate(contents=[Content.from_text(text=self._reply_text)]) async def test_sequential_adapter_uses_full_conversation() -> None: diff --git a/python/packages/core/tests/workflow/test_group_chat.py b/python/packages/core/tests/workflow/test_group_chat.py index c65f19d599..330a887a30 100644 --- a/python/packages/core/tests/workflow/test_group_chat.py +++ b/python/packages/core/tests/workflow/test_group_chat.py @@ -17,6 +17,7 @@ ChatMessage, ChatResponse, ChatResponseUpdate, + Content, GroupChatBuilder, GroupChatState, MagenticContext, @@ -25,7 +26,6 @@ MagenticProgressLedgerItem, RequestInfoEvent, Role, - TextContent, WorkflowOutputEvent, WorkflowRunState, WorkflowStatusEvent, @@ -57,7 +57,7 @@ def run_stream( # type: ignore[override] ) -> AsyncIterable[AgentResponseUpdate]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[TextContent(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name + contents=[Content.from_text(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name ) return _stream() @@ -141,7 +141,7 @@ def run_stream( async def _stream_initial() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( contents=[ - TextContent( + Content.from_text( text=( '{"terminate": false, "reason": "Selecting agent", ' '"next_speaker": "agent", "final_message": null}' @@ -157,7 +157,7 @@ async def _stream_initial() -> AsyncIterable[AgentResponseUpdate]: async def _stream_final() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( contents=[ - TextContent( + Content.from_text( text=( '{"terminate": true, "reason": "Task complete", ' '"next_speaker": null, "final_message": "agent manager final"}' diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index 268f89d513..83531139c7 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -11,12 +11,11 @@ ChatMessage, ChatResponse, ChatResponseUpdate, - FunctionCallContent, + Content, HandoffAgentUserRequest, HandoffBuilder, RequestInfoEvent, Role, - TextContent, WorkflowEvent, WorkflowOutputEvent, resolve_agent_id, @@ -74,14 +73,16 @@ def _build_reply_contents( agent_name: str, handoff_to: str | None, call_id: str | None, -) -> list[TextContent | FunctionCallContent]: - contents: list[TextContent | FunctionCallContent] = [] +) -> list[Content]: + contents: list[Content] = [] if handoff_to and call_id: contents.append( - FunctionCallContent(call_id=call_id, name=f"handoff_to_{handoff_to}", arguments={"handoff_to": handoff_to}) + Content.from_function_call( + call_id=call_id, name=f"handoff_to_{handoff_to}", arguments={"handoff_to": handoff_to} + ) ) text = f"{agent_name} reply" - contents.append(TextContent(text=text)) + contents.append(Content.from_text(text=text)) return contents diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 7e4a5bb48e..24c830968f 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -15,6 +15,7 @@ AgentThread, BaseAgent, ChatMessage, + Content, Executor, GroupChatRequestMessage, MagenticBuilder, @@ -28,7 +29,6 @@ RequestInfoEvent, Role, StandardMagenticManager, - TextContent, Workflow, WorkflowCheckpoint, WorkflowCheckpointException, @@ -172,7 +172,7 @@ def run_stream( # type: ignore[override] ) -> AsyncIterable[AgentResponseUpdate]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[TextContent(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name + contents=[Content.from_text(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name ) return _stream() @@ -541,7 +541,7 @@ def __init__(self, name: str | None = None) -> None: async def run_stream(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] yield AgentResponseUpdate( - contents=[TextContent(text="thread-ok")], + contents=[Content.from_text(text="thread-ok")], author_name=self.name, role=Role.ASSISTANT, ) @@ -563,7 +563,7 @@ def __init__(self) -> None: async def run_stream(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] yield AgentResponseUpdate( - contents=[TextContent(text="assistants-ok")], + contents=[Content.from_text(text="assistants-ok")], author_name=self.name, role=Role.ASSISTANT, ) diff --git a/python/packages/core/tests/workflow/test_sequential.py b/python/packages/core/tests/workflow/test_sequential.py index d104eb8a02..a685db73db 100644 --- a/python/packages/core/tests/workflow/test_sequential.py +++ b/python/packages/core/tests/workflow/test_sequential.py @@ -12,10 +12,10 @@ AgentThread, BaseAgent, ChatMessage, + Content, Executor, Role, SequentialBuilder, - TextContent, TypeCompatibilityError, WorkflowContext, WorkflowOutputEvent, @@ -46,7 +46,7 @@ async def run_stream( # type: ignore[override] **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: # Minimal async generator with one assistant update - yield AgentResponseUpdate(contents=[TextContent(text=f"{self.name} reply")]) + yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} reply")]) class _SummarizerExec(Executor): diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 5a0f54a24e..6b08b7b22a 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -18,12 +18,12 @@ AgentThread, BaseAgent, ChatMessage, + Content, Executor, FileCheckpointStorage, Message, RequestInfoEvent, Role, - TextContent, WorkflowBuilder, WorkflowCheckpointException, WorkflowContext, @@ -881,7 +881,7 @@ async def run_stream( """Streaming run - yields incremental updates.""" # Simulate streaming by yielding character by character for char in self._reply_text: - yield AgentResponseUpdate(contents=[TextContent(text=char)]) + yield AgentResponseUpdate(contents=[Content.from_text(text=char)]) async def test_agent_streaming_vs_non_streaming() -> None: diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 7e47a82c9c..58e5c7e20f 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -14,16 +14,9 @@ AgentThread, ChatMessage, ChatMessageStore, - DataContent, + Content, Executor, - FunctionApprovalRequestContent, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, Role, - TextContent, - UriContent, - UsageContent, UsageDetails, WorkflowAgent, WorkflowBuilder, @@ -44,17 +37,15 @@ def __init__(self, id: str, response_text: str, emit_streaming: bool = False): @handler async def handle_message(self, message: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: - input_text = ( - message[0].contents[0].text if message and isinstance(message[0].contents[0], TextContent) else "no input" - ) + input_text = message[0].contents[0].text if message and message[0].contents[0].type == "text" else "no input" response_text = f"{self.response_text}: {input_text}" # Create response message for both streaming and non-streaming cases - response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]) + response_message = ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text=response_text)]) # Emit update event. streaming_update = AgentResponseUpdate( - contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) + contents=[Content.from_text(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) ) await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update)) @@ -76,7 +67,7 @@ async def handle_request_response( ) -> None: # Handle the response and emit completion response update = AgentResponseUpdate( - contents=[TextContent(text="Request completed successfully")], + contents=[Content.from_text(text="Request completed successfully")], role=Role.ASSISTANT, message_id=str(uuid.uuid4()), ) @@ -99,10 +90,10 @@ async def handle_message(self, messages: list[ChatMessage], ctx: WorkflowContext message_count = len(messages) response_text = f"Received {message_count} messages" - response_message = ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]) + response_message = ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text=response_text)]) streaming_update = AgentResponseUpdate( - contents=[TextContent(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) + contents=[Content.from_text(text=response_text)], role=Role.ASSISTANT, message_id=str(uuid.uuid4()) ) await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=streaming_update)) await ctx.send_message([response_message]) @@ -134,7 +125,7 @@ async def test_end_to_end_basic_workflow(self): for message in result.messages: first_content = message.contents[0] - if isinstance(first_content, TextContent): + if first_content.type == "text": text = first_content.text if text.startswith("Step1:"): step1_messages.append(message) @@ -172,11 +163,11 @@ async def test_end_to_end_basic_workflow_streaming(self): # Verify we got a streaming update assert updates[0].contents is not None - first_content: TextContent = updates[0].contents[0] # type: ignore[assignment] - second_content: TextContent = updates[1].contents[0] # type: ignore[assignment] - assert isinstance(first_content, TextContent) + first_content: Content = updates[0].contents[0] # type: ignore[assignment] + second_content: Content = updates[1].contents[0] # type: ignore[assignment] + assert first_content.type == "text" assert "Streaming1: Test input" in first_content.text - assert isinstance(second_content, TextContent) + assert second_content.type == "text" assert "Streaming2: Streaming1: Test input" in second_content.text async def test_end_to_end_request_info_handling(self): @@ -200,17 +191,15 @@ async def test_end_to_end_request_info_handling(self): approval_update: AgentResponseUpdate | None = None for update in updates: - if any(isinstance(content, FunctionApprovalRequestContent) for content in update.contents): + if any(content.type == "function_approval_request" for content in update.contents): approval_update = update break assert approval_update is not None, "Should have received a request_info approval request" - function_call = next( - content for content in approval_update.contents if isinstance(content, FunctionCallContent) - ) + function_call = next(content for content in approval_update.contents if content.type == "function_call") approval_request = next( - content for content in approval_update.contents if isinstance(content, FunctionApprovalRequestContent) + content for content in approval_update.contents if content.type == "function_approval_request" ) # Verify the function call has expected structure @@ -233,10 +222,10 @@ async def test_end_to_end_request_info_handling(self): data="User provided answer", ).to_dict() - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( approved=True, id=approval_request.id, - function_call=FunctionCallContent( + function_call=Content.from_function_call( call_id=function_call.call_id, name=function_call.name, arguments=response_args, @@ -306,7 +295,7 @@ async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) - workflow = WorkflowBuilder().set_start_executor(yielding_executor).build() # Run directly - should return WorkflowOutputEvent in result - direct_result = await workflow.run([ChatMessage(role=Role.USER, contents=[TextContent(text="hello")])]) + direct_result = await workflow.run([ChatMessage(role=Role.USER, contents=[Content.from_text(text="hello")])]) direct_outputs = direct_result.get_outputs() assert len(direct_outputs) == 1 assert direct_outputs[0] == "processed: hello" @@ -340,14 +329,14 @@ async def yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) - assert "second output" in texts async def test_workflow_as_agent_yield_output_with_content_types(self) -> None: - """Test that yield_output preserves different content types (TextContent, DataContent, etc.).""" + """Test that yield_output preserves different content types (Content, Content, etc.).""" @executor async def content_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: # Yield different content types - await ctx.yield_output(TextContent(text="text content")) - await ctx.yield_output(DataContent(data=b"binary data", media_type="application/octet-stream")) - await ctx.yield_output(UriContent(uri="https://example.com/image.png", media_type="image/png")) + await ctx.yield_output(Content.from_text(text="text content")) + await ctx.yield_output(Content.from_data(data=b"binary data", media_type="application/octet-stream")) + await ctx.yield_output(Content.from_uri(uri="https://example.com/image.png", media_type="image/png")) workflow = WorkflowBuilder().set_start_executor(content_yielding_executor).build() agent = workflow.as_agent("content-test-agent") @@ -358,13 +347,13 @@ async def content_yielding_executor(messages: list[ChatMessage], ctx: WorkflowCo assert len(result.messages) == 3 # Verify each content type is preserved - assert isinstance(result.messages[0].contents[0], TextContent) + assert result.messages[0].contents[0].type == "text" assert result.messages[0].contents[0].text == "text content" - assert isinstance(result.messages[1].contents[0], DataContent) + assert result.messages[1].contents[0].type == "data" assert result.messages[1].contents[0].media_type == "application/octet-stream" - assert isinstance(result.messages[2].contents[0], UriContent) + assert result.messages[2].contents[0].type == "uri" assert result.messages[2].contents[0].uri == "https://example.com/image.png" async def test_workflow_as_agent_yield_output_with_chat_message(self) -> None: @@ -374,7 +363,7 @@ async def test_workflow_as_agent_yield_output_with_chat_message(self) -> None: async def chat_message_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: msg = ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="response text")], + contents=[Content.from_text(text="response text")], author_name="custom-author", ) await ctx.yield_output(msg) @@ -404,7 +393,7 @@ def __str__(self) -> str: async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: # Yield different types of data await ctx.yield_output("simple string") - await ctx.yield_output(TextContent(text="text content")) + await ctx.yield_output(Content.from_text(text="text content")) custom = CustomData(42) await ctx.yield_output(custom) @@ -420,7 +409,7 @@ async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContex # Verify raw_representation is set for each update assert updates[0].raw_representation == "simple string" - assert isinstance(updates[1].raw_representation, TextContent) + assert updates[1].raw_representation.type == "text" assert updates[1].raw_representation.text == "text content" assert isinstance(updates[2].raw_representation, CustomData) assert updates[2].raw_representation.value == 42 @@ -428,19 +417,19 @@ async def raw_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContex async def test_workflow_as_agent_yield_output_with_list_of_chat_messages(self) -> None: """Test that yield_output with list[ChatMessage] extracts contents from all messages. - Note: TextContent items are coalesced by _finalize_response, so multiple text contents - become a single merged TextContent in the final response. + Note: Content items are coalesced by _finalize_response, so multiple text contents + become a single merged Content in the final response. """ @executor async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowContext) -> None: # Yield a list of ChatMessages (as SequentialBuilder does) msg_list = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="first message")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="second message")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="first message")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="second message")]), ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="third"), TextContent(text="fourth")], + contents=[Content.from_text(text="third"), Content.from_text(text="fourth")], ), ] await ctx.yield_output(msg_list) @@ -455,7 +444,7 @@ async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowConte assert len(updates) == 1 assert len(updates[0].contents) == 4 - texts = [c.text for c in updates[0].contents if isinstance(c, TextContent)] + texts = [c.text for c in updates[0].contents if c.type == "text"] assert texts == ["first message", "second message", "third", "fourth"] # Verify run() coalesces text contents (expected behavior) @@ -463,7 +452,7 @@ async def list_yielding_executor(messages: list[ChatMessage], ctx: WorkflowConte assert isinstance(result, AgentResponse) assert len(result.messages) == 1 - # TextContent items are coalesced into one + # Content items are coalesced into one assert len(result.messages[0].contents) == 1 assert result.messages[0].text == "first messagesecond messagethirdfourth" @@ -599,7 +588,7 @@ async def run_stream( ) -> AsyncIterable[AgentResponseUpdate]: for word in self._response_text.split(): yield AgentResponseUpdate( - contents=[TextContent(text=word + " ")], + contents=[Content.from_text(text=word + " ")], role=Role.ASSISTANT, author_name=self._name, ) @@ -672,7 +661,7 @@ async def run_stream( self, messages: Any, *, thread: AgentThread | None = None, **kwargs: Any ) -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[TextContent(text=self._response_text)], + contents=[Content.from_text(text=self._response_text)], role=Role.ASSISTANT, author_name=self._name, ) @@ -738,7 +727,7 @@ class AuthorNameExecutor(Executor): async def handle_message(self, message: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: # Emit update with explicit author_name update = AgentResponseUpdate( - contents=[TextContent(text="Response with author")], + contents=[Content.from_text(text="Response with author")], role=Role.ASSISTANT, author_name="custom_author_name", # Explicitly set message_id=str(uuid.uuid4()), @@ -790,7 +779,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): updates = [ # Response B, Message 2 (latest in resp B) AgentResponseUpdate( - contents=[TextContent(text="RespB-Msg2")], + contents=[Content.from_text(text="RespB-Msg2")], role=Role.ASSISTANT, response_id="resp-b", message_id="msg-2", @@ -798,7 +787,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): ), # Response A, Message 1 (earliest overall) AgentResponseUpdate( - contents=[TextContent(text="RespA-Msg1")], + contents=[Content.from_text(text="RespA-Msg1")], role=Role.ASSISTANT, response_id="resp-a", message_id="msg-1", @@ -806,7 +795,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): ), # Response B, Message 1 (earlier in resp B) AgentResponseUpdate( - contents=[TextContent(text="RespB-Msg1")], + contents=[Content.from_text(text="RespB-Msg1")], role=Role.ASSISTANT, response_id="resp-b", message_id="msg-1", @@ -814,7 +803,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): ), # Response A, Message 2 (later in resp A) AgentResponseUpdate( - contents=[TextContent(text="RespA-Msg2")], + contents=[Content.from_text(text="RespA-Msg2")], role=Role.ASSISTANT, response_id="resp-a", message_id="msg-2", @@ -822,7 +811,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): ), # Global dangling update (no response_id) - should go at end AgentResponseUpdate( - contents=[TextContent(text="Global-Dangling")], + contents=[Content.from_text(text="Global-Dangling")], role=Role.ASSISTANT, response_id=None, message_id="msg-global", @@ -841,9 +830,7 @@ def test_merge_updates_ordering_by_response_and_message_id(self): # Verify ordering: responses are processed by response_id groups, # within each group messages are chronologically ordered, # global dangling goes at the end - message_texts = [ - msg.contents[0].text if isinstance(msg.contents[0], TextContent) else "" for msg in result.messages - ] + message_texts = [msg.contents[0].text if msg.contents[0].type == "text" else "" for msg in result.messages] # The exact order depends on dict iteration order for response_ids, # but within each response group, chronological order should be maintained @@ -894,9 +881,9 @@ def test_merge_updates_metadata_aggregation(self): updates = [ AgentResponseUpdate( contents=[ - TextContent(text="First"), - UsageContent( - details=UsageDetails(input_token_count=10, output_token_count=5, total_token_count=15) + Content.from_text(text="First"), + Content.from_usage( + usage_details={"input_token_count": 10, "output_token_count": 5, "total_token_count": 15} ), ], role=Role.ASSISTANT, @@ -907,9 +894,9 @@ def test_merge_updates_metadata_aggregation(self): ), AgentResponseUpdate( contents=[ - TextContent(text="Second"), - UsageContent( - details=UsageDetails(input_token_count=20, output_token_count=8, total_token_count=28) + Content.from_text(text="Second"), + Content.from_usage( + usage_details={"input_token_count": 20, "output_token_count": 8, "total_token_count": 28} ), ], role=Role.ASSISTANT, @@ -920,8 +907,10 @@ def test_merge_updates_metadata_aggregation(self): ), AgentResponseUpdate( contents=[ - TextContent(text="Third"), - UsageContent(details=UsageDetails(input_token_count=5, output_token_count=3, total_token_count=8)), + Content.from_text(text="Third"), + Content.from_usage( + usage_details={"input_token_count": 5, "output_token_count": 3, "total_token_count": 8} + ), ], role=Role.ASSISTANT, response_id="resp-1", # Same response_id as first diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index 75c34f9d95..14ec9f43ec 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -12,12 +12,12 @@ BaseAgent, ChatMessage, ConcurrentBuilder, + Content, GroupChatBuilder, GroupChatState, HandoffBuilder, Role, SequentialBuilder, - TextContent, WorkflowRunState, WorkflowStatusEvent, ai_function, @@ -67,7 +67,7 @@ async def run_stream( **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: self.captured_kwargs.append(dict(kwargs)) - yield AgentResponseUpdate(contents=[TextContent(text=f"{self.name} response")]) + yield AgentResponseUpdate(contents=[Content.from_text(text=f"{self.name} response")]) # region Sequential Builder Tests diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index 0e9f4dc07f..70cc623dc4 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -10,11 +10,9 @@ ChatAgent, ChatClientProtocol, HostedCodeInterpreterTool, - HostedFileContent, HostedFileSearchTool, HostedMCPSpecificApproval, HostedMCPTool, - HostedVectorStoreContent, HostedWebSearchTool, ToolProtocol, ) @@ -739,14 +737,14 @@ def _parse_tool(self, tool_resource: Tool) -> ToolProtocol: if tool_resource.filters: add_props["filters"] = tool_resource.filters return HostedFileSearchTool( - inputs=[HostedVectorStoreContent(id) for id in tool_resource.vectorStoreIds or []], + inputs=[Content.from_hosted_vector_store(id) for id in tool_resource.vectorStoreIds or []], description=tool_resource.description, max_results=tool_resource.maximumResultCount, additional_properties=add_props, ) case CodeInterpreterTool(): return HostedCodeInterpreterTool( - inputs=[HostedFileContent(file_id=file) for file in tool_resource.fileIds or []], + inputs=[Content.from_hosted_file(file_id=file) for file in tool_resource.fileIds or []], description=tool_resource.description, ) case McpTool(): diff --git a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py index 669f662a9b..18685ef401 100644 --- a/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py +++ b/python/packages/declarative/agent_framework_declarative/_workflows/_executors_agents.py @@ -21,8 +21,7 @@ from agent_framework import ( ChatMessage, - FunctionCallContent, - FunctionResultContent, + Content, WorkflowContext, handler, response_handler, @@ -191,7 +190,7 @@ def _validate_conversation_history(messages: list[ChatMessage], agent_name: str) if not hasattr(msg, "contents") or msg.contents is None: continue for content in msg.contents: - if isinstance(content, FunctionCallContent) and content.call_id: + if content.type == "function_call" and content.call_id: tool_call_ids.add(content.call_id) logger.debug( "Agent '%s': Found tool call '%s' (id=%s) in message %d", @@ -200,7 +199,7 @@ def _validate_conversation_history(messages: list[ChatMessage], agent_name: str) content.call_id, i, ) - elif isinstance(content, FunctionResultContent) and content.call_id: + elif content.type == "function_result" and content.call_id: tool_result_ids.add(content.call_id) logger.debug( "Agent '%s': Found tool result for call_id=%s in message %d", @@ -265,7 +264,7 @@ class AgentResult: response: str agent_name: str messages: list[ChatMessage] = field(default_factory=lambda: cast(list[ChatMessage], [])) - tool_calls: list[FunctionCallContent] = field(default_factory=lambda: cast(list[FunctionCallContent], [])) + tool_calls: list[Content] = field(default_factory=lambda: cast(list[Content], [])) error: str | None = None @@ -311,7 +310,7 @@ async def on_request(request: AgentExternalInputRequest) -> ExternalInputRespons agent_response: str iteration: int = 0 messages: list[ChatMessage] = field(default_factory=lambda: cast(list[ChatMessage], [])) - function_calls: list[FunctionCallContent] = field(default_factory=lambda: cast(list[FunctionCallContent], [])) + function_calls: list[Content] = field(default_factory=lambda: cast(list[Content], [])) @dataclass @@ -342,9 +341,7 @@ class AgentExternalInputResponse: user_input: str messages: list[ChatMessage] = field(default_factory=lambda: cast(list[ChatMessage], [])) - function_results: dict[str, FunctionResultContent] = field( - default_factory=lambda: cast(dict[str, FunctionResultContent], {}) - ) + function_results: dict[str, Content] = field(default_factory=lambda: cast(dict[str, Content], {})) @dataclass @@ -641,7 +638,7 @@ async def _invoke_agent_and_store_results( """ accumulated_response = "" all_messages: list[ChatMessage] = [] - tool_calls: list[FunctionCallContent] = [] + tool_calls: list[Content] = [] # Add user input to conversation history first (via state.append only) if input_text: @@ -679,7 +676,7 @@ async def _invoke_agent_and_store_results( all_messages = list(cast(list[ChatMessage], result_messages)) result_tool_calls: Any = getattr(result, "tool_calls", None) if result_tool_calls is not None: - tool_calls = list(cast(list[FunctionCallContent], result_tool_calls)) + tool_calls = list(cast(list[Content], result_tool_calls)) else: raise RuntimeError(f"Agent '{agent_name}' has no run or run_stream method") diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index 86db2172e1..868ca3e162 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -321,7 +321,7 @@ async def add_items(self, conversation_id: str, items: list[dict[str, Any]]) -> # Convert ChatMessage contents to OpenAI TextContent format message_content = [] for content_item in msg.contents: - if hasattr(content_item, "type") and content_item.type == "text": + if content_item.type == "text": # Extract text from TextContent object text_value = getattr(content_item, "text", "") message_content.append(TextContent(type="text", text=text_value)) diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 585036bef9..5234e37c10 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -7,7 +7,7 @@ from collections.abc import AsyncGenerator from typing import Any -from agent_framework import AgentProtocol +from agent_framework import AgentProtocol, Content from agent_framework._workflows._events import RequestInfoEvent from ._conversations import ConversationStore, InMemoryConversationStore @@ -602,7 +602,7 @@ def _convert_input_to_chat_message(self, input_data: Any) -> Any: """ # Import Agent Framework types try: - from agent_framework import ChatMessage, DataContent, Role, TextContent + from agent_framework import ChatMessage, Content, Role except ImportError: # Fallback to string extraction if Agent Framework not available return self._extract_user_message_fallback(input_data) @@ -613,14 +613,12 @@ def _convert_input_to_chat_message(self, input_data: Any) -> Any: # Handle OpenAI ResponseInputParam (List[ResponseInputItemParam]) if isinstance(input_data, list): - return self._convert_openai_input_to_chat_message(input_data, ChatMessage, TextContent, DataContent, Role) + return self._convert_openai_input_to_chat_message(input_data, ChatMessage, Role) # Fallback for other formats return self._extract_user_message_fallback(input_data) - def _convert_openai_input_to_chat_message( - self, input_items: list[Any], ChatMessage: Any, TextContent: Any, DataContent: Any, Role: Any - ) -> Any: + def _convert_openai_input_to_chat_message(self, input_items: list[Any], ChatMessage: Any, Role: Any) -> Any: """Convert OpenAI ResponseInputParam to Agent Framework ChatMessage. Processes text, images, files, and other content types from OpenAI format @@ -629,14 +627,12 @@ def _convert_openai_input_to_chat_message( Args: input_items: List of OpenAI ResponseInputItemParam objects (dicts or objects) ChatMessage: ChatMessage class for creating chat messages - TextContent: TextContent class for text content - DataContent: DataContent class for data/media content Role: Role enum for message roles Returns: ChatMessage with converted content """ - contents = [] + contents: list[Content] = [] # Process each input item for item in input_items: @@ -649,7 +645,7 @@ def _convert_openai_input_to_chat_message( # Handle both string content and list content if isinstance(message_content, str): - contents.append(TextContent(text=message_content)) + contents.append(Content.from_text(text=message_content)) elif isinstance(message_content, list): for content_item in message_content: # Handle dict content items @@ -658,7 +654,7 @@ def _convert_openai_input_to_chat_message( if content_type == "input_text": text = content_item.get("text", "") - contents.append(TextContent(text=text)) + contents.append(Content.from_text(text=text)) elif content_type == "input_image": image_url = content_item.get("image_url", "") @@ -676,7 +672,7 @@ def _convert_openai_input_to_chat_message( media_type = "image/png" else: media_type = "image/png" - contents.append(DataContent(uri=image_url, media_type=media_type)) + contents.append(Content.from_uri(uri=image_url, media_type=media_type)) elif content_type == "input_file": # Handle file input @@ -710,7 +706,7 @@ def _convert_openai_input_to_chat_message( # Assume file_data is base64, create data URI data_uri = f"data:{media_type};base64,{file_data}" contents.append( - DataContent( + Content.from_uri( uri=data_uri, media_type=media_type, additional_properties=additional_props, @@ -718,7 +714,7 @@ def _convert_openai_input_to_chat_message( ) elif file_url: contents.append( - DataContent( + Content.from_uri( uri=file_url, media_type=media_type, additional_properties=additional_props, @@ -728,21 +724,19 @@ def _convert_openai_input_to_chat_message( elif content_type == "function_approval_response": # Handle function approval response (DevUI extension) try: - from agent_framework import FunctionApprovalResponseContent, FunctionCallContent - request_id = content_item.get("request_id", "") approved = content_item.get("approved", False) function_call_data = content_item.get("function_call", {}) # Create FunctionCallContent from the function_call data - function_call = FunctionCallContent( + function_call = Content.from_function_call( call_id=function_call_data.get("id", ""), name=function_call_data.get("name", ""), arguments=function_call_data.get("arguments", {}), ) # Create FunctionApprovalResponseContent with correct signature - approval_response = FunctionApprovalResponseContent( + approval_response = Content.from_function_approval_response( approved, # positional argument id=request_id, # keyword argument 'id', NOT 'request_id' function_call=function_call, # FunctionCallContent object @@ -764,7 +758,7 @@ def _convert_openai_input_to_chat_message( # If no contents found, create a simple text message if not contents: - contents.append(TextContent(text="")) + contents.append(Content.from_text(text="")) chat_message = ChatMessage(role=Role.USER, contents=contents) diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index 71bbda9b85..acdb939134 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -12,7 +12,7 @@ from typing import Any, Union from uuid import uuid4 -from agent_framework import ChatMessage, TextContent +from agent_framework import ChatMessage from openai.types.responses import ( Response, ResponseContentPartAddedEvent, @@ -92,7 +92,7 @@ def _serialize_content_recursive(value: Any) -> Any: if isinstance(value, (list, tuple)): serialized = [_serialize_content_recursive(item) for item in value] # For single-item lists containing text Content, extract just the text - # This handles the MCP case where result = [TextContent(text="Hello")] + # This handles the MCP case where result = [Content.from_text(text="Hello")] # and we want output = "Hello" not output = '[{"type": "text", "text": "Hello"}]' if len(serialized) == 1 and isinstance(serialized[0], dict) and serialized[0].get("type") == "text": return serialized[0].get("text", "") @@ -127,18 +127,18 @@ def __init__(self, max_contexts: int = 1000) -> None: # Register content type mappers for all 12 Agent Framework content types self.content_mappers = { - "TextContent": self._map_text_content, - "TextReasoningContent": self._map_reasoning_content, - "FunctionCallContent": self._map_function_call_content, - "FunctionResultContent": self._map_function_result_content, - "ErrorContent": self._map_error_content, - "UsageContent": self._map_usage_content, - "DataContent": self._map_data_content, - "UriContent": self._map_uri_content, - "HostedFileContent": self._map_hosted_file_content, - "HostedVectorStoreContent": self._map_hosted_vector_store_content, - "FunctionApprovalRequestContent": self._map_approval_request_content, - "FunctionApprovalResponseContent": self._map_approval_response_content, + "text": self._map_text_content, + "text_reasoning": self._map_reasoning_content, + "function_call": self._map_function_call_content, + "function_result": self._map_function_result_content, + "error": self._map_error_content, + "usage": self._map_usage_content, + "data": self._map_data_content, + "uri": self._map_uri_content, + "hosted_file": self._map_hosted_file_content, + "hosted_vector_store": self._map_hosted_vector_store_content, + "function_approval_request": self._map_approval_request_content, + "function_approval_response": self._map_approval_response_content, } async def convert_event(self, raw_event: Any, request: AgentFrameworkRequest) -> Sequence[Any]: @@ -603,7 +603,7 @@ async def _convert_agent_update(self, update: Any, context: dict[str, Any]) -> S return events # Check if we're streaming text content - has_text_content = any(isinstance(content, TextContent) for content in update.contents) + has_text_content = any(content.type == "text" for content in update.contents) # Check if we're in an executor context with an existing item executor_id = context.get("current_executor_id") @@ -647,10 +647,8 @@ async def _convert_agent_update(self, update: Any, context: dict[str, Any]) -> S # Process each content item for content in update.contents: - content_type = content.__class__.__name__ - # Special handling for TextContent to use proper delta events - if content_type == "TextContent" and "current_message_id" in context: + if content.type == "text" and "current_message_id" in context: # Stream text content via proper delta events events.append( ResponseTextDeltaEvent( @@ -663,9 +661,9 @@ async def _convert_agent_update(self, update: Any, context: dict[str, Any]) -> S sequence_number=self._next_sequence(context), ) ) - elif content_type in self.content_mappers: + elif content.type in self.content_mappers: # Use existing mappers for other content types - mapped_events = await self.content_mappers[content_type](content, context) + mapped_events = await self.content_mappers[content.type](content, context) if mapped_events is not None: # Handle None returns (e.g., UsageContent) if isinstance(mapped_events, list): events.extend(mapped_events) @@ -676,7 +674,7 @@ async def _convert_agent_update(self, update: Any, context: dict[str, Any]) -> S events.append(await self._create_unknown_content_event(content, context)) # Don't increment content_index for text deltas within the same part - if content_type != "TextContent": + if content.type != "text": context["content_index"] = context.get("content_index", 0) + 1 except Exception as e: @@ -708,10 +706,8 @@ async def _convert_agent_response(self, response: Any, context: dict[str, Any]) for message in messages: if hasattr(message, "contents") and message.contents: for content in message.contents: - content_type = content.__class__.__name__ - - if content_type in self.content_mappers: - mapped_events = await self.content_mappers[content_type](content, context) + if content.type in self.content_mappers: + mapped_events = await self.content_mappers[content.type](content, context) if mapped_events is not None: # Handle None returns (e.g., UsageContent) if isinstance(mapped_events, list): events.extend(mapped_events) @@ -726,9 +722,7 @@ async def _convert_agent_response(self, response: Any, context: dict[str, Any]) # Add usage information if present usage_details = getattr(response, "usage_details", None) if usage_details: - from agent_framework import UsageContent - - usage_content = UsageContent(details=usage_details) + usage_content = Content.from_usage(details=usage_details) await self._map_usage_content(usage_content, context) # Note: _map_usage_content returns None - it accumulates usage for final Response.usage @@ -1421,8 +1415,8 @@ async def _map_usage_content(self, content: Any, context: dict[str, Any]) -> Non Returns: None - no event emitted (usage goes in final Response.usage) """ - # Extract usage from UsageContent.details (UsageDetails object) - details = getattr(content, "details", None) + # Extract usage from UsageContent.usage_details (UsageDetails object) + details = getattr(content, "usage_details", None) total_tokens = getattr(details, "total_token_count", 0) or 0 prompt_tokens = getattr(details, "input_token_count", 0) or 0 completion_tokens = getattr(details, "output_token_count", 0) or 0 diff --git a/python/packages/devui/frontend/src/types/agent-framework.ts b/python/packages/devui/frontend/src/types/agent-framework.ts index a41b0ce9a1..5e26580d5f 100644 --- a/python/packages/devui/frontend/src/types/agent-framework.ts +++ b/python/packages/devui/frontend/src/types/agent-framework.ts @@ -187,7 +187,7 @@ export interface HostedVectorStoreContent extends BaseContent { } // Union type for all content -export type Contents = +export type Content = | TextContent | FunctionCallContent | FunctionResultContent @@ -209,7 +209,7 @@ export interface UsageDetails { // Agent run response update (streaming) export interface AgentResponseUpdate { - contents: Contents[]; + contents: Content[]; role?: Role; author_name?: string; response_id?: string; @@ -233,7 +233,7 @@ export interface AgentResponse { // Chat message export interface ChatMessage { - contents: Contents[]; + contents: Content[]; role?: Role; author_name?: string; message_id?: string; @@ -244,7 +244,7 @@ export interface ChatMessage { // Chat response update (model client streaming) export interface ChatResponseUpdate { - contents: Contents[]; + contents: Content[]; role?: Role; author_name?: string; response_id?: string; @@ -330,18 +330,18 @@ export interface TraceSpan { } // Helper type guards for Agent Framework content types -export function isTextContent(content: Contents): content is TextContent { +export function isTextContent(content: Content): content is TextContent { return content.type === "text"; } export function isFunctionCallContent( - content: Contents + content: Content ): content is FunctionCallContent { return content.type === "function_call"; } export function isFunctionResultContent( - content: Contents + content: Content ): content is FunctionResultContent { return content.type === "function_result"; } diff --git a/python/packages/devui/frontend/src/types/index.ts b/python/packages/devui/frontend/src/types/index.ts index 029a671a6b..3cbc471403 100644 --- a/python/packages/devui/frontend/src/types/index.ts +++ b/python/packages/devui/frontend/src/types/index.ts @@ -188,7 +188,7 @@ export interface MetaResponse { export interface ChatMessage { id: string; role: "user" | "assistant" | "system" | "tool"; - contents: import("./agent-framework").Contents[]; + contents: import("./agent-framework").Content[]; timestamp: string; streaming?: boolean; author_name?: string; diff --git a/python/packages/devui/tests/test_cleanup_hooks.py b/python/packages/devui/tests/test_cleanup_hooks.py index 2d6fd7b614..e821779686 100644 --- a/python/packages/devui/tests/test_cleanup_hooks.py +++ b/python/packages/devui/tests/test_cleanup_hooks.py @@ -7,7 +7,7 @@ from pathlib import Path import pytest -from agent_framework import AgentResponse, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, ChatMessage, Content, Role from agent_framework_devui import register_cleanup from agent_framework_devui._discovery import EntityDiscovery @@ -36,7 +36,7 @@ def __init__(self, name: str = "TestAgent"): async def run_stream(self, messages=None, *, thread=None, **kwargs): """Mock streaming run method.""" yield AgentResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Test response")])], + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Test response")])], ) @@ -259,7 +259,7 @@ async def test_cleanup_with_file_based_discovery(): # Write agent module with cleanup registration agent_file = agent_dir / "__init__.py" agent_file.write_text(""" -from agent_framework import AgentResponse, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, ChatMessage, Role, Content from agent_framework_devui import register_cleanup class MockCredential: @@ -279,7 +279,7 @@ class TestAgent: async def run_stream(self, messages=None, *, thread=None, **kwargs): yield AgentResponse( - messages=[ChatMessage(role=Role.ASSISTANT, content=[TextContent(text="Test")])], + messages=[ChatMessage(role=Role.ASSISTANT, content=[Content.from_text(text="Test")])], inner_messages=[], ) diff --git a/python/packages/devui/tests/test_discovery.py b/python/packages/devui/tests/test_discovery.py index 72e534b012..f2b321d75c 100644 --- a/python/packages/devui/tests/test_discovery.py +++ b/python/packages/devui/tests/test_discovery.py @@ -84,7 +84,7 @@ async def test_discovery_accepts_agents_with_only_run(): init_file = agent_dir / "__init__.py" init_file.write_text(""" -from agent_framework import AgentResponse, AgentThread, ChatMessage, Role, TextContent +from agent_framework import AgentResponse, AgentThread, ChatMessage, Role, Content class NonStreamingAgent: id = "non_streaming" @@ -95,7 +95,7 @@ async def run(self, messages=None, *, thread=None, **kwargs): return AgentResponse( messages=[ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="response")] + contents=[Content.from_text(text="response")] )], response_id="test" ) @@ -210,7 +210,7 @@ class TestAgent: async def run(self, messages=None, *, thread=None, **kwargs): return AgentResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="test")])], + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="test")])], response_id="test" ) diff --git a/python/packages/devui/tests/test_execution.py b/python/packages/devui/tests/test_execution.py index 15cb9bf4df..e367782597 100644 --- a/python/packages/devui/tests/test_execution.py +++ b/python/packages/devui/tests/test_execution.py @@ -566,7 +566,7 @@ def test_extract_workflow_hil_responses_handles_stringified_json(): async def test_executor_handles_non_streaming_agent(): """Test executor can handle agents with only run() method (no run_stream).""" - from agent_framework import AgentResponse, AgentThread, ChatMessage, Role, TextContent + from agent_framework import AgentResponse, AgentThread, ChatMessage, Content, Role class NonStreamingAgent: """Agent with only run() method - does NOT satisfy full AgentProtocol.""" @@ -577,7 +577,9 @@ class NonStreamingAgent: async def run(self, messages=None, *, thread=None, **kwargs): return AgentResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=f"Processed: {messages}")])], + messages=[ + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text=f"Processed: {messages}")]) + ], response_id="test_123", ) diff --git a/python/packages/devui/tests/test_helpers.py b/python/packages/devui/tests/test_helpers.py index 1385dc867d..130ab475d9 100644 --- a/python/packages/devui/tests/test_helpers.py +++ b/python/packages/devui/tests/test_helpers.py @@ -28,11 +28,9 @@ ChatResponse, ChatResponseUpdate, ConcurrentBuilder, - FunctionCallContent, - FunctionResultContent, + Content, Role, SequentialBuilder, - TextContent, use_chat_middleware, ) from agent_framework._clients import TOptions_co @@ -93,7 +91,7 @@ async def get_streaming_response( for update in self.streaming_responses.pop(0): yield update else: - yield ChatResponseUpdate(text=TextContent(text="test streaming response"), role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="test streaming response"), role="assistant") @use_chat_middleware @@ -141,10 +139,10 @@ async def _inner_get_streaming_response( yield update else: # Simulate realistic streaming chunks - yield ChatResponseUpdate(text=TextContent(text="Mock "), role="assistant") - yield ChatResponseUpdate(text=TextContent(text="streaming "), role="assistant") - yield ChatResponseUpdate(text=TextContent(text="response "), role="assistant") - yield ChatResponseUpdate(text=TextContent(text="from ChatAgent"), role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="Mock "), role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="streaming "), role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="response "), role="assistant") + yield ChatResponseUpdate(text=Content.from_text(text="from ChatAgent"), role="assistant") # ============================================================================= @@ -175,7 +173,7 @@ async def run( ) -> AgentResponse: self.call_count += 1 return AgentResponse( - messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=self.response_text)])] + messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text=self.response_text)])] ) async def run_stream( @@ -187,7 +185,7 @@ async def run_stream( ) -> AsyncIterable[AgentResponseUpdate]: self.call_count += 1 for chunk in self.streaming_chunks: - yield AgentResponseUpdate(contents=[TextContent(text=chunk)], role=Role.ASSISTANT) + yield AgentResponseUpdate(contents=[Content.from_text(text=chunk)], role=Role.ASSISTANT) class MockToolCallingAgent(BaseAgent): @@ -217,13 +215,13 @@ async def run_stream( self.call_count += 1 # First: text yield AgentResponseUpdate( - contents=[TextContent(text="Let me search for that...")], + contents=[Content.from_text(text="Let me search for that...")], role=Role.ASSISTANT, ) # Second: tool call yield AgentResponseUpdate( contents=[ - FunctionCallContent( + Content.from_function_call( call_id="call_123", name="search", arguments={"query": "weather"}, @@ -234,7 +232,7 @@ async def run_stream( # Third: tool result yield AgentResponseUpdate( contents=[ - FunctionResultContent( + Content.from_function_result( call_id="call_123", result={"temperature": 72, "condition": "sunny"}, ) @@ -243,7 +241,7 @@ async def run_stream( ) # Fourth: final text yield AgentResponseUpdate( - contents=[TextContent(text="The weather is sunny, 72°F.")], + contents=[Content.from_text(text="The weather is sunny, 72°F.")], role=Role.ASSISTANT, ) @@ -297,7 +295,7 @@ def create_mock_tool_agent(id: str = "tool_agent", name: str = "ToolAgent") -> M def create_agent_run_response(text: str = "Test response") -> AgentResponse: """Create an AgentResponse with the given text.""" - return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=text)])]) + return AgentResponse(messages=[ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text=text)])]) def create_agent_executor_response( @@ -310,8 +308,8 @@ def create_agent_executor_response( executor_id=executor_id, agent_response=agent_response, full_conversation=[ - ChatMessage(role=Role.USER, contents=[TextContent(text="User input")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text=response_text)]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="User input")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text=response_text)]), ], ) diff --git a/python/packages/devui/tests/test_mapper.py b/python/packages/devui/tests/test_mapper.py index faf0c831d8..4ea3ba9333 100644 --- a/python/packages/devui/tests/test_mapper.py +++ b/python/packages/devui/tests/test_mapper.py @@ -14,11 +14,8 @@ # Import Agent Framework types from agent_framework._types import ( AgentResponseUpdate, - ErrorContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, ) # Import real workflow event classes - NOT mocks! @@ -71,15 +68,17 @@ def test_request() -> AgentFrameworkRequest: def create_test_content(content_type: str, **kwargs: Any) -> Any: """Create test content objects.""" if content_type == "text": - return TextContent(text=kwargs.get("text", "Hello, world!")) + return Content.from_text(text=kwargs.get("text", "Hello, world!")) if content_type == "function_call": - return FunctionCallContent( + return Content.from_function_call( call_id=kwargs.get("call_id", "test_call_id"), name=kwargs.get("name", "test_func"), arguments=kwargs.get("arguments", {"param": "value"}), ) if content_type == "error": - return ErrorContent(message=kwargs.get("message", "Test error"), error_code=kwargs.get("code", "test_error")) + return Content.from_error( + message=kwargs.get("message", "Test error"), error_code=kwargs.get("code", "test_error") + ) raise ValueError(f"Unknown content type: {content_type}") @@ -162,7 +161,7 @@ async def test_function_result_content_with_string_result( mapper: MessageMapper, test_request: AgentFrameworkRequest ) -> None: """Test FunctionResultContent with plain string result (regular tools).""" - content = FunctionResultContent( + content = Content.from_function_result( call_id="test_call_123", result="Hello, World!", ) @@ -182,9 +181,9 @@ async def test_function_result_content_with_nested_content_objects( mapper: MessageMapper, test_request: AgentFrameworkRequest ) -> None: """Test FunctionResultContent with nested Content objects (MCP tools case).""" - content = FunctionResultContent( + content = Content.from_function_result( call_id="mcp_call_456", - result=[TextContent(text="Hello from MCP!")], + result=[Content.from_text(text="Hello from MCP!")], ) update = create_test_agent_update([content]) @@ -451,12 +450,12 @@ async def test_magentic_agent_run_update_event_with_agent_delta_metadata( This tests the ACTUAL event format Magentic emits - not a fake MagenticAgentDeltaEvent class. Magentic uses AgentRunUpdateEvent with additional_properties containing magentic_event_type. """ - from agent_framework._types import AgentResponseUpdate, Role, TextContent + from agent_framework._types import AgentResponseUpdate, Role from agent_framework._workflows._events import AgentRunUpdateEvent # Create the REAL event format that Magentic emits update = AgentResponseUpdate( - contents=[TextContent(text="Hello from agent")], + contents=[Content.from_text(text="Hello from agent")], role=Role.ASSISTANT, author_name="Writer", additional_properties={ @@ -482,12 +481,12 @@ async def test_magentic_orchestrator_message_event(mapper: MessageMapper, test_r Magentic emits orchestrator planning/instruction messages using AgentRunUpdateEvent with additional_properties containing magentic_event_type='orchestrator_message'. """ - from agent_framework._types import AgentResponseUpdate, Role, TextContent + from agent_framework._types import AgentResponseUpdate, Role from agent_framework._workflows._events import AgentRunUpdateEvent # Create orchestrator message event (REAL format from Magentic) update = AgentResponseUpdate( - contents=[TextContent(text="Planning: First, the writer will create content...")], + contents=[Content.from_text(text="Planning: First, the writer will create content...")], role=Role.ASSISTANT, author_name="Orchestrator", additional_properties={ @@ -518,20 +517,20 @@ async def test_magentic_events_use_same_event_class_as_other_workflows( additional_properties. Any mapper code checking for 'MagenticAgentDeltaEvent' class names is dead code. """ - from agent_framework._types import AgentResponseUpdate, Role, TextContent + from agent_framework._types import AgentResponseUpdate, Role from agent_framework._workflows._events import AgentRunUpdateEvent # Create events the way different workflows do it # 1. Regular workflow (no additional_properties) regular_update = AgentResponseUpdate( - contents=[TextContent(text="Regular workflow response")], + contents=[Content.from_text(text="Regular workflow response")], role=Role.ASSISTANT, ) regular_event = AgentRunUpdateEvent(executor_id="regular_executor", data=regular_update) # 2. Magentic workflow (with additional_properties) magentic_update = AgentResponseUpdate( - contents=[TextContent(text="Magentic workflow response")], + contents=[Content.from_text(text="Magentic workflow response")], role=Role.ASSISTANT, additional_properties={"magentic_event_type": "agent_delta"}, ) @@ -599,13 +598,13 @@ async def test_workflow_output_event(mapper: MessageMapper, test_request: AgentF async def test_workflow_output_event_with_list_data(mapper: MessageMapper, test_request: AgentFrameworkRequest) -> None: """Test WorkflowOutputEvent with list data (common for sequential/concurrent workflows).""" - from agent_framework import ChatMessage, Role, TextContent + from agent_framework import ChatMessage, Role from agent_framework._workflows._events import WorkflowOutputEvent # Sequential/Concurrent workflows often output list[ChatMessage] messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="World")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="World")]), ] event = WorkflowOutputEvent(data=messages, executor_id="complete") events = await mapper.convert_event(event, test_request) diff --git a/python/packages/devui/tests/test_multimodal_workflow.py b/python/packages/devui/tests/test_multimodal_workflow.py index f4a0061a23..b962fccd7b 100644 --- a/python/packages/devui/tests/test_multimodal_workflow.py +++ b/python/packages/devui/tests/test_multimodal_workflow.py @@ -49,7 +49,7 @@ def test_is_openai_multimodal_format_detects_message_format(self): def test_convert_openai_input_to_chat_message_with_image(self): """Test that OpenAI format with image is converted to ChatMessage with DataContent.""" - from agent_framework import ChatMessage, DataContent, Role, TextContent + from agent_framework import ChatMessage, Role discovery = MagicMock(spec=EntityDiscovery) mapper = MagicMock(spec=MessageMapper) @@ -78,11 +78,11 @@ def test_convert_openai_input_to_chat_message_with_image(self): assert len(result.contents) == 2, f"Expected 2 contents, got {len(result.contents)}" # First content should be text - assert isinstance(result.contents[0], TextContent) + assert result.contents[0].type == "text" assert result.contents[0].text == "Describe this image" # Second content should be image (DataContent) - assert isinstance(result.contents[1], DataContent) + assert result.contents[1].type == "data" assert result.contents[1].media_type == "image/png" assert result.contents[1].uri == TEST_IMAGE_DATA_URI @@ -90,7 +90,7 @@ def test_parse_workflow_input_handles_json_string_with_multimodal(self): """Test that _parse_workflow_input correctly handles JSON string with multimodal content.""" import asyncio - from agent_framework import ChatMessage, DataContent, TextContent + from agent_framework import ChatMessage discovery = MagicMock(spec=EntityDiscovery) mapper = MagicMock(spec=MessageMapper) @@ -120,11 +120,11 @@ def test_parse_workflow_input_handles_json_string_with_multimodal(self): assert len(result.contents) == 2 # Verify text content - assert isinstance(result.contents[0], TextContent) + assert result.contents[0].type == "text" assert result.contents[0].text == "What is in this image?" # Verify image content - assert isinstance(result.contents[1], DataContent) + assert result.contents[1].type == "data" assert result.contents[1].media_type == "image/png" def test_parse_workflow_input_still_handles_simple_dict(self): diff --git a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py index 7e09894861..7017062b32 100644 --- a/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py +++ b/python/packages/lab/lightning/agent_framework_lab_lightning/__init__.py @@ -4,7 +4,7 @@ import importlib.metadata -from agent_framework.observability import OBSERVABILITY_SETTINGS +from agent_framework.observability import enable_instrumentation from agentlightning import AgentOpsTracer # type: ignore try: @@ -22,13 +22,12 @@ class AgentFrameworkTracer(AgentOpsTracer): # type: ignore def init(self) -> None: """Initialize the agent-framework-lab-lightning for training.""" - OBSERVABILITY_SETTINGS.enable_instrumentation = True + enable_instrumentation() super().init() def teardown(self) -> None: """Teardown the agent-framework-lab-lightning for training.""" super().teardown() - OBSERVABILITY_SETTINGS.enable_instrumentation = False __all__: list[str] = ["AgentFrameworkTracer"] diff --git a/python/packages/lab/lightning/tests/test_lightning.py b/python/packages/lab/lightning/tests/test_lightning.py index 5f85532de1..413087877e 100644 --- a/python/packages/lab/lightning/tests/test_lightning.py +++ b/python/packages/lab/lightning/tests/test_lightning.py @@ -10,7 +10,7 @@ agentlightning = pytest.importorskip("agentlightning") from agent_framework import AgentExecutor, AgentRunEvent, ChatAgent, WorkflowBuilder -from agent_framework.lab.lightning import AgentFrameworkTracer +from agent_framework_lab_lightning import AgentFrameworkTracer from agent_framework.openai import OpenAIChatClient from agentlightning import TracerTraceToTriplet from openai.types.chat import ChatCompletion, ChatCompletionMessage diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py index 6356ad038c..770c902244 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -from agent_framework._types import ChatMessage, Contents, Role +from agent_framework._types import ChatMessage, Content, Role from loguru import logger @@ -12,7 +12,7 @@ def flip_messages(messages: list[ChatMessage]) -> list[ChatMessage]: messages to user messages (since users typically don't make function calls). """ - def filter_out_function_calls(messages: list[Contents]) -> list[Contents]: + def filter_out_function_calls(messages: list[Content]) -> list[Content]: """Remove function call content from message contents.""" return [content for content in messages if content.type != "function_call"] diff --git a/python/packages/lab/tau2/tests/test_message_utils.py b/python/packages/lab/tau2/tests/test_message_utils.py index 547d20ff26..255f0a96ae 100644 --- a/python/packages/lab/tau2/tests/test_message_utils.py +++ b/python/packages/lab/tau2/tests/test_message_utils.py @@ -2,7 +2,7 @@ from unittest.mock import patch -from agent_framework._types import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent +from agent_framework._types import ChatMessage, Content, Role from agent_framework_lab_tau2._message_utils import flip_messages, log_messages @@ -10,7 +10,10 @@ def test_flip_messages_user_to_assistant(): """Test flipping user message to assistant.""" messages = [ ChatMessage( - role=Role.USER, contents=[TextContent(text="Hello assistant")], author_name="User1", message_id="msg_001" + role=Role.USER, + contents=[Content.from_text(text="Hello assistant")], + author_name="User1", + message_id="msg_001", ) ] @@ -28,7 +31,7 @@ def test_flip_messages_assistant_to_user(): messages = [ ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="Hello user")], + contents=[Content.from_text(text="Hello user")], author_name="Assistant1", message_id="msg_002", ) @@ -45,12 +48,16 @@ def test_flip_messages_assistant_to_user(): def test_flip_messages_assistant_with_function_calls_filtered(): """Test that function calls are filtered out when flipping assistant to user.""" - function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"}) + function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"}) messages = [ ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="I'll call a function"), function_call, TextContent(text="After the call")], + contents=[ + Content.from_text(text="I'll call a function"), + function_call, + Content.from_text(text="After the call"), + ], message_id="msg_003", ) ] @@ -68,7 +75,7 @@ def test_flip_messages_assistant_with_function_calls_filtered(): def test_flip_messages_assistant_with_only_function_calls_skipped(): """Test that assistant messages with only function calls are skipped.""" - function_call = FunctionCallContent(call_id="call_456", name="another_function", arguments={"key": "value"}) + function_call = Content.from_function_call(call_id="call_456", name="another_function", arguments={"key": "value"}) messages = [ ChatMessage(role=Role.ASSISTANT, contents=[function_call], message_id="msg_004") # Only function call, no text @@ -82,7 +89,7 @@ def test_flip_messages_assistant_with_only_function_calls_skipped(): def test_flip_messages_tool_messages_skipped(): """Test that tool messages are skipped.""" - function_result = FunctionResultContent(call_id="call_789", result={"success": True}) + function_result = Content.from_function_result(call_id="call_789", result={"success": True}) messages = [ChatMessage(role=Role.TOOL, contents=[function_result])] @@ -94,7 +101,9 @@ def test_flip_messages_tool_messages_skipped(): def test_flip_messages_system_messages_preserved(): """Test that system messages are preserved as-is.""" - messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System instruction")], message_id="sys_001")] + messages = [ + ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System instruction")], message_id="sys_001") + ] flipped = flip_messages(messages) @@ -106,16 +115,16 @@ def test_flip_messages_system_messages_preserved(): def test_flip_messages_mixed_conversation(): """Test flipping a mixed conversation.""" - function_call = FunctionCallContent(call_id="call_mixed", name="mixed_function", arguments={}) + function_call = Content.from_function_call(call_id="call_mixed", name="mixed_function", arguments={}) - function_result = FunctionResultContent(call_id="call_mixed", result="function result") + function_result = Content.from_function_result(call_id="call_mixed", result="function result") messages = [ - ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System prompt")]), - ChatMessage(role=Role.USER, contents=[TextContent(text="User question")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Assistant response"), function_call]), + ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System prompt")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="User question")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Assistant response"), function_call]), ChatMessage(role=Role.TOOL, contents=[function_result]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Final response")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Final response")]), ] flipped = flip_messages(messages) @@ -151,7 +160,10 @@ def test_flip_messages_preserves_metadata(): """Test that message metadata is preserved during flipping.""" messages = [ ChatMessage( - role=Role.USER, contents=[TextContent(text="Test message")], author_name="TestUser", message_id="test_123" + role=Role.USER, + contents=[Content.from_text(text="Test message")], + author_name="TestUser", + message_id="test_123", ) ] @@ -166,8 +178,8 @@ def test_flip_messages_preserves_metadata(): def test_log_messages_text_content(mock_logger): """Test logging messages with text content.""" messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Hi there!")]), ] log_messages(messages) @@ -179,7 +191,7 @@ def test_log_messages_text_content(mock_logger): @patch("agent_framework_lab_tau2._message_utils.logger") def test_log_messages_function_call(mock_logger): """Test logging messages with function calls.""" - function_call = FunctionCallContent(call_id="call_log", name="log_function", arguments={"param": "value"}) + function_call = Content.from_function_call(call_id="call_log", name="log_function", arguments={"param": "value"}) messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])] @@ -195,7 +207,7 @@ def test_log_messages_function_call(mock_logger): @patch("agent_framework_lab_tau2._message_utils.logger") def test_log_messages_function_result(mock_logger): """Test logging messages with function results.""" - function_result = FunctionResultContent(call_id="call_result", result="success") + function_result = Content.from_function_result(call_id="call_result", result="success") messages = [ChatMessage(role=Role.TOOL, contents=[function_result])] @@ -211,10 +223,10 @@ def test_log_messages_function_result(mock_logger): def test_log_messages_different_roles(mock_logger): """Test logging messages with different roles get different colors.""" messages = [ - ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System")]), - ChatMessage(role=Role.USER, contents=[TextContent(text="User")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Assistant")]), - ChatMessage(role=Role.TOOL, contents=[TextContent(text="Tool")]), + ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="User")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Assistant")]), + ChatMessage(role=Role.TOOL, contents=[Content.from_text(text="Tool")]), ] log_messages(messages) @@ -238,7 +250,7 @@ def test_log_messages_different_roles(mock_logger): @patch("agent_framework_lab_tau2._message_utils.logger") def test_log_messages_escapes_html(mock_logger): """Test that HTML-like characters are properly escaped in log output.""" - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Message with content")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Message with content")])] log_messages(messages) @@ -251,12 +263,12 @@ def test_log_messages_escapes_html(mock_logger): @patch("agent_framework_lab_tau2._message_utils.logger") def test_log_messages_mixed_content_types(mock_logger): """Test logging messages with mixed content types.""" - function_call = FunctionCallContent(call_id="mixed_call", name="mixed_function", arguments={"key": "value"}) + function_call = Content.from_function_call(call_id="mixed_call", name="mixed_function", arguments={"key": "value"}) messages = [ ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="I'll call a function"), function_call, TextContent(text="Done!")], + contents=[Content.from_text(text="I'll call a function"), function_call, Content.from_text(text="Done!")], ) ] diff --git a/python/packages/lab/tau2/tests/test_sliding_window.py b/python/packages/lab/tau2/tests/test_sliding_window.py index e4ea3cd5ad..030b57a750 100644 --- a/python/packages/lab/tau2/tests/test_sliding_window.py +++ b/python/packages/lab/tau2/tests/test_sliding_window.py @@ -4,7 +4,7 @@ from unittest.mock import patch -from agent_framework._types import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent +from agent_framework._types import ChatMessage, Content, Role from agent_framework_lab_tau2._sliding_window import SlidingWindowChatMessageStore @@ -36,8 +36,8 @@ def test_initialization_with_parameters(): def test_initialization_with_messages(): """Test initializing with existing messages.""" messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Hi there!")]), ] sliding_window = SlidingWindowChatMessageStore(messages=messages, max_tokens=1000) @@ -51,8 +51,8 @@ async def test_add_messages_simple(): sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit new_messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="What's the weather?")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I can help with that.")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="What's the weather?")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="I can help with that.")]), ] await sliding_window.add_messages(new_messages) @@ -69,7 +69,8 @@ async def test_list_all_messages_vs_list_messages(): # Add many messages to trigger truncation messages = [ - ChatMessage(role=Role.USER, contents=[TextContent(text=f"Message {i} with some content")]) for i in range(10) + ChatMessage(role=Role.USER, contents=[Content.from_text(text=f"Message {i} with some content")]) + for i in range(10) ] await sliding_window.add_messages(messages) @@ -87,7 +88,7 @@ async def test_list_all_messages_vs_list_messages(): def test_get_token_count_basic(): """Test basic token counting.""" sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) - sliding_window.truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] + sliding_window.truncated_messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])] token_count = sliding_window.get_token_count() @@ -104,7 +105,7 @@ def test_get_token_count_with_system_message(): token_count_empty = sliding_window.get_token_count() # Add a message - sliding_window.truncated_messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])] + sliding_window.truncated_messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])] token_count_with_message = sliding_window.get_token_count() # With message should be more tokens @@ -114,7 +115,7 @@ def test_get_token_count_with_system_message(): def test_get_token_count_function_call(): """Test token counting with function calls.""" - function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"}) + function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"}) sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) sliding_window.truncated_messages = [ChatMessage(role=Role.ASSISTANT, contents=[function_call])] @@ -125,7 +126,7 @@ def test_get_token_count_function_call(): def test_get_token_count_function_result(): """Test token counting with function results.""" - function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result"}) + function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result"}) sliding_window = SlidingWindowChatMessageStore(max_tokens=1000) sliding_window.truncated_messages = [ChatMessage(role=Role.TOOL, contents=[function_result])] @@ -143,13 +144,15 @@ def test_truncate_messages_removes_old_messages(mock_logger): messages = [ ChatMessage( role=Role.USER, - contents=[TextContent(text="This is a very long message that should exceed the token limit")], + contents=[Content.from_text(text="This is a very long message that should exceed the token limit")], ), ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="This is another very long message that should also exceed the token limit")], + contents=[ + Content.from_text(text="This is another very long message that should also exceed the token limit") + ], ), - ChatMessage(role=Role.USER, contents=[TextContent(text="Short msg")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Short msg")]), ] sliding_window.truncated_messages = messages.copy() @@ -168,8 +171,10 @@ def test_truncate_messages_removes_leading_tool_messages(mock_logger): sliding_window = SlidingWindowChatMessageStore(max_tokens=10000) # Large limit # Create messages starting with tool message - tool_message = ChatMessage(role=Role.TOOL, contents=[FunctionResultContent(call_id="call_123", result="result")]) - user_message = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]) + tool_message = ChatMessage( + role=Role.TOOL, contents=[Content.from_function_result(call_id="call_123", result="result")] + ) + user_message = ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]) sliding_window.truncated_messages = [tool_message, user_message] sliding_window.truncate_messages() @@ -227,24 +232,27 @@ async def test_real_world_scenario(): # Simulate a conversation conversation = [ - ChatMessage(role=Role.USER, contents=[TextContent(text="Hello, how are you?")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello, how are you?")]), ChatMessage( - role=Role.ASSISTANT, contents=[TextContent(text="I'm doing well, thank you! How can I help you today?")] + role=Role.ASSISTANT, + contents=[Content.from_text(text="I'm doing well, thank you! How can I help you today?")], ), - ChatMessage(role=Role.USER, contents=[TextContent(text="Can you tell me about the weather?")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="Can you tell me about the weather?")]), ChatMessage( role=Role.ASSISTANT, contents=[ - TextContent( + Content.from_text( text="I'd be happy to help with weather information, " "but I don't have access to current weather data." ) ], ), - ChatMessage(role=Role.USER, contents=[TextContent(text="What about telling me a joke instead?")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="What about telling me a joke instead?")]), ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="Sure! Why don't scientists trust atoms? Because they make up everything!")], + contents=[ + Content.from_text(text="Sure! Why don't scientists trust atoms? Because they make up everything!") + ], ), ] diff --git a/python/packages/lab/tau2/tests/test_tau2_utils.py b/python/packages/lab/tau2/tests/test_tau2_utils.py index 21e8622da8..252646081b 100644 --- a/python/packages/lab/tau2/tests/test_tau2_utils.py +++ b/python/packages/lab/tau2/tests/test_tau2_utils.py @@ -7,7 +7,7 @@ import pytest from agent_framework._tools import AIFunction -from agent_framework._types import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent +from agent_framework._types import ChatMessage, Content, Role from agent_framework_lab_tau2._tau2_utils import ( convert_agent_framework_messages_to_tau2_messages, convert_tau2_tool_to_ai_function, @@ -92,7 +92,7 @@ def test_convert_tau2_tool_to_ai_function_multiple_tools(tau2_airline_environmen def test_convert_agent_framework_messages_to_tau2_messages_system(): """Test converting system message.""" - messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System instruction")])] + messages = [ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System instruction")])] tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages) @@ -104,7 +104,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_system(): def test_convert_agent_framework_messages_to_tau2_messages_user(): """Test converting user message.""" - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello assistant")])] + messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello assistant")])] tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages) @@ -117,7 +117,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_user(): def test_convert_agent_framework_messages_to_tau2_messages_assistant(): """Test converting assistant message.""" - messages = [ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hello user")])] + messages = [ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Hello user")])] tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages) @@ -130,9 +130,11 @@ def test_convert_agent_framework_messages_to_tau2_messages_assistant(): def test_convert_agent_framework_messages_to_tau2_messages_with_function_call(): """Test converting message with function call.""" - function_call = FunctionCallContent(call_id="call_123", name="test_function", arguments={"param": "value"}) + function_call = Content.from_function_call(call_id="call_123", name="test_function", arguments={"param": "value"}) - messages = [ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I'll call a function"), function_call])] + messages = [ + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="I'll call a function"), function_call]) + ] tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages) @@ -152,7 +154,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_call(): def test_convert_agent_framework_messages_to_tau2_messages_with_function_result(): """Test converting message with function result.""" - function_result = FunctionResultContent(call_id="call_123", result={"success": True, "data": "result data"}) + function_result = Content.from_function_result(call_id="call_123", result={"success": True, "data": "result data"}) messages = [ChatMessage(role=Role.TOOL, contents=[function_result])] @@ -170,7 +172,7 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_function_result( def test_convert_agent_framework_messages_to_tau2_messages_with_error(): """Test converting function result with error.""" - function_result = FunctionResultContent( + function_result = Content.from_function_result( call_id="call_456", result="Error occurred", exception=Exception("Test error") ) @@ -185,7 +187,11 @@ def test_convert_agent_framework_messages_to_tau2_messages_with_error(): def test_convert_agent_framework_messages_to_tau2_messages_multiple_text_contents(): """Test converting message with multiple text contents.""" - messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="First part"), TextContent(text="Second part")])] + messages = [ + ChatMessage( + role=Role.USER, contents=[Content.from_text(text="First part"), Content.from_text(text="Second part")] + ) + ] tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages) @@ -196,16 +202,16 @@ def test_convert_agent_framework_messages_to_tau2_messages_multiple_text_content def test_convert_agent_framework_messages_to_tau2_messages_complex_scenario(): """Test converting complex scenario with multiple message types.""" - function_call = FunctionCallContent(call_id="call_789", name="complex_tool", arguments='{"key": "value"}') + function_call = Content.from_function_call(call_id="call_789", name="complex_tool", arguments='{"key": "value"}') - function_result = FunctionResultContent(call_id="call_789", result={"output": "tool result"}) + function_result = Content.from_function_result(call_id="call_789", result={"output": "tool result"}) messages = [ - ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System prompt")]), - ChatMessage(role=Role.USER, contents=[TextContent(text="User request")]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="I'll help you"), function_call]), + ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System prompt")]), + ChatMessage(role=Role.USER, contents=[Content.from_text(text="User request")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="I'll help you"), function_call]), ChatMessage(role=Role.TOOL, contents=[function_result]), - ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Based on the result...")]), + ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Based on the result...")]), ] tau2_messages = convert_agent_framework_messages_to_tau2_messages(messages) diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index 7464aad913..0013e7ce6e 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, patch import pytest -from agent_framework import ChatMessage, Context, Role +from agent_framework import ChatMessage, Content, Context, Role from agent_framework.exceptions import ServiceInitializationError from agent_framework.mem0 import Mem0Provider @@ -409,14 +409,13 @@ async def test_model_invoking_function_approval_response_returns_none_instructio self, mock_mem0_client: AsyncMock ) -> None: """Test invoking with function approval response content messages returns context with None instructions.""" - from agent_framework import FunctionApprovalResponseContent, FunctionCallContent provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) - function_call = FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}') + function_call = Content.from_function_call(call_id="1", name="test_func", arguments='{"arg1": "value1"}') message = ChatMessage( role=Role.USER, contents=[ - FunctionApprovalResponseContent( + Content.from_function_approval_response( id="approval_1", function_call=function_call, approved=True, diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index 825ee47bec..3f1b325bb0 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -20,13 +20,8 @@ ChatOptions, ChatResponse, ChatResponseUpdate, - Contents, - DataContent, - FunctionCallContent, - FunctionResultContent, + Content, Role, - TextContent, - TextReasoningContent, ToolProtocol, UsageDetails, get_logger, @@ -452,16 +447,16 @@ def _format_system_message(self, message: ChatMessage) -> list[OllamaMessage]: return [OllamaMessage(role="system", content=message.text)] def _format_user_message(self, message: ChatMessage) -> list[OllamaMessage]: - if not any(isinstance(c, (DataContent, TextContent)) for c in message.contents) and not message.text: + if not any(c.type in {"text", "data"} for c in message.contents) and not message.text: raise ServiceInvalidRequestError( "Ollama connector currently only supports user messages with TextContent or DataContent." ) - if not any(isinstance(c, DataContent) for c in message.contents): + if not any(isinstance(c, Content) and c.type == "data" for c in message.contents): return [OllamaMessage(role="user", content=message.text)] user_message = OllamaMessage(role="user", content=message.text) - data_contents = [c for c in message.contents if isinstance(c, DataContent)] + data_contents = [c for c in message.contents if isinstance(c, Content) and c.type == "data"] if data_contents: if not any(c.has_top_level_media_type("image") for c in data_contents): raise ServiceInvalidRequestError("Only image data content is supported for user messages in Ollama.") @@ -472,11 +467,11 @@ def _format_user_message(self, message: ChatMessage) -> list[OllamaMessage]: def _format_assistant_message(self, message: ChatMessage) -> list[OllamaMessage]: text_content = message.text # Ollama shouldn't have encrypted reasoning, so we just process text. - reasoning_contents = "".join((c.text or "") for c in message.contents if isinstance(c, TextReasoningContent)) + reasoning_contents = "".join((c.text or "") for c in message.contents if c.type == "text_reasoning") assistant_message = OllamaMessage(role="assistant", content=text_content, thinking=reasoning_contents) - tool_calls = [item for item in message.contents if isinstance(item, FunctionCallContent)] + tool_calls = [item for item in message.contents if item.type == "function_call"] if tool_calls: assistant_message["tool_calls"] = [ { @@ -497,15 +492,15 @@ def _format_tool_message(self, message: ChatMessage) -> list[OllamaMessage]: return [ OllamaMessage(role="tool", content=str(item.result), tool_name=item.call_id) for item in message.contents - if isinstance(item, FunctionResultContent) + if item.type == "function_result" ] - def _parse_contents_from_ollama(self, response: OllamaChatResponse) -> list[Contents]: - contents: list[Contents] = [] + def _parse_contents_from_ollama(self, response: OllamaChatResponse) -> list[Content]: + contents: list[Content] = [] if response.message.thinking: - contents.append(TextReasoningContent(text=response.message.thinking)) + contents.append(Content.from_text_reasoning(text=response.message.thinking)) if response.message.content: - contents.append(TextContent(text=response.message.content)) + contents.append(Content.from_text(text=response.message.content)) if response.message.tool_calls: tool_calls = self._parse_tool_calls_from_ollama(response.message.tool_calls) contents.extend(tool_calls) @@ -533,10 +528,10 @@ def _parse_response_from_ollama(self, response: OllamaChatResponse) -> ChatRespo ), ) - def _parse_tool_calls_from_ollama(self, tool_calls: Sequence[OllamaMessage.ToolCall]) -> list[Contents]: - resp: list[Contents] = [] + def _parse_tool_calls_from_ollama(self, tool_calls: Sequence[OllamaMessage.ToolCall]) -> list[Content]: + resp: list[Content] = [] for tool in tool_calls: - fcc = FunctionCallContent( + fcc = Content.from_function_call( call_id=tool.function.name, # Use name of function as call ID since Ollama doesn't provide a call ID name=tool.function.name, arguments=tool.function.arguments if isinstance(tool.function.arguments, dict) else "", diff --git a/python/packages/ollama/tests/test_ollama_chat_client.py b/python/packages/ollama/tests/test_ollama_chat_client.py index e2aebb2a6a..ed517fae20 100644 --- a/python/packages/ollama/tests/test_ollama_chat_client.py +++ b/python/packages/ollama/tests/test_ollama_chat_client.py @@ -9,13 +9,8 @@ BaseChatClient, ChatMessage, ChatResponseUpdate, - DataContent, - FunctionCallContent, - FunctionResultContent, + Content, HostedWebSearchTool, - TextContent, - TextReasoningContent, - UriContent, ai_function, chat_middleware, ) @@ -231,7 +226,7 @@ async def test_cmc_reasoning( ollama_client = OllamaChatClient() result = await ollama_client.get_response(messages=chat_history) - reasoning = "".join(c.text for c in result.messages.pop().contents if isinstance(c, TextReasoningContent)) + reasoning = "".join(c.text for c in result.messages.pop().contents if c.type == "text_reasoning") assert reasoning == "test" @@ -286,7 +281,7 @@ async def test_cmc_streaming_reasoning( result = ollama_client.get_streaming_response(messages=chat_history) async for chunk in result: - reasoning = "".join(c.text for c in chunk.contents if isinstance(c, TextReasoningContent)) + reasoning = "".join(c.text for c in chunk.contents if c.type == "text_reasoning") assert reasoning == "test" @@ -333,14 +328,14 @@ async def test_cmc_streaming_with_tool_call( chunks.append(chunk) # Check parsed Toolcalls - assert isinstance(chunks[0].contents[0], FunctionCallContent) + assert chunks[0].contents[0].type == "function_call" tool_call = chunks[0].contents[0] assert tool_call.name == "hello_world" assert tool_call.arguments == {"arg1": "value1"} - assert isinstance(chunks[1].contents[0], FunctionResultContent) + assert chunks[1].contents[0].type == "function_result" tool_result = chunks[1].contents[0] assert tool_result.result == "Hello World" - assert isinstance(chunks[2].contents[0], TextContent) + assert chunks[2].contents[0].type == "text" text_result = chunks[2].contents[0] assert text_result.text == "test" @@ -378,7 +373,7 @@ async def test_cmc_with_data_content_type( mock_chat.return_value = mock_chat_completion_response chat_history.append( ChatMessage( - contents=[DataContent(uri="data:image/png;base64,xyz", media_type="image/png")], + contents=[Content.from_uri(uri="data:image/png;base64,xyz", media_type="image/png")], role="user", ) ) @@ -401,7 +396,7 @@ async def test_cmc_with_invalid_data_content_media_type( # Remote Uris are not supported by Ollama client chat_history.append( ChatMessage( - contents=[DataContent(uri="data:audio/mp3;base64,xyz", media_type="audio/mp3")], + contents=[Content.from_uri(uri="data:audio/mp3;base64,xyz", media_type="audio/mp3")], role="user", ) ) @@ -424,7 +419,7 @@ async def test_cmc_with_invalid_content_type( # Remote Uris are not supported by Ollama client chat_history.append( ChatMessage( - contents=[UriContent(uri="http://example.com/image.png", media_type="image/png")], + contents=[Content.from_uri(uri="http://example.com/image.png", media_type="image/png")], role="user", ) ) @@ -444,7 +439,7 @@ async def test_cmc_integration_with_tool_call( result = await ollama_client.get_response(messages=chat_history, options={"tools": [hello_world]}) assert "hello" in result.text.lower() and "world" in result.text.lower() - assert isinstance(result.messages[-2].contents[0], FunctionResultContent) + assert result.messages[-2].contents[0].type == "function_result" tool_result = result.messages[-2].contents[0] assert tool_result.result == "Hello World" @@ -478,10 +473,10 @@ async def test_cmc_streaming_integration_with_tool_call( for c in chunks: if len(c.contents) > 0: - if isinstance(c.contents[0], FunctionResultContent): + if c.contents[0].type == "function_result": tool_result = c.contents[0] assert tool_result.result == "Hello World" - if isinstance(c.contents[0], FunctionCallContent): + if c.contents[0].type == "function_call": tool_call = c.contents[0] assert tool_call.name == "hello_world" diff --git a/python/packages/redis/tests/test_redis_chat_message_store.py b/python/packages/redis/tests/test_redis_chat_message_store.py index dc97d81872..a69aef1b0a 100644 --- a/python/packages/redis/tests/test_redis_chat_message_store.py +++ b/python/packages/redis/tests/test_redis_chat_message_store.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import ChatMessage, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework_redis import RedisChatMessageStore @@ -413,7 +413,7 @@ async def test_message_serialization_with_complex_content(self): # Message with multiple content types message = ChatMessage( role=Role.ASSISTANT, - contents=[TextContent(text="Hello"), TextContent(text="World")], + contents=[Content.from_text(text="Hello"), Content.from_text(text="World")], author_name="TestBot", message_id="complex_msg", additional_properties={"metadata": "test"}, diff --git a/python/samples/getting_started/agents/anthropic/anthropic_advanced.py b/python/samples/getting_started/agents/anthropic/anthropic_advanced.py index 7d6c64234c..7ba38d12b7 100644 --- a/python/samples/getting_started/agents/anthropic/anthropic_advanced.py +++ b/python/samples/getting_started/agents/anthropic/anthropic_advanced.py @@ -43,7 +43,7 @@ async def main() -> None: if isinstance(content, TextReasoningContent): print(f"\033[32m{content.text}\033[0m", end="", flush=True) if isinstance(content, UsageContent): - print(f"\n\033[34m[Usage so far: {content.details}]\033[0m\n", end="", flush=True) + print(f"\n\033[34m[Usage so far: {content.usage_details}]\033[0m\n", end="", flush=True) if chunk.text: print(chunk.text, end="", flush=True) diff --git a/python/samples/getting_started/agents/anthropic/anthropic_foundry.py b/python/samples/getting_started/agents/anthropic/anthropic_foundry.py index 04dbcab50d..728e4915c3 100644 --- a/python/samples/getting_started/agents/anthropic/anthropic_foundry.py +++ b/python/samples/getting_started/agents/anthropic/anthropic_foundry.py @@ -54,7 +54,7 @@ async def main() -> None: if isinstance(content, TextReasoningContent): print(f"\033[32m{content.text}\033[0m", end="", flush=True) if isinstance(content, UsageContent): - print(f"\n\033[34m[Usage so far: {content.details}]\033[0m\n", end="", flush=True) + print(f"\n\033[34m[Usage so far: {content.usage_details}]\033[0m\n", end="", flush=True) if chunk.text: print(chunk.text, end="", flush=True) diff --git a/python/samples/getting_started/agents/anthropic/anthropic_skills.py b/python/samples/getting_started/agents/anthropic/anthropic_skills.py index 6dcbe8ccaf..009f485761 100644 --- a/python/samples/getting_started/agents/anthropic/anthropic_skills.py +++ b/python/samples/getting_started/agents/anthropic/anthropic_skills.py @@ -61,7 +61,7 @@ async def main() -> None: case "text_reasoning": print(f"\033[32m{content.text}\033[0m", end="", flush=True) case "usage": - print(f"\n\033[34m[Usage so far: {content.details}]\033[0m\n", end="", flush=True) + print(f"\n\033[34m[Usage so far: {content.usage_details}]\033[0m\n", end="", flush=True) case "hosted_file": # Catch generated files files.append(content) diff --git a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py index 4118b50f9d..3e2b520ede 100644 --- a/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py +++ b/python/samples/getting_started/agents/azure_ai/azure_ai_with_code_interpreter_file_generation.py @@ -4,10 +4,7 @@ from agent_framework import ( AgentResponseUpdate, - CitationAnnotation, HostedCodeInterpreterTool, - HostedFileContent, - TextContent, ) from agent_framework.azure import AzureAIProjectAgentProvider from azure.identity.aio import AzureCliCredential @@ -50,9 +47,9 @@ async def non_streaming_example() -> None: # AgentResponse has messages property, which contains ChatMessage objects for message in result.messages: for content in message.contents: - if isinstance(content, TextContent) and content.annotations: + if content.type == "text" and content.annotations: for annotation in content.annotations: - if isinstance(annotation, CitationAnnotation) and annotation.file_id: + if annotation.file_id: annotations_found.append(annotation.file_id) print(f"Found file annotation: file_id={annotation.file_id}") @@ -84,15 +81,15 @@ async def streaming_example() -> None: async for update in agent.run_stream(QUERY): if isinstance(update, AgentResponseUpdate): for content in update.contents: - if isinstance(content, TextContent): + if content.type == "text": if content.text: text_chunks.append(content.text) if content.annotations: for annotation in content.annotations: - if isinstance(annotation, CitationAnnotation) and annotation.file_id: + if annotation.file_id: annotations_found.append(annotation.file_id) print(f"Found streaming annotation: file_id={annotation.file_id}") - elif isinstance(content, HostedFileContent): + elif content.type == "hosted_file": file_ids_found.append(content.file_id) print(f"Found streaming HostedFileContent: file_id={content.file_id}") diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py index 8f36d5ebec..52da0c450c 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_azure_ai_search.py @@ -3,7 +3,7 @@ import asyncio import os -from agent_framework import CitationAnnotation +from agent_framework import Annotation from agent_framework.azure import AzureAIAgentsProvider from azure.ai.agents.aio import AgentsClient from azure.ai.projects.aio import AIProjectClient @@ -85,13 +85,11 @@ async def main() -> None: ) print(f"User: {user_input}") print("Agent: ", end="", flush=True) - # Stream the response and collect citations - citations: list[CitationAnnotation] = [] + citations: list[Annotation] = [] async for chunk in agent.run_stream(user_input): if chunk.text: print(chunk.text, end="", flush=True) - # Collect citations from Azure AI Search responses for content in getattr(chunk, "contents", []): annotations = getattr(content, "annotations", []) @@ -104,7 +102,7 @@ async def main() -> None: if citations: print("\n\nCitation:") for i, citation in enumerate(citations, 1): - print(f"[{i}] {citation.url}") + print(f"[{i}] {citation.get('url')}") print("\n" + "=" * 50 + "\n") print("Hotel search conversation completed!") diff --git a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py index 752d7e5a54..b1483b141b 100644 --- a/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py +++ b/python/samples/getting_started/agents/azure_ai_agent/azure_ai_with_bing_grounding_citations.py @@ -2,7 +2,7 @@ import asyncio -from agent_framework import CitationAnnotation, HostedWebSearchTool +from agent_framework import Annotation, HostedWebSearchTool from agent_framework.azure import AzureAIAgentsProvider from azure.identity.aio import AzureCliCredential @@ -57,7 +57,7 @@ async def main() -> None: print("Agent: ", end="", flush=True) # Stream the response and collect citations - citations: list[CitationAnnotation] = [] + citations: list[Annotation] = [] async for chunk in agent.run_stream(user_input): if chunk.text: print(chunk.text, end="", flush=True) @@ -74,9 +74,9 @@ async def main() -> None: if citations: print("\n\nCitations:") for i, citation in enumerate(citations, 1): - print(f"[{i}] {citation.title}: {citation.url}") - if citation.snippet: - print(f" Snippet: {citation.snippet}") + print(f"[{i}] {citation['title']}: {citation.get('url')}") + if "snippet" in citation: + print(f" Snippet: {citation.get('snippet')}") else: print("\nNo citations found in the response.") diff --git a/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py b/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py index 1b830c2692..c78053a91a 100644 --- a/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py +++ b/python/samples/getting_started/agents/ollama/ollama_chat_multimodal.py @@ -2,7 +2,7 @@ import asyncio -from agent_framework import ChatMessage, DataContent, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework.ollama import OllamaChatClient """ @@ -35,8 +35,8 @@ async def test_image() -> None: message = ChatMessage( role=Role.USER, contents=[ - TextContent(text="What's in this image?"), - DataContent(uri=image_uri, media_type="image/png"), + Content.from_text(text="What's in this image?"), + Content.from_uri(uri=image_uri, media_type="image/png"), ], ) diff --git a/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py b/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py index 601d321883..06080db943 100644 --- a/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py +++ b/python/samples/getting_started/agents/openai/openai_responses_client_reasoning.py @@ -66,7 +66,7 @@ async def streaming_reasoning_example() -> None: usage = content print("\n") if usage: - print(f"Usage: {usage.details}") + print(f"Usage: {usage.usage_details}") async def main() -> None: diff --git a/python/samples/getting_started/multimodal_input/README.md b/python/samples/getting_started/multimodal_input/README.md index e67052fa8a..2254fe89f7 100644 --- a/python/samples/getting_started/multimodal_input/README.md +++ b/python/samples/getting_started/multimodal_input/README.md @@ -82,7 +82,7 @@ with open("path/to/your/image.jpg", "rb") as f: image_uri = f"data:image/jpeg;base64,{image_base64}" # Use in DataContent -DataContent( +Content.from_uri( uri=image_uri, media_type="image/jpeg" ) @@ -96,7 +96,7 @@ with open("path/to/your/image.jpg", "rb") as f: image_bytes = f.read() # Use in DataContent -DataContent( +Content.from_data( data=image_bytes, media_type="image/jpeg" ) diff --git a/python/samples/getting_started/multimodal_input/azure_chat_multimodal.py b/python/samples/getting_started/multimodal_input/azure_chat_multimodal.py index fb80d65ad1..d5c5e58476 100644 --- a/python/samples/getting_started/multimodal_input/azure_chat_multimodal.py +++ b/python/samples/getting_started/multimodal_input/azure_chat_multimodal.py @@ -2,7 +2,7 @@ import asyncio -from agent_framework import ChatMessage, DataContent, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -26,7 +26,10 @@ async def test_image() -> None: image_uri = create_sample_image() message = ChatMessage( role=Role.USER, - contents=[TextContent(text="What's in this image?"), DataContent(uri=image_uri, media_type="image/png")], + contents=[ + Content.from_text(text="What's in this image?"), + Content.from_uri(uri=image_uri, media_type="image/png"), + ], ) response = await client.get_response(message) diff --git a/python/samples/getting_started/multimodal_input/azure_responses_multimodal.py b/python/samples/getting_started/multimodal_input/azure_responses_multimodal.py index edab52a789..350de89aa4 100644 --- a/python/samples/getting_started/multimodal_input/azure_responses_multimodal.py +++ b/python/samples/getting_started/multimodal_input/azure_responses_multimodal.py @@ -3,7 +3,7 @@ import asyncio from pathlib import Path -from agent_framework import ChatMessage, DataContent, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework.azure import AzureOpenAIResponsesClient from azure.identity import AzureCliCredential @@ -35,7 +35,10 @@ async def test_image() -> None: image_uri = create_sample_image() message = ChatMessage( role=Role.USER, - contents=[TextContent(text="What's in this image?"), DataContent(uri=image_uri, media_type="image/png")], + contents=[ + Content.from_text(text="What's in this image?"), + Content.from_uri(uri=image_uri, media_type="image/png"), + ], ) response = await client.get_response(message) @@ -50,8 +53,8 @@ async def test_pdf() -> None: message = ChatMessage( role=Role.USER, contents=[ - TextContent(text="What information can you extract from this document?"), - DataContent( + Content.from_text(text="What information can you extract from this document?"), + Content.from_data( data=pdf_bytes, media_type="application/pdf", additional_properties={"filename": "sample.pdf"}, diff --git a/python/samples/getting_started/multimodal_input/openai_chat_multimodal.py b/python/samples/getting_started/multimodal_input/openai_chat_multimodal.py index 1985d01bab..e0743340dd 100644 --- a/python/samples/getting_started/multimodal_input/openai_chat_multimodal.py +++ b/python/samples/getting_started/multimodal_input/openai_chat_multimodal.py @@ -5,7 +5,7 @@ import struct from pathlib import Path -from agent_framework import ChatMessage, DataContent, Role, TextContent +from agent_framework import ChatMessage, Content, Role from agent_framework.openai import OpenAIChatClient ASSETS_DIR = Path(__file__).resolve().parent.parent / "sample_assets" @@ -47,7 +47,10 @@ async def test_image() -> None: image_uri = create_sample_image() message = ChatMessage( role=Role.USER, - contents=[TextContent(text="What's in this image?"), DataContent(uri=image_uri, media_type="image/png")], + contents=[ + Content.from_text(text="What's in this image?"), + Content.from_uri(uri=image_uri, media_type="image/png"), + ], ) response = await client.get_response(message) @@ -62,8 +65,8 @@ async def test_audio() -> None: message = ChatMessage( role=Role.USER, contents=[ - TextContent(text="What do you hear in this audio?"), - DataContent(uri=audio_uri, media_type="audio/wav"), + Content.from_text(text="What do you hear in this audio?"), + Content.from_uri(uri=audio_uri, media_type="audio/wav"), ], ) @@ -79,8 +82,8 @@ async def test_pdf() -> None: message = ChatMessage( role=Role.USER, contents=[ - TextContent(text="What information can you extract from this document?"), - DataContent( + Content.from_text(text="What information can you extract from this document?"), + Content.from_data( data=pdf_bytes, media_type="application/pdf", additional_properties={"filename": "employee_report.pdf"} ), ], diff --git a/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py b/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py index bafe55dd4e..0320d02a1f 100644 --- a/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py +++ b/python/samples/getting_started/workflows/agents/workflow_as_agent_reflection_pattern.py @@ -9,7 +9,7 @@ AgentRunUpdateEvent, ChatClientProtocol, ChatMessage, - Contents, + Content, Executor, Role, WorkflowBuilder, @@ -155,7 +155,7 @@ async def handle_review_response(self, review: ReviewResponse, ctx: WorkflowCont if review.approved: print("Worker: Response approved. Emitting to external consumer...") - contents: list[Contents] = [] + contents: list[Content] = [] for message in request.agent_messages: contents.extend(message.contents) diff --git a/python/uv.lock b/python/uv.lock index 2e68d4a6ac..25500d9ebb 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -3914,7 +3914,7 @@ wheels = [ [[package]] name = "openai-agents" -version = "0.6.5" +version = "0.6.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "griffe", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -3925,14 +3925,14 @@ dependencies = [ { name = "types-requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e7/5c/5ebface62a0efdc7298152dcd2d32164403e25e53f1088c042936d8d40f9/openai_agents-0.6.5.tar.gz", hash = "sha256:67e8cab27082d1a1fe6f3fecfcf89b41ff249988a75640bbcc2764952d603ef0", size = 2044506, upload-time = "2026-01-06T15:32:50.936Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/35/4fa8b007c831e1aaa58d337f0dc888a9a44e5dd32103c8f9c77dfe2201b0/openai_agents-0.6.6.tar.gz", hash = "sha256:3c4a9a96dc04307a6df5ef1462ce96451c6875e4fc0d24437c6b3b10e317d3a6", size = 2096069, upload-time = "2026-01-15T05:38:33.045Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/db/16020e45d53366f2ed653ce0ddf959a647687d47180954de7654a133b910/openai_agents-0.6.5-py3-none-any.whl", hash = "sha256:c81d2eaa5c4563b8e893ba836fe170cf10ba974420ff283b4f001f84e7cb6e6b", size = 249352, upload-time = "2026-01-06T15:32:48.847Z" }, + { url = "https://files.pythonhosted.org/packages/2f/42/e2ba13b62cc4d51e49c64a876020ba57304b3d6466b7b1e7f3f201ac5124/openai_agents-0.6.6-py3-none-any.whl", hash = "sha256:737bc58516c0f63a8d9f1e8034fc9a8f1f852d373d3448cda0a8c81e5e75a239", size = 256303, upload-time = "2026-01-15T05:38:30.852Z" }, ] [[package]] name = "openai-chatkit" -version = "1.5.2" +version = "1.5.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -3941,9 +3941,9 @@ dependencies = [ { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "uvicorn", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0e/f3/3e7aafd6c29348e60d32082fb14e539661fe4100453a31b34d0fef1ff7b7/openai_chatkit-1.5.2.tar.gz", hash = "sha256:187d27b815f153fa060337c86ee3aab189f72269f23ac2bb2a35c6c88b83846d", size = 59268, upload-time = "2026-01-10T00:59:41.215Z" } +sdist = { url = "https://files.pythonhosted.org/packages/16/45/854c988728c8638064f3fb5652b30be395a164b526aa7fda2f24de6652c9/openai_chatkit-1.5.3.tar.gz", hash = "sha256:ef43c500a9a8f7e066a99cd04369cee5ee4dacdb1ec94976cb3b17a2caab185f", size = 59768, upload-time = "2026-01-15T06:25:51.225Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/b6/475a4c723fb2e0de30feea505505eabe77666aa7d81855d356fb289e3d8a/openai_chatkit-1.5.2-py3-none-any.whl", hash = "sha256:3bf3f140f314924ef1d4148ce5174cff6aa4c5d1760f988ba2aa267fd434f960", size = 41482, upload-time = "2026-01-10T00:59:40.023Z" }, + { url = "https://files.pythonhosted.org/packages/9d/40/f5dacf061be090a912a3977033123044856bb0e195bbf0a6afa27dd288a7/openai_chatkit-1.5.3-py3-none-any.whl", hash = "sha256:aa6a5bda8620d5875cdeccd843f675cf74c56842622c4a24dc1e89f3b74304fb", size = 41731, upload-time = "2026-01-15T06:25:50.105Z" }, ] [[package]] @@ -4688,7 +4688,7 @@ wheels = [ [[package]] name = "py2docfx" -version = "0.1.22" +version = "0.1.23.dev2382197" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pyyaml", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -4696,7 +4696,7 @@ dependencies = [ { name = "wheel", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/94/6475c7faa94a1d90303f624936471cd0f4c20430bd2c92deab607cd0ff31/py2docfx-0.1.22-py3-none-any.whl", hash = "sha256:ccee611af2aefe9f39f446f72b5e07d3369bbdb77c13ebafb69ed1a116116467", size = 11420273, upload-time = "2025-10-10T07:16:25.294Z" }, + { url = "https://files.pythonhosted.org/packages/3d/5b/201ff811df866172fc422ecbfe9fc4a0463a361b1fe4a92905a7bed53fac/py2docfx-0.1.23.dev2382197-py3-none-any.whl", hash = "sha256:d343babcdee307365776e4c7b9ef95d4f2e218c9d3d23072a483a1ca93c1b316", size = 11339298, upload-time = "2026-01-15T06:35:07.866Z" }, ] [[package]] From 848b13d72e93ab2cd98e38153aefd5136bb78dc8 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 16 Jan 2026 21:29:42 +0100 Subject: [PATCH 2/7] fixed linting --- .../a2a/agent_framework_a2a/_agent.py | 2 +- .../ag-ui/agent_framework_ag_ui/_client.py | 24 +++++----- .../ag-ui/agent_framework_ag_ui/_events.py | 32 +++++++------- .../_message_adapters.py | 10 ++--- .../_orchestration/_helpers.py | 2 +- .../agent_framework_ag_ui/_orchestrators.py | 10 ++--- .../packages/ag-ui/getting_started/client.py | 6 +-- .../ag-ui/getting_started/client_advanced.py | 16 +++---- .../getting_started/client_with_agent.py | 22 ++++++---- .../agent_framework_anthropic/_chat_client.py | 18 ++++---- .../agent_framework_azure_ai/_chat_client.py | 16 +++---- .../agent_framework_azure_ai/_client.py | 2 +- .../_durable_agent_state.py | 24 +++++----- .../agent_framework_bedrock/_chat_client.py | 4 +- .../packages/core/agent_framework/_agents.py | 4 +- python/packages/core/agent_framework/_mcp.py | 2 +- .../packages/core/agent_framework/_tools.py | 44 ++++++++++++------- .../packages/core/agent_framework/_types.py | 29 ++++++------ .../core/agent_framework/_workflows/_agent.py | 10 ++--- .../_workflows/_agent_executor.py | 6 +-- .../openai/_assistants_client.py | 6 +-- .../agent_framework/openai/_chat_client.py | 26 ++++++----- .../openai/_responses_client.py | 20 ++++----- .../agent_framework_declarative/_loader.py | 1 + .../devui/agent_framework_devui/_executor.py | 2 +- .../devui/agent_framework_devui/_mapper.py | 4 +- .../_message_utils.py | 2 +- .../_sliding_window.py | 2 +- .../agent_framework_lab_tau2/_tau2_utils.py | 2 +- .../agent_framework_ollama/_chat_client.py | 6 +-- python/pyproject.toml | 1 + 31 files changed, 183 insertions(+), 172 deletions(-) diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 84cc8fb20c..00e045fba6 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -330,7 +330,7 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage: A2APart( root=FilePart( file=FileWithBytes( - bytes=_get_uri_data(content.uri), + bytes=_get_uri_data(content.uri), # type: ignore[arg-type] mime_type=content.media_type, ), metadata=content.additional_properties, diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 3da099b3b8..a336f28b76 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -54,8 +54,8 @@ def _unwrap_server_function_call_contents(contents: MutableSequence[Content | dict[str, Any]]) -> None: """Replace server_function_call instances with their underlying call content.""" for idx, content in enumerate(contents): - if content.type == "server_function_call": - contents[idx] = content.function_call # type: ignore[assignment] + if content.type == "server_function_call": # type: ignore[union-attr] + contents[idx] = content.function_call # type: ignore[assignment, union-attr] TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient[Any]]) @@ -273,10 +273,10 @@ def _extract_state_from_messages( if isinstance(content, Content) and content.type == "data" and content.media_type == "application/json": try: uri = content.uri - if uri.startswith("data:application/json;base64,"): + if uri.startswith("data:application/json;base64,"): # type: ignore[union-attr] import base64 - encoded_data = uri.split(",", 1)[1] + encoded_data = uri.split(",", 1)[1] # type: ignore[union-attr] decoded_bytes = base64.b64decode(encoded_data) state = json.loads(decoded_bytes.decode("utf-8")) @@ -414,19 +414,19 @@ async def _inner_get_streaming_response( ) # Distinguish client vs server tools for i, content in enumerate(update.contents): - if content.type == "function_call": + if content.type == "function_call": # type: ignore[attr-defined] logger.debug( - f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" + f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" # type: ignore[attr-defined] ) - if content.name in client_tool_set: + if content.name in client_tool_set: # type: ignore[attr-defined] # Client tool - let @use_function_invocation execute it - if not content.additional_properties: - content.additional_properties = {} - content.additional_properties["agui_thread_id"] = thread_id + if not content.additional_properties: # type: ignore[attr-defined] + content.additional_properties = {} # type: ignore[attr-defined] + content.additional_properties["agui_thread_id"] = thread_id # type: ignore[attr-defined] else: # Server tool - wrap so @use_function_invocation ignores it - logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") - self._register_server_tool_placeholder(content.name) + logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") # type: ignore[union-attr] + self._register_server_tool_placeholder(content.name) # type: ignore[arg-type] update.contents[i] = Content(type="server_function_call", function_call=content) # type: ignore yield update diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index bff1780170..34c1e3ed86 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -108,7 +108,7 @@ async def from_agent_run_update(self, update: AgentResponseUpdate) -> list[BaseE def _handle_text_content(self, content: Content) -> list[BaseEvent]: events: list[BaseEvent] = [] - logger.info(f" TextContent found: length={len(content.text)}") + logger.info(f" TextContent found: length={len(content.text)}") # type: ignore[arg-type] logger.info( " Flags: skip_text_content=%s, should_stop_after_confirm=%s", self.skip_text_content, @@ -121,7 +121,7 @@ def _handle_text_content(self, content: Content) -> list[BaseEvent]: if self.should_stop_after_confirm: logger.info(" SKIPPING TextContent: waiting for confirm_changes response") - self.suppressed_summary += content.text + self.suppressed_summary += content.text # type: ignore[operator] logger.info(f" Suppressed summary length={len(self.suppressed_summary)}") return events @@ -309,7 +309,7 @@ def _handle_function_result_content(self, content: Content) -> list[BaseEvent]: result_event = ToolCallResultEvent( message_id=result_message_id, - tool_call_id=content.call_id, + tool_call_id=content.call_id, # type: ignore[arg-type] content=result_content, role="tool", ) @@ -464,10 +464,10 @@ def _emit_function_approval_tool_call(self, function_call: Content) -> list[Base def _handle_function_approval_request_content(self, content: Content) -> list[BaseEvent]: events: list[BaseEvent] = [] logger.info("=== FUNCTION APPROVAL REQUEST ===") - logger.info(f" Function: {content.function_call.name}") - logger.info(f" Call ID: {content.function_call.call_id}") + logger.info(f" Function: {content.function_call.name}") # type: ignore[union-attr] + logger.info(f" Call ID: {content.function_call.call_id}") # type: ignore[union-attr] - parsed_args = content.function_call.parse_arguments() + parsed_args = content.function_call.parse_arguments() # type: ignore[union-attr] parsed_arg_keys = list(parsed_args.keys()) if parsed_args else "None" logger.info(f" Parsed args keys: {parsed_arg_keys}") @@ -477,12 +477,12 @@ def _handle_function_approval_request_content(self, content: Content) -> list[Ba list(self.predict_state_config.keys()) if self.predict_state_config else "None", ) for state_key, config in self.predict_state_config.items(): - if config["tool"] != content.function_call.name: + if config["tool"] != content.function_call.name: # type: ignore[union-attr] continue tool_arg_name = config["tool_argument"] logger.info( " MATCHED tool '%s' for state key '%s', arg='%s'", - content.function_call.name, + content.function_call.name, # type: ignore[union-attr] state_key, tool_arg_name, ) @@ -499,11 +499,11 @@ def _handle_function_approval_request_content(self, content: Content) -> list[Ba ) events.append(state_snapshot) - if content.function_call.call_id: + if content.function_call.call_id: # type: ignore[union-attr] end_event = ToolCallEndEvent( - tool_call_id=content.function_call.call_id, + tool_call_id=content.function_call.call_id, # type: ignore[union-attr] ) - logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") + logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") # type: ignore[union-attr] events.append(end_event) # Emit the function_approval_request custom event for UI implementations that support it @@ -512,18 +512,18 @@ def _handle_function_approval_request_content(self, content: Content) -> list[Ba value={ "id": content.id, "function_call": { - "call_id": content.function_call.call_id, - "name": content.function_call.name, - "arguments": content.function_call.parse_arguments(), + "call_id": content.function_call.call_id, # type: ignore[union-attr] + "name": content.function_call.name, # type: ignore[union-attr] + "arguments": content.function_call.parse_arguments(), # type: ignore[union-attr] }, }, ) - logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") + logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") # type: ignore[union-attr] events.append(approval_event) # Emit a UI-friendly approval tool call for function approvals. if self.require_confirmation: - events.extend(self._emit_function_approval_tool_call(content.function_call)) + events.extend(self._emit_function_approval_tool_call(content.function_call)) # type: ignore[arg-type] # Signal orchestrator to stop the run and wait for user approval response self.should_stop_after_confirm = True diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index b46a7cd288..cf14641258 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -91,11 +91,11 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: user_text = "" for content in msg.contents or []: if content.type == "text": - user_text = content.text + user_text = content.text # type: ignore[assignment] break try: - parsed = json.loads(user_text) + parsed = json.loads(user_text) # type: ignore[arg-type] if "accepted" in parsed: logger.info( f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" @@ -460,8 +460,8 @@ def _filter_modified_args( _update_tool_call_arguments(messages, str(approval_call_id), merged_args) # Create a new FunctionCallContent with the modified arguments func_call_for_approval = Content.from_function_call( - call_id=matching_func_call.call_id, - name=matching_func_call.name, + call_id=matching_func_call.call_id, # type: ignore[arg-type] + name=matching_func_call.name, # type: ignore[arg-type] arguments=json.dumps(filtered_args), ) logger.info(f"Using modified arguments from approval: {filtered_args}") @@ -647,7 +647,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str for content in msg.contents: if content.type == "text": - content_text += content.text + content_text += content.text # type: ignore[operator] elif content.type == "function_call": tool_calls.append( { diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py index b327192367..b2e3b1d5eb 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -53,7 +53,7 @@ def is_state_context_message(message: ChatMessage) -> bool: if get_role_value(message) != "system": return False for content in message.contents: - if content.type == "text" and content.text.startswith("Current state of the application:"): + if content.type == "text" and content.text.startswith("Current state of the application:"): # type: ignore[union-attr] return True return False diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6ac93810db..2bd24de8c8 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -288,7 +288,7 @@ async def run( break try: - tool_result = json.loads(tool_content_text) + tool_result = json.loads(tool_content_text) # type: ignore[arg-type] accepted = tool_result.get("accepted", False) steps = tool_result.get("steps", []) @@ -326,7 +326,7 @@ async def run( except json.JSONDecodeError: logger.error(f"Failed to parse tool result: {tool_content_text}") - yield RunErrorEvent(message=f"Invalid tool result format: {tool_content_text[:100]}") + yield RunErrorEvent(message=f"Invalid tool result format: {tool_content_text[:100]}") # type: ignore[index] yield event_bridge.create_run_finished_event() @@ -440,7 +440,7 @@ async def run( if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): if content.type == "text": - logger.debug(" Content %s: %s - text_length=%s", j, content.type, len(content.text)) + logger.debug(" Content %s: %s - text_length=%s", j, content.type, len(content.text)) # type: ignore[arg-type] elif content.type == "function_call": arg_length = len(str(content.arguments)) if content.arguments else 0 logger.debug( @@ -538,9 +538,9 @@ async def _resolve_approval_responses( if idx < len(approved_function_results) and approved_function_results[idx].type == "function_result": normalized_results.append(approved_function_results[idx]) continue - call_id = approval.function_call.call_id or approval.id + call_id = approval.function_call.call_id or approval.id # type: ignore[union-attr] normalized_results.append( - Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") + Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") # type: ignore[arg-type] ) _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore diff --git a/python/packages/ag-ui/getting_started/client.py b/python/packages/ag-ui/getting_started/client.py index 61bdf0bfb3..7b56103050 100644 --- a/python/packages/ag-ui/getting_started/client.py +++ b/python/packages/ag-ui/getting_started/client.py @@ -50,11 +50,9 @@ async def main(): print("\nAssistant: ", end="", flush=True) # Display text content as it streams - from agent_framework import TextContent - for content in update.contents: - if isinstance(content, TextContent) and content.text: - print(f"\033[96m{content.text}\033[0m", end="", flush=True) + if hasattr(content, "text") and content.text: # type: ignore[attr-defined] + print(f"\033[96m{content.text}\033[0m", end="", flush=True) # type: ignore[attr-defined] # Display finish reason if present if update.finish_reason: diff --git a/python/packages/ag-ui/getting_started/client_advanced.py b/python/packages/ag-ui/getting_started/client_advanced.py index 08698a80a0..3c7ae6a334 100644 --- a/python/packages/ag-ui/getting_started/client_advanced.py +++ b/python/packages/ag-ui/getting_started/client_advanced.py @@ -73,11 +73,9 @@ async def streaming_example(client: AGUIChatClient, thread_id: str | None = None if not thread_id and update.additional_properties: thread_id = update.additional_properties.get("thread_id") - from agent_framework import TextContent - for content in update.contents: - if isinstance(content, TextContent) and content.text: - print(content.text, end="", flush=True) + if content.type == "text" and content.text: # type: ignore[attr-defined] + print(content.text, end="", flush=True) # type: ignore[attr-defined] print("\n") return thread_id @@ -138,13 +136,11 @@ async def tool_example(client: AGUIChatClient, thread_id: str | None = None): print(f"Assistant: {response.text}") # Show tool calls if any - from agent_framework import FunctionCallContent - tool_called = False for message in response.messages: for content in message.contents: - if isinstance(content, FunctionCallContent): - print(f"\n[Tool Called: {content.name}]") + if content.type == "function_call": # type: ignore[attr-defined] + print(f"\n[Tool Called: {content.name}]") # type: ignore[attr-defined] tool_called = True if not tool_called: @@ -176,7 +172,7 @@ async def conversation_example(client: AGUIChatClient): # Second turn - using same thread print("\nUser: What's my name?\n") - response2 = await client.get_response("What's my name?", metadata={"thread_id": thread_id}) + response2 = await client.get_response("What's my name?", options={"metadata": {"thread_id": thread_id}}) print(f"Assistant: {response2.text}") # Check if context was maintained @@ -186,7 +182,7 @@ async def conversation_example(client: AGUIChatClient): # Third turn print("\nUser: Can you also tell me what 10 * 5 is?\n") response3 = await client.get_response( - "Can you also tell me what 10 * 5 is?", metadata={"thread_id": thread_id}, tools=[calculate] + "Can you also tell me what 10 * 5 is?", options={"metadata": {"thread_id": thread_id}}, tools=[calculate] ) print(f"Assistant: {response3.text}") diff --git a/python/packages/ag-ui/getting_started/client_with_agent.py b/python/packages/ag-ui/getting_started/client_with_agent.py index 91b099820b..63a89b4344 100644 --- a/python/packages/ag-ui/getting_started/client_with_agent.py +++ b/python/packages/ag-ui/getting_started/client_with_agent.py @@ -22,7 +22,7 @@ import logging import os -from agent_framework import ChatAgent, FunctionCallContent, FunctionResultContent, TextContent, ai_function +from agent_framework import ChatAgent, ai_function from agent_framework.ag_ui import AGUIChatClient # Enable debug logging @@ -141,8 +141,9 @@ def _preview_for_message(m) -> str: # Build from contents when no direct text parts: list[str] = [] for c in getattr(m, "contents", []) or []: - if isinstance(c, FunctionCallContent): - args = c.arguments + content_type = getattr(c, "type", None) + if content_type == "function_call": + args = getattr(c, "arguments", None) if isinstance(args, dict): try: import json as _json @@ -152,12 +153,15 @@ def _preview_for_message(m) -> str: args_str = str(args) else: args_str = str(args or "{}") - parts.append(f"tool_call {c.name} {args_str}") - elif isinstance(c, FunctionResultContent): - parts.append(f"tool_result[{c.call_id}]: {str(c.result)[:40]}") - elif isinstance(c, TextContent): - if c.text: - parts.append(c.text) + parts.append(f"tool_call {getattr(c, 'name', '?')} {args_str}") + elif content_type == "function_result": + call_id = getattr(c, "call_id", "?") + result = getattr(c, "result", None) + parts.append(f"tool_result[{call_id}]: {str(result)[:40]}") + elif content_type == "text": + text = getattr(c, "text", None) + if text: + parts.append(text) else: typename = getattr(c, "type", c.__class__.__name__) parts.append(f"<{typename}>") diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index c95851a3c8..4fdcdfadc7 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -474,7 +474,7 @@ def _prepare_message_for_anthropic(self, message: ChatMessage) -> dict[str, Any] a_content.append({ "type": "image", "source": { - "data": content.get_data_bytes_as_str(), + "data": content.get_data_bytes_as_str(), # type: ignore[attr-defined] "media_type": content.media_type, "type": "base64", }, @@ -692,9 +692,9 @@ def _parse_usage_from_anthropic(self, usage: BetaUsage | BetaMessageDeltaUsage | if usage.input_tokens is not None: usage_details["input_token_count"] = usage.input_tokens if usage.cache_creation_input_tokens is not None: - usage_details["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens + usage_details["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens # type: ignore[typeddict-unknown-key] if usage.cache_read_input_tokens is not None: - usage_details["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens + usage_details["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens # type: ignore[typeddict-unknown-key] return usage_details def _parse_contents_from_anthropic( @@ -748,7 +748,7 @@ def _parse_contents_from_anthropic( ) ) case "mcp_tool_result": - call_id, name = self._last_call_id_name or (None, None) + call_id, _ = self._last_call_id_name or (None, None) parsed_output: list[Content] | None = None if content_block.content: if isinstance(content_block.content, list): @@ -921,7 +921,7 @@ def _parse_contents_from_anthropic( Annotation( type="citation", raw_representation=content_block.content, - snippet="\n".join(content_block.content.lines) + snippet="\n".join(content_block.content.lines) # type: ignore[typeddict-item] if content_block.content.lines else None, annotated_regions=[ @@ -1006,7 +1006,7 @@ def _parse_citations_from_anthropic( if citation.file_id: cit["file_id"] = citation.file_id cit.setdefault("annotated_regions", []) - cit["annotated_regions"].append( + cit["annotated_regions"].append( # type: ignore[attr-defined] TextSpanRegion( type="text_span", start_index=citation.start_char_index, @@ -1019,7 +1019,7 @@ def _parse_citations_from_anthropic( if citation.file_id: cit["file_id"] = citation.file_id cit.setdefault("annotated_regions", []) - cit["annotated_regions"].append( + cit["annotated_regions"].append( # type: ignore[attr-defined] TextSpanRegion( type="text_span", start_index=citation.start_page_number, @@ -1032,7 +1032,7 @@ def _parse_citations_from_anthropic( if citation.file_id: cit["file_id"] = citation.file_id cit.setdefault("annotated_regions", []) - cit["annotated_regions"].append( + cit["annotated_regions"].append( # type: ignore[attr-defined] TextSpanRegion( type="text_span", start_index=citation.start_block_index, @@ -1048,7 +1048,7 @@ def _parse_citations_from_anthropic( cit["snippet"] = citation.cited_text cit["url"] = citation.source cit.setdefault("annotated_regions", []) - cit["annotated_regions"].append( + cit["annotated_regions"].append( # type: ignore[attr-defined] TextSpanRegion( type="text_span", start_index=citation.start_block_index, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index e9dca3d1dd..f7571e1e71 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -527,9 +527,9 @@ def _extract_url_citations( # Create Annotation with real URL citation = Annotation( type="citation", - title=annotation.url_citation.title, + title=annotation.url_citation.title, # type: ignore[typeddict-item] url=real_url, - snippet=None, + snippet=None, # type: ignore[typeddict-item] annotated_regions=annotated_regions, raw_representation=annotation, ) @@ -1067,7 +1067,7 @@ def _prepare_messages( for chat_message in messages: if chat_message.role.value in ["system", "developer"]: for text_content in [content for content in chat_message.contents if content.type == "text"]: - instructions.append(text_content.text) + instructions.append(text_content.text) # type: ignore[arg-type] continue message_contents: list[MessageInputContentBlock] = [] @@ -1075,11 +1075,11 @@ def _prepare_messages( for content in chat_message.contents: match content.type: case "text": - message_contents.append(MessageInputTextBlock(text=content.text)) + message_contents.append(MessageInputTextBlock(text=content.text)) # type: ignore[arg-type] case "data" | "uri": if content.has_top_level_media_type("image"): message_contents.append( - MessageInputImageUrlBlock(image_url=MessageImageUrlParam(url=content.uri)) + MessageInputImageUrlBlock(image_url=MessageImageUrlParam(url=content.uri)) # type: ignore[arg-type] ) # Only images are supported. Other media types are ignored. case "function_result" | "function_approval_response": @@ -1165,7 +1165,7 @@ async def _prepare_tools_for_azure_ai( case HostedFileSearchTool(): vector_stores = [inp for inp in tool.inputs or [] if inp.type == "hosted_vector_store"] if vector_stores: - file_search = FileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) + file_search = FileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) # type: ignore[misc] tool_definitions.extend(file_search.definitions) # Set tool_resources for file search to work properly with Azure AI if run_options is not None and "tool_resources" not in run_options: @@ -1194,7 +1194,7 @@ def _prepare_tool_outputs_for_azure_ai( # We need to extract the run ID and ensure that the Output/Approval we send back to Azure # is only the call ID. run_and_call_ids: list[str] = ( - json.loads(content.call_id) if content.type == "function_result" else json.loads(content.id) + json.loads(content.call_id) if content.type == "function_result" else json.loads(content.id) # type: ignore[arg-type] ) if ( @@ -1218,7 +1218,7 @@ def _prepare_tool_outputs_for_azure_ai( elif content.type == "function_approval_response": if tool_approvals is None: tool_approvals = [] - tool_approvals.append(ToolApproval(tool_call_id=call_id, approve=content.approved)) + tool_approvals.append(ToolApproval(tool_call_id=call_id, approve=content.approved)) # type: ignore[arg-type] return run_id, tool_outputs, tool_approvals diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index a78379eb61..c64b6df44c 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -477,7 +477,7 @@ def _prepare_messages_for_azure_ai( for message in messages: if message.role.value in ["system", "developer"]: for text_content in [content for content in message.contents if content.type == "text"]: - instructions_list.append(text_content.text) + instructions_list.append(text_content.text) # type: ignore[arg-type] else: result.append(message) diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py index 67d99e5be1..d33d9ea91c 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py @@ -689,7 +689,7 @@ def from_run_response(correlation_id: str, response: AgentResponse) -> DurableAg correlation_id=correlation_id, created_at=_parse_created_at(response.created_at), messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages], - usage=DurableAgentStateUsage.from_usage(response.usage_details), + usage=DurableAgentStateUsage.from_usage(response.usage_details), # type: ignore[arg-type] ) @@ -859,7 +859,7 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_data_content(content: Content) -> DurableAgentStateDataContent: - return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) + return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) # type: ignore[arg-type] def to_ai_content(self) -> Content: return Content.from_uri(uri=self.uri, media_type=self.media_type) @@ -952,7 +952,7 @@ def from_function_call_content(content: Content) -> DurableAgentStateFunctionCal except json.JSONDecodeError: arguments = {} - return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments) + return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments) # type: ignore[arg-type] def to_ai_content(self) -> Content: return Content.from_function_call(call_id=self.call_id, name=self.name, arguments=self.arguments) @@ -988,7 +988,7 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_function_result_content(content: Content) -> DurableAgentStateFunctionResultContent: - return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) + return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) # type: ignore[arg-type] def to_ai_content(self) -> Content: return Content.from_function_result(call_id=self.call_id, result=self.result) @@ -1016,7 +1016,7 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_hosted_file_content(content: Content) -> DurableAgentStateHostedFileContent: - return DurableAgentStateHostedFileContent(file_id=content.file_id) + return DurableAgentStateHostedFileContent(file_id=content.file_id) # type: ignore[arg-type] def to_ai_content(self) -> Content: return Content.from_hosted_file(file_id=self.file_id) @@ -1050,7 +1050,7 @@ def to_dict(self) -> dict[str, Any]: def from_hosted_vector_store_content( content: Content, ) -> DurableAgentStateHostedVectorStoreContent: - return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) + return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) # type: ignore[arg-type] def to_ai_content(self) -> Content: return Content.from_hosted_vector_store(vector_store_id=self.vector_store_id) @@ -1137,7 +1137,7 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_uri_content(content: Content) -> DurableAgentStateUriContent: - return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) + return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) # type: ignore[arg-type] def to_ai_content(self) -> Content: return Content.from_uri(uri=self.uri, media_type=self.media_type) @@ -1194,13 +1194,13 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateUsage: ) @staticmethod - def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None: + def from_usage(usage: UsageDetails | dict[str, int] | None) -> DurableAgentStateUsage | None: if usage is None: return None return DurableAgentStateUsage( - input_token_count=usage.input_token_count, - output_token_count=usage.output_token_count, - total_token_count=usage.total_token_count, + input_token_count=usage.get("input_token_count"), + output_token_count=usage.get("output_token_count"), + total_token_count=usage.get("total_token_count"), ) def to_usage_details(self) -> UsageDetails: @@ -1272,4 +1272,4 @@ def from_unknown_content(content: Any) -> DurableAgentStateUnknownContent: def to_ai_content(self) -> Content: if not self.content: raise Exception("The content is missing and cannot be converted to valid AI content.") - return Content(type=self.type, additional_properties={"content": self.content}) + return Content(type=self.type, additional_properties={"content": self.content}) # type: ignore diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 60a4b829f2..a6325a6603 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -324,7 +324,7 @@ async def _inner_get_streaming_response( response = await self._inner_get_response(messages=messages, options=options, **kwargs) contents = list(response.messages[0].contents if response.messages else []) if response.usage_details: - contents.append(Content.from_usage(details=response.usage_details)) + contents.append(Content.from_usage(usage_details=response.usage_details)) # type: ignore[arg-type] yield ChatResponseUpdate( response_id=response.response_id, contents=contents, @@ -629,7 +629,7 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A Content.from_function_result( call_id=tool_result.get("toolUseId") or self._generate_tool_call_id(), result=result_value, - exception=exception, + exception=str(exception) if exception else None, # type: ignore[arg-type] raw_representation=block, ) ) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 1fdd7d10f7..412221af1f 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1140,9 +1140,9 @@ async def _call_tool( # type: ignore # Convert result to MCP content if isinstance(result, str): - return [types.Content.from_text(type="text", text=result)] + return [types.TextContent(type="text", text=result)] # type: ignore[attr-defined] - return [types.Content.from_text(type="text", text=str(result))] + return [types.TextContent(type="text", text=str(result))] # type: ignore[attr-defined] @server.set_logging_level() # type: ignore async def _set_logging_level(level: types.LoggingLevel) -> None: # type: ignore diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 908d7a0fd8..db912dd3a5 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -185,7 +185,7 @@ def _parse_content_from_mcp( result=_parse_content_from_mcp(mcp_type.content) if mcp_type.content else mcp_type.structuredContent, - exception=Exception() if mcp_type.isError else None, + exception=str(Exception()) if mcp_type.isError else None, # type: ignore[arg-type] raw_representation=mcp_type, ) ) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 4b6362ae13..3ea7d33e72 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1524,28 +1524,28 @@ async def _auto_invoke_function( tool: AIFunction[BaseModel, Any] | None = None if function_call_content.type == "function_call": - tool = tool_map.get(function_call_content.name) + tool = tool_map.get(function_call_content.name) # type: ignore[arg-type] # Tool should exist because _try_execute_function_calls validates this if tool is None: exc = KeyError(f'Function "{function_call_content.name}" not found.') return FunctionExecutionResult( content=Content.from_function_result( - call_id=function_call_content.call_id, + call_id=function_call_content.call_id, # type: ignore[arg-type] result=f'Error: Requested function "{function_call_content.name}" not found.', - exception=exc, + exception=str(exc), # type: ignore[arg-type] ) ) else: # Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results # and never reach this function, so we only handle approved=True cases here. inner_call = function_call_content.function_call # type: ignore[attr-defined] - if inner_call.type != "function_call": + if inner_call.type != "function_call": # type: ignore[union-attr] return function_call_content - tool = tool_map.get(inner_call.name) # type: ignore[attr-defined] + tool = tool_map.get(inner_call.name) # type: ignore[attr-defined, union-attr, arg-type] if tool is None: # we assume it is a hosted tool return function_call_content - function_call_content = inner_call + function_call_content = inner_call # type: ignore[assignment] parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {}) @@ -1562,7 +1562,11 @@ async def _auto_invoke_function( if config.include_detailed_errors: message = f"{message} Exception: {exc}" return FunctionExecutionResult( - content=Content.from_function_result(call_id=function_call_content.call_id, result=message, exception=exc) + content=Content.from_function_result( + call_id=function_call_content.call_id, # type: ignore[arg-type] + result=message, + exception=str(exc), # type: ignore[arg-type] + ) ) if not middleware_pipeline or ( @@ -1577,7 +1581,7 @@ async def _auto_invoke_function( ) return FunctionExecutionResult( content=Content.from_function_result( - call_id=function_call_content.call_id, + call_id=function_call_content.call_id, # type: ignore[arg-type] result=function_result, ) ) @@ -1587,7 +1591,9 @@ async def _auto_invoke_function( message = f"{message} Exception: {exc}" return FunctionExecutionResult( content=Content.from_function_result( - call_id=function_call_content.call_id, result=message, exception=exc + call_id=function_call_content.call_id, # type: ignore[arg-type] + result=message, + exception=str(exc), ) ) # Execute through middleware pipeline if available @@ -1615,7 +1621,7 @@ async def final_function_handler(context_obj: Any) -> Any: ) return FunctionExecutionResult( content=Content.from_function_result( - call_id=function_call_content.call_id, + call_id=function_call_content.call_id, # type: ignore[arg-type] result=function_result, ), terminate=middleware_context.terminate, @@ -1625,7 +1631,11 @@ async def final_function_handler(context_obj: Any) -> Any: if config.include_detailed_errors: message = f"{message} Exception: {exc}" return FunctionExecutionResult( - content=Content.from_function_result(call_id=function_call_content.call_id, result=message, exception=exc) + content=Content.from_function_result( + call_id=function_call_content.call_id, # type: ignore[arg-type] + result=message, + exception=str(exc), # type: ignore[arg-type] + ) ) @@ -1698,7 +1708,7 @@ async def _try_execute_function_calls( # approval can only be needed for Function Call Content, not Approval Responses. return ( [ - Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined] + Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type] for fcc in function_calls if fcc.type == "function_call" ], @@ -1778,7 +1788,7 @@ def _collect_approval_responses( for content in msg.contents if isinstance(msg, ChatMessage) else []: # Collect BOTH approved and rejected responses if content.type == "function_approval_response": - fcc_todo[content.id] = content # type: ignore[attr-defined] + fcc_todo[content.id] = content # type: ignore[attr-defined, index] return fcc_todo @@ -1797,7 +1807,7 @@ def _replace_approval_contents_with_results( for msg in messages: # First pass - collect existing function call IDs to avoid duplicates existing_call_ids = { - content.call_id + content.call_id # type: ignore[union-attr, operator] for content in msg.contents if content.type == "function_call" and content.call_id # type: ignore[attr-defined] } @@ -1808,12 +1818,12 @@ def _replace_approval_contents_with_results( for content_idx, content in enumerate(msg.contents): if content.type == "function_approval_request": # Don't add the function call if it already exists (would create duplicate) - if content.function_call.call_id in existing_call_ids: # type: ignore[attr-defined] + if content.function_call.call_id in existing_call_ids: # type: ignore[attr-defined, union-attr, operator] # Just mark for removal - the function call already exists contents_to_remove.append(content_idx) else: # Put back the function call content only if it doesn't exist - msg.contents[content_idx] = content.function_call # type: ignore[attr-defined] + msg.contents[content_idx] = content.function_call # type: ignore[attr-defined, assignment] elif content.type == "function_approval_response": if content.approved and content.id in fcc_todo: # type: ignore[attr-defined] # Replace with the corresponding result @@ -1825,7 +1835,7 @@ def _replace_approval_contents_with_results( # Create a "not approved" result for rejected calls # Use function_call.call_id (the function's ID), not content.id (approval's ID) msg.contents[content_idx] = Content.from_function_result( - call_id=content.function_call.call_id, + call_id=content.function_call.call_id, # type: ignore[union-attr, arg-type] result="Error: Tool call invocation was rejected by user.", ) msg.role = Role.TOOL diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 86f069edf9..3d837003ce 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -131,9 +131,8 @@ def detect_media_type_from_base64(data_base64: str) -> str | None: # Returns: "image/png" """ # Remove data URI prefix if present - if data_base64.startswith("data:"): - if ";base64," in data_base64: - data_base64 = data_base64.split(";base64,", 1)[1] + if data_base64.startswith("data:") and ";base64," in data_base64: + data_base64 = data_base64.split(";base64,", 1)[1] try: # Decode just the first few bytes to check magic numbers @@ -236,7 +235,7 @@ def _get_data_bytes_as_str(content: "Content") -> str | None: raise ContentError("Data URI must use base64 encoding") _, data = uri.split(";base64,", 1) - return data + return data # type: ignore[return-value, no-any-return] def _get_data_bytes(content: "Content") -> bytes | None: @@ -462,7 +461,7 @@ def add_usage_details(usage1: UsageDetails | None, usage2: UsageDetails | None) # Sum if both present, otherwise use the non-None value if val1 is not None and val2 is not None: - result[key] = val1 + val2 # type: ignore[literal-required] + result[key] = val1 + val2 # type: ignore[literal-required, operator] elif val1 is not None: result[key] = val1 # type: ignore[literal-required] elif val2 is not None: @@ -532,7 +531,7 @@ def __init__( """ self.type = type self.annotations = annotations - self.additional_properties: dict[str, Any] = additional_properties or {} + self.additional_properties: dict[str, Any] = additional_properties or {} # type: ignore[assignment] self.raw_representation = raw_representation # Set all content-specific attributes @@ -1032,8 +1031,8 @@ def to_function_approval_response( ) return Content.from_function_approval_response( approved=approved, - id=self.id, # type: ignore[attr-defined] - function_call=self.function_call, # type: ignore[attr-defined] + id=self.id, # type: ignore[attr-defined, arg-type] + function_call=self.function_call, # type: ignore[attr-defined, arg-type] annotations=self.annotations, additional_properties=self.additional_properties, raw_representation=self.raw_representation, @@ -1177,11 +1176,11 @@ def _add_text_content(self, other: "Content") -> "Content": elif other.annotations is None: annotations = self.annotations else: - annotations = self.annotations + other.annotations + annotations = self.annotations + other.annotations # type: ignore[operator] return Content( "text", - text=self.text + other.text, # type: ignore[attr-defined] + text=self.text + other.text, # type: ignore[attr-defined, operator] annotations=annotations, additional_properties={ **(other.additional_properties or {}), @@ -1208,7 +1207,7 @@ def _add_text_reasoning_content(self, other: "Content") -> "Content": elif other.annotations is None: annotations = self.annotations else: - annotations = self.annotations + other.annotations + annotations = self.annotations + other.annotations # type: ignore[operator] # Concatenate text, handling None values self_text = self.text or "" # type: ignore[attr-defined] @@ -1382,7 +1381,7 @@ def parse_arguments(self) -> dict[str, Any | None] | None: return {"raw": loaded} except (json.JSONDecodeError, TypeError): return {"raw": self.arguments} - return self.arguments + return self.arguments # type: ignore[return-value] # endregion @@ -1715,7 +1714,7 @@ def text(self) -> str: Remarks: This property concatenates the text of all TextContent objects in Content. """ - return " ".join(content.text for content in self.contents if content.type == "text") + return " ".join(content.text for content in self.contents if content.type == "text") # type: ignore[misc] def prepare_messages( @@ -2378,7 +2377,7 @@ def __init__( @property def text(self) -> str: """Returns the concatenated text of all contents in the update.""" - return "".join(content.text for content in self.contents if content.type == "text") + return "".join(content.text for content in self.contents if content.type == "text") # type: ignore[misc] def __str__(self) -> str: return self.text @@ -2700,7 +2699,7 @@ def __init__( @property def text(self) -> str: """Get the concatenated text of all TextContent objects in contents.""" - return "".join(content.text for content in self.contents if content.type == "text") if self.contents else "" + return "".join(content.text for content in self.contents if content.type == "text") if self.contents else "" # type: ignore[misc] @property def user_input_requests(self) -> list[Content]: diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index 341bc6239b..e2b1a6071c 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -382,7 +382,7 @@ def _extract_function_responses(self, input_messages: list[ChatMessage]) -> dict for content in message.contents: if content.type == "function_approval_response": # Parse the function arguments to recover request payload - arguments_payload = content.function_call.arguments # type: ignore[attr-defined] + arguments_payload = content.function_call.arguments # type: ignore[attr-defined, union-attr] if isinstance(arguments_payload, str): try: parsed_args = self.RequestInfoFunctionArgs.from_json(arguments_payload) @@ -428,7 +428,7 @@ def _extract_contents(self, data: Any) -> list[Content]: if isinstance(data, list): return [c for item in data for c in self._extract_contents(item)] if isinstance(data, Content): - return [cast(Content, data)] + return [data] # type: ignore[redundant-cast] if isinstance(data, str): return [Content.from_text(text=data)] return [Content.from_text(text=str(data))] @@ -522,7 +522,7 @@ def _add_raw(value: object) -> None: messages=(current.messages or []) + (incoming.messages or []), response_id=current.response_id or incoming.response_id, created_at=incoming.created_at or current.created_at, - usage_details=add_usage_details(current.usage_details, incoming.usage_details), + usage_details=add_usage_details(current.usage_details, incoming.usage_details), # type: ignore[arg-type] raw_representation=raw_list if raw_list else None, additional_properties=incoming.additional_properties or current.additional_properties, ) @@ -557,7 +557,7 @@ def _add_raw(value: object) -> None: if aggregated: final_messages.extend(aggregated.messages) if aggregated.usage_details: - merged_usage = add_usage_details(merged_usage, aggregated.usage_details) + merged_usage = add_usage_details(merged_usage, aggregated.usage_details) # type: ignore[arg-type] if aggregated.created_at and ( not latest_created_at or _parse_dt(aggregated.created_at) > _parse_dt(latest_created_at) ): @@ -581,7 +581,7 @@ def _add_raw(value: object) -> None: flattened = AgentResponse.from_agent_run_response_updates(global_dangling) final_messages.extend(flattened.messages) if flattened.usage_details: - merged_usage = add_usage_details(merged_usage, flattened.usage_details) + merged_usage = add_usage_details(merged_usage, flattened.usage_details) # type: ignore[arg-type] if flattened.created_at and ( not latest_created_at or _parse_dt(flattened.created_at) > _parse_dt(latest_created_at) ): diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 3d665bd56c..9beaf06a65 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -193,7 +193,7 @@ async def handle_user_input_response( ctx: The workflow context for emitting events and outputs. """ self._pending_responses_to_agent.append(response) - self._pending_agent_requests.pop(original_request.id, None) + self._pending_agent_requests.pop(original_request.id, None) # type: ignore[arg-type] if not self._pending_agent_requests: # All pending requests have been resolved; resume agent execution @@ -344,7 +344,7 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentResponse | None: # Handle any user input requests if response.user_input_requests: for user_input_request in response.user_input_requests: - self._pending_agent_requests[user_input_request.id] = user_input_request + self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index] await ctx.request_info(user_input_request, Content) return None @@ -387,7 +387,7 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentResponse | No # Handle any user input requests after the streaming completes if user_input_requests: for user_input_request in user_input_requests: - self._pending_agent_requests[user_input_request.id] = user_input_request + self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index] await ctx.request_info(user_input_request, Content) return None diff --git a/python/packages/core/agent_framework/openai/_assistants_client.py b/python/packages/core/agent_framework/openai/_assistants_client.py index 640d0b49d7..12ad1b5797 100644 --- a/python/packages/core/agent_framework/openai/_assistants_client.py +++ b/python/packages/core/agent_framework/openai/_assistants_client.py @@ -687,10 +687,10 @@ def _prepare_options( for content in chat_message.contents: if content.type == "text": - message_contents.append(TextContentBlockParam(type="text", text=content.text)) # type: ignore[attr-defined] + message_contents.append(TextContentBlockParam(type="text", text=content.text)) # type: ignore[attr-defined, typeddict-item] elif content.type == "uri" and content.has_top_level_media_type("image"): message_contents.append( - ImageURLContentBlockParam(type="image_url", image_url=ImageURLParam(url=content.uri)) # type: ignore[attr-defined] + ImageURLContentBlockParam(type="image_url", image_url=ImageURLParam(url=content.uri)) # type: ignore[attr-defined, typeddict-item] ) elif content.type == "function_result": if tool_results is None: @@ -728,7 +728,7 @@ def _prepare_tool_outputs_for_assistants( # When creating the FunctionCallContent, we created it with a CallId == [runId, callId]. # We need to extract the run ID and ensure that the ToolOutput we send back to Azure # is only the call ID. - run_and_call_ids: list[str] = json.loads(function_result_content.call_id) + run_and_call_ids: list[str] = json.loads(function_result_content.call_id) # type: ignore[arg-type] if ( not run_and_call_ids diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 7a9d9f7fac..2b4023e85a 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -291,7 +291,7 @@ def _parse_response_from_openai(self, response: ChatCompletion, options: dict[st if parsed_tool_calls := [tool for tool in self._parse_tool_calls_from_openai(choice)]: contents.extend(parsed_tool_calls) if reasoning_details := getattr(choice.message, "reasoning_details", None): - contents.append(Content.from_text_reasoning(None, protected_data=json.dumps(reasoning_details))) + contents.append(Content.from_text_reasoning(protected_data=json.dumps(reasoning_details))) messages.append(ChatMessage(role="assistant", contents=contents)) return ChatResponse( response_id=response.id, @@ -314,7 +314,9 @@ def _parse_response_update_from_openai( return ChatResponseUpdate( role=Role.ASSISTANT, contents=[ - Content.from_usage(details=self._parse_usage_from_openai(chunk.usage), raw_representation=chunk) + Content.from_usage( + usage_details=self._parse_usage_from_openai(chunk.usage), raw_representation=chunk + ) ], model_id=chunk.model, additional_properties=chunk_metadata, @@ -332,7 +334,7 @@ def _parse_response_update_from_openai( if text_content := self._parse_text_from_openai(choice): contents.append(text_content) if reasoning_details := getattr(choice.delta, "reasoning_details", None): - contents.append(Content.from_text_reasoning(None, protected_data=json.dumps(reasoning_details))) + contents.append(Content.from_text_reasoning(protected_data=json.dumps(reasoning_details))) return ChatResponseUpdate( created_at=datetime.fromtimestamp(chunk.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ"), contents=contents, @@ -353,18 +355,18 @@ def _parse_usage_from_openai(self, usage: CompletionUsage) -> UsageDetails: ) if usage.completion_tokens_details: if tokens := usage.completion_tokens_details.accepted_prediction_tokens: - details["completion/accepted_prediction_tokens"] = tokens + details["completion/accepted_prediction_tokens"] = tokens # type: ignore[typeddict-unknown-key] if tokens := usage.completion_tokens_details.audio_tokens: - details["completion/audio_tokens"] = tokens + details["completion/audio_tokens"] = tokens # type: ignore[typeddict-unknown-key] if tokens := usage.completion_tokens_details.reasoning_tokens: - details["completion/reasoning_tokens"] = tokens + details["completion/reasoning_tokens"] = tokens # type: ignore[typeddict-unknown-key] if tokens := usage.completion_tokens_details.rejected_prediction_tokens: - details["completion/rejected_prediction_tokens"] = tokens + details["completion/rejected_prediction_tokens"] = tokens # type: ignore[typeddict-unknown-key] if usage.prompt_tokens_details: if tokens := usage.prompt_tokens_details.audio_tokens: - details["prompt/audio_tokens"] = tokens + details["prompt/audio_tokens"] = tokens # type: ignore[typeddict-unknown-key] if tokens := usage.prompt_tokens_details.cached_tokens: - details["prompt/cached_tokens"] = tokens + details["prompt/cached_tokens"] = tokens # type: ignore[typeddict-unknown-key] return details def _parse_text_from_openai(self, choice: Choice | ChunkChoice) -> Content | None: @@ -516,9 +518,9 @@ def _prepare_content_for_openai(self, content: Content) -> dict[str, Any]: # Extract base64 data from data URI audio_data = content.uri - if audio_data.startswith("data:"): + if audio_data.startswith("data:"): # type: ignore[union-attr] # Extract just the base64 part after "data:audio/format;base64," - audio_data = audio_data.split(",", 1)[-1] + audio_data = audio_data.split(",", 1)[-1] # type: ignore[union-attr] return { "type": "input_audio", @@ -527,7 +529,7 @@ def _prepare_content_for_openai(self, content: Content) -> dict[str, Any]: "format": audio_format, }, } - case "data" | "uri" if content.has_top_level_media_type("application") and content.uri.startswith("data:"): + case "data" | "uri" if content.has_top_level_media_type("application") and content.uri.startswith("data:"): # type: ignore[union-attr] # All application/* media types should be treated as files for OpenAI filename = getattr(content, "filename", None) or ( content.additional_properties.get("filename") diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index e5858499c4..9262bb3e8e 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -400,7 +400,7 @@ def _prepare_tools_for_openai( if not tool.inputs: raise ValueError("HostedFileSearchTool requires inputs to be specified.") inputs: list[str] = [ - inp.vector_store_id + inp.vector_store_id # type: ignore[misc] for inp in tool.inputs if inp.type == "hosted_vector_store" # type: ignore[attr-defined] ] @@ -618,7 +618,7 @@ def _prepare_messages_for_openai(self, chat_messages: Sequence[ChatMessage]) -> and content.additional_properties and "fc_id" in content.additional_properties ): - call_id_to_id[content.call_id] = content.additional_properties["fc_id"] # type: ignore[attr-defined] + call_id_to_id[content.call_id] = content.additional_properties["fc_id"] # type: ignore[attr-defined, index] list_of_list = [self._prepare_message_for_openai(message, call_id_to_id) for message in chat_messages] # Flatten the list of lists into a single list return list(chain.from_iterable(list_of_list)) @@ -757,11 +757,11 @@ def _prepare_content_for_openai( case "function_approval_request": return { "type": "mcp_approval_request", - "id": content.id, - "arguments": content.function_call.arguments, - "name": content.function_call.name, - "server_label": content.function_call.additional_properties.get("server_label") - if content.function_call.additional_properties + "id": content.id, # type: ignore[union-attr] + "arguments": content.function_call.arguments, # type: ignore[union-attr] + "name": content.function_call.name, # type: ignore[union-attr] + "server_label": content.function_call.additional_properties.get("server_label") # type: ignore[union-attr] + if content.function_call.additional_properties # type: ignore[union-attr] else None, } case "function_approval_response": @@ -1151,7 +1151,7 @@ def _parse_chunk_from_openai( if event.response.usage: usage = self._parse_usage_from_openai(event.response.usage) if usage: - contents.append(Content.from_usage(details=usage, raw_representation=event)) + contents.append(Content.from_usage(usage_details=usage, raw_representation=event)) case "response.output_item.added": event_item = event.item match event_item.type: @@ -1405,9 +1405,9 @@ def _parse_usage_from_openai(self, usage: ResponseUsage) -> UsageDetails | None: total_token_count=usage.total_tokens, ) if usage.input_tokens_details and usage.input_tokens_details.cached_tokens: - details["openai.cached_input_tokens"] = usage.input_tokens_details.cached_tokens + details["openai.cached_input_tokens"] = usage.input_tokens_details.cached_tokens # type: ignore[typeddict-unknown-key] if usage.output_tokens_details and usage.output_tokens_details.reasoning_tokens: - details["openai.reasoning_tokens"] = usage.output_tokens_details.reasoning_tokens + details["openai.reasoning_tokens"] = usage.output_tokens_details.reasoning_tokens # type: ignore[typeddict-unknown-key] return details def _get_metadata_from_response(self, output: Any) -> dict[str, Any]: diff --git a/python/packages/declarative/agent_framework_declarative/_loader.py b/python/packages/declarative/agent_framework_declarative/_loader.py index 70cc623dc4..6181be5f62 100644 --- a/python/packages/declarative/agent_framework_declarative/_loader.py +++ b/python/packages/declarative/agent_framework_declarative/_loader.py @@ -9,6 +9,7 @@ AIFunction, ChatAgent, ChatClientProtocol, + Content, HostedCodeInterpreterTool, HostedFileSearchTool, HostedMCPSpecificApproval, diff --git a/python/packages/devui/agent_framework_devui/_executor.py b/python/packages/devui/agent_framework_devui/_executor.py index 5234e37c10..cf4fa0066f 100644 --- a/python/packages/devui/agent_framework_devui/_executor.py +++ b/python/packages/devui/agent_framework_devui/_executor.py @@ -602,7 +602,7 @@ def _convert_input_to_chat_message(self, input_data: Any) -> Any: """ # Import Agent Framework types try: - from agent_framework import ChatMessage, Content, Role + from agent_framework import ChatMessage, Role except ImportError: # Fallback to string extraction if Agent Framework not available return self._extract_user_message_fallback(input_data) diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index acdb939134..182cf3526e 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -12,7 +12,7 @@ from typing import Any, Union from uuid import uuid4 -from agent_framework import ChatMessage +from agent_framework import ChatMessage, Content from openai.types.responses import ( Response, ResponseContentPartAddedEvent, @@ -722,7 +722,7 @@ async def _convert_agent_response(self, response: Any, context: dict[str, Any]) # Add usage information if present usage_details = getattr(response, "usage_details", None) if usage_details: - usage_content = Content.from_usage(details=usage_details) + usage_content = Content.from_usage(usage_details=usage_details) await self._map_usage_content(usage_content, context) # Note: _map_usage_content returns None - it accumulates usage for final Response.usage diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py index 770c902244..2da621a21a 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_message_utils.py @@ -58,7 +58,7 @@ def log_messages(messages: list[ChatMessage]) -> None: for content in msg.contents: if hasattr(content, "type"): if content.type == "text": - escape_text = content.text.replace("<", r"\<") + escape_text = content.text.replace("<", r"\<") # type: ignore[union-attr] if msg.role == Role.SYSTEM: logger_.info(f"[SYSTEM] {escape_text}") elif msg.role == Role.USER: diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py index 3734e2e55b..e34d9f48a4 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_sliding_window.py @@ -77,7 +77,7 @@ def get_token_count(self) -> int: for content in msg.contents: if hasattr(content, "type"): if content.type == "text": - total_tokens += len(self.encoding.encode(content.text)) + total_tokens += len(self.encoding.encode(content.text)) # type: ignore[arg-type] elif content.type == "function_call": total_tokens += 4 # Serialize function call and count tokens diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py index 228072aae1..6c64bb44be 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py @@ -60,7 +60,7 @@ def convert_agent_framework_messages_to_tau2_messages(messages: list[ChatMessage text_content = None text_contents = [c for c in msg.contents if hasattr(c, "text") and hasattr(c, "type") and c.type == "text"] if text_contents: - text_content = " ".join(c.text for c in text_contents) + text_content = " ".join(c.text for c in text_contents) # type: ignore[misc] # Extract function calls and convert to ToolCall objects function_calls = [c for c in msg.contents if hasattr(c, "type") and c.type == "function_call"] diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index 3f1b325bb0..ff2f5b6a37 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -452,16 +452,16 @@ def _format_user_message(self, message: ChatMessage) -> list[OllamaMessage]: "Ollama connector currently only supports user messages with TextContent or DataContent." ) - if not any(isinstance(c, Content) and c.type == "data" for c in message.contents): + if not any(c.type == "data" for c in message.contents): return [OllamaMessage(role="user", content=message.text)] user_message = OllamaMessage(role="user", content=message.text) - data_contents = [c for c in message.contents if isinstance(c, Content) and c.type == "data"] + data_contents = [c for c in message.contents if c.type == "data"] if data_contents: if not any(c.has_top_level_media_type("image") for c in data_contents): raise ServiceInvalidRequestError("Only image data content is supported for user messages in Ollama.") # Ollama expects base64 strings without prefix - user_message["images"] = [c.uri.split(",")[1] for c in data_contents] + user_message["images"] = [c.uri.split(",")[1] for c in data_contents if c.uri] return [user_message] def _format_assistant_message(self, message: ChatMessage) -> list[OllamaMessage]: diff --git a/python/pyproject.toml b/python/pyproject.toml index 45cfdc4b55..97e90fad5e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -146,6 +146,7 @@ ignore = [ "TD003", # allow missing link to todo issue "FIX002", # allow todo "B027", # allow empty non-abstract method in ABC + "B905", # `zip()` without an explicit `strict=` parameter "RUF067", # allow version detection in __init__.py ] From aa845ab6fe875f6b7c0d1659380af3bbc19c1fd0 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Sun, 18 Jan 2026 16:05:02 +0100 Subject: [PATCH 3/7] fixes --- .../tests/test_agent_wrapper_comprehensive.py | 4 +- .../ag-ui/tests/test_service_thread_id.py | 6 +-- .../agent_framework_azure_ai/_chat_client.py | 8 +++ .../agent_framework_azure_ai/_shared.py | 31 +++++------ .../azure-ai/tests/test_agent_provider.py | 4 +- .../azure-ai/tests/test_azure_ai_client.py | 6 +-- .../core/agent_framework/_workflows/_agent.py | 4 +- .../tests/workflow/test_workflow_agent.py | 52 +++++++++++-------- 8 files changed, 64 insertions(+), 51 deletions(-) diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 8b708b3ac7..f8f5c1db8a 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -650,7 +650,7 @@ async def stream_fn( thread = kwargs.get("thread") request_service_thread_id = thread.service_thread_id if thread else None yield ChatResponseUpdate( - contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345" ) agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) @@ -677,7 +677,7 @@ async def stream_fn( thread = kwargs.get("thread") request_service_thread_id = thread.service_thread_id if thread else None yield ChatResponseUpdate( - contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345" ) agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py index 8c00f7b67c..eab60abf7a 100644 --- a/python/packages/ag-ui/tests/test_service_thread_id.py +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -7,7 +7,7 @@ from typing import Any from ag_ui.core import RunFinishedEvent, RunStartedEvent -from agent_framework import TextContent +from agent_framework import Content from agent_framework._types import AgentResponseUpdate, ChatResponseUpdate sys.path.insert(0, str(Path(__file__).parent)) @@ -20,10 +20,10 @@ async def test_service_thread_id_when_there_are_updates(): updates: list[AgentResponseUpdate] = [ AgentResponseUpdate( - contents=[TextContent(text="Hello, user!")], + contents=[Content.from_text(text="Hello, user!")], response_id="resp_67890", raw_representation=ChatResponseUpdate( - contents=[TextContent(text="Hello, user!")], + contents=[Content.from_text(text="Hello, user!")], conversation_id="conv_12345", response_id="resp_67890", ), diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index f7571e1e71..4bb646da19 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -2,6 +2,7 @@ import ast import json +import os import re import sys from collections.abc import AsyncIterable, Callable, Mapping, MutableMapping, MutableSequence, Sequence @@ -19,10 +20,12 @@ ChatResponse, ChatResponseUpdate, Content, + ContextProvider, HostedCodeInterpreterTool, HostedFileSearchTool, HostedMCPTool, HostedWebSearchTool, + Middleware, Role, TextSpanRegion, ToolProtocol, @@ -43,9 +46,14 @@ AgentStreamEvent, AsyncAgentEventHandler, AsyncAgentRunStream, + BingCustomSearchTool, + BingGroundingTool, + CodeInterpreterToolDefinition, + FileSearchTool, FunctionName, FunctionToolDefinition, ListSortOrder, + McpTool, MessageDeltaChunk, MessageDeltaTextContent, MessageDeltaTextFileCitationAnnotation, diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py index 61340abb72..020969cd12 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py @@ -6,12 +6,10 @@ from agent_framework import ( AIFunction, - Contents, + Content, HostedCodeInterpreterTool, - HostedFileContent, HostedFileSearchTool, HostedMCPTool, - HostedVectorStoreContent, HostedWebSearchTool, ToolProtocol, get_logger, @@ -189,9 +187,9 @@ def to_azure_ai_agent_tools( ) tool_definitions.extend(mcp_tool.definitions) case HostedFileSearchTool(): - vector_stores = [inp for inp in tool.inputs or [] if isinstance(inp, HostedVectorStoreContent)] + vector_stores = [inp for inp in tool.inputs or [] if inp.type == "hosted_vector_store"] if vector_stores: - file_search = AgentsFileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) + file_search = AgentsFileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) # type: ignore[misc] tool_definitions.extend(file_search.definitions) # Set tool_resources for file search to work properly with Azure AI if run_options is not None and "tool_resources" not in run_options: @@ -247,7 +245,7 @@ def _convert_dict_tool(tool: dict[str, Any]) -> ToolProtocol | dict[str, Any] | if tool_type == "file_search": file_search_config = tool.get("file_search", {}) vector_store_ids = file_search_config.get("vector_store_ids", []) - inputs = [HostedVectorStoreContent(vector_store_id=vs_id) for vs_id in vector_store_ids] + inputs = [Content.from_hosted_vector_store(vector_store_id=vs_id) for vs_id in vector_store_ids] return HostedFileSearchTool(inputs=inputs if inputs else None) # type: ignore if tool_type == "bing_grounding": @@ -287,7 +285,7 @@ def _convert_sdk_tool(tool: ToolDefinition) -> ToolProtocol | dict[str, Any] | N if tool_type == "file_search": file_search_config = getattr(tool, "file_search", None) vector_store_ids = getattr(file_search_config, "vector_store_ids", []) if file_search_config else [] - inputs = [HostedVectorStoreContent(vector_store_id=vs_id) for vs_id in vector_store_ids] + inputs = [Content.from_hosted_vector_store(vector_store_id=vs_id) for vs_id in vector_store_ids] return HostedFileSearchTool(inputs=inputs if inputs else None) # type: ignore if tool_type == "bing_grounding": @@ -372,18 +370,18 @@ def from_azure_ai_tools(tools: Sequence[Tool | dict[str, Any]] | None) -> list[T elif tool_type == "code_interpreter": ci_tool = cast(CodeInterpreterTool, tool_dict) container = ci_tool.get("container", {}) - ci_inputs: list[Contents] = [] + ci_inputs: list[Content] = [] if "file_ids" in container: for file_id in container["file_ids"]: - ci_inputs.append(HostedFileContent(file_id=file_id)) + ci_inputs.append(Content.from_hosted_file(file_id=file_id)) agent_tools.append(HostedCodeInterpreterTool(inputs=ci_inputs if ci_inputs else None)) # type: ignore elif tool_type == "file_search": fs_tool = cast(ProjectsFileSearchTool, tool_dict) - fs_inputs: list[Contents] = [] + fs_inputs: list[Content] = [] if "vector_store_ids" in fs_tool: for vs_id in fs_tool["vector_store_ids"]: - fs_inputs.append(HostedVectorStoreContent(vector_store_id=vs_id)) + fs_inputs.append(Content.from_hosted_vector_store(vector_store_id=vs_id)) agent_tools.append( HostedFileSearchTool( @@ -433,8 +431,8 @@ def to_azure_ai_tools( file_ids: list[str] = [] if tool.inputs: for tool_input in tool.inputs: - if isinstance(tool_input, HostedFileContent): - file_ids.append(tool_input.file_id) + if tool_input.type == "hosted_file": + file_ids.append(tool_input.file_id) # type: ignore[misc, arg-type] container = CodeInterpreterToolAuto(file_ids=file_ids if file_ids else None) ci_tool: CodeInterpreterTool = CodeInterpreterTool(container=container) azure_tools.append(ci_tool) @@ -453,11 +451,14 @@ def to_azure_ai_tools( if not tool.inputs: raise ValueError("HostedFileSearchTool requires inputs to be specified.") vector_store_ids: list[str] = [ - inp.vector_store_id for inp in tool.inputs if isinstance(inp, HostedVectorStoreContent) + inp.vector_store_id # type: ignore[misc] + for inp in tool.inputs + if inp.type == "hosted_vector_store" ] if not vector_store_ids: raise ValueError( - "HostedFileSearchTool requires inputs to be of type `HostedVectorStoreContent`." + "HostedFileSearchTool requires inputs to be of type `Content` with " + "type 'hosted_vector_store'." ) fs_tool: ProjectsFileSearchTool = ProjectsFileSearchTool(vector_store_ids=vector_store_ids) if tool.max_results: diff --git a/python/packages/azure-ai/tests/test_agent_provider.py b/python/packages/azure-ai/tests/test_agent_provider.py index 3df8d318ec..edfd749f4c 100644 --- a/python/packages/azure-ai/tests/test_agent_provider.py +++ b/python/packages/azure-ai/tests/test_agent_provider.py @@ -7,10 +7,10 @@ import pytest from agent_framework import ( ChatAgent, + Content, HostedCodeInterpreterTool, HostedFileSearchTool, HostedMCPTool, - HostedVectorStoreContent, HostedWebSearchTool, ai_function, ) @@ -509,7 +509,7 @@ def test_to_azure_ai_agent_tools_code_interpreter() -> None: def test_to_azure_ai_agent_tools_file_search() -> None: """Test converting HostedFileSearchTool with vector stores.""" - tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id="vs-123")]) + tool = HostedFileSearchTool(inputs=[Content.from_hosted_vector_store(vector_store_id="vs-123")]) run_options: dict[str, Any] = {} result = to_azure_ai_agent_tools([tool], run_options) diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py index 7e3f7e2db4..aba45b3f1b 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_client.py @@ -18,10 +18,8 @@ ChatResponse, Content, HostedCodeInterpreterTool, - HostedFileContent, HostedFileSearchTool, HostedMCPTool, - HostedVectorStoreContent, HostedWebSearchTool, Role, ) @@ -992,7 +990,7 @@ def test_from_azure_ai_tools() -> None: tool_input = parsed_tools[0].inputs[0] - assert tool_input and isinstance(tool_input, HostedFileContent) and tool_input.file_id == "file-1" + assert tool_input and tool_input.type == "hosted_file" and tool_input.file_id == "file-1" # Test File Search tool fs_tool = FileSearchTool(vector_store_ids=["vs-1"], max_num_results=5) @@ -1004,7 +1002,7 @@ def test_from_azure_ai_tools() -> None: tool_input = parsed_tools[0].inputs[0] - assert tool_input and isinstance(tool_input, HostedVectorStoreContent) and tool_input.vector_store_id == "vs-1" + assert tool_input and tool_input.type == "hosted_vector_store" and tool_input.vector_store_id == "vs-1" assert parsed_tools[0].max_results == 5 # Test Web Search tool diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index e2b1a6071c..345e120c1f 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -463,7 +463,7 @@ def merge_updates(updates: list[AgentResponseUpdate], response_id: str) -> Agent for u in updates: if u.response_id: for content in u.contents: - if isinstance(content, FunctionCallContent) and content.call_id: + if content.type == "function_call" and content.call_id: call_id_to_response_id[content.call_id] = u.response_id # Second pass: group updates, associating FunctionResultContent with their calls @@ -475,7 +475,7 @@ def merge_updates(updates: list[AgentResponseUpdate], response_id: str) -> Agent # If no response_id, check if this is a FunctionResultContent that matches a call if not effective_response_id: for content in u.contents: - if isinstance(content, FunctionResultContent) and content.call_id: + if content.type == "function_result" and content.call_id: effective_response_id = call_id_to_response_id.get(content.call_id) if effective_response_id: break diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 58e5c7e20f..9514efdf74 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -974,7 +974,7 @@ def test_merge_updates_function_result_ordering_github_2977(self): updates = [ # User question AgentResponseUpdate( - contents=[TextContent(text="What is the weather?")], + contents=[Content.from_text(text="What is the weather?")], role=Role.USER, response_id="resp-1", message_id="msg-1", @@ -982,7 +982,9 @@ def test_merge_updates_function_result_ordering_github_2977(self): ), # Assistant with function call AgentResponseUpdate( - contents=[FunctionCallContent(call_id=call_id, name="get_weather", arguments='{"location": "NYC"}')], + contents=[ + Content.from_function_call(call_id=call_id, name="get_weather", arguments='{"location": "NYC"}') + ], role=Role.ASSISTANT, response_id="resp-1", message_id="msg-2", @@ -991,7 +993,7 @@ def test_merge_updates_function_result_ordering_github_2977(self): # Function result: no response_id previously caused this to go to global_dangling # and be placed at the end (the bug); fix now correctly associates via call_id AgentResponseUpdate( - contents=[FunctionResultContent(call_id=call_id, result="Sunny, 72F")], + contents=[Content.from_function_result(call_id=call_id, result="Sunny, 72F")], role=Role.TOOL, response_id=None, message_id="msg-3", @@ -999,7 +1001,7 @@ def test_merge_updates_function_result_ordering_github_2977(self): ), # Final assistant answer AgentResponseUpdate( - contents=[TextContent(text="The weather in NYC is sunny and 72F.")], + contents=[Content.from_text(text="The weather in NYC is sunny and 72F.")], role=Role.ASSISTANT, response_id="resp-1", message_id="msg-4", @@ -1015,11 +1017,11 @@ def test_merge_updates_function_result_ordering_github_2977(self): content_sequence = [] for msg in result.messages: for content in msg.contents: - if isinstance(content, TextContent): + if content.type == "text": content_sequence.append(("text", msg.role)) - elif isinstance(content, FunctionCallContent): + elif content.type == "function_call": content_sequence.append(("function_call", msg.role)) - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": content_sequence.append(("function_result", msg.role)) # Verify correct ordering: user -> function_call -> function_result -> assistant_answer @@ -1040,10 +1042,10 @@ def test_merge_updates_function_result_ordering_github_2977(self): function_result_idx = None for i, msg in enumerate(result.messages): for content in msg.contents: - if isinstance(content, FunctionCallContent): + if content.type == "function_call": function_call_idx = i assert content.call_id == call_id - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": function_result_idx = i assert content.call_id == call_id @@ -1070,7 +1072,7 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): updates = [ # User question AgentResponseUpdate( - contents=[TextContent(text="What's the weather and time?")], + contents=[Content.from_text(text="What's the weather and time?")], role=Role.USER, response_id="resp-1", message_id="msg-1", @@ -1078,7 +1080,9 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): ), # Assistant with first function call AgentResponseUpdate( - contents=[FunctionCallContent(call_id=call_id_1, name="get_weather", arguments='{"location": "NYC"}')], + contents=[ + Content.from_function_call(call_id=call_id_1, name="get_weather", arguments='{"location": "NYC"}') + ], role=Role.ASSISTANT, response_id="resp-1", message_id="msg-2", @@ -1086,7 +1090,9 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): ), # Assistant with second function call AgentResponseUpdate( - contents=[FunctionCallContent(call_id=call_id_2, name="get_time", arguments='{"timezone": "EST"}')], + contents=[ + Content.from_function_call(call_id=call_id_2, name="get_time", arguments='{"timezone": "EST"}') + ], role=Role.ASSISTANT, response_id="resp-1", message_id="msg-3", @@ -1094,7 +1100,7 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): ), # Second function result arrives first (no response_id) AgentResponseUpdate( - contents=[FunctionResultContent(call_id=call_id_2, result="3:00 PM EST")], + contents=[Content.from_function_result(call_id=call_id_2, result="3:00 PM EST")], role=Role.TOOL, response_id=None, message_id="msg-4", @@ -1102,7 +1108,7 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): ), # First function result arrives second (no response_id) AgentResponseUpdate( - contents=[FunctionResultContent(call_id=call_id_1, result="Sunny, 72F")], + contents=[Content.from_function_result(call_id=call_id_1, result="Sunny, 72F")], role=Role.TOOL, response_id=None, message_id="msg-5", @@ -1110,7 +1116,7 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): ), # Final assistant answer AgentResponseUpdate( - contents=[TextContent(text="It's sunny (72F) and 3 PM in NYC.")], + contents=[Content.from_text(text="It's sunny (72F) and 3 PM in NYC.")], role=Role.ASSISTANT, response_id="resp-1", message_id="msg-6", @@ -1126,11 +1132,11 @@ def test_merge_updates_multiple_function_results_ordering_github_2977(self): content_sequence = [] for msg in result.messages: for content in msg.contents: - if isinstance(content, TextContent): + if content.type == "text": content_sequence.append(("text", None)) - elif isinstance(content, FunctionCallContent): + elif content.type == "function_call": content_sequence.append(("function_call", content.call_id)) - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": content_sequence.append(("function_result", content.call_id)) # Verify all function results appear before the final assistant text @@ -1161,7 +1167,7 @@ def test_merge_updates_function_result_no_matching_call(self): """ updates = [ AgentResponseUpdate( - contents=[TextContent(text="Hello")], + contents=[Content.from_text(text="Hello")], role=Role.USER, response_id="resp-1", message_id="msg-1", @@ -1169,14 +1175,14 @@ def test_merge_updates_function_result_no_matching_call(self): ), # Function result with no matching call AgentResponseUpdate( - contents=[FunctionResultContent(call_id="orphan_call_id", result="orphan result")], + contents=[Content.from_function_result(call_id="orphan_call_id", result="orphan result")], role=Role.TOOL, response_id=None, message_id="msg-2", created_at="2024-01-01T12:00:01Z", ), AgentResponseUpdate( - contents=[TextContent(text="Goodbye")], + contents=[Content.from_text(text="Goodbye")], role=Role.ASSISTANT, response_id="resp-1", message_id="msg-3", @@ -1192,9 +1198,9 @@ def test_merge_updates_function_result_no_matching_call(self): content_types = [] for msg in result.messages: for content in msg.contents: - if isinstance(content, TextContent): + if content.type == "text": content_types.append("text") - elif isinstance(content, FunctionResultContent): + elif content.type == "function_result": content_types.append("function_result") # Order: text (user), text (assistant), function_result (orphan at end) From 9d35bce88fda9a80eb18c8ba481bcedd0f5eed93 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 19 Jan 2026 11:05:30 +0100 Subject: [PATCH 4/7] fixed data format handling --- .../packages/core/agent_framework/_types.py | 112 ++++++++---------- .../openai/_responses_client.py | 39 ++---- python/packages/core/tests/core/test_types.py | 104 ++++++++-------- .../devui/agent_framework_devui/_mapper.py | 8 +- 4 files changed, 125 insertions(+), 138 deletions(-) diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 3d837003ce..9fdf9bcbf6 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -101,7 +101,12 @@ def _parse_content_list(contents_data: Sequence[Any]) -> list["Content"]: # region Internal Helper functions for unified Content -def detect_media_type_from_base64(data_base64: str) -> str | None: +def detect_media_type_from_base64( + *, + data_bytes: bytes | None = None, + data_str: str | None = None, + data_uri: str | None = None, +) -> str | None: """Detect media type from base64-encoded data by examining magic bytes. This function examines the binary signature (magic bytes) at the start of the data @@ -109,7 +114,14 @@ def detect_media_type_from_base64(data_base64: str) -> str | None: video, and documents, but cannot detect text-based formats like JSON or plain text. Args: - data_base64: Base64-encoded data (with or without data URI prefix). + data_bytes: Raw binary data. + data_str: Base64-encoded data (without data URI prefix). + data_uri: Full data URI string (e.g., "data:image/png;base64,iVBORw0KGgo..."). + This will look at the actual data to determine the media_type and not at the URI prefix. + Will also not compare those two values. + + Raises: + ValueError: If not exactly 1 of data_bytes, data_str, or data_uri is provided, or if base64 decoding fails. Returns: The detected media type (e.g., 'image/png', 'audio/wav', 'application/pdf') @@ -130,85 +142,57 @@ def detect_media_type_from_base64(data_base64: str) -> str | None: media_type = detect_media_type_from_base64(data_uri) # Returns: "image/png" """ - # Remove data URI prefix if present - if data_base64.startswith("data:") and ";base64," in data_base64: - data_base64 = data_base64.split(";base64,", 1)[1] - - try: - # Decode just the first few bytes to check magic numbers - decoded = base64.b64decode(data_base64[:50]) - except Exception: - return None + data: bytes | None = None + if data_bytes is not None: + data = data_bytes + if data_uri is not None: + if data is not None: + raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.") + # Remove data URI prefix if present + data_str = data_uri.split(";base64,", 1)[1] + if data_str is not None: + if data is not None: + raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.") + try: + # Decode just the first few bytes to check magic numbers + data = base64.b64decode(data_str[:50]) + except Exception as exc: + raise ValueError("Invalid base64 data provided.") from exc + if data is None: + raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.") # Check magic bytes for common formats # Images - if decoded.startswith(b"\x89PNG\r\n\x1a\n"): + if data.startswith(b"\x89PNG\r\n\x1a\n"): return "image/png" - if decoded.startswith(b"\xff\xd8\xff"): + if data.startswith(b"\xff\xd8\xff"): return "image/jpeg" - if decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"): + if data.startswith(b"GIF87a") or data.startswith(b"GIF89a"): return "image/gif" - if decoded.startswith(b"RIFF") and len(decoded) > 11 and decoded[8:12] == b"WEBP": + if data.startswith(b"RIFF") and len(data) > 11 and data[8:12] == b"WEBP": return "image/webp" - if decoded.startswith(b"BM"): + if data.startswith(b"BM"): return "image/bmp" - if decoded.startswith(b" 11 and decoded[8:12] == b"WAVE": + if data.startswith(b"RIFF") and len(data) > 11 and data[8:12] == b"WAVE": return "audio/wav" - if decoded.startswith(b"ID3") or decoded.startswith(b"\xff\xfb") or decoded.startswith(b"\xff\xf3"): + if data.startswith(b"ID3") or data.startswith(b"\xff\xfb") or data.startswith(b"\xff\xf3"): return "audio/mpeg" - if decoded.startswith(b"OggS"): + if data.startswith(b"OggS"): return "audio/ogg" - if decoded.startswith(b"fLaC"): + if data.startswith(b"fLaC"): return "audio/flac" return None -def _create_data_uri_from_base64(image_base64: str, media_type: str | None = None) -> tuple[str, str]: - """Create a data URI and media type from base64 data. - - Args: - image_base64: Base64-encoded image data (with or without data URI prefix). - media_type: Optional explicit media type. If not provided, will attempt to detect. - - Returns: - A tuple of (data_uri, media_type). - - Raises: - ContentError: If media type cannot be determined. - """ - # If it's already a data URI, extract the parts - if image_base64.startswith("data:"): - if ";base64," in image_base64: - prefix, data = image_base64.split(";base64,", 1) - existing_media_type = prefix.split(":", 1)[1] if ":" in prefix else None - if media_type is None: - media_type = existing_media_type - image_base64 = data - else: - raise ContentError("Data URI must use base64 encoding") - - # Detect format if media type not provided - if media_type is None: - detected_media_type = detect_media_type_from_base64(image_base64) - if detected_media_type: - media_type = detected_media_type - else: - raise ContentError("Could not detect media type from base64 data") - - # Construct data URI - data_uri = f"data:{media_type};base64,{image_base64}" - return data_uri, media_type - - def _get_data_bytes_as_str(content: "Content") -> str | None: """Extract base64 data string from data URI. @@ -623,7 +607,7 @@ def from_data( .. code-block:: python - from agent_framework import detect_media_type_from_base64 + from agent_framework import detect_media_type_from_base64, Content media_type = detect_media_type_from_base64(base64_string) if media_type is None: @@ -719,6 +703,14 @@ def from_uri( content = Content.from_uri(uri="https://example.com/image.png", media_type="image/png") assert content.type == "uri" + # When receiving a raw already encode data string, you can do this: + raw_base64_string = "iVBORw0KGgo..." + content = Content.from_uri( + uri=f"data:{(detect_media_type_from_base64(data_str=raw_base64_string) or 'image/png')};base64,{ + raw_base64_string + }" + ) + Returns: A Content instance with type="data" for data URIs or type="uri" for external URIs. """ diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 9262bb3e8e..d574849351 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -import base64 import sys from collections.abc import ( AsyncIterable, @@ -58,6 +57,7 @@ Role, TextSpanRegion, UsageDetails, + detect_media_type_from_base64, prepare_function_call_results, prepend_instructions_to_messages, validate_tool_mode, @@ -991,19 +991,13 @@ def _parse_response_from_openai( ) case "image_generation_call": # ResponseOutputImageGenerationCall image_output: Content | None = None - if item.result: - base64_data = item.result # OpenAI returns base64-encoded string - # Detect media type from base64 data - from agent_framework._types import detect_media_type_from_base64 - - media_type = detect_media_type_from_base64(base64_data) - if media_type is None: - media_type = "image/png" # Default fallback - # Convert base64 string to bytes for Content.from_data - data_bytes = base64.b64decode(base64_data) - image_output = Content.from_data( - data=data_bytes, - media_type=media_type, + if item.result is not None: + # item.result contains raw base64 string + # so we call detect_media_type_from_base64 to get the media type and fallback to image/png + image_output = Content.from_uri( + uri=f"data:{(detect_media_type_from_base64(data_str=item.result) or 'image/png')};base64,{ + item.result + }", raw_representation=item.result, ) image_id = item.id @@ -1297,19 +1291,10 @@ def _parse_chunk_from_openai( # Handle streaming partial image generation image_base64 = event.partial_image_b64 partial_index = event.partial_image_index - - # Detect media type from base64 data - from agent_framework._types import detect_media_type_from_base64 - - media_type = detect_media_type_from_base64(image_base64) - if media_type is None: - media_type = "image/png" # Default fallback - - # Decode base64 and use Content.from_data - data_bytes = base64.b64decode(image_base64) - image_output = Content.from_data( - data=data_bytes, - media_type=media_type, + image_output = Content.from_uri( + uri=f"data:{(detect_media_type_from_base64(data_str=image_base64) or 'image/png')};base64,{ + image_base64 + }", additional_properties={ "partial_image_index": partial_index, "is_partial_image": True, diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 15f67fb44f..3e5317fdae 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. +import base64 from collections.abc import AsyncIterable from datetime import datetime, timezone from typing import Any @@ -24,6 +25,7 @@ ToolProtocol, UsageDetails, ai_function, + detect_media_type_from_base64, merge_chat_options, prepare_function_call_results, ) @@ -154,55 +156,63 @@ def test_data_content_empty(): assert data.media_type == "application/octet-stream" -# def test_data_content_detect_image_format_from_base64(): -# """Test the detect_image_format_from_base64 static method.""" -# # Test each supported format -# png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" -# assert detect_image_format_from_base64(base64.b64encode(png_data).decode()) == "png" - -# jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" -# assert DataContent.detect_image_format_from_base64(base64.b64encode(jpeg_data).decode()) == "jpeg" - -# webp_data = b"RIFF" + b"1234" + b"WEBP" + b"fake_data" -# assert DataContent.detect_image_format_from_base64(base64.b64encode(webp_data).decode()) == "webp" - -# gif_data = b"GIF89a" + b"fake_data" -# assert DataContent.detect_image_format_from_base64(base64.b64encode(gif_data).decode()) == "gif" - -# # Test fallback behavior -# unknown_data = b"UNKNOWN_FORMAT" -# assert DataContent.detect_image_format_from_base64(base64.b64encode(unknown_data).decode()) == "png" - -# # Test error handling -# assert DataContent.detect_image_format_from_base64("invalid_base64!") == "png" -# assert DataContent.detect_image_format_from_base64("") == "png" - - -# def test_data_content_create_data_uri_from_base64(): -# """Test the create_data_uri_from_base64 class method.""" -# # Test with PNG data -# png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" -# png_base64 = base64.b64encode(png_data).decode() -# uri, media_type = Content.create_data_uri_from_base64(png_base64) - -# assert uri == f"data:image/png;base64,{png_base64}" -# assert media_type == "image/png" - -# # Test with different format -# jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" -# jpeg_base64 = base64.b64encode(jpeg_data).decode() -# uri, media_type = DataContent.create_data_uri_from_base64(jpeg_base64) - -# assert uri == f"data:image/jpeg;base64,{jpeg_base64}" -# assert media_type == "image/jpeg" +def test_data_content_detect_image_format_from_base64(): + """Test the detect_image_format_from_base64 static method.""" + # Test each supported format + png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" + assert detect_media_type_from_base64(data_bytes=png_data) == "image/png" + assert detect_media_type_from_base64(data_str=base64.b64encode(png_data).decode()) == "image/png" + + jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" + assert detect_media_type_from_base64(data_bytes=jpeg_data) == "image/jpeg" + assert detect_media_type_from_base64(data_str=base64.b64encode(jpeg_data).decode()) == "image/jpeg" + + webp_data = b"RIFF" + b"1234" + b"WEBP" + b"fake_data" + assert detect_media_type_from_base64(data_str=base64.b64encode(webp_data).decode()) == "image/webp" + gif_data = b"GIF89a" + b"fake_data" + assert detect_media_type_from_base64(data_str=base64.b64encode(gif_data).decode()) == "image/gif" + + # Test fallback behavior + unknown_data = b"UNKNOWN_FORMAT" + assert detect_media_type_from_base64(data_str=base64.b64encode(unknown_data).decode()) is None + assert ( + detect_media_type_from_base64( + data_uri=f"data:application/octet-stream;base64,{base64.b64encode(unknown_data).decode()}" + ) + is None + ) + assert detect_media_type_from_base64(data_bytes=unknown_data) is None + # Test error handling + with pytest.raises(ValueError, match="Invalid base64 data provided."): + detect_media_type_from_base64(data_str="invalid_base64!") + detect_media_type_from_base64(data_str="") + + with pytest.raises(ValueError, match="Provide exactly one of data_bytes, data_str, or data_uri."): + detect_media_type_from_base64() + detect_media_type_from_base64( + data_bytes=b"data", data_str="data", data_uri="data:application/octet-stream;base64,AAA" + ) + detect_media_type_from_base64(data_bytes=b"data", data_str="data") + detect_media_type_from_base64(data_bytes=b"data", data_uri="data:application/octet-stream;base64,AAA") + detect_media_type_from_base64(data_str="data", data_uri="data:application/octet-stream;base64,AAA") + + +def test_data_content_create_data_uri_from_base64(): + """Test the create_data_uri_from_base64 class method.""" + # Test with PNG data + png_data = b"\x89PNG\r\n\x1a\n" + b"fake_data" + content = Content.from_data(png_data, media_type=detect_media_type_from_base64(data_bytes=png_data)) + + assert content.uri == f"data:image/png;base64,{base64.b64encode(png_data).decode()}" + assert content.media_type == "image/png" -# # Test fallback for unknown format -# unknown_data = b"UNKNOWN_FORMAT" -# unknown_base64 = base64.b64encode(unknown_data).decode() -# uri, media_type = DataContent.create_data_uri_from_base64(unknown_base64) + # Test with different format + jpeg_data = b"\xff\xd8\xff\xe0" + b"fake_data" + jpeg_base64 = base64.b64encode(jpeg_data).decode() + content = Content.from_data(jpeg_data, media_type=detect_media_type_from_base64(data_bytes=jpeg_data)) -# assert uri == f"data:image/png;base64,{unknown_base64}" -# assert media_type == "image/png" + assert content.uri == f"data:image/jpeg;base64,{jpeg_base64}" + assert content.media_type == "image/jpeg" # region UriContent diff --git a/python/packages/devui/agent_framework_devui/_mapper.py b/python/packages/devui/agent_framework_devui/_mapper.py index 182cf3526e..f11a6811ce 100644 --- a/python/packages/devui/agent_framework_devui/_mapper.py +++ b/python/packages/devui/agent_framework_devui/_mapper.py @@ -1416,10 +1416,10 @@ async def _map_usage_content(self, content: Any, context: dict[str, Any]) -> Non None - no event emitted (usage goes in final Response.usage) """ # Extract usage from UsageContent.usage_details (UsageDetails object) - details = getattr(content, "usage_details", None) - total_tokens = getattr(details, "total_token_count", 0) or 0 - prompt_tokens = getattr(details, "input_token_count", 0) or 0 - completion_tokens = getattr(details, "output_token_count", 0) or 0 + details = content.usage_details or {} + total_tokens = details.get("total_token_count", 0) + prompt_tokens = details.get("input_token_count", 0) + completion_tokens = details.get("output_token_count", 0) # Accumulate for final Response.usage request_id = context.get("request_id", "default") From 085332529eff44937b99756b4ef8f6f9cd2f7457 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 19 Jan 2026 13:41:59 +0100 Subject: [PATCH 5/7] fix for 3.10 mypy --- .../core/agent_framework/openai/_responses_client.py | 10 ++++------ python/packages/lab/lightning/tests/test_lightning.py | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index d574849351..778c7c7c54 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -995,9 +995,8 @@ def _parse_response_from_openai( # item.result contains raw base64 string # so we call detect_media_type_from_base64 to get the media type and fallback to image/png image_output = Content.from_uri( - uri=f"data:{(detect_media_type_from_base64(data_str=item.result) or 'image/png')};base64,{ - item.result - }", + uri=f"data:{detect_media_type_from_base64(data_str=item.result) or 'image/png'}" + f";base64,{item.result}", raw_representation=item.result, ) image_id = item.id @@ -1292,9 +1291,8 @@ def _parse_chunk_from_openai( image_base64 = event.partial_image_b64 partial_index = event.partial_image_index image_output = Content.from_uri( - uri=f"data:{(detect_media_type_from_base64(data_str=image_base64) or 'image/png')};base64,{ - image_base64 - }", + uri=f"data:{detect_media_type_from_base64(data_str=image_base64) or 'image/png'}" + f";base64,{image_base64}", additional_properties={ "partial_image_index": partial_index, "is_partial_image": True, diff --git a/python/packages/lab/lightning/tests/test_lightning.py b/python/packages/lab/lightning/tests/test_lightning.py index 413087877e..c56adf2b20 100644 --- a/python/packages/lab/lightning/tests/test_lightning.py +++ b/python/packages/lab/lightning/tests/test_lightning.py @@ -9,7 +9,7 @@ agentlightning = pytest.importorskip("agentlightning") -from agent_framework import AgentExecutor, AgentRunEvent, ChatAgent, WorkflowBuilder +from agent_framework import AgentExecutor, AgentRunEvent, ChatAgent, WorkflowBuilder, Workflow from agent_framework_lab_lightning import AgentFrameworkTracer from agent_framework.openai import OpenAIChatClient from agentlightning import TracerTraceToTriplet @@ -106,7 +106,7 @@ def workflow_two_agents(): yield workflow -async def test_openai_workflow_two_agents(workflow_two_agents): +async def test_openai_workflow_two_agents(workflow_two_agents: Workflow): events = await workflow_two_agents.run("Please analyze the quarterly sales data") # Get all AgentRunEvent data @@ -121,7 +121,7 @@ async def test_openai_workflow_two_agents(workflow_two_agents): ) -async def test_observability(workflow_two_agents): +async def test_observability(workflow_two_agents: Workflow): r"""Expected trace tree: [workflow.run] From fb03609e378ab1c2ac4299143c435f73b936a341 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Mon, 19 Jan 2026 14:28:04 +0100 Subject: [PATCH 6/7] fix --- python/packages/core/tests/azure/test_azure_chat_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index d91b58c646..84d7d897ff 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -627,9 +627,7 @@ async def test_streaming_with_none_delta( results.append(msg) assert len(results) > 0 - assert any( - isinstance(content, TextContent) and content.text == "test" for msg in results for content in msg.contents - ) + assert any(content.type == "text" and content.text == "test" for msg in results for content in msg.contents) assert any(msg.contents for msg in results) From a6e6944754aa3c54b1a08a7967ed8f06d594e4ac Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Tue, 20 Jan 2026 23:01:12 +0100 Subject: [PATCH 7/7] fix int test --- python/packages/core/agent_framework/_types.py | 3 +-- .../packages/core/agent_framework/openai/_responses_client.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index 9fdf9bcbf6..d586f9ff5d 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -154,8 +154,7 @@ def detect_media_type_from_base64( if data is not None: raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.") try: - # Decode just the first few bytes to check magic numbers - data = base64.b64decode(data_str[:50]) + data = base64.b64decode(data_str) except Exception as exc: raise ValueError("Invalid base64 data provided.") from exc if data is None: diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 778c7c7c54..464fbdbb28 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -214,7 +214,6 @@ async def _inner_get_response( response = await client.responses.parse(stream=False, **run_options) else: response = await client.responses.create(stream=False, **run_options) - return self._parse_response_from_openai(response, options=options) except BadRequestError as ex: if ex.code == "content_filter": raise OpenAIContentFilterException( @@ -230,6 +229,7 @@ async def _inner_get_response( f"{type(self)} service failed to complete the prompt: {ex}", inner_exception=ex, ) from ex + return self._parse_response_from_openai(response, options=options) @override async def _inner_get_streaming_response(