Skip to content

Commit bddc70b

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Better handling the A2A streaming tasks so calling Agent can tell whether it's in progress updates (thought) or the final response
PiperOrigin-RevId: 817682171
1 parent 85ed500 commit bddc70b

2 files changed

Lines changed: 212 additions & 5 deletions

File tree

src/google/adk/agents/remote_a2a_agent.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from a2a.types import Part as A2APart
3838
from a2a.types import Role
3939
from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent
40+
from a2a.types import TaskState
4041
from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent
4142
from a2a.types import TransportProtocol as A2ATransport
4243
except ImportError as e:
@@ -414,14 +415,25 @@ async def _handle_a2a_response(
414415
# response for a non-streaming task, which is the full task state.
415416
# We process this to get the initial message.
416417
event = convert_a2a_task_to_event(task, self.name, ctx)
418+
# for streaming task, we update the event with the task status.
419+
# We update the event as Thought updates.
420+
if task and task.status and task.status.state == TaskState.submitted:
421+
event.content.parts[0].thought = True
417422
elif (
418423
isinstance(update, A2ATaskStatusUpdateEvent)
424+
and update.status
419425
and update.status.message
420426
):
421427
# This is a streaming task status update with a message.
422428
event = convert_a2a_message_to_event(
423429
update.status.message, self.name, ctx
424430
)
431+
if event.content and update.status.state in [
432+
TaskState.submitted,
433+
TaskState.working,
434+
]:
435+
for part in event.content.parts:
436+
part.thought = True
425437
elif isinstance(update, A2ATaskArtifactUpdateEvent) and (
426438
not update.append or update.last_chunk
427439
):

tests/unittests/agents/test_remote_a2a_agent.py

Lines changed: 200 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from google.adk.events.event import Event
2525
from google.adk.sessions.session import Session
26+
from google.genai import types as genai_types
2627
import httpx
2728
import pytest
2829

@@ -41,12 +42,14 @@
4142
from a2a.types import AgentSkill
4243
from a2a.types import Artifact
4344
from a2a.types import Message as A2AMessage
45+
from a2a.types import Part as A2ATaskStatus
4446
from a2a.types import SendMessageSuccessResponse
4547
from a2a.types import Task as A2ATask
4648
from a2a.types import TaskArtifactUpdateEvent
4749
from a2a.types import TaskState
4850
from a2a.types import TaskStatus
4951
from a2a.types import TaskStatusUpdateEvent
52+
from a2a.types import TextPart
5053
from google.adk.agents.invocation_context import InvocationContext
5154
from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX
5255
from google.adk.agents.remote_a2a_agent import AgentCardResolutionError
@@ -693,17 +696,21 @@ async def test_handle_a2a_response_success_with_message(self):
693696
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
694697

695698
@pytest.mark.asyncio
696-
async def test_handle_a2a_response_with_task_and_no_update(self):
697-
"""Test successful A2A response handling with task and no update."""
699+
async def test_handle_a2a_response_with_task_completed_and_no_update(self):
700+
"""Test successful A2A response handling with non-streeaming task and no update."""
698701
mock_a2a_task = Mock(spec=A2ATask)
699702
mock_a2a_task.id = "task-123"
700703
mock_a2a_task.context_id = "context-123"
704+
mock_a2a_task.status = Mock(spec=A2ATaskStatus)
705+
mock_a2a_task.status.state = TaskState.completed
701706

702707
# Create a proper Event mock that can handle custom_metadata
708+
mock_a2a_part = Mock(spec=TextPart)
703709
mock_event = Event(
704710
author=self.agent.name,
705711
invocation_id=self.mock_context.invocation_id,
706712
branch=self.mock_context.branch,
713+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
707714
)
708715

709716
with patch(
@@ -721,6 +728,49 @@ async def test_handle_a2a_response_with_task_and_no_update(self):
721728
self.agent.name,
722729
self.mock_context,
723730
)
731+
# Check the parts are not updated as Thought
732+
assert result.content.parts[0].thought is None
733+
# Check that metadata was added
734+
assert result.custom_metadata is not None
735+
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
736+
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
737+
738+
@pytest.mark.asyncio
739+
async def test_handle_a2a_response_with_task_submitted_and_no_update(self):
740+
"""Test successful A2A response handling with streaming task and no update."""
741+
mock_a2a_task = Mock(spec=A2ATask)
742+
mock_a2a_task.id = "task-123"
743+
mock_a2a_task.context_id = "context-123"
744+
mock_a2a_task.status = Mock(spec=A2ATaskStatus)
745+
mock_a2a_task.status.state = TaskState.submitted
746+
747+
# Create a proper Event mock that can handle custom_metadata
748+
mock_a2a_part = Mock(spec=TextPart)
749+
mock_event = Event(
750+
author=self.agent.name,
751+
invocation_id=self.mock_context.invocation_id,
752+
branch=self.mock_context.branch,
753+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
754+
)
755+
756+
with patch(
757+
"google.adk.agents.remote_a2a_agent.convert_a2a_task_to_event"
758+
) as mock_convert:
759+
mock_convert.return_value = mock_event
760+
761+
result = await self.agent._handle_a2a_response(
762+
(mock_a2a_task, None), self.mock_context
763+
)
764+
765+
assert result == mock_event
766+
mock_convert.assert_called_once_with(
767+
mock_a2a_task,
768+
self.agent.name,
769+
self.mock_context,
770+
)
771+
# Check the parts are updated as Thought
772+
assert result.content.parts[0].thought is True
773+
assert result.content.parts[0].thought_signature is None
724774
# Check that metadata was added
725775
assert result.custom_metadata is not None
726776
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
@@ -740,10 +790,57 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self):
740790
mock_update.status.message = mock_a2a_message
741791

