diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py
index 5f8c27b182..00e045fba6 100644
--- a/python/packages/a2a/agent_framework_a2a/_agent.py
+++ b/python/packages/a2a/agent_framework_a2a/_agent.py
@@ -31,11 +31,8 @@
AgentThread,
BaseAgent,
ChatMessage,
- Contents,
- DataContent,
+ Content,
Role,
- TextContent,
- UriContent,
normalize_messages,
prepend_agent_framework_to_user_agent,
)
@@ -333,7 +330,7 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage:
A2APart(
root=FilePart(
file=FileWithBytes(
- bytes=_get_uri_data(content.uri),
+ bytes=_get_uri_data(content.uri), # type: ignore[arg-type]
mime_type=content.media_type,
),
metadata=content.additional_properties,
@@ -362,19 +359,19 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage:
metadata=cast(dict[str, Any], message.additional_properties),
)
- def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
- """Parse A2A Parts into Agent Framework Contents.
+ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
+ """Parse A2A Parts into Agent Framework Content.
Transforms A2A protocol Parts into framework-native Content objects,
handling text, file (URI/bytes), and data parts with metadata preservation.
"""
- contents: list[Contents] = []
+ contents: list[Content] = []
for part in parts:
inner_part = part.root
match inner_part.kind:
case "text":
contents.append(
- TextContent(
+ Content.from_text(
text=inner_part.text,
additional_properties=inner_part.metadata,
raw_representation=inner_part,
@@ -383,7 +380,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
case "file":
if isinstance(inner_part.file, FileWithUri):
contents.append(
- UriContent(
+ Content.from_uri(
uri=inner_part.file.uri,
media_type=inner_part.file.mime_type or "",
additional_properties=inner_part.metadata,
@@ -392,7 +389,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
)
elif isinstance(inner_part.file, FileWithBytes):
contents.append(
- DataContent(
+ Content.from_data(
data=base64.b64decode(inner_part.file.bytes),
media_type=inner_part.file.mime_type or "",
additional_properties=inner_part.metadata,
@@ -401,7 +398,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
)
case "data":
contents.append(
- TextContent(
+ Content.from_text(
text=json.dumps(inner_part.data),
additional_properties=inner_part.metadata,
raw_representation=inner_part,
diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py
index 5d77345b20..eca97b2ac6 100644
--- a/python/packages/a2a/tests/test_a2a_agent.py
+++ b/python/packages/a2a/tests/test_a2a_agent.py
@@ -24,12 +24,8 @@
AgentResponse,
AgentResponseUpdate,
ChatMessage,
- DataContent,
- ErrorContent,
- HostedFileContent,
+ Content,
Role,
- TextContent,
- UriContent,
)
from agent_framework.a2a import A2AAgent
from pytest import fixture, raises
@@ -289,8 +285,8 @@ def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
# Verify conversion
assert len(contents) == 2
- assert isinstance(contents[0], TextContent)
- assert isinstance(contents[1], TextContent)
+ assert contents[0].type == "text"
+ assert contents[1].type == "text"
assert contents[0].text == "First part"
assert contents[1].text == "Second part"
@@ -299,7 +295,7 @@ def test_prepare_message_for_a2a_with_error_content(a2a_agent: A2AAgent) -> None
"""Test _prepare_message_for_a2a with ErrorContent."""
# Create ChatMessage with ErrorContent
- error_content = ErrorContent(message="Test error message")
+ error_content = Content.from_error(message="Test error message")
message = ChatMessage(role=Role.USER, contents=[error_content])
# Convert to A2A message
@@ -314,7 +310,7 @@ def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None:
"""Test _prepare_message_for_a2a with UriContent."""
# Create ChatMessage with UriContent
- uri_content = UriContent(uri="http://example.com/file.pdf", media_type="application/pdf")
+ uri_content = Content.from_uri(uri="http://example.com/file.pdf", media_type="application/pdf")
message = ChatMessage(role=Role.USER, contents=[uri_content])
# Convert to A2A message
@@ -330,7 +326,7 @@ def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None:
"""Test _prepare_message_for_a2a with DataContent."""
# Create ChatMessage with DataContent (base64 data URI)
- data_content = DataContent(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain")
+ data_content = Content.from_uri(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain")
message = ChatMessage(role=Role.USER, contents=[data_content])
# Convert to A2A message
@@ -368,7 +364,7 @@ async def test_run_stream_with_message_response(a2a_agent: A2AAgent, mock_a2a_cl
assert len(updates[0].contents) == 1
content = updates[0].contents[0]
- assert isinstance(content, TextContent)
+ assert content.type == "text"
assert content.text == "Streaming response from agent!"
assert updates[0].response_id == "msg-stream-123"
@@ -414,10 +410,10 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
message = ChatMessage(
role=Role.USER,
contents=[
- TextContent(text="Here's the analysis:"),
- DataContent(data=b"binary data", media_type="application/octet-stream"),
- UriContent(uri="https://example.com/image.png", media_type="image/png"),
- TextContent(text='{"structured": "data"}'),
+ Content.from_text(text="Here's the analysis:"),
+ Content.from_data(data=b"binary data", media_type="application/octet-stream"),
+ Content.from_uri(uri="https://example.com/image.png", media_type="image/png"),
+ Content.from_text(text='{"structured": "data"}'),
],
)
@@ -445,7 +441,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None:
assert len(contents) == 1
- assert isinstance(contents[0], TextContent)
+ assert contents[0].type == "text"
assert contents[0].text == '{"key": "value", "number": 42}'
assert contents[0].additional_properties == {"source": "test"}
@@ -470,7 +466,7 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None:
# Create message with hosted file content
message = ChatMessage(
role=Role.USER,
- contents=[HostedFileContent(file_id="hosted://storage/document.pdf")],
+ contents=[Content.from_hosted_file(file_id="hosted://storage/document.pdf")],
)
result = agent._prepare_message_for_a2a(message) # noqa: SLF001
@@ -507,7 +503,7 @@ def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
assert len(contents) == 1
- assert isinstance(contents[0], UriContent)
+ assert contents[0].type == "uri"
assert contents[0].uri == "hosted://storage/document.pdf"
assert contents[0].media_type == "" # Converted None to empty string
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py
index e31036803c..a336f28b76 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py
@@ -17,12 +17,10 @@
ChatMessage,
ChatResponse,
ChatResponseUpdate,
- DataContent,
- FunctionCallContent,
+ Content,
+ use_chat_middleware,
+ use_function_invocation,
)
-from agent_framework._middleware import use_chat_middleware
-from agent_framework._tools import use_function_invocation
-from agent_framework._types import BaseContent, Contents
from agent_framework.observability import use_instrumentation
from ._event_converters import AGUIEventConverter
@@ -53,26 +51,11 @@
logger: logging.Logger = logging.getLogger(__name__)
-class ServerFunctionCallContent(BaseContent):
- """Wrapper for server function calls to prevent client re-execution.
-
- All function calls from the remote server are server-side executions.
- This wrapper prevents @use_function_invocation from trying to execute them again.
- """
-
- function_call_content: FunctionCallContent
-
- def __init__(self, function_call_content: FunctionCallContent) -> None:
- """Initialize with the function call content."""
- super().__init__(type="server_function_call")
- self.function_call_content = function_call_content
-
-
-def _unwrap_server_function_call_contents(contents: MutableSequence[Contents | dict[str, Any]]) -> None:
- """Replace ServerFunctionCallContent instances with their underlying call content."""
+def _unwrap_server_function_call_contents(contents: MutableSequence[Content | dict[str, Any]]) -> None:
+ """Replace server_function_call instances with their underlying call content."""
for idx, content in enumerate(contents):
- if isinstance(content, ServerFunctionCallContent):
- contents[idx] = content.function_call_content # type: ignore[assignment]
+ if content.type == "server_function_call": # type: ignore[union-attr]
+ contents[idx] = content.function_call # type: ignore[assignment, union-attr]
TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient[Any]])
@@ -93,7 +76,7 @@ def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseCha
@wraps(original_get_streaming_response)
async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]:
async for update in original_get_streaming_response(self, *args, **kwargs):
- _unwrap_server_function_call_contents(cast(MutableSequence[Contents | dict[str, Any]], update.contents))
+ _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], update.contents))
yield update
chat_client.get_streaming_response = streaming_wrapper # type: ignore[assignment]
@@ -105,9 +88,7 @@ async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse:
response = await original_get_response(self, *args, **kwargs)
if response.messages:
for message in response.messages:
- _unwrap_server_function_call_contents(
- cast(MutableSequence[Contents | dict[str, Any]], message.contents)
- )
+ _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents))
return response
chat_client.get_response = response_wrapper # type: ignore[assignment]
@@ -289,13 +270,13 @@ def _extract_state_from_messages(
last_message = messages[-1]
for content in last_message.contents:
- if isinstance(content, DataContent) and content.media_type == "application/json":
+ if isinstance(content, Content) and content.type == "data" and content.media_type == "application/json":
try:
uri = content.uri
- if uri.startswith("data:application/json;base64,"):
+ if uri.startswith("data:application/json;base64,"): # type: ignore[union-attr]
import base64
- encoded_data = uri.split(",", 1)[1]
+ encoded_data = uri.split(",", 1)[1] # type: ignore[union-attr]
decoded_bytes = base64.b64decode(encoded_data)
state = json.loads(decoded_bytes.decode("utf-8"))
@@ -433,19 +414,19 @@ async def _inner_get_streaming_response(
)
# Distinguish client vs server tools
for i, content in enumerate(update.contents):
- if isinstance(content, FunctionCallContent):
+ if content.type == "function_call": # type: ignore[attr-defined]
logger.debug(
- f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}"
+ f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" # type: ignore[attr-defined]
)
- if content.name in client_tool_set:
+ if content.name in client_tool_set: # type: ignore[attr-defined]
# Client tool - let @use_function_invocation execute it
- if not content.additional_properties:
- content.additional_properties = {}
- content.additional_properties["agui_thread_id"] = thread_id
+ if not content.additional_properties: # type: ignore[attr-defined]
+ content.additional_properties = {} # type: ignore[attr-defined]
+ content.additional_properties["agui_thread_id"] = thread_id # type: ignore[attr-defined]
else:
# Server tool - wrap so @use_function_invocation ignores it
- logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}")
- self._register_server_tool_placeholder(content.name)
- update.contents[i] = ServerFunctionCallContent(content) # type: ignore
+ logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") # type: ignore[union-attr]
+ self._register_server_tool_placeholder(content.name) # type: ignore[arg-type]
+ update.contents[i] = Content(type="server_function_call", function_call=content) # type: ignore
yield update
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py b/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py
index 0f485739c9..bd2d989f2a 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py
@@ -6,12 +6,9 @@
from agent_framework import (
ChatResponseUpdate,
- ErrorContent,
+ Content,
FinishReason,
- FunctionCallContent,
- FunctionResultContent,
Role,
- TextContent,
)
@@ -117,7 +114,7 @@ def _handle_text_message_content(self, event: dict[str, Any]) -> ChatResponseUpd
return ChatResponseUpdate(
role=Role.ASSISTANT,
message_id=self.current_message_id,
- contents=[TextContent(text=delta)],
+ contents=[Content.from_text(text=delta)],
)
def _handle_text_message_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
@@ -133,7 +130,7 @@ def _handle_tool_call_start(self, event: dict[str, Any]) -> ChatResponseUpdate:
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[
- FunctionCallContent(
+ Content.from_function_call(
call_id=self.current_tool_call_id or "",
name=self.current_tool_name or "",
arguments="",
@@ -149,7 +146,7 @@ def _handle_tool_call_args(self, event: dict[str, Any]) -> ChatResponseUpdate:
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[
- FunctionCallContent(
+ Content.from_function_call(
call_id=self.current_tool_call_id or "",
name=self.current_tool_name or "",
arguments=delta,
@@ -170,7 +167,7 @@ def _handle_tool_call_result(self, event: dict[str, Any]) -> ChatResponseUpdate:
return ChatResponseUpdate(
role=Role.TOOL,
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id=tool_call_id,
result=result,
)
@@ -197,7 +194,7 @@ def _handle_run_error(self, event: dict[str, Any]) -> ChatResponseUpdate:
role=Role.ASSISTANT,
finish_reason=FinishReason.CONTENT_FILTER,
contents=[
- ErrorContent(
+ Content.from_error(
message=error_message,
error_code="RUN_ERROR",
)
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py
index ddf3ebba01..34c1e3ed86 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py
@@ -25,10 +25,7 @@
)
from agent_framework import (
AgentResponseUpdate,
- FunctionApprovalRequestContent,
- FunctionCallContent,
- FunctionResultContent,
- TextContent,
+ Content,
prepare_function_call_results,
)
@@ -96,20 +93,22 @@ async def from_agent_run_update(self, update: AgentResponseUpdate) -> list[BaseE
logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items")
for idx, content in enumerate(update.contents):
logger.info(f" Content {idx}: type={type(content).__name__}")
- if isinstance(content, TextContent):
- events.extend(self._handle_text_content(content))
- elif isinstance(content, FunctionCallContent):
- events.extend(self._handle_function_call_content(content))
- elif isinstance(content, FunctionResultContent):
- events.extend(self._handle_function_result_content(content))
- elif isinstance(content, FunctionApprovalRequestContent):
- events.extend(self._handle_function_approval_request_content(content))
-
+ match content.type:
+ case "text":
+ events.extend(self._handle_text_content(content))
+ case "function_call":
+ events.extend(self._handle_function_call_content(content))
+ case "function_result":
+ events.extend(self._handle_function_result_content(content))
+ case "function_approval_request":
+ events.extend(self._handle_function_approval_request_content(content))
+ case _:
+ logger.warning(f" Unsupported content type: {content.type}, skipping.")
return events
- def _handle_text_content(self, content: TextContent) -> list[BaseEvent]:
+ def _handle_text_content(self, content: Content) -> list[BaseEvent]:
events: list[BaseEvent] = []
- logger.info(f" TextContent found: length={len(content.text)}")
+ logger.info(f" TextContent found: length={len(content.text)}") # type: ignore[arg-type]
logger.info(
" Flags: skip_text_content=%s, should_stop_after_confirm=%s",
self.skip_text_content,
@@ -122,7 +121,7 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]:
if self.should_stop_after_confirm:
logger.info(" SKIPPING TextContent: waiting for confirm_changes response")
- self.suppressed_summary += content.text
+ self.suppressed_summary += content.text # type: ignore[operator]
logger.info(f" Suppressed summary length={len(self.suppressed_summary)}")
return events
@@ -150,14 +149,14 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]:
events.append(event)
return events
- def _handle_function_call_content(self, content: FunctionCallContent) -> list[BaseEvent]:
+ def _handle_function_call_content(self, content: Content) -> list[BaseEvent]:
events: list[BaseEvent] = []
if content.name:
logger.debug(f"Tool call: {content.name} (call_id: {content.call_id})")
if not content.name and not content.call_id and not self.current_tool_call_name:
args_length = len(str(content.arguments)) if content.arguments else 0
- logger.warning(f"FunctionCallContent missing name and call_id. args_length={args_length}")
+ logger.warning(f"Content missing name and call_id. args_length={args_length}")
tool_call_id = self._coalesce_tool_call_id(content)
# Only emit ToolCallStartEvent once per tool call (when it's a new tool call)
@@ -190,7 +189,7 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba
return events
- def _coalesce_tool_call_id(self, content: FunctionCallContent) -> str:
+ def _coalesce_tool_call_id(self, content: Content) -> str:
if content.call_id:
return content.call_id
if self.current_tool_call_id:
@@ -286,7 +285,7 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]:
self.pending_state_updates[state_key] = state_value
return events
- def _handle_function_result_content(self, content: FunctionResultContent) -> list[BaseEvent]:
+ def _handle_function_result_content(self, content: Content) -> list[BaseEvent]:
events: list[BaseEvent] = []
if content.call_id:
end_event = ToolCallEndEvent(
@@ -310,7 +309,7 @@ def _handle_function_result_content(self, content: FunctionResultContent) -> lis
result_event = ToolCallResultEvent(
message_id=result_message_id,
- tool_call_id=content.call_id,
+ tool_call_id=content.call_id, # type: ignore[arg-type]
content=result_content,
role="tool",
)
@@ -367,7 +366,7 @@ def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]:
self.current_tool_call_name = None
return events
- def _emit_confirm_changes_tool_call(self, function_call: FunctionCallContent | None = None) -> list[BaseEvent]:
+ def _emit_confirm_changes_tool_call(self, function_call: Content | None = None) -> list[BaseEvent]:
"""Emit a confirm_changes tool call for Dojo UI compatibility.
Args:
@@ -419,7 +418,7 @@ def _emit_confirm_changes_tool_call(self, function_call: FunctionCallContent | N
logger.info("Set flag to stop run after confirm_changes")
return events
- def _emit_function_approval_tool_call(self, function_call: FunctionCallContent) -> list[BaseEvent]:
+ def _emit_function_approval_tool_call(self, function_call: Content) -> list[BaseEvent]:
"""Emit a tool call that can drive UI approval for function requests."""
tool_call_name = "confirm_changes"
if self.approval_tool_name and self.approval_tool_name != function_call.name:
@@ -462,13 +461,13 @@ def _emit_function_approval_tool_call(self, function_call: FunctionCallContent)
logger.info("Set flag to stop run after confirm_changes")
return events
- def _handle_function_approval_request_content(self, content: FunctionApprovalRequestContent) -> list[BaseEvent]:
+ def _handle_function_approval_request_content(self, content: Content) -> list[BaseEvent]:
events: list[BaseEvent] = []
logger.info("=== FUNCTION APPROVAL REQUEST ===")
- logger.info(f" Function: {content.function_call.name}")
- logger.info(f" Call ID: {content.function_call.call_id}")
+ logger.info(f" Function: {content.function_call.name}") # type: ignore[union-attr]
+ logger.info(f" Call ID: {content.function_call.call_id}") # type: ignore[union-attr]
- parsed_args = content.function_call.parse_arguments()
+ parsed_args = content.function_call.parse_arguments() # type: ignore[union-attr]
parsed_arg_keys = list(parsed_args.keys()) if parsed_args else "None"
logger.info(f" Parsed args keys: {parsed_arg_keys}")
@@ -478,12 +477,12 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq
list(self.predict_state_config.keys()) if self.predict_state_config else "None",
)
for state_key, config in self.predict_state_config.items():
- if config["tool"] != content.function_call.name:
+ if config["tool"] != content.function_call.name: # type: ignore[union-attr]
continue
tool_arg_name = config["tool_argument"]
logger.info(
" MATCHED tool '%s' for state key '%s', arg='%s'",
- content.function_call.name,
+ content.function_call.name, # type: ignore[union-attr]
state_key,
tool_arg_name,
)
@@ -500,11 +499,11 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq
)
events.append(state_snapshot)
- if content.function_call.call_id:
+ if content.function_call.call_id: # type: ignore[union-attr]
end_event = ToolCallEndEvent(
- tool_call_id=content.function_call.call_id,
+ tool_call_id=content.function_call.call_id, # type: ignore[union-attr]
)
- logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'")
+ logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") # type: ignore[union-attr]
events.append(end_event)
# Emit the function_approval_request custom event for UI implementations that support it
@@ -513,18 +512,18 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq
value={
"id": content.id,
"function_call": {
- "call_id": content.function_call.call_id,
- "name": content.function_call.name,
- "arguments": content.function_call.parse_arguments(),
+ "call_id": content.function_call.call_id, # type: ignore[union-attr]
+ "name": content.function_call.name, # type: ignore[union-attr]
+ "arguments": content.function_call.parse_arguments(), # type: ignore[union-attr]
},
},
)
- logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'")
+ logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") # type: ignore[union-attr]
events.append(approval_event)
# Emit a UI-friendly approval tool call for function approvals.
if self.require_confirmation:
- events.extend(self._emit_function_approval_tool_call(content.function_call))
+ events.extend(self._emit_function_approval_tool_call(content.function_call)) # type: ignore[arg-type]
# Signal orchestrator to stop the run and wait for user approval response
self.should_stop_after_confirm = True
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py
index 1ff858e9f5..cf14641258 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py
@@ -8,11 +8,8 @@
from agent_framework import (
ChatMessage,
- FunctionApprovalResponseContent,
- FunctionCallContent,
- FunctionResultContent,
+ Content,
Role,
- TextContent,
prepare_function_call_results,
)
@@ -40,11 +37,11 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
tool_ids = {
str(content.call_id)
for content in msg.contents or []
- if isinstance(content, FunctionCallContent) and content.call_id
+ if content.type == "function_call" and content.call_id
}
confirm_changes_call = None
for content in msg.contents or []:
- if isinstance(content, FunctionCallContent) and content.name == "confirm_changes":
+ if content.type == "function_call" and content.name == "confirm_changes":
confirm_changes_call = content
break
@@ -59,7 +56,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
approval_call_ids: set[str] = set()
approval_accepted: bool | None = None
for content in msg.contents or []:
- if type(content) is FunctionApprovalResponseContent:
+ if content.type == "function_approval_response":
if content.function_call and content.function_call.call_id:
approval_call_ids.add(str(content.function_call.call_id))
if approval_accepted is None:
@@ -79,7 +76,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
synthetic_result = ChatMessage(
role="tool",
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id=pending_confirm_changes_id,
result="Confirmed" if approval_accepted else "Rejected",
)
@@ -93,12 +90,12 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
if pending_confirm_changes_id:
user_text = ""
for content in msg.contents or []:
- if isinstance(content, TextContent):
- user_text = content.text
+ if content.type == "text":
+ user_text = content.text # type: ignore[assignment]
break
try:
- parsed = json.loads(user_text)
+ parsed = json.loads(user_text) # type: ignore[arg-type]
if "accepted" in parsed:
logger.info(
f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}"
@@ -106,7 +103,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
synthetic_result = ChatMessage(
role="tool",
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id=pending_confirm_changes_id,
result="Confirmed" if parsed.get("accepted") else "Rejected",
)
@@ -130,7 +127,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
synthetic_result = ChatMessage(
role="tool",
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id=pending_call_id,
result="Tool execution skipped - user provided follow-up message",
)
@@ -149,7 +146,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]:
continue
keep = False
for content in msg.contents or []:
- if isinstance(content, FunctionResultContent):
+ if content.type == "function_result" and content.call_id:
call_id = str(content.call_id)
if call_id in pending_tool_call_ids:
keep = True
@@ -175,7 +172,7 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
for idx, msg in enumerate(messages):
role_value = get_role_value(msg)
- if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent):
+ if role_value == "tool" and msg.contents and msg.contents[0].type == "function_result":
call_id = str(msg.contents[0].call_id)
key: Any = (role_value, call_id)
@@ -184,7 +181,7 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
existing_msg = unique_messages[existing_idx]
existing_result = None
- if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent):
+ if existing_msg.contents and existing_msg.contents[0].type == "function_result":
existing_result = existing_msg.contents[0].result
new_result = msg.contents[0].result
@@ -198,11 +195,9 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]:
seen_keys[key] = len(unique_messages)
unique_messages.append(msg)
- elif (
- role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents)
- ):
+ elif role_value == "assistant" and msg.contents and any(c.type == "function_call" for c in msg.contents):
tool_call_ids = tuple(
- sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id)
+ sorted(str(c.call_id) for c in msg.contents if c.type == "function_call" and c.call_id)
)
key = (role_value, tool_call_ids)
@@ -275,15 +270,14 @@ def _update_tool_call_arguments(
function_payload_dict["arguments"] = modified_args
return
- def _find_matching_func_call(call_id: str) -> FunctionCallContent | None:
+ def _find_matching_func_call(call_id: str) -> Content | None:
for prev_msg in result:
role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role)
if role_val != "assistant":
continue
for content in prev_msg.contents or []:
- if isinstance(content, FunctionCallContent):
- if content.call_id == call_id and content.name != "confirm_changes":
- return content
+ if content.type == "function_call" and content.call_id == call_id and content.name != "confirm_changes":
+ return content
return None
def _parse_arguments(arguments: Any) -> dict[str, Any] | None:
@@ -301,9 +295,9 @@ def _resolve_approval_call_id(tool_call_id: str, parsed_payload: dict[str, Any]
continue
direct_call = None
confirm_call = None
- sibling_calls: list[FunctionCallContent] = []
+ sibling_calls: list[Content] = []
for content in prev_msg.contents or []:
- if not isinstance(content, FunctionCallContent):
+ if content.type != "function_call":
continue
if content.call_id == tool_call_id:
direct_call = content
@@ -407,7 +401,7 @@ def _filter_modified_args(
if not (
(m.role.value if hasattr(m.role, "value") else str(m.role)) == "tool"
and any(
- isinstance(c, FunctionResultContent) and c.call_id == approval_call_id
+ c.type == "function_result" and c.call_id == approval_call_id
for c in (m.contents or [])
)
)
@@ -465,9 +459,9 @@ def _filter_modified_args(
matching_func_call.arguments = updated_args
_update_tool_call_arguments(messages, str(approval_call_id), merged_args)
# Create a new FunctionCallContent with the modified arguments
- func_call_for_approval = FunctionCallContent(
- call_id=matching_func_call.call_id,
- name=matching_func_call.name,
+ func_call_for_approval = Content.from_function_call(
+ call_id=matching_func_call.call_id, # type: ignore[arg-type]
+ name=matching_func_call.name, # type: ignore[arg-type]
arguments=json.dumps(filtered_args),
)
logger.info(f"Using modified arguments from approval: {filtered_args}")
@@ -476,7 +470,7 @@ def _filter_modified_args(
func_call_for_approval = matching_func_call
# Create FunctionApprovalResponseContent for the agent framework
- approval_response = FunctionApprovalResponseContent(
+ approval_response = Content.from_function_approval_response(
approved=accepted,
id=str(approval_call_id),
function_call=func_call_for_approval,
@@ -491,7 +485,7 @@ def _filter_modified_args(
# Keep the old behavior for backwards compatibility
chat_msg = ChatMessage(
role=Role.USER,
- contents=[TextContent(text=approval_payload_text)],
+ contents=[Content.from_text(text=approval_payload_text)],
additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")},
)
if "id" in msg:
@@ -511,7 +505,7 @@ def _filter_modified_args(
func_result = str(result_content)
chat_msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id=str(tool_call_id), result=func_result)],
+ contents=[Content.from_function_result(call_id=str(tool_call_id), result=func_result)],
)
if "id" in msg:
chat_msg.message_id = msg["id"]
@@ -527,21 +521,21 @@ def _filter_modified_args(
chat_msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id=str(tool_call_id), result=result_content)],
+ contents=[Content.from_function_result(call_id=str(tool_call_id), result=result_content)],
)
if "id" in msg:
chat_msg.message_id = msg["id"]
result.append(chat_msg)
continue
- # If assistant message includes tool calls, convert to FunctionCallContent(s)
+ # If assistant message includes tool calls, convert to Content.from_function_call(s)
tool_calls = msg.get("tool_calls") or msg.get("toolCalls")
if tool_calls:
contents: list[Any] = []
# Include any assistant text content if present
content_text = msg.get("content")
if isinstance(content_text, str) and content_text:
- contents.append(TextContent(text=content_text))
+ contents.append(Content.from_text(text=content_text))
# Convert each tool call entry
for tc in tool_calls:
if not isinstance(tc, dict):
@@ -558,7 +552,7 @@ def _filter_modified_args(
arguments = func_dict.get("arguments")
contents.append(
- FunctionCallContent(
+ Content.from_function_call(
call_id=call_id,
name=name,
arguments=arguments,
@@ -580,14 +574,14 @@ def _filter_modified_args(
approval_contents: list[Any] = []
for approval in msg["function_approvals"]:
# Create FunctionCallContent with the modified arguments
- func_call = FunctionCallContent(
+ func_call = Content.from_function_call(
call_id=approval.get("call_id", ""),
name=approval.get("name", ""),
arguments=approval.get("arguments", {}),
)
# Create the approval response
- approval_response = FunctionApprovalResponseContent(
+ approval_response = Content.from_function_approval_response(
approved=approval.get("approved", True),
id=approval.get("id", ""),
function_call=func_call,
@@ -599,9 +593,9 @@ def _filter_modified_args(
# Regular text message
content = msg.get("content", "")
if isinstance(content, str):
- chat_msg = ChatMessage(role=role, contents=[TextContent(text=content)])
+ chat_msg = ChatMessage(role=role, contents=[Content.from_text(text=content)])
else:
- chat_msg = ChatMessage(role=role, contents=[TextContent(text=str(content))])
+ chat_msg = ChatMessage(role=role, contents=[Content.from_text(text=str(content))])
if "id" in msg:
chat_msg.message_id = msg["id"]
@@ -652,9 +646,9 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str
tool_result_call_id: str | None = None
for content in msg.contents:
- if isinstance(content, TextContent):
- content_text += content.text
- elif isinstance(content, FunctionCallContent):
+ if content.type == "text":
+ content_text += content.text # type: ignore[operator]
+ elif content.type == "function_call":
tool_calls.append(
{
"id": content.call_id,
@@ -665,7 +659,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str
},
}
)
- elif isinstance(content, FunctionResultContent):
+ elif content.type == "function_result":
# Tool result content - extract call_id and result
tool_result_call_id = content.call_id
# Serialize result to string using core utility
@@ -702,8 +696,13 @@ def extract_text_from_contents(contents: list[Any]) -> str:
"""
text_parts: list[str] = []
for content in contents:
- if isinstance(content, TextContent):
- text_parts.append(content.text)
+ if type_ := getattr(content, "type", None):
+ if type_ == "text_reasoning":
+ continue
+ if text := getattr(content, "text", None):
+ text_parts.append(text)
+ continue
+ # TODO (moonbox3): should this handle both text and text_reasoning?
elif hasattr(content, "text"):
text_parts.append(content.text)
return "".join(text_parts)
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py
index ebf6ef6f57..b2e3b1d5eb 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py
@@ -9,10 +9,7 @@
from ag_ui.core import StateSnapshotEvent
from agent_framework import (
ChatMessage,
- FunctionApprovalResponseContent,
- FunctionCallContent,
- FunctionResultContent,
- TextContent,
+ Content,
)
from .._utils import get_role_value, safe_json_parse
@@ -37,9 +34,9 @@ def pending_tool_call_ids(messages: list[ChatMessage]) -> set[str]:
resolved_ids: set[str] = set()
for msg in messages:
for content in msg.contents:
- if isinstance(content, FunctionCallContent) and content.call_id:
+ if content.type == "function_call" and content.call_id:
pending_ids.add(str(content.call_id))
- elif isinstance(content, FunctionResultContent) and content.call_id:
+ elif content.type == "function_result" and content.call_id:
resolved_ids.add(str(content.call_id))
return pending_ids - resolved_ids
@@ -56,7 +53,7 @@ def is_state_context_message(message: ChatMessage) -> bool:
if get_role_value(message) != "system":
return False
for content in message.contents:
- if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
+ if content.type == "text" and content.text.startswith("Current state of the application:"): # type: ignore[union-attr]
return True
return False
@@ -139,7 +136,7 @@ def tool_calls_match_state(
if get_role_value(msg) != "assistant":
continue
for content in msg.contents:
- if isinstance(content, FunctionCallContent) and content.name == tool_name:
+ if content.type == "function_call" and content.name == tool_name:
tool_args = safe_json_parse(content.arguments)
break
if tool_args is not None:
@@ -287,7 +284,7 @@ def collect_approved_state_snapshots(
if get_role_value(msg) != "user":
continue
for content in msg.contents:
- if type(content) is FunctionApprovalResponseContent:
+ if content.type == "function_approval_response":
if not content.function_call or not content.approved:
continue
parsed_args = content.function_call.parse_arguments()
@@ -319,7 +316,7 @@ def collect_approved_state_snapshots(
return events
-def latest_approval_response(messages: list[ChatMessage]) -> FunctionApprovalResponseContent | None:
+def latest_approval_response(messages: list[ChatMessage]) -> Content | None:
"""Get the latest approval response from messages.
Args:
@@ -332,12 +329,12 @@ def latest_approval_response(messages: list[ChatMessage]) -> FunctionApprovalRes
return None
last_message = messages[-1]
for content in last_message.contents:
- if type(content) is FunctionApprovalResponseContent:
+ if content.type == "function_approval_response":
return content
return None
-def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]:
+def approval_steps(approval: Content) -> list[Any]:
"""Extract steps from an approval response.
Args:
@@ -346,9 +343,7 @@ def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]:
Returns:
List of steps, or empty list if none
"""
- state_args: Any | None = None
- if approval.additional_properties:
- state_args = approval.additional_properties.get("ag_ui_state_args")
+ state_args = approval.additional_properties.get("ag_ui_state_args", None)
if isinstance(state_args, dict):
steps = state_args.get("steps")
if isinstance(steps, list):
@@ -365,7 +360,7 @@ def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]:
def is_step_based_approval(
- approval: FunctionApprovalResponseContent,
+ approval: Content,
predict_state_config: dict[str, dict[str, str]] | None,
) -> bool:
"""Check if an approval is step-based.
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py
index 7d8a23d84c..05cc55228d 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py
@@ -6,7 +6,7 @@
from typing import Any
from ag_ui.core import CustomEvent, EventType
-from agent_framework import ChatMessage, TextContent
+from agent_framework import ChatMessage, Content
class StateManager:
@@ -71,7 +71,7 @@ def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_ca
return ChatMessage(
role="system",
contents=[
- TextContent(
+ Content.from_text(
text=(
"Current state of the application:\n"
f"{state_json}\n\n"
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py
index b5566f0aec..2bd24de8c8 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py
@@ -25,13 +25,11 @@
AgentProtocol,
AgentThread,
ChatAgent,
- FunctionCallContent,
- FunctionResultContent,
- TextContent,
+ Content,
+ FunctionInvocationConfiguration,
)
from agent_framework._middleware import extract_and_merge_function_middleware
from agent_framework._tools import (
- FunctionInvocationConfiguration,
_collect_approval_responses, # type: ignore
_replace_approval_contents_with_results, # type: ignore
_try_execute_function_calls, # type: ignore
@@ -285,12 +283,12 @@ async def run(
last_message = context.last_message
if last_message:
for content in last_message.contents:
- if isinstance(content, TextContent):
+ if content.type == "text":
tool_content_text = content.text
break
try:
- tool_result = json.loads(tool_content_text)
+ tool_result = json.loads(tool_content_text) # type: ignore[arg-type]
accepted = tool_result.get("accepted", False)
steps = tool_result.get("steps", [])
@@ -328,7 +326,7 @@ async def run(
except json.JSONDecodeError:
logger.error(f"Failed to parse tool result: {tool_content_text}")
- yield RunErrorEvent(message=f"Invalid tool result format: {tool_content_text[:100]}")
+ yield RunErrorEvent(message=f"Invalid tool result format: {tool_content_text[:100]}") # type: ignore[index]
yield event_bridge.create_run_finished_event()
@@ -441,25 +439,24 @@ async def run(
logger.info(f" Message {i}: role={role}, id={msg_id}")
if hasattr(msg, "contents") and msg.contents:
for j, content in enumerate(msg.contents):
- content_type = type(content).__name__
- if isinstance(content, TextContent):
- logger.debug(" Content %s: %s - text_length=%s", j, content_type, len(content.text))
- elif isinstance(content, FunctionCallContent):
+ if content.type == "text":
+ logger.debug(" Content %s: %s - text_length=%s", j, content.type, len(content.text)) # type: ignore[arg-type]
+ elif content.type == "function_call":
arg_length = len(str(content.arguments)) if content.arguments else 0
logger.debug(
- " Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length
+ " Content %s: %s - %s args_length=%s", j, content.type, content.name, arg_length
)
- elif isinstance(content, FunctionResultContent):
+ elif content.type == "function_result":
result_preview = type(content.result).__name__ if content.result is not None else "None"
logger.debug(
" Content %s: %s - call_id=%s, result_type=%s",
j,
- content_type,
+ content.type,
content.call_id,
result_preview,
)
else:
- logger.debug(f" Content {j}: {content_type}")
+ logger.debug(f" Content {j}: {content.type}")
pending_tool_calls: list[dict[str, Any]] = []
tool_calls_by_id: dict[str, dict[str, Any]] = {}
@@ -536,16 +533,14 @@ async def _resolve_approval_responses(
logger.error("Failed to execute approved tool calls; injecting error results.")
approved_function_results = []
- normalized_results: list[FunctionResultContent] = []
+ normalized_results: list[Content] = []
for idx, approval in enumerate(approved_responses):
- if idx < len(approved_function_results) and isinstance(
- approved_function_results[idx], FunctionResultContent
- ):
+ if idx < len(approved_function_results) and approved_function_results[idx].type == "function_result":
normalized_results.append(approved_function_results[idx])
continue
- call_id = approval.function_call.call_id or approval.id
+ call_id = approval.function_call.call_id or approval.id # type: ignore[union-attr]
normalized_results.append(
- FunctionResultContent(call_id=call_id, result="Error: Tool call invocation failed.")
+ Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") # type: ignore[arg-type]
)
_replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore
@@ -661,8 +656,8 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
if all_updates is not None:
all_updates.append(update)
if event_bridge.current_message_id is None and update.contents:
- has_tool_call = any(isinstance(content, FunctionCallContent) for content in update.contents)
- has_text = any(isinstance(content, TextContent) for content in update.contents)
+ has_tool_call = any(content.type == "function_call" for content in update.contents)
+ has_text = any(content.type == "text" for content in update.contents)
if has_tool_call and not has_text:
tool_message_id = generate_event_id()
event_bridge.current_message_id = tool_message_id
diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py
index 226abae692..f88dceb78b 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py
@@ -6,6 +6,7 @@
from typing import Any, TypedDict
from agent_framework import ChatOptions
+from pydantic import BaseModel, Field
if sys.version_info >= (3, 13):
from typing import TypeVar
@@ -19,8 +20,6 @@
"RunMetadata",
]
-from pydantic import BaseModel, Field
-
class PredictStateConfig(TypedDict):
"""Configuration for predictive state updates."""
diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py
index 572df2720b..9a4acf4319 100644
--- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py
+++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_steps_agent.py
@@ -18,7 +18,7 @@
TextMessageStartEvent,
ToolCallStartEvent,
)
-from agent_framework import ChatAgent, ChatClientProtocol, ai_function
+from agent_framework import ChatAgent, ChatClientProtocol, ChatMessage, Content, ai_function
from agent_framework.ag_ui import AgentFrameworkAgent
from pydantic import BaseModel, Field
@@ -221,7 +221,6 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non
chat_client = chat_agent.chat_client # type: ignore
# Build messages for summary call
- from agent_framework._types import ChatMessage, TextContent
original_messages = input_data.get("messages", [])
@@ -234,7 +233,7 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non
messages.append(
ChatMessage(
role=msg.get("role", "user"),
- contents=[TextContent(text=content_str)],
+ contents=[Content.from_text(text=content_str)],
)
)
elif isinstance(msg, ChatMessage):
@@ -245,7 +244,7 @@ async def run_agent(self, input_data: dict[str, Any]) -> AsyncGenerator[Any, Non
ChatMessage(
role="user",
contents=[
- TextContent(
+ Content.from_text(
text="The steps have been successfully executed. Provide a brief one-sentence summary."
)
],
diff --git a/python/packages/ag-ui/getting_started/client.py b/python/packages/ag-ui/getting_started/client.py
index 61bdf0bfb3..7b56103050 100644
--- a/python/packages/ag-ui/getting_started/client.py
+++ b/python/packages/ag-ui/getting_started/client.py
@@ -50,11 +50,9 @@ async def main():
print("\nAssistant: ", end="", flush=True)
# Display text content as it streams
- from agent_framework import TextContent
-
for content in update.contents:
- if isinstance(content, TextContent) and content.text:
- print(f"\033[96m{content.text}\033[0m", end="", flush=True)
+ if hasattr(content, "text") and content.text: # type: ignore[attr-defined]
+ print(f"\033[96m{content.text}\033[0m", end="", flush=True) # type: ignore[attr-defined]
# Display finish reason if present
if update.finish_reason:
diff --git a/python/packages/ag-ui/getting_started/client_advanced.py b/python/packages/ag-ui/getting_started/client_advanced.py
index 08698a80a0..3c7ae6a334 100644
--- a/python/packages/ag-ui/getting_started/client_advanced.py
+++ b/python/packages/ag-ui/getting_started/client_advanced.py
@@ -73,11 +73,9 @@ async def streaming_example(client: AGUIChatClient, thread_id: str | None = None
if not thread_id and update.additional_properties:
thread_id = update.additional_properties.get("thread_id")
- from agent_framework import TextContent
-
for content in update.contents:
- if isinstance(content, TextContent) and content.text:
- print(content.text, end="", flush=True)
+ if content.type == "text" and content.text: # type: ignore[attr-defined]
+ print(content.text, end="", flush=True) # type: ignore[attr-defined]
print("\n")
return thread_id
@@ -138,13 +136,11 @@ async def tool_example(client: AGUIChatClient, thread_id: str | None = None):
print(f"Assistant: {response.text}")
# Show tool calls if any
- from agent_framework import FunctionCallContent
-
tool_called = False
for message in response.messages:
for content in message.contents:
- if isinstance(content, FunctionCallContent):
- print(f"\n[Tool Called: {content.name}]")
+ if content.type == "function_call": # type: ignore[attr-defined]
+ print(f"\n[Tool Called: {content.name}]") # type: ignore[attr-defined]
tool_called = True
if not tool_called:
@@ -176,7 +172,7 @@ async def conversation_example(client: AGUIChatClient):
# Second turn - using same thread
print("\nUser: What's my name?\n")
- response2 = await client.get_response("What's my name?", metadata={"thread_id": thread_id})
+ response2 = await client.get_response("What's my name?", options={"metadata": {"thread_id": thread_id}})
print(f"Assistant: {response2.text}")
# Check if context was maintained
@@ -186,7 +182,7 @@ async def conversation_example(client: AGUIChatClient):
# Third turn
print("\nUser: Can you also tell me what 10 * 5 is?\n")
response3 = await client.get_response(
- "Can you also tell me what 10 * 5 is?", metadata={"thread_id": thread_id}, tools=[calculate]
+ "Can you also tell me what 10 * 5 is?", options={"metadata": {"thread_id": thread_id}}, tools=[calculate]
)
print(f"Assistant: {response3.text}")
diff --git a/python/packages/ag-ui/getting_started/client_with_agent.py b/python/packages/ag-ui/getting_started/client_with_agent.py
index 91b099820b..63a89b4344 100644
--- a/python/packages/ag-ui/getting_started/client_with_agent.py
+++ b/python/packages/ag-ui/getting_started/client_with_agent.py
@@ -22,7 +22,7 @@
import logging
import os
-from agent_framework import ChatAgent, FunctionCallContent, FunctionResultContent, TextContent, ai_function
+from agent_framework import ChatAgent, ai_function
from agent_framework.ag_ui import AGUIChatClient
# Enable debug logging
@@ -141,8 +141,9 @@ def _preview_for_message(m) -> str:
# Build from contents when no direct text
parts: list[str] = []
for c in getattr(m, "contents", []) or []:
- if isinstance(c, FunctionCallContent):
- args = c.arguments
+ content_type = getattr(c, "type", None)
+ if content_type == "function_call":
+ args = getattr(c, "arguments", None)
if isinstance(args, dict):
try:
import json as _json
@@ -152,12 +153,15 @@ def _preview_for_message(m) -> str:
args_str = str(args)
else:
args_str = str(args or "{}")
- parts.append(f"tool_call {c.name} {args_str}")
- elif isinstance(c, FunctionResultContent):
- parts.append(f"tool_result[{c.call_id}]: {str(c.result)[:40]}")
- elif isinstance(c, TextContent):
- if c.text:
- parts.append(c.text)
+ parts.append(f"tool_call {getattr(c, 'name', '?')} {args_str}")
+ elif content_type == "function_result":
+ call_id = getattr(c, "call_id", "?")
+ result = getattr(c, "result", None)
+ parts.append(f"tool_result[{call_id}]: {str(result)[:40]}")
+ elif content_type == "text":
+ text = getattr(c, "text", None)
+ if text:
+ parts.append(text)
else:
typename = getattr(c, "type", c.__class__.__name__)
parts.append(f"<{typename}>")
diff --git a/python/packages/ag-ui/tests/test_ag_ui_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py
index bc1cc6d711..b05810972e 100644
--- a/python/packages/ag-ui/tests/test_ag_ui_client.py
+++ b/python/packages/ag-ui/tests/test_ag_ui_client.py
@@ -11,14 +11,13 @@
ChatOptions,
ChatResponse,
ChatResponseUpdate,
- FunctionCallContent,
+ Content,
Role,
- TextContent,
ai_function,
)
from pytest import MonkeyPatch
-from agent_framework_ag_ui._client import AGUIChatClient, ServerFunctionCallContent
+from agent_framework_ag_ui._client import AGUIChatClient
from agent_framework_ag_ui._http_service import AGUIHttpService
@@ -96,13 +95,11 @@ async def test_extract_state_from_messages_with_state(self) -> None:
state_json = json.dumps(state_data)
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
- from agent_framework import DataContent
-
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(
role="user",
- contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
+ contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")],
),
]
@@ -121,12 +118,10 @@ async def test_extract_state_invalid_json(self) -> None:
invalid_json = "not valid json"
state_b64 = base64.b64encode(invalid_json.encode("utf-8")).decode("utf-8")
- from agent_framework import DataContent
-
messages = [
ChatMessage(
role="user",
- contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
+ contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")],
),
]
@@ -200,8 +195,8 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str
first_content = updates[1].contents[0]
second_content = updates[2].contents[0]
- assert isinstance(first_content, TextContent)
- assert isinstance(second_content, TextContent)
+ assert first_content.type == "text"
+ assert second_content.type == "text"
assert first_content.text == "Hello"
assert second_content.text == " world"
@@ -294,13 +289,12 @@ async def mock_post_run(*args: object, **kwargs: Any) -> AsyncGenerator[dict[str
updates.append(update)
function_calls = [
- content for update in updates for content in update.contents if isinstance(content, FunctionCallContent)
+ content for update in updates for content in update.contents if content.type == "function_call"
]
assert function_calls
assert function_calls[0].name == "get_time_zone"
- assert not any(
- isinstance(content, ServerFunctionCallContent) for update in updates for content in update.contents
- )
+
+ assert not any(content.type == "server_function_call" for update in updates for content in update.contents)
async def test_server_tool_calls_not_executed_locally(self, monkeypatch: MonkeyPatch) -> None:
"""Server tools should not trigger local function invocation even when client tools exist."""
@@ -343,13 +337,11 @@ async def test_state_transmission(self, monkeypatch: MonkeyPatch) -> None:
state_json = json.dumps(state_data)
state_b64 = base64.b64encode(state_json.encode("utf-8")).decode("utf-8")
- from agent_framework import DataContent
-
messages = [
ChatMessage(role="user", text="Hello"),
ChatMessage(
role="user",
- contents=[DataContent(uri=f"data:application/json;base64,{state_b64}")],
+ contents=[Content.from_uri(uri=f"data:application/json;base64,{state_b64}")],
),
]
diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
index f919c00a56..f8f5c1db8a 100644
--- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
+++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
@@ -9,7 +9,7 @@
from typing import Any
import pytest
-from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent
+from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, Content
from pydantic import BaseModel
sys.path.insert(0, str(Path(__file__).parent))
@@ -23,7 +23,7 @@ async def test_agent_initialization_basic():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent[ChatOptions](
chat_client=StreamingChatClientStub(stream_fn),
@@ -45,7 +45,7 @@ async def test_agent_initialization_with_state_schema():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema: dict[str, dict[str, Any]] = {"document": {"type": "string"}}
@@ -61,7 +61,7 @@ async def test_agent_initialization_with_predict_state_config():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
predict_config = {"document": {"tool": "write_doc", "tool_argument": "content"}}
@@ -77,7 +77,7 @@ async def test_agent_initialization_with_pydantic_state_schema():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
class MyState(BaseModel):
document: str
@@ -100,7 +100,7 @@ async def test_run_started_event_emission():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -124,7 +124,7 @@ async def test_predict_state_custom_event_emission():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
predict_config = {
@@ -156,7 +156,7 @@ async def test_initial_state_snapshot_with_schema():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema = {"document": {"type": "string"}}
@@ -186,7 +186,7 @@ async def test_state_initialization_object_type():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema: dict[str, dict[str, Any]] = {"recipe": {"type": "object", "properties": {}}}
@@ -213,7 +213,7 @@ async def test_state_initialization_array_type():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
state_schema: dict[str, dict[str, Any]] = {"steps": {"type": "array", "items": {}}}
@@ -240,7 +240,7 @@ async def test_run_finished_event_emission():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -262,7 +262,7 @@ async def test_tool_result_confirm_changes_accepted():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Document updated")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Document updated")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(
@@ -309,7 +309,7 @@ async def test_tool_result_confirm_changes_rejected():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="OK")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -343,7 +343,7 @@ async def test_tool_result_function_approval_accepted():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="OK")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -389,7 +389,7 @@ async def test_tool_result_function_approval_rejected():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="OK")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="OK")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -431,7 +431,7 @@ async def stream_fn(
metadata = options.get("metadata")
if metadata:
thread_metadata.update(metadata)
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -462,7 +462,7 @@ async def stream_fn(
metadata = options.get("metadata")
if metadata:
thread_metadata.update(metadata)
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(
@@ -492,7 +492,7 @@ async def test_no_messages_provided():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -516,7 +516,7 @@ async def test_message_end_event_emission():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Hello world")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Hello world")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)
@@ -602,7 +602,7 @@ async def test_suppressed_summary_with_document_state():
async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text="Response")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Response")])
agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(
@@ -650,7 +650,7 @@ async def stream_fn(
thread = kwargs.get("thread")
request_service_thread_id = thread.service_thread_id if thread else None
yield ChatResponseUpdate(
- contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
+ contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)
agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn))
@@ -677,7 +677,7 @@ async def stream_fn(
thread = kwargs.get("thread")
request_service_thread_id = thread.service_thread_id if thread else None
yield ChatResponseUpdate(
- contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
+ contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)
agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn))
@@ -693,7 +693,7 @@ async def stream_fn(
async def test_function_approval_mode_executes_tool():
"""Test that function approval with approval_mode='always_require' sends the correct messages."""
- from agent_framework import FunctionResultContent, ai_function
+ from agent_framework import ai_function
from agent_framework.ag_ui import AgentFrameworkAgent
messages_received: list[Any] = []
@@ -712,7 +712,7 @@ async def stream_fn(
# Capture the messages received by the chat client
messages_received.clear()
messages_received.extend(messages)
- yield ChatResponseUpdate(contents=[TextContent(text="Processing completed")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Processing completed")])
agent = ChatAgent(
chat_client=StreamingChatClientStub(stream_fn),
@@ -770,7 +770,7 @@ async def stream_fn(
tool_result_found = False
for msg in messages_received:
for content in msg.contents:
- if isinstance(content, FunctionResultContent):
+ if content.type == "function_result":
tool_result_found = True
assert content.call_id == "call_get_datetime_123"
assert content.result == "2025/12/01 12:00:00"
@@ -784,7 +784,7 @@ async def stream_fn(
async def test_function_approval_mode_rejection():
"""Test that function approval rejection creates a rejection response."""
- from agent_framework import FunctionResultContent, ai_function
+ from agent_framework import ai_function
from agent_framework.ag_ui import AgentFrameworkAgent
messages_received: list[Any] = []
@@ -803,7 +803,7 @@ async def stream_fn(
# Capture the messages received by the chat client
messages_received.clear()
messages_received.extend(messages)
- yield ChatResponseUpdate(contents=[TextContent(text="Operation cancelled")])
+ yield ChatResponseUpdate(contents=[Content.from_text(text="Operation cancelled")])
agent = ChatAgent(
name="test_agent",
@@ -855,7 +855,7 @@ async def stream_fn(
rejection_found = False
for msg in messages_received:
for content in msg.contents:
- if isinstance(content, FunctionResultContent):
+ if content.type == "function_result":
rejection_found = True
assert content.call_id == "call_delete_123"
assert content.result == "Error: Tool call invocation was rejected by user."
diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py
index 446da23ff2..594d127532 100644
--- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py
+++ b/python/packages/ag-ui/tests/test_backend_tool_rendering.py
@@ -12,7 +12,7 @@
ToolCallResultEvent,
ToolCallStartEvent,
)
-from agent_framework import AgentResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent
+from agent_framework import AgentResponseUpdate, Content
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
@@ -22,7 +22,7 @@ async def test_tool_call_flow():
bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread")
# Step 1: Tool call starts
- tool_call = FunctionCallContent(
+ tool_call = Content.from_function_call(
call_id="weather-123",
name="get_weather",
arguments={"location": "Seattle"},
@@ -44,7 +44,7 @@ async def test_tool_call_flow():
assert "Seattle" in args_event.delta
# Step 2: Tool result comes back
- tool_result = FunctionResultContent(
+ tool_result = Content.from_function_result(
call_id="weather-123",
result="Weather in Seattle: Rainy, 52°F",
)
@@ -71,8 +71,8 @@ async def test_text_with_tool_call():
bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread")
# Agent says something then calls a tool
- text_content = TextContent(text="Let me check the weather for you.")
- tool_call = FunctionCallContent(
+ text_content = Content.from_text(text="Let me check the weather for you.")
+ tool_call = Content.from_function_call(
call_id="weather-456",
name="get_forecast",
arguments={"location": "San Francisco", "days": 3},
@@ -102,9 +102,9 @@ async def test_multiple_tool_results():
# Multiple tool results
results = [
- FunctionResultContent(call_id="tool-1", result="Result 1"),
- FunctionResultContent(call_id="tool-2", result="Result 2"),
- FunctionResultContent(call_id="tool-3", result="Result 3"),
+ Content.from_function_result(call_id="tool-1", result="Result 1"),
+ Content.from_function_result(call_id="tool-2", result="Result 2"),
+ Content.from_function_result(call_id="tool-3", result="Result 3"),
]
update = AgentResponseUpdate(contents=results)
diff --git a/python/packages/ag-ui/tests/test_document_writer_flow.py b/python/packages/ag-ui/tests/test_document_writer_flow.py
index 2e5cec9f95..7e154682b4 100644
--- a/python/packages/ag-ui/tests/test_document_writer_flow.py
+++ b/python/packages/ag-ui/tests/test_document_writer_flow.py
@@ -3,7 +3,7 @@
"""Tests for document writer predictive state flow with confirm_changes."""
from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent
-from agent_framework import AgentResponseUpdate, FunctionCallContent, FunctionResultContent, TextContent
+from agent_framework import AgentResponseUpdate, Content
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
@@ -21,7 +21,7 @@ async def test_streaming_document_with_state_deltas():
)
# Simulate streaming tool call - first chunk with name
- tool_call_start = FunctionCallContent(
+ tool_call_start = Content.from_function_call(
call_id="call_123",
name="write_document_local",
arguments='{"document":"Once',
@@ -34,7 +34,9 @@ async def test_streaming_document_with_state_deltas():
assert any(e.type == EventType.TOOL_CALL_ARGS for e in events1)
# Second chunk - incomplete JSON, should try partial extraction
- tool_call_chunk2 = FunctionCallContent(call_id="call_123", name="write_document_local", arguments=" upon a time")
+ tool_call_chunk2 = Content.from_function_call(
+ call_id="call_123", name="write_document_local", arguments=" upon a time"
+ )
update2 = AgentResponseUpdate(contents=[tool_call_chunk2])
events2 = await bridge.from_agent_run_update(update2)
@@ -71,7 +73,7 @@ async def test_confirm_changes_emission():
bridge.pending_state_updates = {"document": "A short story"}
# Tool result
- tool_result = FunctionResultContent(
+ tool_result = Content.from_function_result(
call_id="call_123",
result="Document written.",
)
@@ -115,7 +117,7 @@ async def test_text_suppression_before_confirm():
bridge.should_stop_after_confirm = True
# Text content that should be suppressed
- text = TextContent(text="I have written a story about pirates.")
+ text = Content.from_text(text="I have written a story about pirates.")
update = AgentResponseUpdate(contents=[text])
events = await bridge.from_agent_run_update(update)
@@ -146,7 +148,7 @@ async def test_no_confirm_for_non_predictive_tools():
# Different tool (not in predict_state_config)
bridge.current_tool_call_name = "get_weather"
- tool_result = FunctionResultContent(
+ tool_result = Content.from_function_result(
call_id="call_456",
result="Sunny, 72°F",
)
@@ -175,7 +177,7 @@ async def test_state_delta_deduplication():
)
# First tool call with document
- tool_call1 = FunctionCallContent(
+ tool_call1 = Content.from_function_call(
call_id="call_1",
name="write_document_local",
arguments='{"document":"Same text"}',
@@ -189,7 +191,7 @@ async def test_state_delta_deduplication():
# Second tool call with SAME document (shouldn't emit new delta)
bridge.current_tool_call_name = "write_document_local"
- tool_call2 = FunctionCallContent(
+ tool_call2 = Content.from_function_call(
call_id="call_2",
name="write_document_local",
arguments='{"document":"Same text"}', # Identical content
@@ -216,7 +218,7 @@ async def test_predict_state_config_multiple_fields():
)
# Tool call with both fields
- tool_call = FunctionCallContent(
+ tool_call = Content.from_function_call(
call_id="call_999",
name="create_post",
arguments='{"title":"My Post","body":"Post content"}',
diff --git a/python/packages/ag-ui/tests/test_endpoint.py b/python/packages/ag-ui/tests/test_endpoint.py
index 59cb884c5c..e09bb32fce 100644
--- a/python/packages/ag-ui/tests/test_endpoint.py
+++ b/python/packages/ag-ui/tests/test_endpoint.py
@@ -6,7 +6,7 @@
import sys
from pathlib import Path
-from agent_framework import ChatAgent, ChatResponseUpdate, TextContent
+from agent_framework import ChatAgent, ChatResponseUpdate, Content
from fastapi import FastAPI, Header, HTTPException
from fastapi.params import Depends
from fastapi.testclient import TestClient
@@ -20,7 +20,7 @@
def build_chat_client(response_text: str = "Test response") -> StreamingChatClientStub:
"""Create a typed chat client stub for endpoint tests."""
- updates = [ChatResponseUpdate(contents=[TextContent(text=response_text)])]
+ updates = [ChatResponseUpdate(contents=[Content.from_text(text=response_text)])]
return StreamingChatClientStub(stream_from_updates(updates))
diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py
index 295ba00372..75e923123f 100644
--- a/python/packages/ag-ui/tests/test_events_comprehensive.py
+++ b/python/packages/ag-ui/tests/test_events_comprehensive.py
@@ -6,10 +6,7 @@
from agent_framework import (
AgentResponseUpdate,
- FunctionApprovalRequestContent,
- FunctionCallContent,
- FunctionResultContent,
- TextContent,
+ Content,
)
@@ -19,7 +16,7 @@ async def test_basic_text_message_conversion():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update = AgentResponseUpdate(contents=[TextContent(text="Hello")])
+ update = AgentResponseUpdate(contents=[Content.from_text(text="Hello")])
events = await bridge.from_agent_run_update(update)
assert len(events) == 2
@@ -35,8 +32,8 @@ async def test_text_message_streaming():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")])
- update2 = AgentResponseUpdate(contents=[TextContent(text="world")])
+ update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")])
+ update2 = AgentResponseUpdate(contents=[Content.from_text(text="world")])
events1 = await bridge.from_agent_run_update(update1)
events2 = await bridge.from_agent_run_update(update2)
@@ -61,7 +58,7 @@ async def test_skip_text_content_for_structured_outputs():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread", skip_text_content=True)
- update = AgentResponseUpdate(contents=[TextContent(text='{"result": "data"}')])
+ update = AgentResponseUpdate(contents=[Content.from_text(text='{"result": "data"}')])
events = await bridge.from_agent_run_update(update)
# No events should be emitted
@@ -74,9 +71,9 @@ async def test_skip_text_content_for_empty_text():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")])
- update2 = AgentResponseUpdate(contents=[TextContent(text="")]) # Empty chunk
- update3 = AgentResponseUpdate(contents=[TextContent(text="world")])
+ update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")])
+ update2 = AgentResponseUpdate(contents=[Content.from_text(text="")]) # Empty chunk
+ update3 = AgentResponseUpdate(contents=[Content.from_text(text="world")])
events1 = await bridge.from_agent_run_update(update1)
events2 = await bridge.from_agent_run_update(update2)
@@ -105,7 +102,7 @@ async def test_tool_call_with_name():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update = AgentResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")])
+ update = AgentResponseUpdate(contents=[Content.from_function_call(name="search_web", call_id="call_123")])
events = await bridge.from_agent_run_update(update)
assert len(events) == 1
@@ -121,15 +118,17 @@ async def test_tool_call_streaming_args():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
# First chunk: name only
- update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="search_web", call_id="call_123")])
+ update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="search_web", call_id="call_123")])
events1 = await bridge.from_agent_run_update(update1)
# Second chunk: arguments chunk 1 (name can be empty string for continuation)
- update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='{"query": "')])
+ update2 = AgentResponseUpdate(
+ contents=[Content.from_function_call(name="", call_id="call_123", arguments='{"query": "')]
+ )
events2 = await bridge.from_agent_run_update(update2)
# Third chunk: arguments chunk 2
- update3 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_123", arguments='AI"}')])
+ update3 = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="call_123", arguments='AI"}')])
events3 = await bridge.from_agent_run_update(update3)
# First update: ToolCallStartEvent
@@ -167,9 +166,11 @@ async def test_streaming_tool_call_no_duplicate_start_events():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
# Simulate streaming tool call: first chunk has name, subsequent chunks have name=""
- update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="get_weather", call_id="call_789")])
- update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_789", arguments='{"loc":')])
- update3 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_789", arguments='"SF"}')])
+ update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="get_weather", call_id="call_789")])
+ update2 = AgentResponseUpdate(
+ contents=[Content.from_function_call(name="", call_id="call_789", arguments='{"loc":')]
+ )
+ update3 = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="call_789", arguments='"SF"}')])
events1 = await bridge.from_agent_run_update(update1)
events2 = await bridge.from_agent_run_update(update2)
@@ -193,7 +194,7 @@ async def test_tool_result_with_dict():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
result_data = {"status": "success", "count": 42}
- update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=result_data)])
+ update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=result_data)])
events = await bridge.from_agent_run_update(update)
# Should emit ToolCallEndEvent + ToolCallResultEvent
@@ -214,7 +215,7 @@ async def test_tool_result_with_string():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result="Search complete")])
+ update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result="Search complete")])
events = await bridge.from_agent_run_update(update)
assert len(events) == 2
@@ -229,7 +230,7 @@ async def test_tool_result_with_none():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=None)])
+ update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=None)])
events = await bridge.from_agent_run_update(update)
assert len(events) == 2
@@ -247,8 +248,8 @@ async def test_multiple_tool_results_in_sequence():
update = AgentResponseUpdate(
contents=[
- FunctionResultContent(call_id="call_1", result="Result 1"),
- FunctionResultContent(call_id="call_2", result="Result 2"),
+ Content.from_function_result(call_id="call_1", result="Result 1"),
+ Content.from_function_result(call_id="call_2", result="Result 2"),
]
)
events = await bridge.from_agent_run_update(update)
@@ -272,12 +273,12 @@ async def test_function_approval_request_basic():
require_confirmation=False,
)
- func_call = FunctionCallContent(
+ func_call = Content.from_function_call(
call_id="call_123",
name="send_email",
arguments={"to": "user@example.com", "subject": "Test"},
)
- approval = FunctionApprovalRequestContent(
+ approval = Content.from_function_approval_request(
id="approval_001",
function_call=func_call,
)
@@ -312,8 +313,8 @@ async def test_empty_predict_state_config():
# Tool call with arguments
update = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="write_doc", call_id="call_1", arguments='{"content": "test"}'),
- FunctionResultContent(call_id="call_1", result="Done"),
+ Content.from_function_call(name="write_doc", call_id="call_1", arguments='{"content": "test"}'),
+ Content.from_function_result(call_id="call_1", result="Done"),
]
)
events = await bridge.from_agent_run_update(update)
@@ -347,8 +348,8 @@ async def test_tool_not_in_predict_state_config():
# Different tool name
update = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="search_web", call_id="call_1", arguments='{"query": "AI"}'),
- FunctionResultContent(call_id="call_1", result="Results"),
+ Content.from_function_call(name="search_web", call_id="call_1", arguments='{"query": "AI"}'),
+ Content.from_function_result(call_id="call_1", result="Results"),
]
)
events = await bridge.from_agent_run_update(update)
@@ -376,8 +377,8 @@ async def test_state_management_tracking():
# Streaming tool call
update1 = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="write_doc", call_id="call_1"),
- FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Hello"}'),
+ Content.from_function_call(name="write_doc", call_id="call_1"),
+ Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Hello"}'),
]
)
await bridge.from_agent_run_update(update1)
@@ -387,7 +388,7 @@ async def test_state_management_tracking():
assert bridge.pending_state_updates["document"] == "Hello"
# Tool result should update current_state
- update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")])
+ update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")])
await bridge.from_agent_run_update(update2)
# current_state should be updated
@@ -413,12 +414,12 @@ async def test_wildcard_tool_argument():
# Complete tool call with dict arguments
update = AgentResponseUpdate(
contents=[
- FunctionCallContent(
+ Content.from_function_call(
name="create_recipe",
call_id="call_1",
arguments={"title": "Pasta", "ingredients": ["pasta", "sauce"]},
),
- FunctionResultContent(call_id="call_1", result="Created"),
+ Content.from_function_result(call_id="call_1", result="Created"),
]
)
events = await bridge.from_agent_run_update(update)
@@ -503,14 +504,14 @@ async def test_state_snapshot_after_tool_result():
# Tool call with streaming args
update1 = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="write_doc", call_id="call_1"),
- FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Test"}'),
+ Content.from_function_call(name="write_doc", call_id="call_1"),
+ Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Test"}'),
]
)
await bridge.from_agent_run_update(update1)
# Tool result should trigger StateSnapshotEvent
- update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")])
+ update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")])
events = await bridge.from_agent_run_update(update2)
# Should have: ToolCallEnd, ToolCallResult, StateSnapshot, ToolCallStart (confirm_changes), ToolCallArgs, ToolCallEnd
@@ -526,12 +527,12 @@ async def test_message_id_persistence_across_chunks():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
# First chunk
- update1 = AgentResponseUpdate(contents=[TextContent(text="Hello ")])
+ update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")])
events1 = await bridge.from_agent_run_update(update1)
message_id = events1[0].message_id
# Second chunk
- update2 = AgentResponseUpdate(contents=[TextContent(text="world")])
+ update2 = AgentResponseUpdate(contents=[Content.from_text(text="world")])
events2 = await bridge.from_agent_run_update(update2)
# Should use same message_id
@@ -546,14 +547,16 @@ async def test_tool_call_id_tracking():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
# First chunk with name
- update1 = AgentResponseUpdate(contents=[FunctionCallContent(name="search", call_id="call_1")])
+ update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="search", call_id="call_1")])
await bridge.from_agent_run_update(update1)
assert bridge.current_tool_call_id == "call_1"
assert bridge.current_tool_call_name == "search"
# Second chunk with args but no name
- update2 = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="call_1", arguments='{"q":"AI"}')])
+ update2 = AgentResponseUpdate(
+ contents=[Content.from_function_call(name="", call_id="call_1", arguments='{"q":"AI"}')]
+ )
events2 = await bridge.from_agent_run_update(update2)
# Should still track same tool call
@@ -576,8 +579,8 @@ async def test_tool_name_reset_after_result():
# Tool call
update1 = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="write_doc", call_id="call_1"),
- FunctionCallContent(name="", call_id="call_1", arguments='{"content": "Test"}'),
+ Content.from_function_call(name="write_doc", call_id="call_1"),
+ Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Test"}'),
]
)
await bridge.from_agent_run_update(update1)
@@ -585,7 +588,7 @@ async def test_tool_name_reset_after_result():
assert bridge.current_tool_call_name == "write_doc"
# Tool result with predictive state (should trigger confirm_changes and reset)
- update2 = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")])
+ update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")])
await bridge.from_agent_run_update(update2)
# Tool name should be reset
@@ -604,9 +607,9 @@ async def test_function_approval_with_wildcard_argument():
},
)
- approval_content = FunctionApprovalRequestContent(
+ approval_content = Content.from_function_approval_request(
id="approval_1",
- function_call=FunctionCallContent(
+ function_call=Content.from_function_call(
name="submit", call_id="call_1", arguments='{"key1": "value1", "key2": "value2"}'
),
)
@@ -632,9 +635,11 @@ async def test_function_approval_missing_argument():
},
)
- approval_content = FunctionApprovalRequestContent(
+ approval_content = Content.from_function_approval_request(
id="approval_1",
- function_call=FunctionCallContent(name="process", call_id="call_1", arguments='{"other_field": "value"}'),
+ function_call=Content.from_function_call(
+ name="process", call_id="call_1", arguments='{"other_field": "value"}'
+ ),
)
update = AgentResponseUpdate(contents=[approval_content])
@@ -654,8 +659,8 @@ async def test_empty_predict_state_config_no_deltas():
# Tool call with arguments
update = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="search", call_id="call_1"),
- FunctionCallContent(name="", call_id="call_1", arguments='{"query": "test"}'),
+ Content.from_function_call(name="search", call_id="call_1"),
+ Content.from_function_call(name="", call_id="call_1", arguments='{"query": "test"}'),
]
)
events = await bridge.from_agent_run_update(update)
@@ -678,8 +683,8 @@ async def test_tool_with_no_matching_config():
# Tool call for different tool
update = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="search_web", call_id="call_1"),
- FunctionCallContent(name="", call_id="call_1", arguments='{"query": "test"}'),
+ Content.from_function_call(name="search_web", call_id="call_1"),
+ Content.from_function_call(name="", call_id="call_1", arguments='{"query": "test"}'),
]
)
events = await bridge.from_agent_run_update(update)
@@ -696,7 +701,7 @@ async def test_tool_call_without_name_or_id():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
# This should not crash but log an error
- update = AgentResponseUpdate(contents=[FunctionCallContent(name="", call_id="", arguments='{"arg": "val"}')])
+ update = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="", arguments='{"arg": "val"}')])
events = await bridge.from_agent_run_update(update)
# Should emit ToolCallArgsEvent with generated ID
@@ -717,7 +722,7 @@ async def test_state_delta_count_logging():
for i in range(15):
update = AgentResponseUpdate(
contents=[
- FunctionCallContent(name="", call_id="call_1", arguments=f'{{"text": "Content variation {i}"}}'),
+ Content.from_function_call(name="", call_id="call_1", arguments=f'{{"text": "Content variation {i}"}}'),
]
)
# Set the tool name to match config
@@ -737,7 +742,7 @@ async def test_tool_result_with_empty_list():
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
- update = AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_123", result=[])])
+ update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=[])])
events = await bridge.from_agent_run_update(update)
assert len(events) == 2
@@ -760,7 +765,7 @@ class MockTextContent:
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
update = AgentResponseUpdate(
- contents=[FunctionResultContent(call_id="call_123", result=[MockTextContent("Hello from MCP tool!")])]
+ contents=[Content.from_function_result(call_id="call_123", result=[MockTextContent("Hello from MCP tool!")])]
)
events = await bridge.from_agent_run_update(update)
@@ -785,7 +790,7 @@ class MockTextContent:
update = AgentResponseUpdate(
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id="call_123",
result=[MockTextContent("First result"), MockTextContent("Second result")],
)
@@ -812,7 +817,7 @@ class MockModel(BaseModel):
bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread")
update = AgentResponseUpdate(
- contents=[FunctionResultContent(call_id="call_123", result=[MockModel(value=1), MockModel(value=2)])]
+ contents=[Content.from_function_result(call_id="call_123", result=[MockModel(value=1), MockModel(value=2)])]
)
events = await bridge.from_agent_run_update(update)
diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py
index 00e64472b6..b643465e36 100644
--- a/python/packages/ag-ui/tests/test_human_in_the_loop.py
+++ b/python/packages/ag-ui/tests/test_human_in_the_loop.py
@@ -2,7 +2,7 @@
"""Tests for human in the loop (function approval requests)."""
-from agent_framework import AgentResponseUpdate, FunctionApprovalRequestContent, FunctionCallContent
+from agent_framework import AgentResponseUpdate, Content
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
@@ -17,12 +17,12 @@ async def test_function_approval_request_emission():
)
# Create approval request
- func_call = FunctionCallContent(
+ func_call = Content.from_function_call(
call_id="call_123",
name="send_email",
arguments={"to": "user@example.com", "subject": "Test"},
)
- approval_request = FunctionApprovalRequestContent(
+ approval_request = Content.from_function_approval_request(
id="approval_001",
function_call=func_call,
)
@@ -56,12 +56,12 @@ async def test_function_approval_request_with_confirm_changes():
require_confirmation=True,
)
- func_call = FunctionCallContent(
+ func_call = Content.from_function_call(
call_id="call_456",
name="delete_file",
arguments={"path": "/tmp/test.txt"},
)
- approval_request = FunctionApprovalRequestContent(
+ approval_request = Content.from_function_approval_request(
id="approval_002",
function_call=func_call,
)
@@ -109,22 +109,22 @@ async def test_multiple_approval_requests():
require_confirmation=False,
)
- func_call_1 = FunctionCallContent(
+ func_call_1 = Content.from_function_call(
call_id="call_1",
name="create_event",
arguments={"title": "Meeting"},
)
- approval_1 = FunctionApprovalRequestContent(
+ approval_1 = Content.from_function_approval_request(
id="approval_1",
function_call=func_call_1,
)
- func_call_2 = FunctionCallContent(
+ func_call_2 = Content.from_function_call(
call_id="call_2",
name="book_room",
arguments={"room": "Conference A"},
)
- approval_2 = FunctionApprovalRequestContent(
+ approval_2 = Content.from_function_approval_request(
id="approval_2",
function_call=func_call_2,
)
@@ -164,12 +164,12 @@ async def test_function_approval_request_sets_stop_flag():
assert bridge.should_stop_after_confirm is False
- func_call = FunctionCallContent(
+ func_call = Content.from_function_call(
call_id="call_stop_test",
name="get_datetime",
arguments={},
)
- approval_request = FunctionApprovalRequestContent(
+ approval_request = Content.from_function_approval_request(
id="approval_stop_test",
function_call=func_call,
)
diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py
index 9173314a28..970a4fe76b 100644
--- a/python/packages/ag-ui/tests/test_message_adapters.py
+++ b/python/packages/ag-ui/tests/test_message_adapters.py
@@ -5,7 +5,7 @@
import json
import pytest
-from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, Role, TextContent
+from agent_framework import ChatMessage, Content, Role
from agent_framework_ag_ui._message_adapters import (
agent_framework_messages_to_agui,
@@ -24,7 +24,7 @@ def sample_agui_message():
@pytest.fixture
def sample_agent_framework_message():
"""Create a sample Agent Framework message."""
- return ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")], message_id="msg-123")
+ return ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")], message_id="msg-123")
def test_agui_to_agent_framework_basic(sample_agui_message):
@@ -89,7 +89,7 @@ def test_agui_tool_result_to_agent_framework():
assert message.role == Role.USER
assert len(message.contents) == 1
- assert isinstance(message.contents[0], TextContent)
+ assert message.contents[0].type == "text"
assert message.contents[0].text == '{"accepted": true, "steps": []}'
assert message.additional_properties is not None
@@ -141,7 +141,7 @@ def test_agui_tool_approval_updates_tool_call_arguments():
assert len(messages) == 2
assistant_msg = messages[0]
- func_call = next(content for content in assistant_msg.contents if isinstance(content, FunctionCallContent))
+ func_call = next(content for content in assistant_msg.contents if content.type == "function_call")
assert func_call.arguments == {
"steps": [
{"description": "Boil water", "status": "enabled"},
@@ -157,11 +157,9 @@ def test_agui_tool_approval_updates_tool_call_arguments():
]
}
- from agent_framework import FunctionApprovalResponseContent
-
approval_msg = messages[1]
approval_content = next(
- content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent)
+ content for content in approval_msg.contents if content.type == "function_approval_response"
)
assert approval_content.function_call.parse_arguments() == {
"steps": [
@@ -211,12 +209,9 @@ def test_agui_tool_approval_from_confirm_changes_maps_to_function_call():
]
messages = agui_messages_to_agent_framework(messages_input)
-
- from agent_framework import FunctionApprovalResponseContent
-
approval_msg = messages[1]
approval_content = next(
- content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent)
+ content for content in approval_msg.contents if content.type == "function_approval_response"
)
assert approval_content.function_call.call_id == "call_tool"
@@ -259,12 +254,9 @@ def test_agui_tool_approval_from_confirm_changes_falls_back_to_sibling_call():
]
messages = agui_messages_to_agent_framework(messages_input)
-
- from agent_framework import FunctionApprovalResponseContent
-
approval_msg = messages[1]
approval_content = next(
- content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent)
+ content for content in approval_msg.contents if content.type == "function_approval_response"
)
assert approval_content.function_call.call_id == "call_tool"
@@ -315,12 +307,9 @@ def test_agui_tool_approval_from_generate_task_steps_maps_to_function_call():
]
messages = agui_messages_to_agent_framework(messages_input)
-
- from agent_framework import FunctionApprovalResponseContent
-
approval_msg = messages[1]
approval_content = next(
- content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent)
+ content for content in approval_msg.contents if content.type == "function_approval_response"
)
assert approval_content.function_call.call_id == "call_tool"
@@ -380,15 +369,14 @@ def test_agui_function_approvals():
assert msg.role == Role.USER
assert len(msg.contents) == 2
- from agent_framework import FunctionApprovalResponseContent
-
- assert isinstance(msg.contents[0], FunctionApprovalResponseContent)
+ assert msg.contents[0].type == "function_approval_response"
assert msg.contents[0].approved is True
assert msg.contents[0].id == "approval-1"
assert msg.contents[0].function_call.name == "search"
assert msg.contents[0].function_call.call_id == "call-1"
- assert isinstance(msg.contents[1], FunctionApprovalResponseContent)
+ assert msg.contents[1].type == "function_approval_response"
+ assert msg.contents[1].id == "approval-2"
assert msg.contents[1].approved is False
@@ -406,7 +394,7 @@ def test_agui_non_string_content():
assert len(messages) == 1
assert len(messages[0].contents) == 1
- assert isinstance(messages[0].contents[0], TextContent)
+ assert messages[0].contents[0].type == "text"
assert "nested" in messages[0].contents[0].text
@@ -440,9 +428,9 @@ def test_agui_with_tool_calls_to_agent_framework():
assert msg.role == Role.ASSISTANT
assert msg.message_id == "msg-789"
# First content is text, second is the function call
- assert isinstance(msg.contents[0], TextContent)
+ assert msg.contents[0].type == "text"
assert msg.contents[0].text == "Calling tool"
- assert isinstance(msg.contents[1], FunctionCallContent)
+ assert msg.contents[1].type == "function_call"
assert msg.contents[1].call_id == "call-123"
assert msg.contents[1].name == "get_weather"
assert msg.contents[1].arguments == {"location": "Seattle"}
@@ -453,8 +441,8 @@ def test_agent_framework_to_agui_with_tool_calls():
msg = ChatMessage(
role=Role.ASSISTANT,
contents=[
- TextContent(text="Calling tool"),
- FunctionCallContent(call_id="call-123", name="search", arguments={"query": "test"}),
+ Content.from_text(text="Calling tool"),
+ Content.from_function_call(call_id="call-123", name="search", arguments={"query": "test"}),
],
message_id="msg-456",
)
@@ -477,7 +465,7 @@ def test_agent_framework_to_agui_multiple_text_contents():
"""Test concatenating multiple text contents."""
msg = ChatMessage(
role=Role.ASSISTANT,
- contents=[TextContent(text="Part 1 "), TextContent(text="Part 2")],
+ contents=[Content.from_text(text="Part 1 "), Content.from_text(text="Part 2")],
)
messages = agent_framework_messages_to_agui([msg])
@@ -488,7 +476,7 @@ def test_agent_framework_to_agui_multiple_text_contents():
def test_agent_framework_to_agui_no_message_id():
"""Test message without message_id - should auto-generate ID."""
- msg = ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])
+ msg = ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])
messages = agent_framework_messages_to_agui([msg])
@@ -500,7 +488,7 @@ def test_agent_framework_to_agui_no_message_id():
def test_agent_framework_to_agui_system_role():
"""Test system role conversion."""
- msg = ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="System")])
+ msg = ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="System")])
messages = agent_framework_messages_to_agui([msg])
@@ -510,7 +498,7 @@ def test_agent_framework_to_agui_system_role():
def test_extract_text_from_contents():
"""Test extracting text from contents list."""
- contents = [TextContent(text="Hello "), TextContent(text="World")]
+ contents = [Content.from_text(text="Hello "), Content.from_text(text="World")]
result = extract_text_from_contents(contents)
@@ -533,7 +521,7 @@ def __init__(self, text: str):
def test_extract_text_from_custom_contents():
"""Test extracting text from custom content objects."""
- contents = [CustomTextContent(text="Custom "), TextContent(text="Mixed")]
+ contents = [CustomTextContent(text="Custom "), Content.from_text(text="Mixed")]
result = extract_text_from_contents(contents)
@@ -547,7 +535,7 @@ def test_agent_framework_to_agui_function_result_dict():
"""Test converting FunctionResultContent with dict result to AG-UI."""
msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id="call-123", result={"key": "value", "count": 42})],
+ contents=[Content.from_function_result(call_id="call-123", result={"key": "value", "count": 42})],
message_id="msg-789",
)
@@ -564,7 +552,7 @@ def test_agent_framework_to_agui_function_result_none():
"""Test converting FunctionResultContent with None result to AG-UI."""
msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id="call-123", result=None)],
+ contents=[Content.from_function_result(call_id="call-123", result=None)],
message_id="msg-789",
)
@@ -580,7 +568,7 @@ def test_agent_framework_to_agui_function_result_string():
"""Test converting FunctionResultContent with string result to AG-UI."""
msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id="call-123", result="plain text result")],
+ contents=[Content.from_function_result(call_id="call-123", result="plain text result")],
message_id="msg-789",
)
@@ -595,7 +583,7 @@ def test_agent_framework_to_agui_function_result_empty_list():
"""Test converting FunctionResultContent with empty list result to AG-UI."""
msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id="call-123", result=[])],
+ contents=[Content.from_function_result(call_id="call-123", result=[])],
message_id="msg-789",
)
@@ -617,7 +605,7 @@ class MockTextContent:
msg = ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id="call-123", result=[MockTextContent("Hello from MCP!")])],
+ contents=[Content.from_function_result(call_id="call-123", result=[MockTextContent("Hello from MCP!")])],
message_id="msg-789",
)
@@ -640,7 +628,7 @@ class MockTextContent:
msg = ChatMessage(
role=Role.TOOL,
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id="call-123",
result=[MockTextContent("First result"), MockTextContent("Second result")],
)
diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py
index 380ff438bd..ecc01de3cb 100644
--- a/python/packages/ag-ui/tests/test_message_hygiene.py
+++ b/python/packages/ag-ui/tests/test_message_hygiene.py
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
-from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent
+from agent_framework import ChatMessage, Content
from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history
@@ -10,7 +10,7 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None:
ChatMessage(
role="assistant",
contents=[
- FunctionCallContent(
+ Content.from_function_call(
name="confirm_changes",
call_id="call_confirm_123",
arguments='{"changes": "test"}',
@@ -19,7 +19,7 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None:
),
ChatMessage(
role="user",
- contents=[TextContent(text='{"accepted": true}')],
+ contents=[Content.from_text(text='{"accepted": true}')],
),
]
@@ -37,11 +37,11 @@ def test_deduplicate_messages_prefers_non_empty_tool_results() -> None:
messages = [
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call1", result="")],
+ contents=[Content.from_function_result(call_id="call1", result="")],
),
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call1", result="result data")],
+ contents=[Content.from_function_result(call_id="call1", result="result data")],
),
]
diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py
index 279ddedc82..c951246bfa 100644
--- a/python/packages/ag-ui/tests/test_orchestrators.py
+++ b/python/packages/ag-ui/tests/test_orchestrators.py
@@ -13,8 +13,8 @@
BaseChatClient,
ChatAgent,
ChatResponseUpdate,
+ Content,
FunctionInvocationConfiguration,
- TextContent,
ai_function,
)
@@ -79,11 +79,11 @@ async def mock_run_stream(
if capture_messages is not None:
capture_messages.extend(messages)
yield AgentResponseUpdate(
- contents=[TextContent(text="ok")],
+ contents=[Content.from_text(text="ok")],
role="assistant",
response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator)
raw_representation=ChatResponseUpdate(
- contents=[TextContent(text="ok")],
+ contents=[Content.from_text(text="ok")],
conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator)
response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator)
),
@@ -253,7 +253,7 @@ async def test_state_context_injected_when_tool_call_state_mismatch() -> None:
if role_value != "system":
continue
for content in msg.contents or []:
- if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
+ if content.type == "text" and content.text.startswith("Current state of the application:"):
state_messages.append(content.text)
assert state_messages
assert "Vegetarian" in state_messages[0]
@@ -302,6 +302,6 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None
if role_value != "system":
continue
for content in msg.contents or []:
- if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"):
+ if content.type == "text" and content.text.startswith("Current state of the application:"):
state_messages.append(content.text)
assert not state_messages
diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py
index 6c311d593a..d579c691b7 100644
--- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py
+++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py
@@ -8,12 +8,7 @@
from types import SimpleNamespace
from typing import Any
-from agent_framework import (
- AgentResponseUpdate,
- ChatMessage,
- TextContent,
- ai_function,
-)
+from agent_framework import AgentResponseUpdate, ChatMessage, Content, ai_function
from pydantic import BaseModel
from agent_framework_ag_ui._agent import AgentConfig
@@ -48,14 +43,14 @@ async def test_human_in_the_loop_json_decode_error() -> None:
messages = [
ChatMessage(
role="tool",
- contents=[TextContent(text="not valid json {")],
+ contents=[Content.from_text(text="not valid json {")],
additional_properties={"is_tool_result": True},
)
]
agent = StubAgent(
default_options={"tools": [approval_tool], "response_format": None},
- updates=[AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")],
+ updates=[AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")],
)
context = TestExecutionContext(
input_data=input_data,
@@ -78,14 +73,14 @@ async def test_human_in_the_loop_json_decode_error() -> None:
async def test_sanitize_tool_history_confirm_changes() -> None:
"""Test sanitize_tool_history logic for confirm_changes synthetic result."""
- from agent_framework import ChatMessage, FunctionCallContent, TextContent
+ from agent_framework import ChatMessage
# Create messages that will trigger confirm_changes synthetic result injection
messages = [
ChatMessage(
role="assistant",
contents=[
- FunctionCallContent(
+ Content.from_function_call(
name="confirm_changes",
call_id="call_confirm_123",
arguments='{"changes": "test"}',
@@ -94,7 +89,7 @@ async def test_sanitize_tool_history_confirm_changes() -> None:
),
ChatMessage(
role="user",
- contents=[TextContent(text='{"accepted": true}')],
+ contents=[Content.from_text(text='{"accepted": true}')],
),
]
@@ -134,17 +129,17 @@ async def test_sanitize_tool_history_confirm_changes() -> None:
async def test_sanitize_tool_history_orphaned_tool_result() -> None:
"""Test sanitize_tool_history removes orphaned tool results."""
- from agent_framework import ChatMessage, FunctionResultContent, TextContent
+ from agent_framework import ChatMessage
# Tool result without preceding assistant tool call
messages = [
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="orphan_123", result="orphaned data")],
+ contents=[Content.from_function_result(call_id="orphan_123", result="orphaned data")],
),
ChatMessage(
role="user",
- contents=[TextContent(text="Hello")],
+ contents=[Content.from_text(text="Hello")],
),
]
@@ -214,20 +209,20 @@ async def test_orphaned_tool_result_sanitization() -> None:
async def test_deduplicate_messages_empty_tool_results() -> None:
"""Test deduplicate_messages prefers non-empty tool results."""
- from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent
+ from agent_framework import ChatMessage
messages = [
ChatMessage(
role="assistant",
- contents=[FunctionCallContent(name="test_tool", call_id="call_789", arguments="{}")],
+ contents=[Content.from_function_call(name="test_tool", call_id="call_789", arguments="{}")],
),
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call_789", result="")],
+ contents=[Content.from_function_result(call_id="call_789", result="")],
),
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call_789", result="real data")],
+ contents=[Content.from_function_result(call_id="call_789", result="real data")],
),
]
@@ -259,20 +254,20 @@ async def test_deduplicate_messages_empty_tool_results() -> None:
async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None:
"""Test deduplicate_messages removes duplicate assistant tool call messages."""
- from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent
+ from agent_framework import ChatMessage
messages = [
ChatMessage(
role="assistant",
- contents=[FunctionCallContent(name="test_tool", call_id="call_abc", arguments="{}")],
+ contents=[Content.from_function_call(name="test_tool", call_id="call_abc", arguments="{}")],
),
ChatMessage(
role="assistant",
- contents=[FunctionCallContent(name="test_tool", call_id="call_abc", arguments="{}")],
+ contents=[Content.from_function_call(name="test_tool", call_id="call_abc", arguments="{}")],
),
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call_abc", result="result")],
+ contents=[Content.from_function_result(call_id="call_abc", result="result")],
),
]
@@ -303,20 +298,20 @@ async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None:
async def test_deduplicate_messages_duplicate_system_messages() -> None:
"""Test that deduplication logic is invoked for system messages."""
- from agent_framework import ChatMessage, TextContent
+ from agent_framework import ChatMessage
messages = [
ChatMessage(
role="system",
- contents=[TextContent(text="You are a helpful assistant.")],
+ contents=[Content.from_text(text="You are a helpful assistant.")],
),
ChatMessage(
role="system",
- contents=[TextContent(text="You are a helpful assistant.")],
+ contents=[Content.from_text(text="You are a helpful assistant.")],
),
ChatMessage(
role="user",
- contents=[TextContent(text="Hello")],
+ contents=[Content.from_text(text="Hello")],
),
]
@@ -387,20 +382,20 @@ async def test_state_context_injection() -> None:
async def test_state_context_injection_with_tool_calls_and_input_state() -> None:
"""Test state context is injected when state is provided, even with tool calls."""
- from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent
+ from agent_framework import ChatMessage
messages = [
ChatMessage(
role="assistant",
- contents=[FunctionCallContent(name="get_weather", call_id="call_xyz", arguments="{}")],
+ contents=[Content.from_function_call(name="get_weather", call_id="call_xyz", arguments="{}")],
),
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call_xyz", result="sunny")],
+ contents=[Content.from_function_result(call_id="call_xyz", result="sunny")],
),
ChatMessage(
role="user",
- contents=[TextContent(text="Thanks")],
+ contents=[Content.from_text(text="Thanks")],
),
]
@@ -452,7 +447,7 @@ class RecipeState(BaseModel):
default_options=DEFAULT_OPTIONS,
updates=[
AgentResponseUpdate(
- contents=[TextContent(text='{"ingredients": ["tomato"], "message": "Added tomato"}')],
+ contents=[Content.from_text(text='{"ingredients": ["tomato"], "message": "Added tomato"}')],
role="assistant",
)
],
@@ -641,13 +636,13 @@ async def test_all_messages_filtered_handling() -> None:
async def test_confirm_changes_with_invalid_json_fallback() -> None:
"""Test confirm_changes with invalid JSON falls back to normal processing."""
- from agent_framework import ChatMessage, FunctionCallContent, TextContent
+ from agent_framework import ChatMessage
messages = [
ChatMessage(
role="assistant",
contents=[
- FunctionCallContent(
+ Content.from_function_call(
name="confirm_changes",
call_id="call_confirm_invalid",
arguments='{"changes": "test"}',
@@ -656,7 +651,7 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None:
),
ChatMessage(
role="user",
- contents=[TextContent(text="invalid json {")],
+ contents=[Content.from_text(text="invalid json {")],
),
]
@@ -688,19 +683,18 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None:
async def test_confirm_changes_closes_active_message_before_finish() -> None:
"""Confirm-changes flow closes any active text message before run finishes."""
from ag_ui.core import TextMessageEndEvent, TextMessageStartEvent
- from agent_framework import FunctionCallContent, FunctionResultContent
updates = [
AgentResponseUpdate(
contents=[
- FunctionCallContent(
+ Content.from_function_call(
name="write_document_local",
call_id="call_1",
arguments='{"document": "Draft"}',
)
]
),
- AgentResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]),
+ AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]),
]
orchestrator = DefaultOrchestrator()
@@ -735,16 +729,16 @@ async def test_confirm_changes_closes_active_message_before_finish() -> None:
async def test_tool_result_kept_when_call_id_matches() -> None:
"""Test tool result is kept when call_id matches pending tool calls."""
- from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent
+ from agent_framework import ChatMessage
messages = [
ChatMessage(
role="assistant",
- contents=[FunctionCallContent(name="get_data", call_id="call_match", arguments="{}")],
+ contents=[Content.from_function_call(name="get_data", call_id="call_match", arguments="{}")],
),
ChatMessage(
role="tool",
- contents=[FunctionResultContent(call_id="call_match", result="data")],
+ contents=[Content.from_function_result(call_id="call_match", result="data")],
),
]
@@ -794,11 +788,11 @@ async def run_stream(
**kwargs: Any,
) -> AsyncGenerator[AgentResponseUpdate, None]:
self.messages_received = messages
- yield AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")
+ yield AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")
- from agent_framework import ChatMessage, TextContent
+ from agent_framework import ChatMessage
- messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
orchestrator = DefaultOrchestrator()
input_data: dict[str, Any] = {"messages": []}
@@ -820,9 +814,9 @@ async def run_stream(
async def test_initial_state_snapshot_with_array_schema() -> None:
"""Test state initialization with array type schema."""
- from agent_framework import ChatMessage, TextContent
+ from agent_framework import ChatMessage
- messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
orchestrator = DefaultOrchestrator()
input_data: dict[str, Any] = {"messages": [], "state": {}}
@@ -851,9 +845,9 @@ async def test_response_format_skip_text_content() -> None:
class OutputModel(BaseModel):
result: str
- from agent_framework import ChatMessage, TextContent
+ from agent_framework import ChatMessage
- messages = [ChatMessage(role="user", contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])]
orchestrator = DefaultOrchestrator()
input_data: dict[str, Any] = {"messages": []}
diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py
index 8c00f7b67c..eab60abf7a 100644
--- a/python/packages/ag-ui/tests/test_service_thread_id.py
+++ b/python/packages/ag-ui/tests/test_service_thread_id.py
@@ -7,7 +7,7 @@
from typing import Any
from ag_ui.core import RunFinishedEvent, RunStartedEvent
-from agent_framework import TextContent
+from agent_framework import Content
from agent_framework._types import AgentResponseUpdate, ChatResponseUpdate
sys.path.insert(0, str(Path(__file__).parent))
@@ -20,10 +20,10 @@ async def test_service_thread_id_when_there_are_updates():
updates: list[AgentResponseUpdate] = [
AgentResponseUpdate(
- contents=[TextContent(text="Hello, user!")],
+ contents=[Content.from_text(text="Hello, user!")],
response_id="resp_67890",
raw_representation=ChatResponseUpdate(
- contents=[TextContent(text="Hello, user!")],
+ contents=[Content.from_text(text="Hello, user!")],
conversation_id="conv_12345",
response_id="resp_67890",
),
diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py
index 469f5f5ad8..4b3f5ebb23 100644
--- a/python/packages/ag-ui/tests/test_shared_state.py
+++ b/python/packages/ag-ui/tests/test_shared_state.py
@@ -8,7 +8,7 @@
import pytest
from ag_ui.core import StateSnapshotEvent
-from agent_framework import ChatAgent, ChatResponseUpdate, TextContent
+from agent_framework import ChatAgent, ChatResponseUpdate, Content
from agent_framework_ag_ui._agent import AgentFrameworkAgent
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
@@ -20,7 +20,7 @@
@pytest.fixture
def mock_agent() -> ChatAgent:
"""Create a mock agent for testing."""
- updates = [ChatResponseUpdate(contents=[TextContent(text="Hello!")])]
+ updates = [ChatResponseUpdate(contents=[Content.from_text(text="Hello!")])]
chat_client = StreamingChatClientStub(stream_from_updates(updates))
return ChatAgent(name="test_agent", instructions="Test agent", chat_client=chat_client)
diff --git a/python/packages/ag-ui/tests/test_state_manager.py b/python/packages/ag-ui/tests/test_state_manager.py
index bc0a7b6a19..47b2940978 100644
--- a/python/packages/ag-ui/tests/test_state_manager.py
+++ b/python/packages/ag-ui/tests/test_state_manager.py
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from ag_ui.core import CustomEvent, EventType
-from agent_framework import ChatMessage, TextContent
+from agent_framework import ChatMessage
from agent_framework_ag_ui._events import AgentFrameworkEventBridge
from agent_framework_ag_ui._orchestration._state_manager import StateManager
@@ -47,5 +47,5 @@ def test_state_context_only_when_new_user_turn() -> None:
message = state_manager.state_context_message(is_new_user_turn=True, conversation_has_tool_calls=False)
assert isinstance(message, ChatMessage)
- assert isinstance(message.contents[0], TextContent)
+ assert message.contents[0].type == "text"
assert "Current state of the application" in message.contents[0].text
diff --git a/python/packages/ag-ui/tests/test_structured_output.py b/python/packages/ag-ui/tests/test_structured_output.py
index b9a04353be..7c623f62d6 100644
--- a/python/packages/ag-ui/tests/test_structured_output.py
+++ b/python/packages/ag-ui/tests/test_structured_output.py
@@ -8,7 +8,7 @@
from pathlib import Path
from typing import Any
-from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, TextContent
+from agent_framework import ChatAgent, ChatMessage, ChatOptions, ChatResponseUpdate, Content
from pydantic import BaseModel
sys.path.insert(0, str(Path(__file__).parent))
@@ -43,7 +43,7 @@ async def stream_fn(
messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(
- contents=[TextContent(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
+ contents=[Content.from_text(text='{"recipe": {"name": "Pasta"}, "message": "Here is your recipe"}')]
)
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
@@ -86,7 +86,7 @@ async def stream_fn(
{"id": "2", "description": "Step 2", "status": "pending"},
]
}
- yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(steps_data))])
+ yield ChatResponseUpdate(contents=[Content.from_text(text=json.dumps(steps_data))])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.default_options = ChatOptions(response_format=StepsOutput)
@@ -118,7 +118,7 @@ async def test_structured_output_with_no_schema_match():
from agent_framework.ag_ui import AgentFrameworkAgent
updates = [
- ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}}')]),
+ ChatResponseUpdate(contents=[Content.from_text(text='{"data": {"key": "value"}}')]),
]
agent = ChatAgent(
@@ -156,7 +156,7 @@ class DataOutput(BaseModel):
async def stream_fn(
messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
- yield ChatResponseUpdate(contents=[TextContent(text='{"data": {"key": "value"}, "info": "processed"}')])
+ yield ChatResponseUpdate(contents=[Content.from_text(text='{"data": {"key": "value"}, "info": "processed"}')])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.default_options = ChatOptions(response_format=DataOutput)
@@ -185,7 +185,7 @@ async def test_no_structured_output_when_no_response_format():
"""Test that structured output path is skipped when no response_format."""
from agent_framework.ag_ui import AgentFrameworkAgent
- updates = [ChatResponseUpdate(contents=[TextContent(text="Regular text")])]
+ updates = [ChatResponseUpdate(contents=[Content.from_text(text="Regular text")])]
agent = ChatAgent(
name="test",
@@ -216,7 +216,7 @@ async def stream_fn(
messages: MutableSequence[ChatMessage], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
output_data = {"recipe": {"name": "Salad"}, "message": "Fresh salad recipe ready"}
- yield ChatResponseUpdate(contents=[TextContent(text=json.dumps(output_data))])
+ yield ChatResponseUpdate(contents=[Content.from_text(text=json.dumps(output_data))])
agent = ChatAgent(name="test", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
agent.default_options = ChatOptions(response_format=RecipeOutput)
diff --git a/python/packages/ag-ui/tests/utils_test_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py
index c3fa590cd1..33f462257e 100644
--- a/python/packages/ag-ui/tests/utils_test_ag_ui.py
+++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py
@@ -16,7 +16,7 @@
ChatMessage,
ChatResponse,
ChatResponseUpdate,
- TextContent,
+ Content,
)
from agent_framework._clients import TOptions_co
@@ -91,7 +91,7 @@ def __init__(
self.id = agent_id
self.name = agent_name
self.description = "stub agent"
- self.updates = updates or [AgentResponseUpdate(contents=[TextContent(text="response")], role="assistant")]
+ self.updates = updates or [AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")]
self.default_options: dict[str, Any] = (
default_options if isinstance(default_options, dict) else {"tools": None, "response_format": None}
)
diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py
index c9223e614b..4fdcdfadc7 100644
--- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py
+++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py
@@ -7,31 +7,19 @@
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
AIFunction,
- Annotations,
+ Annotation,
BaseChatClient,
ChatMessage,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
- CitationAnnotation,
- CodeInterpreterToolCallContent,
- CodeInterpreterToolResultContent,
- Contents,
- ErrorContent,
+ Content,
FinishReason,
- FunctionCallContent,
- FunctionResultContent,
HostedCodeInterpreterTool,
- HostedFileContent,
HostedMCPTool,
HostedWebSearchTool,
- MCPServerToolCallContent,
- MCPServerToolResultContent,
Role,
- TextContent,
- TextReasoningContent,
TextSpanRegion,
- UsageContent,
UsageDetails,
get_logger,
prepare_function_call_results,
@@ -486,7 +474,7 @@ def _prepare_message_for_anthropic(self, message: ChatMessage) -> dict[str, Any]
a_content.append({
"type": "image",
"source": {
- "data": content.get_data_bytes_as_str(),
+ "data": content.get_data_bytes_as_str(), # type: ignore[attr-defined]
"media_type": content.media_type,
"type": "base64",
},
@@ -653,9 +641,9 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons
"""
match event.type:
case "message_start":
- usage_details: list[UsageContent] = []
+ usage_details: list[Content] = []
if event.message.usage and (details := self._parse_usage_from_anthropic(event.message.usage)):
- usage_details.append(UsageContent(details=details))
+ usage_details.append(Content.from_usage(usage_details=details))
return ChatResponseUpdate(
response_id=event.message.id,
@@ -672,7 +660,7 @@ def _process_stream_event(self, event: BetaRawMessageStreamEvent) -> ChatRespons
case "message_delta":
usage = self._parse_usage_from_anthropic(event.usage)
return ChatResponseUpdate(
- contents=[UsageContent(details=usage, raw_representation=event.usage)] if usage else [],
+ contents=[Content.from_usage(usage_details=usage, raw_representation=event.usage)] if usage else [],
finish_reason=FINISH_REASON_MAP.get(event.delta.stop_reason) if event.delta.stop_reason else None,
raw_representation=event,
)
@@ -702,24 +690,24 @@ def _parse_usage_from_anthropic(self, usage: BetaUsage | BetaMessageDeltaUsage |
return None
usage_details = UsageDetails(output_token_count=usage.output_tokens)
if usage.input_tokens is not None:
- usage_details.input_token_count = usage.input_tokens
+ usage_details["input_token_count"] = usage.input_tokens
if usage.cache_creation_input_tokens is not None:
- usage_details.additional_counts["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens
+ usage_details["anthropic.cache_creation_input_tokens"] = usage.cache_creation_input_tokens # type: ignore[typeddict-unknown-key]
if usage.cache_read_input_tokens is not None:
- usage_details.additional_counts["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens
+ usage_details["anthropic.cache_read_input_tokens"] = usage.cache_read_input_tokens # type: ignore[typeddict-unknown-key]
return usage_details
def _parse_contents_from_anthropic(
self,
content: Sequence[BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock],
- ) -> list[Contents]:
+ ) -> list[Content]:
"""Parse contents from the Anthropic message."""
- contents: list[Contents] = []
+ contents: list[Content] = []
for content_block in content:
match content_block.type:
case "text" | "text_delta":
contents.append(
- TextContent(
+ Content.from_text(
text=content_block.text,
raw_representation=content_block,
annotations=self._parse_citations_from_anthropic(content_block),
@@ -729,7 +717,7 @@ def _parse_contents_from_anthropic(
self._last_call_id_name = (content_block.id, content_block.name)
if content_block.type == "mcp_tool_use":
contents.append(
- MCPServerToolCallContent(
+ Content.from_mcp_server_tool_call(
call_id=content_block.id,
tool_name=content_block.name,
server_name=None,
@@ -739,10 +727,10 @@ def _parse_contents_from_anthropic(
)
elif "code_execution" in (content_block.name or ""):
contents.append(
- CodeInterpreterToolCallContent(
+ Content.from_code_interpreter_tool_call(
call_id=content_block.id,
inputs=[
- TextContent(
+ Content.from_text(
text=str(content_block.input),
raw_representation=content_block,
)
@@ -752,7 +740,7 @@ def _parse_contents_from_anthropic(
)
else:
contents.append(
- FunctionCallContent(
+ Content.from_function_call(
call_id=content_block.id,
name=content_block.name,
arguments=content_block.input,
@@ -760,14 +748,14 @@ def _parse_contents_from_anthropic(
)
)
case "mcp_tool_result":
- call_id, name = self._last_call_id_name or (None, None)
- parsed_output: list[Contents] | None = None
+ call_id, _ = self._last_call_id_name or (None, None)
+ parsed_output: list[Content] | None = None
if content_block.content:
if isinstance(content_block.content, list):
parsed_output = self._parse_contents_from_anthropic(content_block.content)
elif isinstance(content_block.content, (str, bytes)):
parsed_output = [
- TextContent(
+ Content.from_text(
text=str(content_block.content),
raw_representation=content_block,
)
@@ -775,28 +763,27 @@ def _parse_contents_from_anthropic(
else:
parsed_output = self._parse_contents_from_anthropic([content_block.content])
contents.append(
- MCPServerToolResultContent(
+ Content.from_mcp_server_tool_result(
call_id=content_block.tool_use_id,
output=parsed_output,
raw_representation=content_block,
)
)
case "web_search_tool_result" | "web_fetch_tool_result":
- call_id, name = self._last_call_id_name or (None, None)
+ call_id, _ = self._last_call_id_name or (None, None)
contents.append(
- FunctionResultContent(
+ Content.from_function_result(
call_id=content_block.tool_use_id,
- name=name if name and call_id == content_block.tool_use_id else "web_tool",
result=content_block.content,
raw_representation=content_block,
)
)
case "code_execution_tool_result":
- code_outputs: list[Contents] = []
+ code_outputs: list[Content] = []
if content_block.content:
if isinstance(content_block.content, BetaCodeExecutionToolResultError):
code_outputs.append(
- ErrorContent(
+ Content.from_error(
message=content_block.content.error_code,
raw_representation=content_block.content,
)
@@ -804,41 +791,41 @@ def _parse_contents_from_anthropic(
else:
if content_block.content.stdout:
code_outputs.append(
- TextContent(
+ Content.from_text(
text=content_block.content.stdout,
raw_representation=content_block.content,
)
)
if content_block.content.stderr:
code_outputs.append(
- ErrorContent(
+ Content.from_error(
message=content_block.content.stderr,
raw_representation=content_block.content,
)
)
for code_file_content in content_block.content.content:
code_outputs.append(
- HostedFileContent(
+ Content.from_hosted_file(
file_id=code_file_content.file_id,
raw_representation=code_file_content,
)
)
contents.append(
- CodeInterpreterToolResultContent(
+ Content.from_code_interpreter_tool_result(
call_id=content_block.tool_use_id,
raw_representation=content_block,
outputs=code_outputs,
)
)
case "bash_code_execution_tool_result":
- bash_outputs: list[Contents] = []
+ bash_outputs: list[Content] = []
if content_block.content:
if isinstance(
content_block.content,
BetaBashCodeExecutionToolResultError,
):
bash_outputs.append(
- ErrorContent(
+ Content.from_error(
message=content_block.content.error_code,
raw_representation=content_block.content,
)
@@ -846,39 +833,38 @@ def _parse_contents_from_anthropic(
else:
if content_block.content.stdout:
bash_outputs.append(
- TextContent(
+ Content.from_text(
text=content_block.content.stdout,
raw_representation=content_block.content,
)
)
if content_block.content.stderr:
bash_outputs.append(
- ErrorContent(
+ Content.from_error(
message=content_block.content.stderr,
raw_representation=content_block.content,
)
)
for bash_file_content in content_block.content.content:
contents.append(
- HostedFileContent(
+ Content.from_hosted_file(
file_id=bash_file_content.file_id,
raw_representation=bash_file_content,
)
)
contents.append(
- FunctionResultContent(
+ Content.from_function_result(
call_id=content_block.tool_use_id,
- name=content_block.type,
result=bash_outputs,
raw_representation=content_block,
)
)
case "text_editor_code_execution_tool_result":
- text_editor_outputs: list[Contents] = []
+ text_editor_outputs: list[Content] = []
match content_block.content.type:
case "text_editor_code_execution_tool_result_error":
text_editor_outputs.append(
- ErrorContent(
+ Content.from_error(
message=content_block.content.error_code
and getattr(content_block.content, "error_message", ""),
raw_representation=content_block.content,
@@ -887,10 +873,12 @@ def _parse_contents_from_anthropic(
case "text_editor_code_execution_view_result":
annotations = (
[
- CitationAnnotation(
+ Annotation(
+ type="citation",
raw_representation=content_block.content,
annotated_regions=[
TextSpanRegion(
+ type="text_span",
start_index=content_block.content.start_line,
end_index=content_block.content.start_line
+ (content_block.content.num_lines or 0),
@@ -903,7 +891,7 @@ def _parse_contents_from_anthropic(
else None
)
text_editor_outputs.append(
- TextContent(
+ Content.from_text(
text=content_block.content.content,
annotations=annotations,
raw_representation=content_block.content,
@@ -911,10 +899,12 @@ def _parse_contents_from_anthropic(
)
case "text_editor_code_execution_str_replace_result":
old_annotation = (
- CitationAnnotation(
+ Annotation(
+ type="citation",
raw_representation=content_block.content,
annotated_regions=[
TextSpanRegion(
+ type="text_span",
start_index=content_block.content.old_start or 0,
end_index=(
(content_block.content.old_start or 0)
@@ -928,13 +918,15 @@ def _parse_contents_from_anthropic(
else None
)
new_annotation = (
- CitationAnnotation(
+ Annotation(
+ type="citation",
raw_representation=content_block.content,
- snippet="\n".join(content_block.content.lines)
+ snippet="\n".join(content_block.content.lines) # type: ignore[typeddict-item]
if content_block.content.lines
else None,
annotated_regions=[
TextSpanRegion(
+ type="text_span",
start_index=content_block.content.new_start or 0,
end_index=(
(content_block.content.new_start or 0)
@@ -950,7 +942,7 @@ def _parse_contents_from_anthropic(
annotations = [ann for ann in [old_annotation, new_annotation] if ann is not None]
text_editor_outputs.append(
- TextContent(
+ Content.from_text(
text=(
"\n".join(content_block.content.lines) if content_block.content.lines else ""
),
@@ -960,15 +952,14 @@ def _parse_contents_from_anthropic(
)
case "text_editor_code_execution_create_result":
text_editor_outputs.append(
- TextContent(
+ Content.from_text(
text=f"File update: {content_block.content.is_file_update}",
raw_representation=content_block.content,
)
)
contents.append(
- FunctionResultContent(
+ Content.from_function_result(
call_id=content_block.tool_use_id,
- name=content_block.type,
result=text_editor_outputs,
raw_representation=content_block,
)
@@ -981,7 +972,7 @@ def _parse_contents_from_anthropic(
# This matches OpenAI's behavior where streaming chunks have name="".
call_id, _ = self._last_call_id_name if self._last_call_id_name else ("", "")
contents.append(
- FunctionCallContent(
+ Content.from_function_call(
call_id=call_id,
name="",
arguments=content_block.partial_json,
@@ -990,7 +981,7 @@ def _parse_contents_from_anthropic(
)
case "thinking" | "thinking_delta":
contents.append(
- TextReasoningContent(
+ Content.from_text_reasoning(
text=content_block.thinking,
raw_representation=content_block,
)
@@ -1001,65 +992,65 @@ def _parse_contents_from_anthropic(
def _parse_citations_from_anthropic(
self, content_block: BetaContentBlock | BetaRawContentBlockDelta | BetaTextBlock
- ) -> list[Annotations] | None:
- content_citations = getattr(content_block, "citations", None)
- if not content_citations:
+ ) -> list[Annotation] | None:
+ content_blocks = getattr(content_block, "citations", None)
+ if not content_blocks:
return None
- annotations: list[Annotations] = []
- for citation in content_citations:
- cit = CitationAnnotation(raw_representation=citation)
+ annotations: list[Annotation] = []
+ for citation in content_blocks:
+ cit = Annotation(type="citation", raw_representation=citation)
match citation.type:
case "char_location":
- cit.title = citation.title
- cit.snippet = citation.cited_text
+ cit["title"] = citation.title
+ cit["snippet"] = citation.cited_text
if citation.file_id:
- cit.file_id = citation.file_id
- if not cit.annotated_regions:
- cit.annotated_regions = []
- cit.annotated_regions.append(
+ cit["file_id"] = citation.file_id
+ cit.setdefault("annotated_regions", [])
+ cit["annotated_regions"].append( # type: ignore[attr-defined]
TextSpanRegion(
+ type="text_span",
start_index=citation.start_char_index,
end_index=citation.end_char_index,
)
)
case "page_location":
- cit.title = citation.document_title
- cit.snippet = citation.cited_text
+ cit["title"] = citation.document_title
+ cit["snippet"] = citation.cited_text
if citation.file_id:
- cit.file_id = citation.file_id
- if not cit.annotated_regions:
- cit.annotated_regions = []
- cit.annotated_regions.append(
+ cit["file_id"] = citation.file_id
+ cit.setdefault("annotated_regions", [])
+ cit["annotated_regions"].append( # type: ignore[attr-defined]
TextSpanRegion(
+ type="text_span",
start_index=citation.start_page_number,
end_index=citation.end_page_number,
)
)
case "content_block_location":
- cit.title = citation.document_title
- cit.snippet = citation.cited_text
+ cit["title"] = citation.document_title
+ cit["snippet"] = citation.cited_text
if citation.file_id:
- cit.file_id = citation.file_id
- if not cit.annotated_regions:
- cit.annotated_regions = []
- cit.annotated_regions.append(
+ cit["file_id"] = citation.file_id
+ cit.setdefault("annotated_regions", [])
+ cit["annotated_regions"].append( # type: ignore[attr-defined]
TextSpanRegion(
+ type="text_span",
start_index=citation.start_block_index,
end_index=citation.end_block_index,
)
)
case "web_search_result_location":
- cit.title = citation.title
- cit.snippet = citation.cited_text
- cit.url = citation.url
+ cit["title"] = citation.title
+ cit["snippet"] = citation.cited_text
+ cit["url"] = citation.url
case "search_result_location":
- cit.title = citation.title
- cit.snippet = citation.cited_text
- cit.url = citation.source
- if not cit.annotated_regions:
- cit.annotated_regions = []
- cit.annotated_regions.append(
+ cit["title"] = citation.title
+ cit["snippet"] = citation.cited_text
+ cit["url"] = citation.source
+ cit.setdefault("annotated_regions", [])
+ cit["annotated_regions"].append( # type: ignore[attr-defined]
TextSpanRegion(
+ type="text_span",
start_index=citation.start_block_index,
end_index=citation.end_block_index,
)
diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py
index 828d9916c2..4476c6b3b6 100644
--- a/python/packages/anthropic/tests/test_anthropic_client.py
+++ b/python/packages/anthropic/tests/test_anthropic_client.py
@@ -10,16 +10,12 @@
ChatMessage,
ChatOptions,
ChatResponseUpdate,
- DataContent,
+ Content,
FinishReason,
- FunctionCallContent,
- FunctionResultContent,
HostedCodeInterpreterTool,
HostedMCPTool,
HostedWebSearchTool,
Role,
- TextContent,
- TextReasoningContent,
ai_function,
)
from agent_framework.exceptions import ServiceInitializationError
@@ -170,7 +166,7 @@ def test_prepare_message_for_anthropic_function_call(mock_anthropic_client: Magi
message = ChatMessage(
role=Role.ASSISTANT,
contents=[
- FunctionCallContent(
+ Content.from_function_call(
call_id="call_123",
name="get_weather",
arguments={"location": "San Francisco"},
@@ -194,9 +190,8 @@ def test_prepare_message_for_anthropic_function_result(mock_anthropic_client: Ma
message = ChatMessage(
role=Role.TOOL,
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id="call_123",
- name="get_weather",
result="Sunny, 72°F",
)
],
@@ -219,7 +214,7 @@ def test_prepare_message_for_anthropic_text_reasoning(mock_anthropic_client: Mag
chat_client = create_test_anthropic_client(mock_anthropic_client)
message = ChatMessage(
role=Role.ASSISTANT,
- contents=[TextReasoningContent(text="Let me think about this...")],
+ contents=[Content.from_text_reasoning(text="Let me think about this...")],
)
result = chat_client._prepare_message_for_anthropic(message)
@@ -507,12 +502,12 @@ def test_process_message_basic(mock_anthropic_client: MagicMock) -> None:
assert len(response.messages) == 1
assert response.messages[0].role == Role.ASSISTANT
assert len(response.messages[0].contents) == 1
- assert isinstance(response.messages[0].contents[0], TextContent)
+ assert response.messages[0].contents[0].type == "text"
assert response.messages[0].contents[0].text == "Hello there!"
assert response.finish_reason == FinishReason.STOP
assert response.usage_details is not None
- assert response.usage_details.input_token_count == 10
- assert response.usage_details.output_token_count == 5
+ assert response.usage_details["input_token_count"] == 10
+ assert response.usage_details["output_token_count"] == 5
def test_process_message_with_tool_use(mock_anthropic_client: MagicMock) -> None:
@@ -536,7 +531,7 @@ def test_process_message_with_tool_use(mock_anthropic_client: MagicMock) -> None
response = chat_client._process_message(mock_message)
assert len(response.messages[0].contents) == 1
- assert isinstance(response.messages[0].contents[0], FunctionCallContent)
+ assert response.messages[0].contents[0].type == "function_call"
assert response.messages[0].contents[0].call_id == "call_123"
assert response.messages[0].contents[0].name == "get_weather"
assert response.finish_reason == FinishReason.TOOL_CALLS
@@ -550,8 +545,8 @@ def test_parse_usage_from_anthropic_basic(mock_anthropic_client: MagicMock) -> N
result = chat_client._parse_usage_from_anthropic(usage)
assert result is not None
- assert result.input_token_count == 10
- assert result.output_token_count == 5
+ assert result["input_token_count"] == 10
+ assert result["output_token_count"] == 5
def test_parse_usage_from_anthropic_none(mock_anthropic_client: MagicMock) -> None:
@@ -571,7 +566,7 @@ def test_parse_contents_from_anthropic_text(mock_anthropic_client: MagicMock) ->
result = chat_client._parse_contents_from_anthropic(content)
assert len(result) == 1
- assert isinstance(result[0], TextContent)
+ assert result[0].type == "text"
assert result[0].text == "Hello!"
@@ -590,7 +585,7 @@ def test_parse_contents_from_anthropic_tool_use(mock_anthropic_client: MagicMock
result = chat_client._parse_contents_from_anthropic(content)
assert len(result) == 1
- assert isinstance(result[0], FunctionCallContent)
+ assert result[0].type == "function_call"
assert result[0].call_id == "call_123"
assert result[0].name == "get_weather"
@@ -613,7 +608,7 @@ def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_a
result = chat_client._parse_contents_from_anthropic([tool_use_content])
assert len(result) == 1
- assert isinstance(result[0], FunctionCallContent)
+ assert result[0].type == "function_call"
assert result[0].call_id == "call_123"
assert result[0].name == "get_weather" # Initial event has name
@@ -624,7 +619,7 @@ def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_a
result = chat_client._parse_contents_from_anthropic([delta_content_1])
assert len(result) == 1
- assert isinstance(result[0], FunctionCallContent)
+ assert result[0].type == "function_call"
assert result[0].call_id == "call_123"
assert result[0].name == "" # Delta events should have empty name
assert result[0].arguments == '{"location":'
@@ -636,7 +631,7 @@ def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_a
result = chat_client._parse_contents_from_anthropic([delta_content_2])
assert len(result) == 1
- assert isinstance(result[0], FunctionCallContent)
+ assert result[0].type == "function_call"
assert result[0].call_id == "call_123"
assert result[0].name == "" # Still empty name for subsequent deltas
assert result[0].arguments == '"San Francisco"}'
@@ -771,9 +766,7 @@ async def test_anthropic_client_integration_function_calling() -> None:
assert response is not None
# Should contain function call
- has_function_call = any(
- isinstance(content, FunctionCallContent) for msg in response.messages for content in msg.contents
- )
+ has_function_call = any(content.type == "function_call" for msg in response.messages for content in msg.contents)
assert has_function_call
@@ -872,8 +865,8 @@ async def test_anthropic_client_integration_images() -> None:
ChatMessage(
role=Role.USER,
contents=[
- TextContent(text="Describe this image"),
- DataContent(media_type="image/jpeg", data=image_bytes),
+ Content.from_text(text="Describe this image"),
+ Content.from_data(media_type="image/jpeg", data=image_bytes),
],
),
]
diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py
index 6973517c14..4bb646da19 100644
--- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py
+++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py
@@ -2,6 +2,7 @@
import ast
import json
+import os
import re
import sys
from collections.abc import AsyncIterable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
@@ -9,6 +10,8 @@
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
+ AIFunction,
+ Annotation,
BaseChatClient,
ChatAgent,
ChatMessage,
@@ -16,23 +19,16 @@
ChatOptions,
ChatResponse,
ChatResponseUpdate,
- CitationAnnotation,
- Contents,
+ Content,
ContextProvider,
- DataContent,
- FunctionApprovalRequestContent,
- FunctionApprovalResponseContent,
- FunctionCallContent,
- FunctionResultContent,
- HostedFileContent,
+ HostedCodeInterpreterTool,
+ HostedFileSearchTool,
HostedMCPTool,
+ HostedWebSearchTool,
Middleware,
Role,
- TextContent,
TextSpanRegion,
ToolProtocol,
- UriContent,
- UsageContent,
UsageDetails,
get_logger,
prepare_function_call_results,
@@ -50,9 +46,14 @@
AgentStreamEvent,
AsyncAgentEventHandler,
AsyncAgentRunStream,
+ BingCustomSearchTool,
+ BingGroundingTool,
+ CodeInterpreterToolDefinition,
+ FileSearchTool,
FunctionName,
FunctionToolDefinition,
ListSortOrder,
+ McpTool,
MessageDeltaChunk,
MessageDeltaTextContent,
MessageDeltaTextFileCitationAnnotation,
@@ -422,7 +423,7 @@ async def _create_agent_stream(
self,
agent_id: str,
run_options: dict[str, Any],
- required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None,
+ required_action_results: list[Content] | None,
) -> tuple[AsyncAgentRunStream[AsyncAgentEventHandler[Any]] | AsyncAgentEventHandler[Any], str]:
"""Create the agent stream for processing.
@@ -506,9 +507,9 @@ async def _prepare_thread(
def _extract_url_citations(
self, message_delta_chunk: MessageDeltaChunk, azure_search_tool_calls: list[dict[str, Any]]
- ) -> list[CitationAnnotation]:
+ ) -> list[Annotation]:
"""Extract URL citations from MessageDeltaChunk."""
- url_citations: list[CitationAnnotation] = []
+ url_citations: list[Annotation] = []
# Process each content item in the delta to find citations
for content in message_delta_chunk.delta.content:
@@ -520,6 +521,7 @@ def _extract_url_citations(
if annotation.start_index and annotation.end_index:
annotated_regions = [
TextSpanRegion(
+ type="text_span",
start_index=annotation.start_index,
end_index=annotation.end_index,
)
@@ -530,11 +532,12 @@ def _extract_url_citations(
annotation.url_citation.url, azure_search_tool_calls
)
- # Create CitationAnnotation with real URL
- citation = CitationAnnotation(
- title=getattr(annotation.url_citation, "title", None),
+ # Create Annotation with real URL
+ citation = Annotation(
+ type="citation",
+ title=annotation.url_citation.title, # type: ignore[typeddict-item]
url=real_url,
- snippet=None,
+ snippet=None, # type: ignore[typeddict-item]
annotated_regions=annotated_regions,
raw_representation=annotation,
)
@@ -542,7 +545,7 @@ def _extract_url_citations(
return url_citations
- def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> list[HostedFileContent]:
+ def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> list[Content]:
"""Extract file references from MessageDeltaChunk annotations.
Code interpreter generates files that are referenced via file path or file citation
@@ -559,7 +562,7 @@ def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) ->
Returns:
List of HostedFileContent objects for any files referenced in annotations
"""
- file_contents: list[HostedFileContent] = []
+ file_contents: list[Content] = []
for content in message_delta_chunk.delta.content:
if isinstance(content, MessageDeltaTextContent) and content.text and content.text.annotations:
@@ -570,14 +573,14 @@ def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) ->
if file_path is not None:
file_id = getattr(file_path, "file_id", None)
if file_id:
- file_contents.append(HostedFileContent(file_id=file_id))
+ file_contents.append(Content.from_hosted_file(file_id=file_id))
elif isinstance(annotation, MessageDeltaTextFileCitationAnnotation):
# Extract file_id from the file_citation annotation
file_citation = getattr(annotation, "file_citation", None)
if file_citation is not None:
file_id = getattr(file_citation, "file_id", None)
if file_id:
- file_contents.append(HostedFileContent(file_id=file_id))
+ file_contents.append(Content.from_hosted_file(file_id=file_id))
return file_contents
@@ -644,9 +647,9 @@ async def _process_stream(
file_contents = self._extract_file_path_contents(event_data)
# Create contents with citations if any exist
- citation_content: list[Contents] = []
+ citation_content: list[Content] = []
if event_data.text or url_citations:
- text_content_obj = TextContent(text=event_data.text or "")
+ text_content_obj = Content.from_text(text=event_data.text or "")
if url_citations:
text_content_obj.annotations = url_citations
citation_content.append(text_content_obj)
@@ -722,7 +725,7 @@ async def _process_stream(
self._capture_azure_search_tool_calls(event_data, azure_search_tool_calls)
if event_data.usage:
- usage_content = UsageContent(
+ usage_content = Content.from_usage(
UsageDetails(
input_token_count=event_data.usage.prompt_tokens,
output_token_count=event_data.usage.completion_tokens,
@@ -757,19 +760,21 @@ async def _process_stream(
tool_call.code_interpreter,
RunStepDeltaCodeInterpreterDetailItemObject,
):
- code_contents: list[Contents] = []
+ code_contents: list[Content] = []
if tool_call.code_interpreter.input is not None:
logger.debug(f"Code Interpreter Input: {tool_call.code_interpreter.input}")
if tool_call.code_interpreter.outputs is not None:
for output in tool_call.code_interpreter.outputs:
if isinstance(output, RunStepDeltaCodeInterpreterLogOutput) and output.logs:
- code_contents.append(TextContent(text=output.logs))
+ code_contents.append(Content.from_text(text=output.logs))
if (
isinstance(output, RunStepDeltaCodeInterpreterImageOutput)
and output.image is not None
and output.image.file_id is not None
):
- code_contents.append(HostedFileContent(file_id=output.image.file_id))
+ code_contents.append(
+ Content.from_hosted_file(file_id=output.image.file_id)
+ )
yield ChatResponseUpdate(
role=Role.ASSISTANT,
contents=code_contents,
@@ -822,12 +827,12 @@ def _capture_azure_search_tool_calls(
except Exception as ex:
logger.debug(f"Failed to capture Azure AI Search tool call: {ex}")
- def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id: str | None) -> list[Contents]:
+ def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id: str | None) -> list[Content]:
"""Parse function call contents from an Azure AI tool action event."""
if isinstance(event_data, ThreadRun) and event_data.required_action is not None:
if isinstance(event_data.required_action, SubmitToolOutputsAction):
return [
- FunctionCallContent(
+ Content.from_function_call(
call_id=f'["{response_id}", "{tool.id}"]',
name=tool.function.name,
arguments=tool.function.arguments,
@@ -837,9 +842,9 @@ def _parse_function_calls_from_azure_ai(self, event_data: ThreadRun, response_id
]
if isinstance(event_data.required_action, SubmitToolApprovalAction):
return [
- FunctionApprovalRequestContent(
+ Content.from_function_approval_request(
id=f'["{response_id}", "{tool.id}"]',
- function_call=FunctionCallContent(
+ function_call=Content.from_function_call(
call_id=f'["{response_id}", "{tool.id}"]',
name=tool.name,
arguments=tool.arguments,
@@ -875,7 +880,7 @@ async def _prepare_options(
messages: MutableSequence[ChatMessage],
options: Mapping[str, Any],
**kwargs: Any,
- ) -> tuple[dict[str, Any], list[FunctionResultContent | FunctionApprovalResponseContent] | None]:
+ ) -> tuple[dict[str, Any], list[Content] | None]:
agent_definition = await self._load_agent_definition_if_needed()
# Build run_options from options dict, excluding specific keys
@@ -1052,7 +1057,7 @@ def _prepare_messages(
) -> tuple[
list[ThreadMessageOptions] | None,
list[str],
- list[FunctionResultContent | FunctionApprovalResponseContent] | None,
+ list[Content] | None,
]:
"""Prepare messages for Azure AI Agents API.
@@ -1064,28 +1069,34 @@ def _prepare_messages(
Tuple of (additional_messages, instructions, required_action_results)
"""
instructions: list[str] = []
- required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None = None
+ required_action_results: list[Content] | None = None
additional_messages: list[ThreadMessageOptions] | None = None
for chat_message in messages:
if chat_message.role.value in ["system", "developer"]:
- for text_content in [content for content in chat_message.contents if isinstance(content, TextContent)]:
- instructions.append(text_content.text)
+ for text_content in [content for content in chat_message.contents if content.type == "text"]:
+ instructions.append(text_content.text) # type: ignore[arg-type]
continue
message_contents: list[MessageInputContentBlock] = []
for content in chat_message.contents:
- if isinstance(content, TextContent):
- message_contents.append(MessageInputTextBlock(text=content.text))
- elif isinstance(content, (DataContent, UriContent)) and content.has_top_level_media_type("image"):
- message_contents.append(MessageInputImageUrlBlock(image_url=MessageImageUrlParam(url=content.uri)))
- elif isinstance(content, (FunctionResultContent, FunctionApprovalResponseContent)):
- if required_action_results is None:
- required_action_results = []
- required_action_results.append(content)
- elif isinstance(content.raw_representation, MessageInputContentBlock):
- message_contents.append(content.raw_representation)
+ match content.type:
+ case "text":
+ message_contents.append(MessageInputTextBlock(text=content.text)) # type: ignore[arg-type]
+ case "data" | "uri":
+ if content.has_top_level_media_type("image"):
+ message_contents.append(
+ MessageInputImageUrlBlock(image_url=MessageImageUrlParam(url=content.uri)) # type: ignore[arg-type]
+ )
+ # Only images are supported. Other media types are ignored.
+ case "function_result" | "function_approval_response":
+ if required_action_results is None:
+ required_action_results = []
+ required_action_results.append(content)
+ case _:
+ if isinstance(content.raw_representation, MessageInputContentBlock):
+ message_contents.append(content.raw_representation)
if message_contents:
if additional_messages is None:
@@ -1099,9 +1110,85 @@ def _prepare_messages(
return additional_messages, instructions, required_action_results
+ async def _prepare_tools_for_azure_ai(
+ self, tools: Sequence["ToolProtocol | MutableMapping[str, Any]"], run_options: dict[str, Any] | None = None
+ ) -> list[ToolDefinition | dict[str, Any]]:
+ """Prepare tool definitions for the Azure AI Agents API."""
+ tool_definitions: list[ToolDefinition | dict[str, Any]] = []
+ for tool in tools:
+ match tool:
+ case AIFunction():
+ tool_definitions.append(tool.to_json_schema_spec()) # type: ignore[reportUnknownArgumentType]
+ case HostedWebSearchTool():
+ additional_props = tool.additional_properties or {}
+ config_args: dict[str, Any] = {}
+ if count := additional_props.get("count"):
+ config_args["count"] = count
+ if freshness := additional_props.get("freshness"):
+ config_args["freshness"] = freshness
+ if market := additional_props.get("market"):
+ config_args["market"] = market
+ if set_lang := additional_props.get("set_lang"):
+ config_args["set_lang"] = set_lang
+ # Bing Grounding
+ connection_id = additional_props.get("connection_id") or os.getenv("BING_CONNECTION_ID")
+ # Custom Bing Search
+ custom_connection_id = additional_props.get("custom_connection_id") or os.getenv(
+ "BING_CUSTOM_CONNECTION_ID"
+ )
+ custom_instance_name = additional_props.get("custom_instance_name") or os.getenv(
+ "BING_CUSTOM_INSTANCE_NAME"
+ )
+ bing_search: BingGroundingTool | BingCustomSearchTool | None = None
+ if (connection_id) and not custom_connection_id and not custom_instance_name:
+ if connection_id:
+ conn_id = connection_id
+ else:
+ raise ServiceInitializationError("Parameter connection_id is not provided.")
+ bing_search = BingGroundingTool(connection_id=conn_id, **config_args)
+ if custom_connection_id and custom_instance_name:
+ bing_search = BingCustomSearchTool(
+ connection_id=custom_connection_id,
+ instance_name=custom_instance_name,
+ **config_args,
+ )
+ if not bing_search:
+ raise ServiceInitializationError(
+ "Bing search tool requires either 'connection_id' for Bing Grounding "
+ "or both 'custom_connection_id' and 'custom_instance_name' for Custom Bing Search. "
+ "These can be provided via additional_properties or environment variables: "
+ "'BING_CONNECTION_ID', 'BING_CUSTOM_CONNECTION_ID', "
+ "'BING_CUSTOM_INSTANCE_NAME'"
+ )
+ tool_definitions.extend(bing_search.definitions)
+ case HostedCodeInterpreterTool():
+ tool_definitions.append(CodeInterpreterToolDefinition())
+ case HostedMCPTool():
+ mcp_tool = McpTool(
+ server_label=tool.name.replace(" ", "_"),
+ server_url=str(tool.url),
+ allowed_tools=list(tool.allowed_tools) if tool.allowed_tools else [],
+ )
+ tool_definitions.extend(mcp_tool.definitions)
+ case HostedFileSearchTool():
+ vector_stores = [inp for inp in tool.inputs or [] if inp.type == "hosted_vector_store"]
+ if vector_stores:
+ file_search = FileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) # type: ignore[misc]
+ tool_definitions.extend(file_search.definitions)
+ # Set tool_resources for file search to work properly with Azure AI
+ if run_options is not None and "tool_resources" not in run_options:
+ run_options["tool_resources"] = file_search.resources
+ case ToolDefinition():
+ tool_definitions.append(tool)
+ case dict():
+ tool_definitions.append(tool)
+ case _:
+ raise ServiceInitializationError(f"Unsupported tool type: {type(tool)}")
+ return tool_definitions
+
def _prepare_tool_outputs_for_azure_ai(
self,
- required_action_results: list[FunctionResultContent | FunctionApprovalResponseContent] | None,
+ required_action_results: list[Content] | None,
) -> tuple[str | None, list[ToolOutput] | None, list[ToolApproval] | None]:
"""Prepare function results and approvals for submission to the Azure AI API."""
run_id: str | None = None
@@ -1115,9 +1202,7 @@ def _prepare_tool_outputs_for_azure_ai(
# We need to extract the run ID and ensure that the Output/Approval we send back to Azure
# is only the call ID.
run_and_call_ids: list[str] = (
- json.loads(content.call_id)
- if isinstance(content, FunctionResultContent)
- else json.loads(content.id)
+ json.loads(content.call_id) if content.type == "function_result" else json.loads(content.id) # type: ignore[arg-type]
)
if (
@@ -1132,16 +1217,16 @@ def _prepare_tool_outputs_for_azure_ai(
run_id = run_and_call_ids[0]
call_id = run_and_call_ids[1]
- if isinstance(content, FunctionResultContent):
+ if content.type == "function_result":
if tool_outputs is None:
tool_outputs = []
tool_outputs.append(
ToolOutput(tool_call_id=call_id, output=prepare_function_call_results(content.result))
)
- elif isinstance(content, FunctionApprovalResponseContent):
+ elif content.type == "function_approval_response":
if tool_approvals is None:
tool_approvals = []
- tool_approvals.append(ToolApproval(tool_call_id=call_id, approve=content.approved))
+ tool_approvals.append(ToolApproval(tool_call_id=call_id, approve=content.approved)) # type: ignore[arg-type]
return run_id, tool_outputs, tool_approvals
diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py
index 3dfb1766b2..c64b6df44c 100644
--- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py
+++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py
@@ -12,7 +12,6 @@
ContextProvider,
HostedMCPTool,
Middleware,
- TextContent,
ToolProtocol,
get_logger,
use_chat_middleware,
@@ -477,8 +476,8 @@ def _prepare_messages_for_azure_ai(
# System/developer messages are turned into instructions, since there is no such message roles in Azure AI.
for message in messages:
if message.role.value in ["system", "developer"]:
- for text_content in [content for content in message.contents if isinstance(content, TextContent)]:
- instructions_list.append(text_content.text)
+ for text_content in [content for content in message.contents if content.type == "text"]:
+ instructions_list.append(text_content.text) # type: ignore[arg-type]
else:
result.append(message)
diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py
index 61340abb72..020969cd12 100644
--- a/python/packages/azure-ai/agent_framework_azure_ai/_shared.py
+++ b/python/packages/azure-ai/agent_framework_azure_ai/_shared.py
@@ -6,12 +6,10 @@
from agent_framework import (
AIFunction,
- Contents,
+ Content,
HostedCodeInterpreterTool,
- HostedFileContent,
HostedFileSearchTool,
HostedMCPTool,
- HostedVectorStoreContent,
HostedWebSearchTool,
ToolProtocol,
get_logger,
@@ -189,9 +187,9 @@ def to_azure_ai_agent_tools(
)
tool_definitions.extend(mcp_tool.definitions)
case HostedFileSearchTool():
- vector_stores = [inp for inp in tool.inputs or [] if isinstance(inp, HostedVectorStoreContent)]
+ vector_stores = [inp for inp in tool.inputs or [] if inp.type == "hosted_vector_store"]
if vector_stores:
- file_search = AgentsFileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores])
+ file_search = AgentsFileSearchTool(vector_store_ids=[vs.vector_store_id for vs in vector_stores]) # type: ignore[misc]
tool_definitions.extend(file_search.definitions)
# Set tool_resources for file search to work properly with Azure AI
if run_options is not None and "tool_resources" not in run_options:
@@ -247,7 +245,7 @@ def _convert_dict_tool(tool: dict[str, Any]) -> ToolProtocol | dict[str, Any] |
if tool_type == "file_search":
file_search_config = tool.get("file_search", {})
vector_store_ids = file_search_config.get("vector_store_ids", [])
- inputs = [HostedVectorStoreContent(vector_store_id=vs_id) for vs_id in vector_store_ids]
+ inputs = [Content.from_hosted_vector_store(vector_store_id=vs_id) for vs_id in vector_store_ids]
return HostedFileSearchTool(inputs=inputs if inputs else None) # type: ignore
if tool_type == "bing_grounding":
@@ -287,7 +285,7 @@ def _convert_sdk_tool(tool: ToolDefinition) -> ToolProtocol | dict[str, Any] | N
if tool_type == "file_search":
file_search_config = getattr(tool, "file_search", None)
vector_store_ids = getattr(file_search_config, "vector_store_ids", []) if file_search_config else []
- inputs = [HostedVectorStoreContent(vector_store_id=vs_id) for vs_id in vector_store_ids]
+ inputs = [Content.from_hosted_vector_store(vector_store_id=vs_id) for vs_id in vector_store_ids]
return HostedFileSearchTool(inputs=inputs if inputs else None) # type: ignore
if tool_type == "bing_grounding":
@@ -372,18 +370,18 @@ def from_azure_ai_tools(tools: Sequence[Tool | dict[str, Any]] | None) -> list[T
elif tool_type == "code_interpreter":
ci_tool = cast(CodeInterpreterTool, tool_dict)
container = ci_tool.get("container", {})
- ci_inputs: list[Contents] = []
+ ci_inputs: list[Content] = []
if "file_ids" in container:
for file_id in container["file_ids"]:
- ci_inputs.append(HostedFileContent(file_id=file_id))
+ ci_inputs.append(Content.from_hosted_file(file_id=file_id))
agent_tools.append(HostedCodeInterpreterTool(inputs=ci_inputs if ci_inputs else None)) # type: ignore
elif tool_type == "file_search":
fs_tool = cast(ProjectsFileSearchTool, tool_dict)
- fs_inputs: list[Contents] = []
+ fs_inputs: list[Content] = []
if "vector_store_ids" in fs_tool:
for vs_id in fs_tool["vector_store_ids"]:
- fs_inputs.append(HostedVectorStoreContent(vector_store_id=vs_id))
+ fs_inputs.append(Content.from_hosted_vector_store(vector_store_id=vs_id))
agent_tools.append(
HostedFileSearchTool(
@@ -433,8 +431,8 @@ def to_azure_ai_tools(
file_ids: list[str] = []
if tool.inputs:
for tool_input in tool.inputs:
- if isinstance(tool_input, HostedFileContent):
- file_ids.append(tool_input.file_id)
+ if tool_input.type == "hosted_file":
+ file_ids.append(tool_input.file_id) # type: ignore[misc, arg-type]
container = CodeInterpreterToolAuto(file_ids=file_ids if file_ids else None)
ci_tool: CodeInterpreterTool = CodeInterpreterTool(container=container)
azure_tools.append(ci_tool)
@@ -453,11 +451,14 @@ def to_azure_ai_tools(
if not tool.inputs:
raise ValueError("HostedFileSearchTool requires inputs to be specified.")
vector_store_ids: list[str] = [
- inp.vector_store_id for inp in tool.inputs if isinstance(inp, HostedVectorStoreContent)
+ inp.vector_store_id # type: ignore[misc]
+ for inp in tool.inputs
+ if inp.type == "hosted_vector_store"
]
if not vector_store_ids:
raise ValueError(
- "HostedFileSearchTool requires inputs to be of type `HostedVectorStoreContent`."
+ "HostedFileSearchTool requires inputs to be of type `Content` with "
+ "type 'hosted_vector_store'."
)
fs_tool: ProjectsFileSearchTool = ProjectsFileSearchTool(vector_store_ids=vector_store_ids)
if tool.max_results:
diff --git a/python/packages/azure-ai/tests/test_agent_provider.py b/python/packages/azure-ai/tests/test_agent_provider.py
index 3df8d318ec..edfd749f4c 100644
--- a/python/packages/azure-ai/tests/test_agent_provider.py
+++ b/python/packages/azure-ai/tests/test_agent_provider.py
@@ -7,10 +7,10 @@
import pytest
from agent_framework import (
ChatAgent,
+ Content,
HostedCodeInterpreterTool,
HostedFileSearchTool,
HostedMCPTool,
- HostedVectorStoreContent,
HostedWebSearchTool,
ai_function,
)
@@ -509,7 +509,7 @@ def test_to_azure_ai_agent_tools_code_interpreter() -> None:
def test_to_azure_ai_agent_tools_file_search() -> None:
"""Test converting HostedFileSearchTool with vector stores."""
- tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id="vs-123")])
+ tool = HostedFileSearchTool(inputs=[Content.from_hosted_vector_store(vector_store_id="vs-123")])
run_options: dict[str, Any] = {}
result = to_azure_ai_agent_tools([tool], run_options)
diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py
index 21bedbf710..7b20caea7d 100644
--- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py
+++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py
@@ -17,19 +17,12 @@
ChatOptions,
ChatResponse,
ChatResponseUpdate,
- CitationAnnotation,
- FunctionApprovalRequestContent,
- FunctionApprovalResponseContent,
- FunctionCallContent,
- FunctionResultContent,
+ Content,
HostedCodeInterpreterTool,
- HostedFileContent,
HostedFileSearchTool,
HostedMCPTool,
- HostedVectorStoreContent,
+ HostedWebSearchTool,
Role,
- TextContent,
- UriContent,
)
from agent_framework._serialization import SerializationMixin
from agent_framework.exceptions import ServiceInitializationError
@@ -368,7 +361,7 @@ async def test_azure_ai_chat_client_prepare_options_with_image_content(mock_agen
# Mock get_agent
mock_agents_client.get_agent = AsyncMock(return_value=None)
- image_content = UriContent(uri="https://example.com/image.jpg", media_type="image/jpeg")
+ image_content = Content.from_uri(uri="https://example.com/image.jpg", media_type="image/jpeg")
messages = [ChatMessage(role=Role.USER, contents=[image_content])]
run_options, _ = await chat_client._prepare_options(messages, {}) # type: ignore
@@ -551,7 +544,7 @@ def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_basic(mock_agen
result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore
assert len(result) == 1
- assert isinstance(result[0], FunctionCallContent)
+ assert result[0].type == "function_call"
assert result[0].name == "get_weather"
assert result[0].call_id == '["response_123", "call_123"]'
@@ -728,6 +721,121 @@ async def test_azure_ai_chat_client_prepare_options_mcp_with_headers(mock_agents
assert mcp_resource["headers"] == headers
+async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding(
+ mock_agents_client: MagicMock,
+) -> None:
+ """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Bing Grounding."""
+
+ chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
+
+ web_search_tool = HostedWebSearchTool(
+ additional_properties={
+ "connection_id": "test-connection-id",
+ "count": 5,
+ "freshness": "Day",
+ "market": "en-US",
+ "set_lang": "en",
+ }
+ )
+
+ # Mock BingGroundingTool
+ with patch("agent_framework_azure_ai._chat_client.BingGroundingTool") as mock_bing_grounding:
+ mock_bing_tool = MagicMock()
+ mock_bing_tool.definitions = [{"type": "bing_grounding"}]
+ mock_bing_grounding.return_value = mock_bing_tool
+
+ result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore
+
+ assert len(result) == 1
+ assert result[0] == {"type": "bing_grounding"}
+ call_args = mock_bing_grounding.call_args[1]
+ assert call_args["count"] == 5
+ assert call_args["freshness"] == "Day"
+ assert call_args["market"] == "en-US"
+ assert call_args["set_lang"] == "en"
+ assert "connection_id" in call_args
+
+
+async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_bing_grounding_with_connection_id(
+ mock_agents_client: MagicMock,
+) -> None:
+ """Test _prepare_tools_... with HostedWebSearchTool using Bing Grounding with connection_id (no HTTP call)."""
+
+ chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
+
+ web_search_tool = HostedWebSearchTool(
+ additional_properties={
+ "connection_id": "direct-connection-id",
+ "count": 3,
+ }
+ )
+
+ # Mock BingGroundingTool
+ with patch("agent_framework_azure_ai._chat_client.BingGroundingTool") as mock_bing_grounding:
+ mock_bing_tool = MagicMock()
+ mock_bing_tool.definitions = [{"type": "bing_grounding"}]
+ mock_bing_grounding.return_value = mock_bing_tool
+
+ result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore
+
+ assert len(result) == 1
+ assert result[0] == {"type": "bing_grounding"}
+ mock_bing_grounding.assert_called_once_with(connection_id="direct-connection-id", count=3)
+
+
+async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_web_search_custom_bing(
+ mock_agents_client: MagicMock,
+) -> None:
+ """Test _prepare_tools_for_azure_ai with HostedWebSearchTool using Custom Bing Search."""
+
+ chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
+
+ web_search_tool = HostedWebSearchTool(
+ additional_properties={
+ "custom_connection_id": "custom-connection-id",
+ "custom_instance_name": "custom-instance",
+ "count": 10,
+ }
+ )
+
+ # Mock BingCustomSearchTool
+ with patch("agent_framework_azure_ai._chat_client.BingCustomSearchTool") as mock_custom_bing:
+ mock_custom_tool = MagicMock()
+ mock_custom_tool.definitions = [{"type": "bing_custom_search"}]
+ mock_custom_bing.return_value = mock_custom_tool
+
+ result = await chat_client._prepare_tools_for_azure_ai([web_search_tool]) # type: ignore
+
+ assert len(result) == 1
+ assert result[0] == {"type": "bing_custom_search"}
+
+
+async def test_azure_ai_chat_client_prepare_tools_for_azure_ai_file_search_with_vector_stores(
+ mock_agents_client: MagicMock,
+) -> None:
+ """Test _prepare_tools_for_azure_ai with HostedFileSearchTool using vector stores."""
+
+ chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
+
+ vector_store_input = Content.from_hosted_vector_store(vector_store_id="vs-123")
+ file_search_tool = HostedFileSearchTool(inputs=[vector_store_input])
+
+ # Mock FileSearchTool
+ with patch("agent_framework_azure_ai._chat_client.FileSearchTool") as mock_file_search:
+ mock_file_tool = MagicMock()
+ mock_file_tool.definitions = [{"type": "file_search"}]
+ mock_file_tool.resources = {"vector_store_ids": ["vs-123"]}
+ mock_file_search.return_value = mock_file_tool
+
+ run_options = {}
+ result = await chat_client._prepare_tools_for_azure_ai([file_search_tool], run_options) # type: ignore
+
+ assert len(result) == 1
+ assert result[0] == {"type": "file_search"}
+ assert run_options["tool_resources"] == {"vector_store_ids": ["vs-123"]}
+ mock_file_search.assert_called_once_with(vector_store_ids=["vs-123"])
+
+
async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals(
mock_agents_client: MagicMock,
) -> None:
@@ -741,9 +849,9 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_approvals(
chat_client._get_active_thread_run = AsyncMock(return_value=mock_thread_run) # type: ignore
# Mock required action results with approval response that matches run ID
- approval_response = FunctionApprovalResponseContent(
+ approval_response = Content.from_function_approval_response(
id='["test-run-id", "test-call-id"]',
- function_call=FunctionCallContent(
+ function_call=Content.from_function_call(
call_id='["test-run-id", "test-call-id"]', name="test_function", arguments="{}"
),
approved=True,
@@ -839,7 +947,7 @@ async def test_azure_ai_chat_client_prepare_tool_outputs_for_azure_ai_function_r
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
# Test with simple result
- function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result="Simple result")
+ function_result = Content.from_function_result(call_id='["run_123", "call_456"]', result="Simple result")
run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore
@@ -857,7 +965,7 @@ async def test_azure_ai_chat_client_convert_required_action_invalid_call_id(mock
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
# Invalid call_id format - should raise JSONDecodeError
- function_result = FunctionResultContent(call_id="invalid_json", result="result")
+ function_result = Content.from_function_result(call_id="invalid_json", result="result")
with pytest.raises(json.JSONDecodeError):
chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore
@@ -870,7 +978,7 @@ async def test_azure_ai_chat_client_convert_required_action_invalid_structure(
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
# Valid JSON but invalid structure (missing second element)
- function_result = FunctionResultContent(call_id='["run_123"]', result="result")
+ function_result = Content.from_function_result(call_id='["run_123"]', result="result")
run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore
@@ -894,7 +1002,7 @@ def __init__(self, name: str, value: int):
# Test with BaseModel result
mock_result = MockResult(name="test", value=42)
- function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result=mock_result)
+ function_result = Content.from_function_result(call_id='["run_123", "call_456"]', result=mock_result)
run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore
@@ -922,7 +1030,7 @@ def __init__(self, data: str):
# Test with multiple results - mix of BaseModel and regular objects
mock_basemodel = MockResult(data="model_data")
results_list = [mock_basemodel, {"key": "value"}, "string_result"]
- function_result = FunctionResultContent(call_id='["run_123", "call_456"]', result=results_list)
+ function_result = Content.from_function_result(call_id='["run_123", "call_456"]', result=results_list)
run_id, tool_outputs, tool_approvals = chat_client._prepare_tool_outputs_for_azure_ai([function_result]) # type: ignore
@@ -948,9 +1056,11 @@ async def test_azure_ai_chat_client_convert_required_action_approval_response(
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
# Test with approval response - need to provide required fields
- approval_response = FunctionApprovalResponseContent(
+ approval_response = Content.from_function_approval_response(
id='["run_123", "call_456"]',
- function_call=FunctionCallContent(call_id='["run_123", "call_456"]', name="test_function", arguments="{}"),
+ function_call=Content.from_function_call(
+ call_id='["run_123", "call_456"]', name="test_function", arguments="{}"
+ ),
approved=True,
)
@@ -985,7 +1095,7 @@ async def test_azure_ai_chat_client_parse_function_calls_from_azure_ai_approval_
result = chat_client._parse_function_calls_from_azure_ai(mock_event_data, "response_123") # type: ignore
assert len(result) == 1
- assert isinstance(result[0], FunctionApprovalRequestContent)
+ assert result[0].type == "function_approval_request"
assert result[0].id == '["response_123", "approval_call_123"]'
assert result[0].function_call.name == "approve_action"
assert result[0].function_call.call_id == '["response_123", "approval_call_123"]'
@@ -1064,7 +1174,7 @@ async def test_azure_ai_chat_client_create_agent_stream_submit_tool_outputs(
chat_client._get_active_thread_run = AsyncMock(return_value=mock_thread_run) # type: ignore
# Mock required action results with matching run ID
- function_result = FunctionResultContent(call_id='["test-run-id", "test-call-id"]', result="test result")
+ function_result = Content.from_function_result(call_id='["test-run-id", "test-call-id"]', result="test result")
# Mock submit_tool_outputs_stream
mock_handler = MagicMock()
@@ -1115,14 +1225,13 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c
# Verify results
assert len(citations) == 1
citation = citations[0]
- assert isinstance(citation, CitationAnnotation)
- assert citation.url == "https://example.com/test"
- assert citation.title == "Test Title"
- assert citation.snippet is None
- assert citation.annotated_regions is not None
- assert len(citation.annotated_regions) == 1
- assert citation.annotated_regions[0].start_index == 10
- assert citation.annotated_regions[0].end_index == 20
+ assert citation["url"] == "https://example.com/test"
+ assert citation["title"] == "Test Title"
+ assert citation["snippet"] is None
+ assert citation["annotated_regions"] is not None
+ assert len(citation["annotated_regions"]) == 1
+ assert citation["annotated_regions"][0]["start_index"] == 10
+ assert citation["annotated_regions"][0]["end_index"] == 20
def test_azure_ai_chat_client_extract_file_path_contents_with_file_path_annotation(
@@ -1158,7 +1267,7 @@ def test_azure_ai_chat_client_extract_file_path_contents_with_file_path_annotati
# Verify results
assert len(file_contents) == 1
- assert isinstance(file_contents[0], HostedFileContent)
+ assert file_contents[0].type == "hosted_file"
assert file_contents[0].file_id == "assistant-test-file-123"
@@ -1195,7 +1304,7 @@ def test_azure_ai_chat_client_extract_file_path_contents_with_file_citation_anno
# Verify results
assert len(file_contents) == 1
- assert isinstance(file_contents[0], HostedFileContent)
+ assert file_contents[0].type == "hosted_file"
assert file_contents[0].file_id == "cfile_test-citation-456"
@@ -1305,7 +1414,7 @@ async def test_azure_ai_chat_client_streaming() -> None:
assert chunk is not None
assert isinstance(chunk, ChatResponseUpdate)
for content in chunk.contents:
- if isinstance(content, TextContent) and content.text:
+ if content.type == "text" and content.text:
full_message += content.text
assert any(word in full_message.lower() for word in ["sunny", "25"])
@@ -1331,7 +1440,7 @@ async def test_azure_ai_chat_client_streaming_tools() -> None:
assert chunk is not None
assert isinstance(chunk, ChatResponseUpdate)
for content in chunk.contents:
- if isinstance(content, TextContent) and content.text:
+ if content.type == "text" and content.text:
full_message += content.text
assert any(word in full_message.lower() for word in ["sunny", "25"])
@@ -1476,7 +1585,9 @@ async def test_azure_ai_chat_client_agent_file_search():
)
# 2. Create file search tool with uploaded resources
- file_search_tool = HostedFileSearchTool(inputs=[HostedVectorStoreContent(vector_store_id=vector_store.id)])
+ file_search_tool = HostedFileSearchTool(
+ inputs=[Content.from_hosted_vector_store(vector_store_id=vector_store.id)]
+ )
async with ChatAgent(
chat_client=client,
@@ -1795,7 +1906,7 @@ def test_azure_ai_chat_client_extract_url_citations_with_azure_search_enhanced_u
# Verify real URL was used
assert len(citations) == 1
citation = citations[0]
- assert citation.url == "https://real-example.com/doc2" # doc_1 maps to index 1
+ assert citation["url"] == "https://real-example.com/doc2" # doc_1 maps to index 1
def test_azure_ai_chat_client_init_with_auto_created_agents_client(
diff --git a/python/packages/azure-ai/tests/test_azure_ai_client.py b/python/packages/azure-ai/tests/test_azure_ai_client.py
index dad8f049fe..aba45b3f1b 100644
--- a/python/packages/azure-ai/tests/test_azure_ai_client.py
+++ b/python/packages/azure-ai/tests/test_azure_ai_client.py
@@ -16,14 +16,12 @@
ChatMessage,
ChatOptions,
ChatResponse,
+ Content,
HostedCodeInterpreterTool,
- HostedFileContent,
HostedFileSearchTool,
HostedMCPTool,
- HostedVectorStoreContent,
HostedWebSearchTool,
Role,
- TextContent,
)
from agent_framework.exceptions import ServiceInitializationError
from azure.ai.projects.aio import AIProjectClient
@@ -298,9 +296,9 @@ async def test_prepare_messages_for_azure_ai_with_system_messages(
client = create_test_azure_ai_client(mock_project_client)
messages = [
- ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="You are a helpful assistant.")]),
- ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]),
- ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="System response")]),
+ ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="You are a helpful assistant.")]),
+ ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]),
+ ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="System response")]),
]
result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore
@@ -318,8 +316,8 @@ async def test_prepare_messages_for_azure_ai_no_system_messages(
client = create_test_azure_ai_client(mock_project_client)
messages = [
- ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")]),
- ChatMessage(role=Role.ASSISTANT, contents=[TextContent(text="Hi there!")]),
+ ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")]),
+ ChatMessage(role=Role.ASSISTANT, contents=[Content.from_text(text="Hi there!")]),
]
result_messages, instructions = client._prepare_messages_for_azure_ai(messages) # type: ignore
@@ -419,7 +417,7 @@ async def test_prepare_options_basic(mock_project_client: MagicMock) -> None:
"""Test prepare_options basic functionality."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
- messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])]
with (
patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}),
@@ -453,7 +451,7 @@ async def test_prepare_options_with_application_endpoint(
agent_version="1",
)
- messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])]
with (
patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}),
@@ -492,7 +490,7 @@ async def test_prepare_options_with_application_project_client(
agent_version="1",
)
- messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])]
with (
patch.object(client.__class__.__bases__[0], "_prepare_options", return_value={"model": "test-model"}),
@@ -848,7 +846,7 @@ async def test_prepare_options_excludes_response_format(
"""Test that prepare_options excludes response_format, text, and text_format from final run options."""
client = create_test_azure_ai_client(mock_project_client, agent_name="test-agent", agent_version="1.0")
- messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="Hello")])]
+ messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="Hello")])]
chat_options: ChatOptions = {}
with (
@@ -992,7 +990,7 @@ def test_from_azure_ai_tools() -> None:
tool_input = parsed_tools[0].inputs[0]
- assert tool_input and isinstance(tool_input, HostedFileContent) and tool_input.file_id == "file-1"
+ assert tool_input and tool_input.type == "hosted_file" and tool_input.file_id == "file-1"
# Test File Search tool
fs_tool = FileSearchTool(vector_store_ids=["vs-1"], max_num_results=5)
@@ -1004,7 +1002,7 @@ def test_from_azure_ai_tools() -> None:
tool_input = parsed_tools[0].inputs[0]
- assert tool_input and isinstance(tool_input, HostedVectorStoreContent) and tool_input.vector_store_id == "vs-1"
+ assert tool_input and tool_input.type == "hosted_vector_store" and tool_input.vector_store_id == "vs-1"
assert parsed_tools[0].max_results == 5
# Test Web Search tool
diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py
index c84a30fbb4..d33d9ea91c 100644
--- a/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py
+++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_durable_agent_state.py
@@ -36,18 +36,8 @@
from agent_framework import (
AgentResponse,
- BaseContent,
ChatMessage,
- DataContent,
- ErrorContent,
- FunctionCallContent,
- FunctionResultContent,
- HostedFileContent,
- HostedVectorStoreContent,
- TextContent,
- TextReasoningContent,
- UriContent,
- UsageContent,
+ Content,
UsageDetails,
get_logger,
)
@@ -290,25 +280,25 @@ def from_ai_content(content: Any) -> DurableAgentStateContent:
The corresponding DurableAgentStateContent subclass instance
"""
# Map AI content type to appropriate DurableAgentStateContent subclass
- if isinstance(content, DataContent):
+ if isinstance(content, Content) and content.type == "data":
return DurableAgentStateDataContent.from_data_content(content)
- if isinstance(content, ErrorContent):
+ if isinstance(content, Content) and content.type == "error":
return DurableAgentStateErrorContent.from_error_content(content)
- if isinstance(content, FunctionCallContent):
+ if isinstance(content, Content) and content.type == "function_call":
return DurableAgentStateFunctionCallContent.from_function_call_content(content)
- if isinstance(content, FunctionResultContent):
+ if isinstance(content, Content) and content.type == "function_result":
return DurableAgentStateFunctionResultContent.from_function_result_content(content)
- if isinstance(content, HostedFileContent):
+ if isinstance(content, Content) and content.type == "hosted_file":
return DurableAgentStateHostedFileContent.from_hosted_file_content(content)
- if isinstance(content, HostedVectorStoreContent):
+ if isinstance(content, Content) and content.type == "hosted_vector_store":
return DurableAgentStateHostedVectorStoreContent.from_hosted_vector_store_content(content)
- if isinstance(content, TextContent):
+ if isinstance(content, Content) and content.type == "text":
return DurableAgentStateTextContent.from_text_content(content)
- if isinstance(content, TextReasoningContent):
+ if isinstance(content, Content) and content.type == "text_reasoning":
return DurableAgentStateTextReasoningContent.from_text_reasoning_content(content)
- if isinstance(content, UriContent):
+ if isinstance(content, Content) and content.type == "uri":
return DurableAgentStateUriContent.from_uri_content(content)
- if isinstance(content, UsageContent):
+ if isinstance(content, Content) and content.type == "usage":
return DurableAgentStateUsageContent.from_usage_content(content)
return DurableAgentStateUnknownContent.from_unknown_content(content)
@@ -699,7 +689,7 @@ def from_run_response(correlation_id: str, response: AgentResponse) -> DurableAg
correlation_id=correlation_id,
created_at=_parse_created_at(response.created_at),
messages=[DurableAgentStateMessage.from_chat_message(m) for m in response.messages],
- usage=DurableAgentStateUsage.from_usage(response.usage_details),
+ usage=DurableAgentStateUsage.from_usage(response.usage_details), # type: ignore[arg-type]
)
@@ -868,11 +858,11 @@ def to_dict(self) -> dict[str, Any]:
}
@staticmethod
- def from_data_content(content: DataContent) -> DurableAgentStateDataContent:
- return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type)
+ def from_data_content(content: Content) -> DurableAgentStateDataContent:
+ return DurableAgentStateDataContent(uri=content.uri, media_type=content.media_type) # type: ignore[arg-type]
- def to_ai_content(self) -> DataContent:
- return DataContent(uri=self.uri, media_type=self.media_type)
+ def to_ai_content(self) -> Content:
+ return Content.from_uri(uri=self.uri, media_type=self.media_type)
class DurableAgentStateErrorContent(DurableAgentStateContent):
@@ -907,13 +897,13 @@ def to_dict(self) -> dict[str, Any]:
}
@staticmethod
- def from_error_content(content: ErrorContent) -> DurableAgentStateErrorContent:
+ def from_error_content(content: Content) -> DurableAgentStateErrorContent:
return DurableAgentStateErrorContent(
- message=content.message, error_code=content.error_code, details=content.details
+ message=content.message, error_code=content.error_code, details=content.error_details
)
- def to_ai_content(self) -> ErrorContent:
- return ErrorContent(message=self.message, error_code=self.error_code, details=self.details)
+ def to_ai_content(self) -> Content:
+ return Content.from_error(message=self.message, error_code=self.error_code, error_details=self.details)
class DurableAgentStateFunctionCallContent(DurableAgentStateContent):
@@ -949,7 +939,7 @@ def to_dict(self) -> dict[str, Any]:
}
@staticmethod
- def from_function_call_content(content: FunctionCallContent) -> DurableAgentStateFunctionCallContent:
+ def from_function_call_content(content: Content) -> DurableAgentStateFunctionCallContent:
# Ensure arguments is a dict; parse string if needed
arguments: dict[str, Any] = {}
if content.arguments:
@@ -962,10 +952,10 @@ def from_function_call_content(content: FunctionCallContent) -> DurableAgentStat
except json.JSONDecodeError:
arguments = {}
- return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments)
+ return DurableAgentStateFunctionCallContent(call_id=content.call_id, name=content.name, arguments=arguments) # type: ignore[arg-type]
- def to_ai_content(self) -> FunctionCallContent:
- return FunctionCallContent(call_id=self.call_id, name=self.name, arguments=self.arguments)
+ def to_ai_content(self) -> Content:
+ return Content.from_function_call(call_id=self.call_id, name=self.name, arguments=self.arguments)
class DurableAgentStateFunctionResultContent(DurableAgentStateContent):
@@ -997,11 +987,11 @@ def to_dict(self) -> dict[str, Any]:
}
@staticmethod
- def from_function_result_content(content: FunctionResultContent) -> DurableAgentStateFunctionResultContent:
- return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result)
+ def from_function_result_content(content: Content) -> DurableAgentStateFunctionResultContent:
+ return DurableAgentStateFunctionResultContent(call_id=content.call_id, result=content.result) # type: ignore[arg-type]
- def to_ai_content(self) -> FunctionResultContent:
- return FunctionResultContent(call_id=self.call_id, result=self.result)
+ def to_ai_content(self) -> Content:
+ return Content.from_function_result(call_id=self.call_id, result=self.result)
class DurableAgentStateHostedFileContent(DurableAgentStateContent):
@@ -1025,11 +1015,11 @@ def to_dict(self) -> dict[str, Any]:
return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.FILE_ID: self.file_id}
@staticmethod
- def from_hosted_file_content(content: HostedFileContent) -> DurableAgentStateHostedFileContent:
- return DurableAgentStateHostedFileContent(file_id=content.file_id)
+ def from_hosted_file_content(content: Content) -> DurableAgentStateHostedFileContent:
+ return DurableAgentStateHostedFileContent(file_id=content.file_id) # type: ignore[arg-type]
- def to_ai_content(self) -> HostedFileContent:
- return HostedFileContent(file_id=self.file_id)
+ def to_ai_content(self) -> Content:
+ return Content.from_hosted_file(file_id=self.file_id)
class DurableAgentStateHostedVectorStoreContent(DurableAgentStateContent):
@@ -1058,12 +1048,12 @@ def to_dict(self) -> dict[str, Any]:
@staticmethod
def from_hosted_vector_store_content(
- content: HostedVectorStoreContent,
+ content: Content,
) -> DurableAgentStateHostedVectorStoreContent:
- return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id)
+ return DurableAgentStateHostedVectorStoreContent(vector_store_id=content.vector_store_id) # type: ignore[arg-type]
- def to_ai_content(self) -> HostedVectorStoreContent:
- return HostedVectorStoreContent(vector_store_id=self.vector_store_id)
+ def to_ai_content(self) -> Content:
+ return Content.from_hosted_vector_store(vector_store_id=self.vector_store_id)
class DurableAgentStateTextContent(DurableAgentStateContent):
@@ -1085,11 +1075,11 @@ def to_dict(self) -> dict[str, Any]:
return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.TEXT: self.text}
@staticmethod
- def from_text_content(content: TextContent) -> DurableAgentStateTextContent:
+ def from_text_content(content: Content) -> DurableAgentStateTextContent:
return DurableAgentStateTextContent(text=content.text)
- def to_ai_content(self) -> TextContent:
- return TextContent(text=self.text or "")
+ def to_ai_content(self) -> Content:
+ return Content.from_text(text=self.text or "")
class DurableAgentStateTextReasoningContent(DurableAgentStateContent):
@@ -1111,11 +1101,11 @@ def to_dict(self) -> dict[str, Any]:
return {DurableStateFields.TYPE_DISCRIMINATOR: self.type, DurableStateFields.TEXT: self.text}
@staticmethod
- def from_text_reasoning_content(content: TextReasoningContent) -> DurableAgentStateTextReasoningContent:
+ def from_text_reasoning_content(content: Content) -> DurableAgentStateTextReasoningContent:
return DurableAgentStateTextReasoningContent(text=content.text)
- def to_ai_content(self) -> TextReasoningContent:
- return TextReasoningContent(text=self.text or "")
+ def to_ai_content(self) -> Content:
+ return Content.from_text_reasoning(text=self.text)
class DurableAgentStateUriContent(DurableAgentStateContent):
@@ -1146,11 +1136,11 @@ def to_dict(self) -> dict[str, Any]:
}
@staticmethod
- def from_uri_content(content: UriContent) -> DurableAgentStateUriContent:
- return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type)
+ def from_uri_content(content: Content) -> DurableAgentStateUriContent:
+ return DurableAgentStateUriContent(uri=content.uri, media_type=content.media_type) # type: ignore[arg-type]
- def to_ai_content(self) -> UriContent:
- return UriContent(uri=self.uri, media_type=self.media_type)
+ def to_ai_content(self) -> Content:
+ return Content.from_uri(uri=self.uri, media_type=self.media_type)
class DurableAgentStateUsage:
@@ -1204,22 +1194,22 @@ def from_dict(cls, data: dict[str, Any]) -> DurableAgentStateUsage:
)
@staticmethod
- def from_usage(usage: UsageDetails | None) -> DurableAgentStateUsage | None:
+ def from_usage(usage: UsageDetails | dict[str, int] | None) -> DurableAgentStateUsage | None:
if usage is None:
return None
return DurableAgentStateUsage(
- input_token_count=usage.input_token_count,
- output_token_count=usage.output_token_count,
- total_token_count=usage.total_token_count,
+ input_token_count=usage.get("input_token_count"),
+ output_token_count=usage.get("output_token_count"),
+ total_token_count=usage.get("total_token_count"),
)
def to_usage_details(self) -> UsageDetails:
# Convert back to AI SDK UsageDetails
- return UsageDetails(
- input_token_count=self.input_token_count,
- output_token_count=self.output_token_count,
- total_token_count=self.total_token_count,
- )
+ return {
+ "input_token_count": self.input_token_count,
+ "output_token_count": self.output_token_count,
+ "total_token_count": self.total_token_count,
+ }
class DurableAgentStateUsageContent(DurableAgentStateContent):
@@ -1247,11 +1237,11 @@ def to_dict(self) -> dict[str, Any]:
}
@staticmethod
- def from_usage_content(content: UsageContent) -> DurableAgentStateUsageContent:
- return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.details))
+ def from_usage_content(content: Content) -> DurableAgentStateUsageContent:
+ return DurableAgentStateUsageContent(usage=DurableAgentStateUsage.from_usage(content.usage_details))
- def to_ai_content(self) -> UsageContent:
- return UsageContent(details=self.usage.to_usage_details())
+ def to_ai_content(self) -> Content:
+ return Content.from_usage(usage_details=self.usage.to_usage_details())
class DurableAgentStateUnknownContent(DurableAgentStateContent):
@@ -1279,7 +1269,7 @@ def to_dict(self) -> dict[str, Any]:
def from_unknown_content(content: Any) -> DurableAgentStateUnknownContent:
return DurableAgentStateUnknownContent(content=content)
- def to_ai_content(self) -> BaseContent:
+ def to_ai_content(self) -> Content:
if not self.content:
raise Exception("The content is missing and cannot be converted to valid AI content.")
- return BaseContent(content=self.content)
+ return Content(type=self.type, additional_properties={"content": self.content}) # type: ignore
diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py
index f757004cbb..ba6040b99b 100644
--- a/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py
+++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_entities.py
@@ -18,7 +18,7 @@
AgentResponse,
AgentResponseUpdate,
ChatMessage,
- ErrorContent,
+ Content,
Role,
get_logger,
)
@@ -193,7 +193,7 @@ async def run(
# Create error message
error_message = ChatMessage(
- role=Role.ASSISTANT, contents=[ErrorContent(message=str(exc), error_code=type(exc).__name__)]
+ role=Role.ASSISTANT, contents=[Content.from_error(message=str(exc), error_code=type(exc).__name__)]
)
error_response = AgentResponse(messages=[error_message])
diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py
index 37f70a04e1..29d614e729 100644
--- a/python/packages/azurefunctions/tests/test_app.py
+++ b/python/packages/azurefunctions/tests/test_app.py
@@ -10,7 +10,7 @@
import azure.durable_functions as df
import azure.functions as func
import pytest
-from agent_framework import AgentResponse, ChatMessage, ErrorContent
+from agent_framework import AgentResponse, ChatMessage
from agent_framework_azurefunctions import AgentFunctionApp
from agent_framework_azurefunctions._app import WAIT_FOR_RESPONSE_FIELD, WAIT_FOR_RESPONSE_HEADER
@@ -622,7 +622,7 @@ async def test_entity_handles_agent_error(self) -> None:
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
- assert isinstance(content, ErrorContent)
+ assert content.type == "error"
assert "Agent error" in (content.message or "")
assert content.error_code == "Exception"
diff --git a/python/packages/azurefunctions/tests/test_entities.py b/python/packages/azurefunctions/tests/test_entities.py
index 63bc685afb..5d980f8610 100644
--- a/python/packages/azurefunctions/tests/test_entities.py
+++ b/python/packages/azurefunctions/tests/test_entities.py
@@ -12,7 +12,7 @@
from unittest.mock import AsyncMock, Mock, patch
import pytest
-from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, ErrorContent, Role
+from agent_framework import AgentResponse, AgentResponseUpdate, ChatMessage, Role
from pydantic import BaseModel
from agent_framework_azurefunctions._durable_agent_state import (
@@ -608,7 +608,7 @@ async def test_run_agent_handles_agent_exception(self) -> None:
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
- assert isinstance(content, ErrorContent)
+ assert content.type == "error"
assert "Agent failed" in (content.message or "")
assert content.error_code == "Exception"
@@ -627,7 +627,7 @@ async def test_run_agent_handles_value_error(self) -> None:
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
- assert isinstance(content, ErrorContent)
+ assert content.type == "error"
assert content.error_code == "ValueError"
assert "Invalid input" in str(content.message)
@@ -646,7 +646,7 @@ async def test_run_agent_handles_timeout_error(self) -> None:
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
- assert isinstance(content, ErrorContent)
+ assert content.type == "error"
assert content.error_code == "TimeoutError"
def test_entity_function_handles_exception_in_operation(self) -> None:
@@ -685,7 +685,7 @@ async def test_run_agent_preserves_message_on_error(self) -> None:
assert isinstance(result, AgentResponse)
assert len(result.messages) == 1
content = result.messages[0].contents[0]
- assert isinstance(content, ErrorContent)
+ assert content.type == "error"
class TestConversationHistory:
diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py
index e9e1eeff96..a6325a6603 100644
--- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py
+++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py
@@ -16,14 +16,10 @@
ChatOptions,
ChatResponse,
ChatResponseUpdate,
- Contents,
+ Content,
FinishReason,
- FunctionCallContent,
- FunctionResultContent,
Role,
- TextContent,
ToolProtocol,
- UsageContent,
UsageDetails,
get_logger,
prepare_function_call_results,
@@ -328,7 +324,7 @@ async def _inner_get_streaming_response(
response = await self._inner_get_response(messages=messages, options=options, **kwargs)
contents = list(response.messages[0].contents if response.messages else [])
if response.usage_details:
- contents.append(UsageContent(details=response.usage_details))
+ contents.append(Content.from_usage(usage_details=response.usage_details)) # type: ignore[arg-type]
yield ChatResponseUpdate(
response_id=response.response_id,
contents=contents,
@@ -472,37 +468,41 @@ def _convert_message_to_content_blocks(self, message: ChatMessage) -> list[dict[
blocks.append(block)
return blocks
- def _convert_content_to_bedrock_block(self, content: Contents) -> dict[str, Any] | None:
- if isinstance(content, TextContent):
- return {"text": content.text}
- if isinstance(content, FunctionCallContent):
- arguments = content.parse_arguments() or {}
- return {
- "toolUse": {
- "toolUseId": content.call_id or self._generate_tool_call_id(),
- "name": content.name,
- "input": arguments,
+ def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any] | None:
+ match content.type:
+ case "text":
+ return {"text": content.text}
+ case "function_call":
+ arguments = content.parse_arguments() or {}
+ return {
+ "toolUse": {
+ "toolUseId": content.call_id or self._generate_tool_call_id(),
+ "name": content.name,
+ "input": arguments,
+ }
}
- }
- if isinstance(content, FunctionResultContent):
- tool_result_block = {
- "toolResult": {
- "toolUseId": content.call_id,
- "content": self._convert_tool_result_to_blocks(content.result),
- "status": "error" if content.exception else "success",
+ case "function_result":
+ tool_result_block = {
+ "toolResult": {
+ "toolUseId": content.call_id,
+ "content": self._convert_tool_result_to_blocks(content.result),
+ "status": "error" if content.exception else "success",
+ }
}
- }
- if content.exception:
- tool_result = tool_result_block["toolResult"]
- existing_content = tool_result.get("content")
- content_list: list[dict[str, Any]]
- if isinstance(existing_content, list):
- content_list = existing_content
- else:
- content_list = []
- tool_result["content"] = content_list
- content_list.append({"text": str(content.exception)})
- return tool_result_block
+ if content.exception:
+ tool_result = tool_result_block["toolResult"]
+ existing_content = tool_result.get("content")
+ content_list: list[dict[str, Any]]
+ if isinstance(existing_content, list):
+ content_list = existing_content
+ else:
+ content_list = []
+ tool_result["content"] = content_list
+ content_list.append({"text": str(content.exception)})
+ return tool_result_block
+ case _:
+ # Bedrock does not support other content types at this time
+ pass
return None
def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]:
@@ -531,7 +531,7 @@ def _normalize_tool_result_value(self, value: Any) -> dict[str, Any]:
return {"text": value}
if isinstance(value, (int, float, bool)) or value is None:
return {"json": value}
- if isinstance(value, TextContent) and getattr(value, "text", None):
+ if isinstance(value, Content) and value.type == "text":
return {"text": value.text}
if hasattr(value, "to_dict"):
try:
@@ -586,23 +586,23 @@ def _process_converse_response(self, response: dict[str, Any]) -> ChatResponse:
def _parse_usage(self, usage: dict[str, Any] | None) -> UsageDetails | None:
if not usage:
return None
- details = UsageDetails()
+ details: UsageDetails = {}
if (input_tokens := usage.get("inputTokens")) is not None:
- details.input_token_count = input_tokens
+ details["input_token_count"] = input_tokens
if (output_tokens := usage.get("outputTokens")) is not None:
- details.output_token_count = output_tokens
+ details["output_token_count"] = output_tokens
if (total_tokens := usage.get("totalTokens")) is not None:
- details.additional_counts["bedrock.total_tokens"] = total_tokens
+ details["total_token_count"] = total_tokens
return details
def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, Any]]) -> list[Any]:
contents: list[Any] = []
for block in content_blocks:
if text_value := block.get("text"):
- contents.append(TextContent(text=text_value, raw_representation=block))
+ contents.append(Content.from_text(text=text_value, raw_representation=block))
continue
if (json_value := block.get("json")) is not None:
- contents.append(TextContent(text=json.dumps(json_value), raw_representation=block))
+ contents.append(Content.from_text(text=json.dumps(json_value), raw_representation=block))
continue
tool_use = block.get("toolUse")
if isinstance(tool_use, MutableMapping):
@@ -610,7 +610,7 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A
if not tool_name:
raise ServiceInvalidResponseError("Bedrock response missing required tool name in toolUse block.")
contents.append(
- FunctionCallContent(
+ Content.from_function_call(
call_id=tool_use.get("toolUseId") or self._generate_tool_call_id(),
name=tool_name,
arguments=tool_use.get("input"),
@@ -626,10 +626,10 @@ def _parse_message_contents(self, content_blocks: Sequence[MutableMapping[str, A
exception = RuntimeError(f"Bedrock tool result status: {status}")
result_value = self._convert_bedrock_tool_result_to_value(tool_result.get("content"))
contents.append(
- FunctionResultContent(
+ Content.from_function_result(
call_id=tool_result.get("toolUseId") or self._generate_tool_call_id(),
result=result_value,
- exception=exception,
+ exception=str(exception) if exception else None, # type: ignore[arg-type]
raw_representation=block,
)
)
diff --git a/python/packages/bedrock/tests/test_bedrock_client.py b/python/packages/bedrock/tests/test_bedrock_client.py
index 5842426483..704eb2138a 100644
--- a/python/packages/bedrock/tests/test_bedrock_client.py
+++ b/python/packages/bedrock/tests/test_bedrock_client.py
@@ -6,7 +6,7 @@
from typing import Any
import pytest
-from agent_framework import ChatMessage, Role, TextContent
+from agent_framework import ChatMessage, Content, Role
from agent_framework.exceptions import ServiceInitializationError
from agent_framework_bedrock import BedrockChatClient
@@ -42,8 +42,8 @@ def test_get_response_invokes_bedrock_runtime() -> None:
)
messages = [
- ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="You are concise.")]),
- ChatMessage(role=Role.USER, contents=[TextContent(text="hello")]),
+ ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="You are concise.")]),
+ ChatMessage(role=Role.USER, contents=[Content.from_text(text="hello")]),
]
response = asyncio.run(client.get_response(messages=messages, options={"max_tokens": 32}))
@@ -53,7 +53,7 @@ def test_get_response_invokes_bedrock_runtime() -> None:
assert payload["modelId"] == "amazon.titan-text"
assert payload["messages"][0]["content"][0]["text"] == "hello"
assert response.messages[0].contents[0].text == "Bedrock says hi"
- assert response.usage_details and response.usage_details.input_token_count == 10
+ assert response.usage_details and response.usage_details["input_token_count"] == 10
def test_build_request_requires_non_system_messages() -> None:
@@ -63,7 +63,7 @@ def test_build_request_requires_non_system_messages() -> None:
client=_StubBedrockRuntime(),
)
- messages = [ChatMessage(role=Role.SYSTEM, contents=[TextContent(text="Only system text")])]
+ messages = [ChatMessage(role=Role.SYSTEM, contents=[Content.from_text(text="Only system text")])]
with pytest.raises(ServiceInitializationError):
client._prepare_options(messages, {})
diff --git a/python/packages/bedrock/tests/test_bedrock_settings.py b/python/packages/bedrock/tests/test_bedrock_settings.py
index 1924c750c6..07898303de 100644
--- a/python/packages/bedrock/tests/test_bedrock_settings.py
+++ b/python/packages/bedrock/tests/test_bedrock_settings.py
@@ -9,10 +9,8 @@
AIFunction,
ChatMessage,
ChatOptions,
- FunctionCallContent,
- FunctionResultContent,
+ Content,
Role,
- TextContent,
)
from pydantic import BaseModel
@@ -49,7 +47,7 @@ def test_build_request_includes_tool_config() -> None:
"tools": [tool],
"tool_choice": {"mode": "required", "required_function_name": "get_weather"},
}
- messages = [ChatMessage(role=Role.USER, contents=[TextContent(text="hi")])]
+ messages = [ChatMessage(role=Role.USER, contents=[Content.from_text(text="hi")])]
request = client._prepare_options(messages, options)
@@ -61,14 +59,16 @@ def test_build_request_serializes_tool_history() -> None:
client = _build_client()
options: ChatOptions = {}
messages = [
- ChatMessage(role=Role.USER, contents=[TextContent(text="how's weather?")]),
+ ChatMessage(role=Role.USER, contents=[Content.from_text(text="how's weather?")]),
ChatMessage(
role=Role.ASSISTANT,
- contents=[FunctionCallContent(call_id="call-1", name="get_weather", arguments='{"location": "SEA"}')],
+ contents=[
+ Content.from_function_call(call_id="call-1", name="get_weather", arguments='{"location": "SEA"}')
+ ],
),
ChatMessage(
role=Role.TOOL,
- contents=[FunctionResultContent(call_id="call-1", result={"answer": "72F"})],
+ contents=[Content.from_function_result(call_id="call-1", result={"answer": "72F"})],
),
]
@@ -101,9 +101,9 @@ def test_process_response_parses_tool_use_and_result() -> None:
chat_response = client._process_converse_response(response)
contents = chat_response.messages[0].contents
- assert isinstance(contents[0], FunctionCallContent)
+ assert contents[0].type == "function_call"
assert contents[0].name == "get_weather"
- assert isinstance(contents[1], TextContent)
+ assert contents[1].type == "text"
assert chat_response.finish_reason == client._map_finish_reason("tool_use")
@@ -131,5 +131,5 @@ def test_process_response_parses_tool_result() -> None:
chat_response = client._process_converse_response(response)
contents = chat_response.messages[0].contents
- assert isinstance(contents[0], FunctionResultContent)
+ assert contents[0].type == "function_result"
assert contents[0].result == {"answer": 42}
diff --git a/python/packages/chatkit/agent_framework_chatkit/_converter.py b/python/packages/chatkit/agent_framework_chatkit/_converter.py
index 252ac8a753..b83fd40812 100644
--- a/python/packages/chatkit/agent_framework_chatkit/_converter.py
+++ b/python/packages/chatkit/agent_framework_chatkit/_converter.py
@@ -8,12 +8,8 @@
from agent_framework import (
ChatMessage,
- DataContent,
- FunctionCallContent,
- FunctionResultContent,
+ Content,
Role,
- TextContent,
- UriContent,
)
from chatkit.types import (
AssistantMessageItem,
@@ -91,8 +87,8 @@ async def user_message_to_input(
if isinstance(content_part, UserMessageTextContent):
text_content += content_part.text
- # Convert attachments to DataContent or UriContent
- data_contents: list[DataContent | UriContent] = []
+ # Convert attachments to Content
+ data_contents: list[Content] = []
if item.attachments:
for attachment in item.attachments:
content = await self.attachment_to_message_content(attachment)
@@ -108,9 +104,9 @@ async def user_message_to_input(
user_message = ChatMessage(role=Role.USER, text=text_content.strip())
else:
# Build contents list with both text and attachments
- contents: list[TextContent | DataContent | UriContent] = []
+ contents: list[Content] = []
if text_content.strip():
- contents.append(TextContent(text=text_content.strip()))
+ contents.append(Content.from_text(text=text_content.strip()))
contents.extend(data_contents)
user_message = ChatMessage(role=Role.USER, contents=contents)
@@ -126,7 +122,7 @@ async def user_message_to_input(
return messages
- async def attachment_to_message_content(self, attachment: Attachment) -> DataContent | UriContent | None:
+ async def attachment_to_message_content(self, attachment: Attachment) -> Content | None:
"""Convert a ChatKit attachment to Agent Framework content.
This method is called internally by `user_message_to_input()` to handle attachments.
@@ -169,14 +165,14 @@ async def fetch_data(attachment_id: str) -> bytes:
if self.attachment_data_fetcher is not None:
try:
data = await self.attachment_data_fetcher(attachment.id)
- return DataContent(data=data, media_type=attachment.mime_type)
+ return Content.from_data(data=data, media_type=attachment.mime_type)
except Exception as e:
# If fetch fails, fall through to URL-based approach
logger.debug(f"Failed to fetch attachment data for {attachment.id}: {e}")
# For ImageAttachment, try to use preview_url
if isinstance(attachment, ImageAttachment) and attachment.preview_url:
- return UriContent(uri=str(attachment.preview_url), media_type=attachment.mime_type)
+ return Content.from_uri(uri=str(attachment.preview_url), media_type=attachment.mime_type)
# For FileAttachment without data fetcher, skip the attachment
# Subclasses can override this method to provide custom handling
@@ -220,7 +216,7 @@ def hidden_context_to_input(
"""
return ChatMessage(role=Role.SYSTEM, text=f"{item.content}")
- def tag_to_message_content(self, tag: UserMessageTagContent) -> TextContent:
+ def tag_to_message_content(self, tag: UserMessageTagContent) -> Content:
"""Convert a ChatKit tag (@-mention) to Agent Framework content.
This method is called internally by `user_message_to_input()` to handle tags.
@@ -248,10 +244,10 @@ def tag_to_message_content(self, tag: UserMessageTagContent) -> TextContent:
type="input_tag", id="tag_1", text="john", data={"name": "John Doe"}, interactive=False
)
content = converter.tag_to_message_content(tag)
- # Returns: TextContent(text="Name:John Doe")
+ # Returns: Content.from_text(text="Name:John Doe")
"""
name = getattr(tag.data, "name", tag.text if hasattr(tag, "text") else "unknown")
- return TextContent(text=f"Name:{name}")
+ return Content.from_text(text=f"Name:{name}")
def task_to_input(self, item: TaskItem) -> ChatMessage | list[ChatMessage] | None:
"""Convert a ChatKit TaskItem to Agent Framework ChatMessage(s).
@@ -448,7 +444,7 @@ async def client_tool_call_to_input(self, item: ClientToolCallItem) -> ChatMessa
function_call_msg = ChatMessage(
role=Role.ASSISTANT,
contents=[
- FunctionCallContent(
+ Content.from_function_call(
call_id=item.call_id,
name=item.name,
arguments=json.dumps(item.arguments),
@@ -460,7 +456,7 @@ async def client_tool_call_to_input(self, item: ClientToolCallItem) -> ChatMessa
function_result_msg = ChatMessage(
role=Role.TOOL,
contents=[
- FunctionResultContent(
+ Content.from_function_result(
call_id=item.call_id,
result=json.dumps(item.output) if item.output is not None else "",
)
diff --git a/python/packages/chatkit/agent_framework_chatkit/_streaming.py b/python/packages/chatkit/agent_framework_chatkit/_streaming.py
index b0273c5944..df44fa005d 100644
--- a/python/packages/chatkit/agent_framework_chatkit/_streaming.py
+++ b/python/packages/chatkit/agent_framework_chatkit/_streaming.py
@@ -6,7 +6,7 @@
from collections.abc import AsyncIterable, AsyncIterator, Callable
from datetime import datetime
-from agent_framework import AgentResponseUpdate, TextContent
+from agent_framework import AgentResponseUpdate
from chatkit.types import (
AssistantMessageContent,
AssistantMessageContentPartTextDelta,
@@ -77,7 +77,7 @@ def _default_id_generator(item_type: str) -> str:
if update.contents:
for content in update.contents:
# Handle text content - only TextContent has a text attribute
- if isinstance(content, TextContent) and content.text is not None:
+ if content.type == "text" and content.text is not None:
# Yield incremental text delta for streaming display
yield ThreadItemUpdated(
type="thread.item.updated",
diff --git a/python/packages/chatkit/tests/test_converter.py b/python/packages/chatkit/tests/test_converter.py
index 457017f647..b75139bf58 100644
--- a/python/packages/chatkit/tests/test_converter.py
+++ b/python/packages/chatkit/tests/test_converter.py
@@ -5,7 +5,7 @@
from unittest.mock import Mock
import pytest
-from agent_framework import ChatMessage, Role, TextContent
+from agent_framework import ChatMessage, Role
from chatkit.types import UserMessageTextContent
from agent_framework_chatkit import ThreadItemConverter, simple_to_agent_input
@@ -133,7 +133,7 @@ def test_tag_to_message_content(self, converter):
)
result = converter.tag_to_message_content(tag)
- assert isinstance(result, TextContent)
+ assert result.type == "text"
# Since data is a dict, getattr won't work, so it will fall back to text
assert result.text == "Name:john"
@@ -150,7 +150,7 @@ def test_tag_to_message_content_no_name(self, converter):
)
result = converter.tag_to_message_content(tag)
- assert isinstance(result, TextContent)
+ assert result.type == "text"
assert result.text == "Name:jane"
async def test_attachment_to_message_content_file_without_fetcher(self, converter):
@@ -169,7 +169,6 @@ async def test_attachment_to_message_content_file_without_fetcher(self, converte
async def test_attachment_to_message_content_image_with_preview_url(self, converter):
"""Test that ImageAttachment with preview_url creates UriContent."""
- from agent_framework import UriContent
from chatkit.types import ImageAttachment
attachment = ImageAttachment(
@@ -181,13 +180,12 @@ async def test_attachment_to_message_content_image_with_preview_url(self, conver
)
result = await converter.attachment_to_message_content(attachment)
- assert isinstance(result, UriContent)
+ assert result.type == "uri"
assert result.uri == "https://example.com/photo.jpg"
assert result.media_type == "image/jpeg"
async def test_attachment_to_message_content_with_data_fetcher(self):
"""Test attachment conversion with data fetcher."""
- from agent_framework import DataContent
from chatkit.types import FileAttachment
# Mock data fetcher
@@ -204,14 +202,13 @@ async def fetch_data(attachment_id: str) -> bytes:
)
result = await converter.attachment_to_message_content(attachment)
- assert isinstance(result, DataContent)
+ assert result.type == "data"
assert result.media_type == "application/pdf"
async def test_to_agent_input_with_image_attachment(self):
"""Test converting user message with text and image attachment."""
from datetime import datetime
- from agent_framework import UriContent
from chatkit.types import ImageAttachment, UserMessageItem
attachment = ImageAttachment(
@@ -241,11 +238,11 @@ async def test_to_agent_input_with_image_attachment(self):
assert len(message.contents) == 2
# First content should be text
- assert isinstance(message.contents[0], TextContent)
+ assert message.contents[0].type == "text"
assert message.contents[0].text == "Check out this photo!"
# Second content should be UriContent for the image
- assert isinstance(message.contents[1], UriContent)
+ assert message.contents[1].type == "uri"
assert message.contents[1].uri == "https://example.com/photo.jpg"
assert message.contents[1].media_type == "image/jpeg"
@@ -253,7 +250,6 @@ async def test_to_agent_input_with_file_attachment_and_fetcher(self):
"""Test converting user message with file attachment using data fetcher."""
from datetime import datetime
- from agent_framework import DataContent
from chatkit.types import FileAttachment, UserMessageItem
attachment = FileAttachment(
@@ -285,10 +281,10 @@ async def fetch_data(attachment_id: str) -> bytes:
assert len(message.contents) == 2
# First content should be text
- assert isinstance(message.contents[0], TextContent)
+ assert message.contents[0].type == "text"
# Second content should be DataContent for the file
- assert isinstance(message.contents[1], DataContent)
+ assert message.contents[1].type == "data"
assert message.contents[1].media_type == "application/pdf"
def test_task_to_input(self, converter):
diff --git a/python/packages/chatkit/tests/test_streaming.py b/python/packages/chatkit/tests/test_streaming.py
index ead7c5f33e..ff552d79e8 100644
--- a/python/packages/chatkit/tests/test_streaming.py
+++ b/python/packages/chatkit/tests/test_streaming.py
@@ -4,7 +4,7 @@
from unittest.mock import Mock
-from agent_framework import AgentResponseUpdate, Role, TextContent
+from agent_framework import AgentResponseUpdate, Content, Role
from chatkit.types import (
ThreadItemAddedEvent,
ThreadItemDoneEvent,
@@ -34,7 +34,7 @@ async def test_stream_single_text_update(self):
"""Test streaming single text update."""
async def single_update_stream():
- yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello world")])
+ yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="Hello world")])
events = []
async for event in stream_agent_response(single_update_stream(), thread_id="test_thread"):
@@ -59,8 +59,8 @@ async def test_stream_multiple_text_updates(self):
"""Test streaming multiple text updates."""
async def multiple_updates_stream():
- yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Hello ")])
- yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="world!")])
+ yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="Hello ")])
+ yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="world!")])
events = []
async for event in stream_agent_response(multiple_updates_stream(), thread_id="test_thread"):
@@ -91,7 +91,7 @@ def custom_id_generator(item_type: str) -> str:
return f"custom_{item_type}_123"
async def single_update_stream():
- yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[TextContent(text="Test")])
+ yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[Content.from_text(text="Test")])
events = []
async for event in stream_agent_response(
@@ -125,9 +125,10 @@ async def empty_content_stream():
async def test_stream_non_text_content(self):
"""Test streaming updates with non-text content."""
# Mock a content object without text attribute
- non_text_content = Mock()
+ non_text_content = Mock(spec=Content)
+ non_text_content.type = "image"
# Don't set text attribute
- del non_text_content.text
+ non_text_content.text = None
async def non_text_stream():
yield AgentResponseUpdate(role=Role.ASSISTANT, contents=[non_text_content])
diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py
index 7dc676b06a..98d5a2b475 100644
--- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py
+++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py
@@ -10,9 +10,9 @@
AgentThread,
BaseAgent,
ChatMessage,
+ Content,
ContextProvider,
Role,
- TextContent,
normalize_messages,
)
from agent_framework._pydantic import AFBaseSettings
@@ -332,7 +332,7 @@ async def _process_activities(self, activities: AsyncIterable[Any], streaming: b
):
yield ChatMessage(
role=Role.ASSISTANT,
- contents=[TextContent(activity.text)],
+ contents=[Content.from_text(activity.text)],
author_name=activity.from_property.name if activity.from_property else None,
message_id=activity.id,
raw_representation=activity,
diff --git a/python/packages/copilotstudio/tests/test_copilot_agent.py b/python/packages/copilotstudio/tests/test_copilot_agent.py
index 4777557d32..c4e2ff3e08 100644
--- a/python/packages/copilotstudio/tests/test_copilot_agent.py
+++ b/python/packages/copilotstudio/tests/test_copilot_agent.py
@@ -4,14 +4,7 @@
from unittest.mock import MagicMock, patch
import pytest
-from agent_framework import (
- AgentResponse,
- AgentResponseUpdate,
- AgentThread,
- ChatMessage,
- Role,
- TextContent,
-)
+from agent_framework import AgentResponse, AgentResponseUpdate, AgentThread, ChatMessage, Content, Role
from agent_framework.exceptions import ServiceException, ServiceInitializationError
from microsoft_agents.copilotstudio.client import CopilotClient
@@ -136,7 +129,7 @@ async def test_run_with_string_message(self, mock_copilot_client: MagicMock, moc
assert isinstance(response, AgentResponse)
assert len(response.messages) == 1
content = response.messages[0].contents[0]
- assert isinstance(content, TextContent)
+ assert content.type == "text"
assert content.text == "Test response"
assert response.messages[0].role == Role.ASSISTANT
@@ -150,13 +143,13 @@ async def test_run_with_chat_message(self, mock_copilot_client: MagicMock, mock_
mock_copilot_client.start_conversation.return_value = create_async_generator([conversation_activity])
mock_copilot_client.ask_question.return_value = create_async_generator([mock_activity])
- chat_message = ChatMessage(role=Role.USER, contents=[TextContent("test message")])
+ chat_message = ChatMessage(role=Role.USER, contents=[Content.from_text("test message")])
response = await agent.run(chat_message)
assert isinstance(response, AgentResponse)
assert len(response.messages) == 1
content = response.messages[0].contents[0]
- assert isinstance(content, TextContent)
+ assert content.type == "text"
assert content.text == "Test response"
assert response.messages[0].role == Role.ASSISTANT
@@ -206,7 +199,7 @@ async def test_run_stream_with_string_message(self, mock_copilot_client: MagicMo
async for response in agent.run_stream("test message"):
assert isinstance(response, AgentResponseUpdate)
content = response.contents[0]
- assert isinstance(content, TextContent)
+ assert content.type == "text"
assert content.text == "Streaming response"
response_count += 1
@@ -233,7 +226,7 @@ async def test_run_stream_with_thread(self, mock_copilot_client: MagicMock) -> N
async for response in agent.run_stream("test message", thread=thread):
assert isinstance(response, AgentResponseUpdate)
content = response.contents[0]
- assert isinstance(content, TextContent)
+ assert content.type == "text"
assert content.text == "Streaming response"
response_count += 1
diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py
index c748d7f2df..412221af1f 100644
--- a/python/packages/core/agent_framework/_agents.py
+++ b/python/packages/core/agent_framework/_agents.py
@@ -1140,9 +1140,9 @@ async def _call_tool( # type: ignore
# Convert result to MCP content
if isinstance(result, str):
- return [types.TextContent(type="text", text=result)]
+ return [types.TextContent(type="text", text=result)] # type: ignore[attr-defined]
- return [types.TextContent(type="text", text=str(result))]
+ return [types.TextContent(type="text", text=str(result))] # type: ignore[attr-defined]
@server.set_logging_level() # type: ignore
async def _set_logging_level(level: types.LoggingLevel) -> None: # type: ignore
diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py
index 3a6d5b818c..db912dd3a5 100644
--- a/python/packages/core/agent_framework/_mcp.py
+++ b/python/packages/core/agent_framework/_mcp.py
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
+import base64
import logging
import re
import sys
@@ -29,13 +30,8 @@
)
from ._types import (
ChatMessage,
- Contents,
- DataContent,
- FunctionCallContent,
- FunctionResultContent,
+ Content,
Role,
- TextContent,
- UriContent,
)
from .exceptions import ToolException, ToolExecutionException
@@ -82,7 +78,7 @@ def _parse_message_from_mcp(
def _parse_contents_from_mcp_tool_result(
mcp_type: types.CallToolResult,
-) -> list[Contents]:
+) -> list[Content]:
"""Parse an MCP CallToolResult into Agent Framework content types.
This function extracts the complete _meta field from CallToolResult objects
@@ -147,25 +143,27 @@ def _parse_content_from_mcp(
| types.ToolUseContent
| types.ToolResultContent
],
-) -> list[Contents]:
+) -> list[Content]:
"""Parse an MCP type into an Agent Framework type."""
mcp_types = mcp_type if isinstance(mcp_type, Sequence) else [mcp_type]
- return_types: list[Contents] = []
+ return_types: list[Content] = []
for mcp_type in mcp_types:
match mcp_type:
case types.TextContent():
- return_types.append(TextContent(text=mcp_type.text, raw_representation=mcp_type))
+ return_types.append(Content.from_text(text=mcp_type.text, raw_representation=mcp_type))
case types.ImageContent() | types.AudioContent():
+ # MCP protocol uses base64-encoded strings, convert to bytes
+ data_bytes = base64.b64decode(mcp_type.data) if isinstance(mcp_type.data, str) else mcp_type.data
return_types.append(
- DataContent(
- data=mcp_type.data,
+ Content.from_data(
+ data=data_bytes,
media_type=mcp_type.mimeType,
raw_representation=mcp_type,
)
)
case types.ResourceLink():
return_types.append(
- UriContent(
+ Content.from_uri(
uri=str(mcp_type.uri),
media_type=mcp_type.mimeType or "application/json",
raw_representation=mcp_type,
@@ -173,7 +171,7 @@ def _parse_content_from_mcp(
)
case types.ToolUseContent():
return_types.append(
- FunctionCallContent(
+ Content.from_function_call(
call_id=mcp_type.id,
name=mcp_type.name,
arguments=mcp_type.input,
@@ -182,12 +180,12 @@ def _parse_content_from_mcp(
)
case types.ToolResultContent():
return_types.append(
- FunctionResultContent(
+ Content.from_function_result(
call_id=mcp_type.toolUseId,
result=_parse_content_from_mcp(mcp_type.content)
if mcp_type.content
else mcp_type.structuredContent,
- exception=Exception() if mcp_type.isError else None,
+ exception=str(Exception()) if mcp_type.isError else None, # type: ignore[arg-type]
raw_representation=mcp_type,
)
)
@@ -195,7 +193,7 @@ def _parse_content_from_mcp(
match mcp_type.resource:
case types.TextResourceContents():
return_types.append(
- TextContent(
+ Content.from_text(
text=mcp_type.resource.text,
raw_representation=mcp_type,
additional_properties=(
@@ -205,7 +203,7 @@ def _parse_content_from_mcp(
)
case types.BlobResourceContents():
return_types.append(
- DataContent(
+ Content.from_uri(
uri=mcp_type.resource.blob,
media_type=mcp_type.resource.mimeType,
raw_representation=mcp_type,
@@ -218,45 +216,41 @@ def _parse_content_from_mcp(
def _prepare_content_for_mcp(
- content: Contents,
+ content: Content,
) -> types.TextContent | types.ImageContent | types.AudioContent | types.EmbeddedResource | types.ResourceLink | None:
"""Prepare an Agent Framework content type for MCP."""
- match content:
- case TextContent():
- return types.TextContent(type="text", text=content.text)
- case DataContent():
- if content.media_type and content.media_type.startswith("image/"):
- return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type)
- if content.media_type and content.media_type.startswith("audio/"):
- return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type)
- if content.media_type and content.media_type.startswith("application/"):
- return types.EmbeddedResource(
- type="resource",
- resource=types.BlobResourceContents(
- blob=content.uri,
- mimeType=content.media_type,
- # uri's are not limited in MCP but they have to be set.
- # the uri of data content, contains the data uri, which
- # is not the uri meant here, UriContent would match this.
- uri=(
- content.additional_properties.get("uri", "af://binary")
- if content.additional_properties
- else "af://binary"
- ), # type: ignore[reportArgumentType]
- ),
- )
- return None
- case UriContent():
- return types.ResourceLink(
- type="resource_link",
- uri=content.uri, # type: ignore[reportArgumentType]
- mimeType=content.media_type,
- name=(
- content.additional_properties.get("name", "Unknown") if content.additional_properties else "Unknown"
+ if content.type == "text":
+ return types.TextContent(type="text", text=content.text) # type: ignore[attr-defined]
+ if content.type == "data":
+ if content.media_type and content.media_type.startswith("image/"): # type: ignore[attr-defined]
+ return types.ImageContent(type="image", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined]
+ if content.media_type and content.media_type.startswith("audio/"): # type: ignore[attr-defined]
+ return types.AudioContent(type="audio", data=content.uri, mimeType=content.media_type) # type: ignore[attr-defined]
+ if content.media_type and content.media_type.startswith("application/"): # type: ignore[attr-defined]
+ return types.EmbeddedResource(
+ type="resource",
+ resource=types.BlobResourceContents(
+ blob=content.uri, # type: ignore[attr-defined]
+ mimeType=content.media_type, # type: ignore[attr-defined]
+ # uri's are not limited in MCP but they have to be set.
+ # the uri of data content, contains the data uri, which
+ # is not the uri meant here, UriContent would match this.
+ uri=(
+ content.additional_properties.get("uri", "af://binary")
+ if content.additional_properties
+ else "af://binary"
+ ), # type: ignore[reportArgumentType]
),
)
- case _:
- return None
+ return None
+ if content.type == "uri":
+ return types.ResourceLink(
+ type="resource_link",
+ uri=content.uri, # type: ignore[reportArgumentType,attr-defined]
+ mimeType=content.media_type, # type: ignore[attr-defined]
+ name=(content.additional_properties.get("name", "Unknown") if content.additional_properties else "Unknown"),
+ )
+ return None
def _prepare_message_for_mcp(
@@ -650,7 +644,7 @@ async def load_tools(self) -> None:
input_model = _get_input_model_from_mcp_tool(tool)
approval_mode = self._determine_approval_mode(local_name)
# Create AIFunctions out of each tool
- func: AIFunction[BaseModel, list[Contents] | Any | types.CallToolResult] = AIFunction(
+ func: AIFunction[BaseModel, list[Content] | Any | types.CallToolResult] = AIFunction(
func=partial(self.call_tool, tool.name),
name=local_name,
description=tool.description or "",
@@ -704,7 +698,7 @@ async def _ensure_connected(self) -> None:
inner_exception=ex,
) from ex
- async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Contents] | Any | types.CallToolResult:
+ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Content] | Any | types.CallToolResult:
"""Call a tool with the given arguments.
Args:
diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py
index a0d0a13dc2..3ea7d33e72 100644
--- a/python/packages/core/agent_framework/_tools.py
+++ b/python/packages/core/agent_framework/_tools.py
@@ -54,9 +54,7 @@
ChatMessage,
ChatResponse,
ChatResponseUpdate,
- Contents,
- FunctionApprovalResponseContent,
- FunctionCallContent,
+ Content,
)
from typing import overload
@@ -104,15 +102,15 @@ def record(self, *args: Any, **kwargs: Any) -> None: # pragma: no cover - trivi
def _parse_inputs(
- inputs: "Contents | dict[str, Any] | str | list[Contents | dict[str, Any] | str] | None",
-) -> list["Contents"]:
- """Parse the inputs for a tool, ensuring they are of type Contents.
+ inputs: "Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None",
+) -> list["Content"]:
+ """Parse the inputs for a tool, ensuring they are of type Content.
Args:
- inputs: The inputs to parse. Can be a single item or list of Contents, dicts, or strings.
+ inputs: The inputs to parse. Can be a single item or list of Content, dicts, or strings.
Returns:
- A list of Contents objects.
+ A list of Content objects.
Raises:
ValueError: If an unsupported input type is encountered.
@@ -122,43 +120,39 @@ def _parse_inputs(
return []
from ._types import (
- BaseContent,
- DataContent,
- HostedFileContent,
- HostedVectorStoreContent,
- UriContent,
+ Content,
)
- parsed_inputs: list["Contents"] = []
+ parsed_inputs: list["Content"] = []
if not isinstance(inputs, list):
inputs = [inputs]
for input_item in inputs:
if isinstance(input_item, str):
# If it's a string, we assume it's a URI or similar identifier.
# Convert it to a UriContent or similar type as needed.
- parsed_inputs.append(UriContent(uri=input_item, media_type="text/plain"))
+ parsed_inputs.append(Content.from_uri(uri=input_item, media_type="text/plain"))
elif isinstance(input_item, dict):
# If it's a dict, we assume it contains properties for a specific content type.
# we check if the required keys are present to determine the type.
# for instance, if it has "uri" and "media_type", we treat it as UriContent.
- # if is only has uri, then we treat it as DataContent.
+ # if it only has uri and media_type without a specific type indicator, we treat it as DataContent.
# etc.
if "uri" in input_item:
- parsed_inputs.append(
- UriContent(**input_item) if "media_type" in input_item else DataContent(**input_item)
- )
+ # Use Content.from_uri for proper URI content, DataContent for backwards compatibility
+ parsed_inputs.append(Content.from_uri(**input_item))
elif "file_id" in input_item:
- parsed_inputs.append(HostedFileContent(**input_item))
+ parsed_inputs.append(Content.from_hosted_file(**input_item))
elif "vector_store_id" in input_item:
- parsed_inputs.append(HostedVectorStoreContent(**input_item))
+ parsed_inputs.append(Content.from_hosted_vector_store(**input_item))
elif "data" in input_item:
- parsed_inputs.append(DataContent(**input_item))
+ # DataContent helper handles both uri and data parameters
+ parsed_inputs.append(Content.from_data(**input_item))
else:
raise ValueError(f"Unsupported input type: {input_item}")
- elif isinstance(input_item, BaseContent):
+ elif isinstance(input_item, Content):
parsed_inputs.append(input_item)
else:
- raise TypeError(f"Unsupported input type: {type(input_item).__name__}. Expected Contents or dict.")
+ raise TypeError(f"Unsupported input type: {type(input_item).__name__}. Expected Content or dict.")
return parsed_inputs
@@ -254,7 +248,7 @@ class HostedCodeInterpreterTool(BaseTool):
def __init__(
self,
*,
- inputs: "Contents | dict[str, Any] | str | list[Contents | dict[str, Any] | str] | None" = None,
+ inputs: "Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None" = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
@@ -266,8 +260,8 @@ def __init__(
This should mostly be HostedFileContent or HostedVectorStoreContent.
Can also be DataContent, depending on the service used.
When supplying a list, it can contain:
- - Contents instances
- - dicts with properties for Contents (e.g., {"uri": "http://example.com", "media_type": "text/html"})
+ - Content instances
+ - dicts with properties for Content (e.g., {"uri": "http://example.com", "media_type": "text/html"})
- strings (which will be converted to UriContent with media_type "text/plain").
If None, defaults to an empty list.
description: A description of the tool.
@@ -503,7 +497,7 @@ class HostedFileSearchTool(BaseTool):
def __init__(
self,
*,
- inputs: "Contents | dict[str, Any] | str | list[Contents | dict[str, Any] | str] | None" = None,
+ inputs: "Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None" = None,
max_results: int | None = None,
description: str | None = None,
additional_properties: dict[str, Any] | None = None,
@@ -515,8 +509,8 @@ def __init__(
inputs: A list of contents that the tool can accept as input. Defaults to None.
This should be one or more HostedVectorStoreContents.
When supplying a list, it can contain:
- - Contents instances
- - dicts with properties for Contents (e.g., {"uri": "http://example.com", "media_type": "text/html"})
+ - Content instances
+ - dicts with properties for Content (e.g., {"uri": "http://example.com", "media_type": "text/html"})
- strings (which will be converted to UriContent with media_type "text/plain").
If None, defaults to an empty list.
max_results: The maximum number of results to return from the file search.
@@ -1480,7 +1474,7 @@ class FunctionExecutionResult:
__slots__ = ("content", "terminate")
- def __init__(self, content: "Contents", terminate: bool = False) -> None:
+ def __init__(self, content: "Content", terminate: bool = False) -> None:
"""Initialize FunctionExecutionResult.
Args:
@@ -1492,7 +1486,7 @@ def __init__(self, content: "Contents", terminate: bool = False) -> None:
async def _auto_invoke_function(
- function_call_content: "FunctionCallContent | FunctionApprovalResponseContent",
+ function_call_content: "Content",
custom_args: dict[str, Any] | None = None,
*,
config: FunctionInvocationConfiguration,
@@ -1500,7 +1494,7 @@ async def _auto_invoke_function(
sequence_index: int | None = None,
request_index: int | None = None,
middleware_pipeline: Any = None, # Optional MiddlewarePipeline
-) -> "FunctionExecutionResult | Contents":
+) -> "FunctionExecutionResult | Content":
"""Invoke a function call requested by the agent, applying middleware that is defined.
Args:
@@ -1516,41 +1510,42 @@ async def _auto_invoke_function(
Returns:
A FunctionExecutionResult wrapping the content and terminate signal,
- or a Contents object for approval/hosted tool scenarios.
+ or a Content object for approval/hosted tool scenarios.
Raises:
KeyError: If the requested function is not found in the tool map.
"""
+ from ._types import Content
+
# Note: The scenarios for approval_mode="always_require", declaration_only, and
# terminate_on_unknown_calls are all handled in _try_execute_function_calls before
# this function is called. This function only handles the actual execution of approved,
# non-declaration-only functions.
- from ._types import FunctionCallContent, FunctionResultContent
tool: AIFunction[BaseModel, Any] | None = None
if function_call_content.type == "function_call":
- tool = tool_map.get(function_call_content.name)
+ tool = tool_map.get(function_call_content.name) # type: ignore[arg-type]
# Tool should exist because _try_execute_function_calls validates this
if tool is None:
exc = KeyError(f'Function "{function_call_content.name}" not found.')
return FunctionExecutionResult(
- content=FunctionResultContent(
- call_id=function_call_content.call_id,
+ content=Content.from_function_result(
+ call_id=function_call_content.call_id, # type: ignore[arg-type]
result=f'Error: Requested function "{function_call_content.name}" not found.',
- exception=exc,
+ exception=str(exc), # type: ignore[arg-type]
)
)
else:
# Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results
# and never reach this function, so we only handle approved=True cases here.
- inner_call = function_call_content.function_call
- if not isinstance(inner_call, FunctionCallContent):
+ inner_call = function_call_content.function_call # type: ignore[attr-defined]
+ if inner_call.type != "function_call": # type: ignore[union-attr]
return function_call_content
- tool = tool_map.get(inner_call.name)
+ tool = tool_map.get(inner_call.name) # type: ignore[attr-defined, union-attr, arg-type]
if tool is None:
# we assume it is a hosted tool
return function_call_content
- function_call_content = inner_call
+ function_call_content = inner_call # type: ignore[assignment]
parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})
@@ -1567,7 +1562,11 @@ async def _auto_invoke_function(
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionExecutionResult(
- content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
+ content=Content.from_function_result(
+ call_id=function_call_content.call_id, # type: ignore[arg-type]
+ result=message,
+ exception=str(exc), # type: ignore[arg-type]
+ )
)
if not middleware_pipeline or (
@@ -1581,8 +1580,8 @@ async def _auto_invoke_function(
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
)
return FunctionExecutionResult(
- content=FunctionResultContent(
- call_id=function_call_content.call_id,
+ content=Content.from_function_result(
+ call_id=function_call_content.call_id, # type: ignore[arg-type]
result=function_result,
)
)
@@ -1591,7 +1590,11 @@ async def _auto_invoke_function(
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionExecutionResult(
- content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
+ content=Content.from_function_result(
+ call_id=function_call_content.call_id, # type: ignore[arg-type]
+ result=message,
+ exception=str(exc),
+ )
)
# Execute through middleware pipeline if available
from ._middleware import FunctionInvocationContext
@@ -1617,8 +1620,8 @@ async def final_function_handler(context_obj: Any) -> Any:
final_handler=final_function_handler,
)
return FunctionExecutionResult(
- content=FunctionResultContent(
- call_id=function_call_content.call_id,
+ content=Content.from_function_result(
+ call_id=function_call_content.call_id, # type: ignore[arg-type]
result=function_result,
),
terminate=middleware_context.terminate,
@@ -1628,7 +1631,11 @@ async def final_function_handler(context_obj: Any) -> Any:
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionExecutionResult(
- content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
+ content=Content.from_function_result(
+ call_id=function_call_content.call_id, # type: ignore[arg-type]
+ result=message,
+ exception=str(exc), # type: ignore[arg-type]
+ )
)
@@ -1653,14 +1660,14 @@ def _get_tool_map(
async def _try_execute_function_calls(
custom_args: dict[str, Any],
attempt_idx: int,
- function_calls: Sequence["FunctionCallContent"] | Sequence["FunctionApprovalResponseContent"],
+ function_calls: Sequence["Content"],
tools: "ToolProtocol \
| Callable[..., Any] \
| MutableMapping[str, Any] \
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]",
config: FunctionInvocationConfiguration,
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
-) -> tuple[Sequence["Contents"], bool]:
+) -> tuple[Sequence["Content"], bool]:
"""Execute multiple function calls concurrently.
Args:
@@ -1673,12 +1680,12 @@ async def _try_execute_function_calls(
Returns:
A tuple of:
- - A list of Contents containing the results of each function call,
+ - A list of Content containing the results of each function call,
or the approval requests if any function requires approval,
or the original function calls if any are declaration only.
- A boolean indicating whether to terminate the function calling loop.
"""
- from ._types import FunctionApprovalRequestContent, FunctionCallContent
+ from ._types import Content
tool_map = _get_tool_map(tools)
approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"]
@@ -1689,27 +1696,27 @@ async def _try_execute_function_calls(
approval_needed = False
declaration_only_flag = False
for fcc in function_calls:
- if isinstance(fcc, FunctionCallContent) and fcc.name in approval_tools:
+ if fcc.type == "function_call" and fcc.name in approval_tools: # type: ignore[attr-defined]
approval_needed = True
break
- if isinstance(fcc, FunctionCallContent) and (fcc.name in declaration_only or fcc.name in additional_tool_names):
+ if fcc.type == "function_call" and (fcc.name in declaration_only or fcc.name in additional_tool_names): # type: ignore[attr-defined]
declaration_only_flag = True
break
- if config.terminate_on_unknown_calls and isinstance(fcc, FunctionCallContent) and fcc.name not in tool_map:
- raise KeyError(f'Error: Requested function "{fcc.name}" not found.')
+ if config.terminate_on_unknown_calls and fcc.type == "function_call" and fcc.name not in tool_map: # type: ignore[attr-defined]
+ raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined]
if approval_needed:
- # approval can only be needed for Function Call Contents, not Approval Responses.
+ # approval can only be needed for Function Call Content, not Approval Responses.
return (
[
- FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc)
+ Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type]
for fcc in function_calls
- if isinstance(fcc, FunctionCallContent)
+ if fcc.type == "function_call"
],
False,
)
if declaration_only_flag:
# return the declaration only tools to the user, since we cannot execute them.
- return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False)
+ return ([fcc for fcc in function_calls if fcc.type == "function_call"], False)
# Run all function calls concurrently
execution_results = await asyncio.gather(*[
@@ -1726,7 +1733,7 @@ async def _try_execute_function_calls(
])
# Unpack FunctionExecutionResult wrappers and check for terminate signal
- contents: list[Contents] = []
+ contents: list[Content] = []
should_terminate = False
for result in execution_results:
if isinstance(result, FunctionExecutionResult):
@@ -1734,7 +1741,7 @@ async def _try_execute_function_calls(
if result.terminate:
should_terminate = True
else:
- # Direct Contents (e.g., from hosted tools)
+ # Direct Content (e.g., from hosted tools)
contents.append(result)
return (contents, should_terminate)
@@ -1772,30 +1779,27 @@ def _extract_tools(options: dict[str, Any] | None) -> Any:
def _collect_approval_responses(
messages: "list[ChatMessage]",
-) -> dict[str, "FunctionApprovalResponseContent"]:
+) -> dict[str, "Content"]:
"""Collect approval responses (both approved and rejected) from messages."""
- from ._types import ChatMessage, FunctionApprovalResponseContent
+ from ._types import ChatMessage, Content
- fcc_todo: dict[str, FunctionApprovalResponseContent] = {}
+ fcc_todo: dict[str, Content] = {}
for msg in messages:
for content in msg.contents if isinstance(msg, ChatMessage) else []:
# Collect BOTH approved and rejected responses
- if isinstance(content, FunctionApprovalResponseContent):
- fcc_todo[content.id] = content
+ if content.type == "function_approval_response":
+ fcc_todo[content.id] = content # type: ignore[attr-defined, index]
return fcc_todo
def _replace_approval_contents_with_results(
messages: "list[ChatMessage]",
- fcc_todo: dict[str, "FunctionApprovalResponseContent"],
- approved_function_results: "list[Contents]",
+ fcc_todo: dict[str, "Content"],
+ approved_function_results: "list[Content]",
) -> None:
"""Replace approval request/response contents with function call/result contents in-place."""
from ._types import (
- FunctionApprovalRequestContent,
- FunctionApprovalResponseContent,
- FunctionCallContent,
- FunctionResultContent,
+ Content,
Role,
)
@@ -1803,23 +1807,25 @@ def _replace_approval_contents_with_results(
for msg in messages:
# First pass - collect existing function call IDs to avoid duplicates
existing_call_ids = {
- content.call_id for content in msg.contents if isinstance(content, FunctionCallContent) and content.call_id
+ content.call_id # type: ignore[union-attr, operator]
+ for content in msg.contents
+ if content.type == "function_call" and content.call_id # type: ignore[attr-defined]
}
# Track approval requests that should be removed (duplicates)
contents_to_remove = []
for content_idx, content in enumerate(msg.contents):
- if isinstance(content, FunctionApprovalRequestContent):
+ if content.type == "function_approval_request":
# Don't add the function call if it already exists (would create duplicate)
- if content.function_call.call_id in existing_call_ids:
+ if content.function_call.call_id in existing_call_ids: # type: ignore[attr-defined, union-attr, operator]
# Just mark for removal - the function call already exists
contents_to_remove.append(content_idx)
else:
# Put back the function call content only if it doesn't exist
- msg.contents[content_idx] = content.function_call
- elif isinstance(content, FunctionApprovalResponseContent):
- if content.approved and content.id in fcc_todo:
+ msg.contents[content_idx] = content.function_call # type: ignore[attr-defined, assignment]
+ elif content.type == "function_approval_response":
+ if content.approved and content.id in fcc_todo: # type: ignore[attr-defined]
# Replace with the corresponding result
if result_idx < len(approved_function_results):
msg.contents[content_idx] = approved_function_results[result_idx]
@@ -1828,8 +1834,8 @@ def _replace_approval_contents_with_results(
else:
# Create a "not approved" result for rejected calls
# Use function_call.call_id (the function's ID), not content.id (approval's ID)
- msg.contents[content_idx] = FunctionResultContent(
- call_id=content.function_call.call_id,
+ msg.contents[content_idx] = Content.from_function_result(
+ call_id=content.function_call.call_id, # type: ignore[union-attr, arg-type]
result="Error: Tool call invocation was rejected by user.",
)
msg.role = Role.TOOL
@@ -1867,9 +1873,6 @@ async def function_invocation_wrapper(
from ._middleware import extract_and_merge_function_middleware
from ._types import (
ChatMessage,
- FunctionApprovalRequestContent,
- FunctionCallContent,
- FunctionResultContent,
prepare_messages,
)
@@ -1893,7 +1896,7 @@ async def function_invocation_wrapper(
tools = _extract_tools(options)
# Only execute APPROVED function calls, not rejected ones
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
- approved_function_results: list[Contents] = []
+ approved_function_results: list[Content] = []
if approved_responses:
results, _ = await _try_execute_function_calls(
custom_args=kwargs,
@@ -1907,7 +1910,7 @@ async def function_invocation_wrapper(
if any(
fcr.exception is not None
for fcr in approved_function_results
- if isinstance(fcr, FunctionResultContent)
+ if fcr.type == "function_result"
):
errors_in_a_row += 1
# no need to reset the counter here, since this is the start of a new attempt.
@@ -1926,13 +1929,11 @@ async def function_invocation_wrapper(
filtered_kwargs = {k: v for k, v in kwargs.items() if k not in ("thread", "tools", "tool_choice")}
response = await func(self, messages=prepped_messages, options=options, **filtered_kwargs)
# if there are function calls, we will handle them first
- function_results = {
- it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent)
- }
+ function_results = {it.call_id for it in response.messages[0].contents if it.type == "function_result"}
function_calls = [
it
for it in response.messages[0].contents
- if isinstance(it, FunctionCallContent) and it.call_id not in function_results
+ if it.type == "function_call" and it.call_id not in function_results
]
if response.conversation_id is not None:
@@ -1953,7 +1954,7 @@ async def function_invocation_wrapper(
config=config,
)
# Check if we have approval requests or function calls (not results) in the results
- if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results):
+ if any(fccr.type == "function_approval_request" for fccr in function_call_results):
# Add approval requests to the existing assistant message (with tool_calls)
# instead of creating a separate tool message
from ._types import Role
@@ -1965,7 +1966,7 @@ async def function_invocation_wrapper(
result_message = ChatMessage(role="assistant", contents=function_call_results)
response.messages.append(result_message)
return response
- if any(isinstance(fccr, FunctionCallContent) for fccr in function_call_results):
+ if any(fccr.type == "function_call" for fccr in function_call_results):
# the function calls are already in the response, so we just continue
return response
@@ -1980,11 +1981,7 @@ async def function_invocation_wrapper(
response.messages.insert(0, msg)
return response
- if any(
- fcr.exception is not None
- for fcr in function_call_results
- if isinstance(fcr, FunctionResultContent)
- ):
+ if any(fcr.exception is not None for fcr in function_call_results if fcr.type == "function_result"):
errors_in_a_row += 1
if errors_in_a_row >= config.max_consecutive_errors_per_request:
logger.warning(
@@ -2071,8 +2068,6 @@ async def streaming_function_invocation_wrapper(
ChatMessage,
ChatResponse,
ChatResponseUpdate,
- FunctionCallContent,
- FunctionResultContent,
prepare_messages,
)
@@ -2094,7 +2089,7 @@ async def streaming_function_invocation_wrapper(
tools = _extract_tools(options)
# Only execute APPROVED function calls, not rejected ones
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
- approved_function_results: list[Contents] = []
+ approved_function_results: list[Content] = []
if approved_responses:
results, _ = await _try_execute_function_calls(
custom_args=kwargs,
@@ -2108,7 +2103,7 @@ async def streaming_function_invocation_wrapper(
if any(
fcr.exception is not None
for fcr in approved_function_results
- if isinstance(fcr, FunctionResultContent)
+ if fcr.type == "function_result"
):
errors_in_a_row += 1
# no need to reset the counter here, since this is the start of a new attempt.
@@ -2124,10 +2119,9 @@ async def streaming_function_invocation_wrapper(
# efficient check for FunctionCallContent in the updates
# if there is at least one, this stops and continuous
# if there are no FCC's then it returns
- from ._types import FunctionApprovalRequestContent
if not any(
- isinstance(item, (FunctionCallContent, FunctionApprovalRequestContent))
+ item.type in ("function_call", "function_approval_request")
for upd in all_updates
for item in upd.contents
):
@@ -2139,13 +2133,11 @@ async def streaming_function_invocation_wrapper(
response: "ChatResponse" = ChatResponse.from_chat_response_updates(all_updates)
# get the function calls (excluding ones that already have results)
- function_results = {
- it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent)
- }
+ function_results = {it.call_id for it in response.messages[0].contents if it.type == "function_result"}
function_calls = [
it
for it in response.messages[0].contents
- if isinstance(it, FunctionCallContent) and it.call_id not in function_results
+ if it.type == "function_call" and it.call_id not in function_results
]
# When conversation id is present, it means that messages are hosted on the server.
@@ -2169,7 +2161,7 @@ async def streaming_function_invocation_wrapper(
)
# Check if we have approval requests or function calls (not results) in the results
- if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results):
+ if any(fccr.type == "function_approval_request" for fccr in function_call_results):
# Add approval requests to the existing assistant message (with tool_calls)
# instead of creating a separate tool message
from ._types import Role
@@ -2184,7 +2176,7 @@ async def streaming_function_invocation_wrapper(
yield ChatResponseUpdate(contents=function_call_results, role="assistant")
response.messages.append(result_message)
return
- if any(isinstance(fccr, FunctionCallContent) for fccr in function_call_results):
+ if any(fccr.type == "function_call" for fccr in function_call_results):
# the function calls were already yielded.
return
@@ -2195,11 +2187,7 @@ async def streaming_function_invocation_wrapper(
yield ChatResponseUpdate(contents=function_call_results, role="tool")
return
- if any(
- fcr.exception is not None
- for fcr in function_call_results
- if isinstance(fcr, FunctionResultContent)
- ):
+ if any(fcr.exception is not None for fcr in function_call_results if fcr.type == "function_result"):
errors_in_a_row += 1
if errors_in_a_row >= config.max_consecutive_errors_per_request:
logger.warning(
diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py
index fce99b3488..d586f9ff5d 100644
--- a/python/packages/core/agent_framework/_types.py
+++ b/python/packages/core/agent_framework/_types.py
@@ -3,7 +3,6 @@
import base64
import json
import re
-import sys
from collections.abc import (
AsyncIterable,
Callable,
@@ -13,7 +12,7 @@
Sequence,
)
from copy import deepcopy
-from typing import Any, ClassVar, Literal, TypedDict, TypeVar, cast, overload
+from typing import Any, ClassVar, Final, Literal, TypedDict, TypeVar, overload
from pydantic import BaseModel, ValidationError
@@ -22,49 +21,23 @@
from ._tools import ToolProtocol, ai_function
from .exceptions import AdditionItemMismatch, ContentError
-if sys.version_info >= (3, 11):
- from typing import Self # pragma: no cover
-else:
- from typing_extensions import Self # pragma: no cover
-
-
__all__ = [
"AgentResponse",
"AgentResponseUpdate",
- "AnnotatedRegions",
- "Annotations",
- "BaseAnnotation",
- "BaseContent",
+ "Annotation",
"ChatMessage",
- "ChatOptions", # Backward compatibility alias
"ChatOptions",
"ChatResponse",
"ChatResponseUpdate",
- "CitationAnnotation",
- "CodeInterpreterToolCallContent",
- "CodeInterpreterToolResultContent",
- "Contents",
- "DataContent",
- "ErrorContent",
+ "Content",
"FinishReason",
- "FunctionApprovalRequestContent",
- "FunctionApprovalResponseContent",
- "FunctionCallContent",
- "FunctionResultContent",
- "HostedFileContent",
- "HostedVectorStoreContent",
- "ImageGenerationToolCallContent",
- "ImageGenerationToolResultContent",
- "MCPServerToolCallContent",
- "MCPServerToolResultContent",
"Role",
- "TextContent",
- "TextReasoningContent",
+ "TextSpanRegion",
"TextSpanRegion",
"ToolMode",
- "UriContent",
- "UsageContent",
"UsageDetails",
+ "add_usage_details",
+ "detect_media_type_from_base64",
"merge_chat_options",
"normalize_messages",
"normalize_tools",
@@ -103,84 +76,233 @@ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any])
return cls
-def _parse_content(content_data: MutableMapping[str, Any]) -> "Contents":
- """Parse a single content data dictionary into the appropriate Content object.
+def _parse_content_list(contents_data: Sequence[Any]) -> list["Content"]:
+ """Parse a list of content data dictionaries into appropriate Content objects.
Args:
- content_data: Content data (dict)
+ contents_data: List of content data (dicts or already constructed objects)
Returns:
- Content object
+ List of Content objects with unknown types logged and ignored
+ """
+ contents: list["Content"] = []
+ for content_data in contents_data:
+ if isinstance(content_data, Content):
+ contents.append(content_data)
+ continue
+ try:
+ contents.append(Content.from_dict(content_data))
+ except ContentError as exc:
+ logger.warning(f"Skipping unknown content type or invalid content: {exc}")
+
+ return contents
+
+
+# region Internal Helper functions for unified Content
+
+
+def detect_media_type_from_base64(
+ *,
+ data_bytes: bytes | None = None,
+ data_str: str | None = None,
+ data_uri: str | None = None,
+) -> str | None:
+ """Detect media type from base64-encoded data by examining magic bytes.
+
+ This function examines the binary signature (magic bytes) at the start of the data
+ to identify common media types. It's reliable for binary formats like images, audio,
+ video, and documents, but cannot detect text-based formats like JSON or plain text.
+
+ Args:
+ data_bytes: Raw binary data.
+ data_str: Base64-encoded data (without data URI prefix).
+ data_uri: Full data URI string (e.g., "data:image/png;base64,iVBORw0KGgo...").
+ This will look at the actual data to determine the media_type and not at the URI prefix.
+ Will also not compare those two values.
Raises:
- ContentError if parsing fails
+ ValueError: If not exactly 1 of data_bytes, data_str, or data_uri is provided, or if base64 decoding fails.
+
+ Returns:
+ The detected media type (e.g., 'image/png', 'audio/wav', 'application/pdf')
+ or None if the format is not recognized.
+
+ Examples:
+ .. code-block:: python
+
+ from agent_framework import detect_media_type_from_base64
+
+ # Detect from base64 string
+ base64_data = "iVBORw0KGgo..."
+ media_type = detect_media_type_from_base64(base64_data)
+ # Returns: "image/png"
+
+ # Works with data URIs too
+ data_uri = "data:image/png;base64,iVBORw0KGgo..."
+ media_type = detect_media_type_from_base64(data_uri)
+ # Returns: "image/png"
"""
- content_type: str | None = content_data.get("type", None)
- match content_type:
- case "text":
- return TextContent.from_dict(content_data)
- case "data":
- return DataContent.from_dict(content_data)
- case "uri":
- return UriContent.from_dict(content_data)
- case "error":
- return ErrorContent.from_dict(content_data)
- case "function_call":
- return FunctionCallContent.from_dict(content_data)
- case "function_result":
- return FunctionResultContent.from_dict(content_data)
- case "usage":
- return UsageContent.from_dict(content_data)
- case "hosted_file":
- return HostedFileContent.from_dict(content_data)
- case "hosted_vector_store":
- return HostedVectorStoreContent.from_dict(content_data)
- case "code_interpreter_tool_call":
- return CodeInterpreterToolCallContent.from_dict(content_data)
- case "code_interpreter_tool_result":
- return CodeInterpreterToolResultContent.from_dict(content_data)
- case "image_generation_tool_call":
- return ImageGenerationToolCallContent.from_dict(content_data)
- case "image_generation_tool_result":
- return ImageGenerationToolResultContent.from_dict(content_data)
- case "mcp_server_tool_call":
- return MCPServerToolCallContent.from_dict(content_data)
- case "mcp_server_tool_result":
- return MCPServerToolResultContent.from_dict(content_data)
- case "function_approval_request":
- return FunctionApprovalRequestContent.from_dict(content_data)
- case "function_approval_response":
- return FunctionApprovalResponseContent.from_dict(content_data)
- case "text_reasoning":
- return TextReasoningContent.from_dict(content_data)
- case None:
- raise ContentError("Content type is missing")
- case _:
- raise ContentError(f"Unknown content type '{content_type}'")
-
-
-def _parse_content_list(contents_data: Sequence[Any]) -> list["Contents"]:
- """Parse a list of content data dictionaries into appropriate Content objects.
+ data: bytes | None = None
+ if data_bytes is not None:
+ data = data_bytes
+ if data_uri is not None:
+ if data is not None:
+ raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.")
+ # Remove data URI prefix if present
+ data_str = data_uri.split(";base64,", 1)[1]
+ if data_str is not None:
+ if data is not None:
+ raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.")
+ try:
+ data = base64.b64decode(data_str)
+ except Exception as exc:
+ raise ValueError("Invalid base64 data provided.") from exc
+ if data is None:
+ raise ValueError("Provide exactly one of data_bytes, data_str, or data_uri.")
+
+ # Check magic bytes for common formats
+ # Images
+ if data.startswith(b"\x89PNG\r\n\x1a\n"):
+ return "image/png"
+ if data.startswith(b"\xff\xd8\xff"):
+ return "image/jpeg"
+ if data.startswith(b"GIF87a") or data.startswith(b"GIF89a"):
+ return "image/gif"
+ if data.startswith(b"RIFF") and len(data) > 11 and data[8:12] == b"WEBP":
+ return "image/webp"
+ if data.startswith(b"BM"):
+ return "image/bmp"
+ if data.startswith(b"