Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,8 @@
AgentThread,
BaseAgent,
ChatMessage,
Contents,
DataContent,
Content,
Role,
TextContent,
UriContent,
normalize_messages,
prepend_agent_framework_to_user_agent,
)
Expand Down Expand Up @@ -333,7 +330,7 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage:
A2APart(
root=FilePart(
file=FileWithBytes(
bytes=_get_uri_data(content.uri),
bytes=_get_uri_data(content.uri), # type: ignore[arg-type]
mime_type=content.media_type,
),
metadata=content.additional_properties,
Expand Down Expand Up @@ -362,19 +359,19 @@ def _prepare_message_for_a2a(self, message: ChatMessage) -> A2AMessage:
metadata=cast(dict[str, Any], message.additional_properties),
)

def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
"""Parse A2A Parts into Agent Framework Contents.
def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Content]:
"""Parse A2A Parts into Agent Framework Content.

Transforms A2A protocol Parts into framework-native Content objects,
handling text, file (URI/bytes), and data parts with metadata preservation.
"""
contents: list[Contents] = []
contents: list[Content] = []
for part in parts:
inner_part = part.root
match inner_part.kind:
case "text":
contents.append(
TextContent(
Content.from_text(
text=inner_part.text,
additional_properties=inner_part.metadata,
raw_representation=inner_part,
Expand All @@ -383,7 +380,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
case "file":
if isinstance(inner_part.file, FileWithUri):
contents.append(
UriContent(
Content.from_uri(
uri=inner_part.file.uri,
media_type=inner_part.file.mime_type or "",
additional_properties=inner_part.metadata,
Expand All @@ -392,7 +389,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
)
elif isinstance(inner_part.file, FileWithBytes):
contents.append(
DataContent(
Content.from_data(
data=base64.b64decode(inner_part.file.bytes),
media_type=inner_part.file.mime_type or "",
additional_properties=inner_part.metadata,
Expand All @@ -401,7 +398,7 @@ def _parse_contents_from_a2a(self, parts: Sequence[A2APart]) -> list[Contents]:
)
case "data":
contents.append(
TextContent(
Content.from_text(
text=json.dumps(inner_part.data),
additional_properties=inner_part.metadata,
raw_representation=inner_part,
Expand Down
32 changes: 14 additions & 18 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@
AgentResponse,
AgentResponseUpdate,
ChatMessage,
DataContent,
ErrorContent,
HostedFileContent,
Content,
Role,
TextContent,
UriContent,
)
from agent_framework.a2a import A2AAgent
from pytest import fixture, raises
Expand Down Expand Up @@ -289,8 +285,8 @@ def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:

# Verify conversion
assert len(contents) == 2
assert isinstance(contents[0], TextContent)
assert isinstance(contents[1], TextContent)
assert contents[0].type == "text"
assert contents[1].type == "text"
assert contents[0].text == "First part"
assert contents[1].text == "Second part"

Expand All @@ -299,7 +295,7 @@ def test_prepare_message_for_a2a_with_error_content(a2a_agent: A2AAgent) -> None
"""Test _prepare_message_for_a2a with ErrorContent."""

# Create ChatMessage with ErrorContent
error_content = ErrorContent(message="Test error message")
error_content = Content.from_error(message="Test error message")
message = ChatMessage(role=Role.USER, contents=[error_content])

# Convert to A2A message
Expand All @@ -314,7 +310,7 @@ def test_prepare_message_for_a2a_with_uri_content(a2a_agent: A2AAgent) -> None:
"""Test _prepare_message_for_a2a with UriContent."""

# Create ChatMessage with UriContent
uri_content = UriContent(uri="http://example.com/file.pdf", media_type="application/pdf")
uri_content = Content.from_uri(uri="http://example.com/file.pdf", media_type="application/pdf")
message = ChatMessage(role=Role.USER, contents=[uri_content])

# Convert to A2A message
Expand All @@ -330,7 +326,7 @@ def test_prepare_message_for_a2a_with_data_content(a2a_agent: A2AAgent) -> None:
"""Test _prepare_message_for_a2a with DataContent."""

# Create ChatMessage with DataContent (base64 data URI)
data_content = DataContent(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain")
data_content = Content.from_uri(uri="data:text/plain;base64,SGVsbG8gV29ybGQ=", media_type="text/plain")
message = ChatMessage(role=Role.USER, contents=[data_content])

# Convert to A2A message
Expand Down Expand Up @@ -368,7 +364,7 @@ async def test_run_stream_with_message_response(a2a_agent: A2AAgent, mock_a2a_cl
assert len(updates[0].contents) == 1

content = updates[0].contents[0]
assert isinstance(content, TextContent)
assert content.type == "text"
assert content.text == "Streaming response from agent!"

assert updates[0].response_id == "msg-stream-123"
Expand Down Expand Up @@ -414,10 +410,10 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
message = ChatMessage(
role=Role.USER,
contents=[
TextContent(text="Here's the analysis:"),
DataContent(data=b"binary data", media_type="application/octet-stream"),
UriContent(uri="https://example.com/image.png", media_type="image/png"),
TextContent(text='{"structured": "data"}'),
Content.from_text(text="Here's the analysis:"),
Content.from_data(data=b"binary data", media_type="application/octet-stream"),
Content.from_uri(uri="https://example.com/image.png", media_type="image/png"),
Content.from_text(text='{"structured": "data"}'),
],
)

Expand Down Expand Up @@ -445,7 +441,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None:

assert len(contents) == 1

assert isinstance(contents[0], TextContent)
assert contents[0].type == "text"
assert contents[0].text == '{"key": "value", "number": 42}'
assert contents[0].additional_properties == {"source": "test"}

Expand All @@ -470,7 +466,7 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None:
# Create message with hosted file content
message = ChatMessage(
role=Role.USER,
contents=[HostedFileContent(file_id="hosted://storage/document.pdf")],
contents=[Content.from_hosted_file(file_id="hosted://storage/document.pdf")],
)

result = agent._prepare_message_for_a2a(message) # noqa: SLF001
Expand Down Expand Up @@ -507,7 +503,7 @@ def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:

assert len(contents) == 1

assert isinstance(contents[0], UriContent)
assert contents[0].type == "uri"
assert contents[0].uri == "hosted://storage/document.pdf"
assert contents[0].media_type == "" # Converted None to empty string

Expand Down
61 changes: 21 additions & 40 deletions python/packages/ag-ui/agent_framework_ag_ui/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
ChatMessage,
ChatResponse,
ChatResponseUpdate,
DataContent,
FunctionCallContent,
Content,
use_chat_middleware,
use_function_invocation,
)
from agent_framework._middleware import use_chat_middleware
from agent_framework._tools import use_function_invocation
from agent_framework._types import BaseContent, Contents
from agent_framework.observability import use_instrumentation

from ._event_converters import AGUIEventConverter
Expand Down Expand Up @@ -53,26 +51,11 @@
logger: logging.Logger = logging.getLogger(__name__)


class ServerFunctionCallContent(BaseContent):
"""Wrapper for server function calls to prevent client re-execution.

All function calls from the remote server are server-side executions.
This wrapper prevents @use_function_invocation from trying to execute them again.
"""

function_call_content: FunctionCallContent

def __init__(self, function_call_content: FunctionCallContent) -> None:
"""Initialize with the function call content."""
super().__init__(type="server_function_call")
self.function_call_content = function_call_content


def _unwrap_server_function_call_contents(contents: MutableSequence[Contents | dict[str, Any]]) -> None:
"""Replace ServerFunctionCallContent instances with their underlying call content."""
def _unwrap_server_function_call_contents(contents: MutableSequence[Content | dict[str, Any]]) -> None:
"""Replace server_function_call instances with their underlying call content."""
for idx, content in enumerate(contents):
if isinstance(content, ServerFunctionCallContent):
contents[idx] = content.function_call_content # type: ignore[assignment]
if content.type == "server_function_call": # type: ignore[union-attr]
contents[idx] = content.function_call # type: ignore[assignment, union-attr]


TBaseChatClient = TypeVar("TBaseChatClient", bound=type[BaseChatClient[Any]])
Expand All @@ -93,7 +76,7 @@ def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseCha
@wraps(original_get_streaming_response)
async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]:
async for update in original_get_streaming_response(self, *args, **kwargs):
_unwrap_server_function_call_contents(cast(MutableSequence[Contents | dict[str, Any]], update.contents))
_unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], update.contents))
yield update

chat_client.get_streaming_response = streaming_wrapper # type: ignore[assignment]
Expand All @@ -105,9 +88,7 @@ async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse:
response = await original_get_response(self, *args, **kwargs)
if response.messages:
for message in response.messages:
_unwrap_server_function_call_contents(
cast(MutableSequence[Contents | dict[str, Any]], message.contents)
)
_unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents))
return response

chat_client.get_response = response_wrapper # type: ignore[assignment]
Expand Down Expand Up @@ -289,13 +270,13 @@ def _extract_state_from_messages(
last_message = messages[-1]

for content in last_message.contents:
if isinstance(content, DataContent) and content.media_type == "application/json":
if isinstance(content, Content) and content.type == "data" and content.media_type == "application/json":
try:
uri = content.uri
if uri.startswith("data:application/json;base64,"):
if uri.startswith("data:application/json;base64,"): # type: ignore[union-attr]
import base64

encoded_data = uri.split(",", 1)[1]
encoded_data = uri.split(",", 1)[1] # type: ignore[union-attr]
decoded_bytes = base64.b64decode(encoded_data)
state = json.loads(decoded_bytes.decode("utf-8"))

Expand Down Expand Up @@ -433,19 +414,19 @@ async def _inner_get_streaming_response(
)
# Distinguish client vs server tools
for i, content in enumerate(update.contents):
if isinstance(content, FunctionCallContent):
if content.type == "function_call": # type: ignore[attr-defined]
logger.debug(
f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}"
f"[AGUIChatClient] Function call: {content.name}, in client_tool_set: {content.name in client_tool_set}" # type: ignore[attr-defined]
)
if content.name in client_tool_set:
if content.name in client_tool_set: # type: ignore[attr-defined]
# Client tool - let @use_function_invocation execute it
if not content.additional_properties:
content.additional_properties = {}
content.additional_properties["agui_thread_id"] = thread_id
if not content.additional_properties: # type: ignore[attr-defined]
content.additional_properties = {} # type: ignore[attr-defined]
content.additional_properties["agui_thread_id"] = thread_id # type: ignore[attr-defined]
else:
# Server tool - wrap so @use_function_invocation ignores it
logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}")
self._register_server_tool_placeholder(content.name)
update.contents[i] = ServerFunctionCallContent(content) # type: ignore
logger.debug(f"[AGUIChatClient] Wrapping server tool: {content.name}") # type: ignore[union-attr]
self._register_server_tool_placeholder(content.name) # type: ignore[arg-type]
update.contents[i] = Content(type="server_function_call", function_call=content) # type: ignore

yield update
15 changes: 6 additions & 9 deletions python/packages/ag-ui/agent_framework_ag_ui/_event_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@

from agent_framework import (
ChatResponseUpdate,
ErrorContent,
Content,
FinishReason,
FunctionCallContent,
FunctionResultContent,
Role,
TextContent,
)


Expand Down Expand Up @@ -117,7 +114,7 @@ def _handle_text_message_content(self, event: dict[str, Any]) -> ChatResponseUpd
return ChatResponseUpdate(
role=Role.ASSISTANT,
message_id=self.current_message_id,
contents=[TextContent(text=delta)],
contents=[Content.from_text(text=delta)],
)

def _handle_text_message_end(self, event: dict[str, Any]) -> ChatResponseUpdate | None:
Expand All @@ -133,7 +130,7 @@ def _handle_tool_call_start(self, event: dict[str, Any]) -> ChatResponseUpdate:
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
Content.from_function_call(
call_id=self.current_tool_call_id or "",
name=self.current_tool_name or "",
arguments="",
Expand All @@ -149,7 +146,7 @@ def _handle_tool_call_args(self, event: dict[str, Any]) -> ChatResponseUpdate:
return ChatResponseUpdate(
role=Role.ASSISTANT,
contents=[
FunctionCallContent(
Content.from_function_call(
call_id=self.current_tool_call_id or "",
name=self.current_tool_name or "",
arguments=delta,
Expand All @@ -170,7 +167,7 @@ def _handle_tool_call_result(self, event: dict[str, Any]) -> ChatResponseUpdate:
return ChatResponseUpdate(
role=Role.TOOL,
contents=[
FunctionResultContent(
Content.from_function_result(
call_id=tool_call_id,
result=result,
)
Expand All @@ -197,7 +194,7 @@ def _handle_run_error(self, event: dict[str, Any]) -> ChatResponseUpdate:
role=Role.ASSISTANT,
finish_reason=FinishReason.CONTENT_FILTER,
contents=[
ErrorContent(
Content.from_error(
message=error_message,
error_code="RUN_ERROR",
)
Expand Down
Loading
Loading