742792
# Create a proper Event mock that can handle custom_metadata
793+
mock_a2a_part = Mock(spec=TextPart)
794+
mock_event = Event(
795+
author=self.agent.name,
796+
invocation_id=self.mock_context.invocation_id,
797+
branch=self.mock_context.branch,
798+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
799+
)
800+
801+
with patch(
802+
"google.adk.agents.remote_a2a_agent.convert_a2a_message_to_event"
803+
) as mock_convert:
804+
mock_convert.return_value = mock_event
805+
806+
result = await self.agent._handle_a2a_response(
807+
(mock_a2a_task, mock_update), self.mock_context
808+
)
809+
810+
assert result == mock_event
811+
mock_convert.assert_called_once_with(
812+
mock_a2a_message,
813+
self.agent.name,
814+
self.mock_context,
815+
)
816+
# Check that metadata was added
817+
assert result.custom_metadata is not None
818+
assert result.content.parts[0].thought is None
819+
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
820+
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
821+
822+
@pytest.mark.asyncio
823+
async def test_handle_a2a_response_with_task_status_working_update_with_message(
824+
self,
825+
):
826+
"""Test handling of a task status update with a message."""
827+
mock_a2a_task = Mock(spec=A2ATask)
828+
mock_a2a_task.id = "task-123"
829+
mock_a2a_task.context_id = "context-123"
830+
831+
mock_a2a_message = Mock(spec=A2AMessage)
832+
mock_update = Mock(spec=TaskStatusUpdateEvent)
833+
mock_update.status = Mock(TaskStatus)
834+
mock_update.status.state = TaskState.working
835+
mock_update.status.message = mock_a2a_message
836+
837+
# Create a proper Event mock that can handle custom_metadata
838+
mock_a2a_part = Mock(spec=TextPart)
743839
mock_event = Event(
744840
author=self.agent.name,
745841
invocation_id=self.mock_context.invocation_id,
746842
branch=self.mock_context.branch,
843+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
747844
)
748845

749846
with patch(
@@ -763,6 +860,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self):
763860
)
764861
# Check that metadata was added
765862
assert result.custom_metadata is not None
863+
assert result.content.parts[0].thought is True
766864
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
767865
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
768866

@@ -985,17 +1083,21 @@ async def test_handle_a2a_response_success_with_message(self):
9851083
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
9861084

9871085
@pytest.mark.asyncio
988-
async def test_handle_a2a_response_with_task_and_no_update(self):
989-
"""Test successful A2A response handling with task and no update."""
1086+
async def test_handle_a2a_response_with_task_completed_and_no_update(self):
1087+
"""Test successful A2A response handling with non-streeaming task and no update."""
9901088
mock_a2a_task = Mock(spec=A2ATask)
9911089
mock_a2a_task.id = "task-123"
9921090
mock_a2a_task.context_id = "context-123"
1091+
mock_a2a_task.status = Mock(spec=A2ATaskStatus)
1092+
mock_a2a_task.status.state = TaskState.completed
9931093

9941094
# Create a proper Event mock that can handle custom_metadata
1095+
mock_a2a_part = Mock(spec=TextPart)
9951096
mock_event = Event(
9961097
author=self.agent.name,
9971098
invocation_id=self.mock_context.invocation_id,
9981099
branch=self.mock_context.branch,
1100+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
9991101
)
10001102

10011103
with patch(
@@ -1009,8 +1111,53 @@ async def test_handle_a2a_response_with_task_and_no_update(self):
10091111

10101112
assert result == mock_event
10111113
mock_convert.assert_called_once_with(
1012-
mock_a2a_task, self.agent.name, self.mock_context
1114+
mock_a2a_task,
1115+
self.agent.name,
1116+
self.mock_context,
10131117
)
1118+
# Check the parts are not updated as Thought
1119+
assert result.content.parts[0].thought is None
1120+
# Check that metadata was added
1121+
assert result.custom_metadata is not None
1122+
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
1123+
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
1124+
1125+
@pytest.mark.asyncio
1126+
async def test_handle_a2a_response_with_task_submitted_and_no_update(self):
1127+
"""Test successful A2A response handling with streaming task and no update."""
1128+
mock_a2a_task = Mock(spec=A2ATask)
1129+
mock_a2a_task.id = "task-123"
1130+
mock_a2a_task.context_id = "context-123"
1131+
mock_a2a_task.status = Mock(spec=A2ATaskStatus)
1132+
mock_a2a_task.status.state = TaskState.submitted
1133+
1134+
# Create a proper Event mock that can handle custom_metadata
1135+
mock_a2a_part = Mock(spec=TextPart)
1136+
mock_event = Event(
1137+
author=self.agent.name,
1138+
invocation_id=self.mock_context.invocation_id,
1139+
branch=self.mock_context.branch,
1140+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
1141+
)
1142+
1143+
with patch(
1144+
"google.adk.agents.remote_a2a_agent.convert_a2a_task_to_event"
1145+
) as mock_convert:
1146+
mock_convert.return_value = mock_event
1147+
1148+
result = await self.agent._handle_a2a_response(
1149+
(mock_a2a_task, None), self.mock_context
1150+
)
1151+
1152+
assert result == mock_event
1153+
mock_convert.assert_called_once_with(
1154+
mock_a2a_task,
1155+
self.agent.name,
1156+
self.mock_context,
1157+
)
1158+
# Check the parts are updated as Thought
1159+
assert result.content.parts[0].thought is True
1160+
assert result.content.parts[0].thought_signature is None
10141161
# Check that metadata was added
10151162
assert result.custom_metadata is not None
10161163
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
@@ -1030,10 +1177,57 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self):
10301177
mock_update.status.message = mock_a2a_message
10311178

10321179
# Create a proper Event mock that can handle custom_metadata
1180+
mock_a2a_part = Mock(spec=TextPart)
1181+
mock_event = Event(
1182+
author=self.agent.name,
1183+
invocation_id=self.mock_context.invocation_id,
1184+
branch=self.mock_context.branch,
1185+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
1186+
)
1187+
1188+
with patch(
1189+
"google.adk.agents.remote_a2a_agent.convert_a2a_message_to_event"
1190+
) as mock_convert:
1191+
mock_convert.return_value = mock_event
1192+
1193+
result = await self.agent._handle_a2a_response(
1194+
(mock_a2a_task, mock_update), self.mock_context
1195+
)
1196+
1197+
assert result == mock_event
1198+
mock_convert.assert_called_once_with(
1199+
mock_a2a_message,
1200+
self.agent.name,
1201+
self.mock_context,
1202+
)
1203+
# Check that metadata was added
1204+
assert result.custom_metadata is not None
1205+
assert result.content.parts[0].thought is None
1206+
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
1207+
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
1208+
1209+
@pytest.mark.asyncio
1210+
async def test_handle_a2a_response_with_task_status_working_update_with_message(
1211+
self,
1212+
):
1213+
"""Test handling of a task status update with a message."""
1214+
mock_a2a_task = Mock(spec=A2ATask)
1215+
mock_a2a_task.id = "task-123"
1216+
mock_a2a_task.context_id = "context-123"
1217+
1218+
mock_a2a_message = Mock(spec=A2AMessage)
1219+
mock_update = Mock(spec=TaskStatusUpdateEvent)
1220+
mock_update.status = Mock(TaskStatus)
1221+
mock_update.status.state = TaskState.working
1222+
mock_update.status.message = mock_a2a_message
1223+
1224+
# Create a proper Event mock that can handle custom_metadata
1225+
mock_a2a_part = Mock(spec=TextPart)
10331226
mock_event = Event(
10341227
author=self.agent.name,
10351228
invocation_id=self.mock_context.invocation_id,
10361229
branch=self.mock_context.branch,
1230+
content=genai_types.Content(role="model", parts=[mock_a2a_part]),
10371231
)
10381232

10391233
with patch(
@@ -1053,6 +1247,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self):
10531247
)
10541248
# Check that metadata was added
10551249
assert result.custom_metadata is not None
1250+
assert result.content.parts[0].thought is True
10561251
assert A2A_METADATA_PREFIX + "task_id" in result.custom_metadata
10571252
assert A2A_METADATA_PREFIX + "context_id" in result.custom_metadata
10581253

0 commit comments

Comments
 (0)