1919from google .adk .agents .base_agent import BaseAgent
2020from google .adk .agents .invocation_context import InvocationContext
2121from google .adk .agents .loop_agent import LoopAgent
22+ from google .adk .agents .loop_agent import LoopAgentState
23+ from google .adk .apps import ResumabilityConfig
2224from google .adk .events .event import Event
2325from google .adk .events .event_actions import EventActions
2426from google .adk .sessions .in_memory_session_service import InMemorySessionService
2527from google .genai import types
2628import pytest
2729from typing_extensions import override
2830
31+ from .. import testing_utils
32+
33+ END_OF_AGENT = testing_utils .END_OF_AGENT
34+
2935
3036class _TestingAgent (BaseAgent ):
3137
@@ -72,13 +78,13 @@ async def _run_async_impl(
7278 author = self .name ,
7379 invocation_id = ctx .invocation_id ,
7480 content = types .Content (
75- parts = [types .Part (text = f 'I have done my job after escalation!!' )]
81+ parts = [types .Part (text = 'I have done my job after escalation!!' )]
7682 ),
7783 )
7884
7985
8086async def _create_parent_invocation_context (
81- test_name : str , agent : BaseAgent
87+ test_name : str , agent : BaseAgent , resumable : bool = False
8288) -> InvocationContext :
8389 session_service = InMemorySessionService ()
8490 session = await session_service .create_session (
@@ -89,11 +95,13 @@ async def _create_parent_invocation_context(
8995 agent = agent ,
9096 session = session ,
9197 session_service = session_service ,
98+ resumability_config = ResumabilityConfig (is_resumable = resumable ),
9299 )
93100
94101
95102@pytest .mark .asyncio
96- async def test_run_async (request : pytest .FixtureRequest ):
103+ @pytest .mark .parametrize ('resumable' , [True , False ])
104+ async def test_run_async (request : pytest .FixtureRequest , resumable : bool ):
97105 agent = _TestingAgent (name = f'{ request .function .__name__ } _test_agent' )
98106 loop_agent = LoopAgent (
99107 name = f'{ request .function .__name__ } _test_loop_agent' ,
@@ -103,15 +111,60 @@ async def test_run_async(request: pytest.FixtureRequest):
103111 ],
104112 )
105113 parent_ctx = await _create_parent_invocation_context (
106- request .function .__name__ , loop_agent
114+ request .function .__name__ , loop_agent , resumable = resumable
115+ )
116+ events = [e async for e in loop_agent .run_async (parent_ctx )]
117+
118+ simplified_events = testing_utils .simplify_resumable_app_events (events )
119+ if resumable :
120+ expected_events = [
121+ (
122+ loop_agent .name ,
123+ {'current_sub_agent' : agent .name , 'times_looped' : 0 },
124+ ),
125+ (agent .name , f'Hello, async { agent .name } !' ),
126+ (
127+ loop_agent .name ,
128+ {'current_sub_agent' : agent .name , 'times_looped' : 1 },
129+ ),
130+ (agent .name , f'Hello, async { agent .name } !' ),
131+ (loop_agent .name , END_OF_AGENT ),
132+ ]
133+ else :
134+ expected_events = [
135+ (agent .name , f'Hello, async { agent .name } !' ),
136+ (agent .name , f'Hello, async { agent .name } !' ),
137+ ]
138+ assert simplified_events == expected_events
139+
140+
141+ @pytest .mark .asyncio
142+ async def test_resume_async (request : pytest .FixtureRequest ):
143+ agent_1 = _TestingAgent (name = f'{ request .function .__name__ } _test_agent_1' )
144+ agent_2 = _TestingAgent (name = f'{ request .function .__name__ } _test_agent_2' )
145+ loop_agent = LoopAgent (
146+ name = f'{ request .function .__name__ } _test_loop_agent' ,
147+ max_iterations = 2 ,
148+ sub_agents = [
149+ agent_1 ,
150+ agent_2 ,
151+ ],
107152 )
153+ parent_ctx = await _create_parent_invocation_context (
154+ request .function .__name__ , loop_agent , resumable = True
155+ )
156+ parent_ctx .agent_states [loop_agent .name ] = LoopAgentState (
157+ current_sub_agent = agent_2 .name , times_looped = 1
158+ ).model_dump (mode = 'json' )
159+
108160 events = [e async for e in loop_agent .run_async (parent_ctx )]
109161
110- assert len (events ) == 2
111- assert events [0 ].author == agent .name
112- assert events [1 ].author == agent .name
113- assert events [0 ].content .parts [0 ].text == f'Hello, async { agent .name } !'
114- assert events [1 ].content .parts [0 ].text == f'Hello, async { agent .name } !'
162+ simplified_events = testing_utils .simplify_resumable_app_events (events )
163+ expected_events = [
164+ (agent_2 .name , f'Hello, async { agent_2 .name } !' ),
165+ (loop_agent .name , END_OF_AGENT ),
166+ ]
167+ assert simplified_events == expected_events
115168
116169
117170@pytest .mark .asyncio
@@ -129,7 +182,10 @@ async def test_run_async_skip_if_no_sub_agent(request: pytest.FixtureRequest):
129182
130183
131184@pytest .mark .asyncio
132- async def test_run_async_with_escalate_action (request : pytest .FixtureRequest ):
185+ @pytest .mark .parametrize ('resumable' , [True , False ])
186+ async def test_run_async_with_escalate_action (
187+ request : pytest .FixtureRequest , resumable : bool
188+ ):
133189 non_escalating_agent = _TestingAgent (
134190 name = f'{ request .function .__name__ } _test_non_escalating_agent'
135191 )
@@ -144,20 +200,52 @@ async def test_run_async_with_escalate_action(request: pytest.FixtureRequest):
144200 sub_agents = [non_escalating_agent , escalating_agent , ignored_agent ],
145201 )
146202 parent_ctx = await _create_parent_invocation_context (
147- request .function .__name__ , loop_agent
203+ request .function .__name__ , loop_agent , resumable = resumable
148204 )
149205 events = [e async for e in loop_agent .run_async (parent_ctx )]
150206
151- # Only two events are generated because the sub escalating_agent escalates.
152- assert len (events ) == 3
153- assert events [0 ].author == non_escalating_agent .name
154- assert events [1 ].author == escalating_agent .name
155- assert events [0 ].content .parts [0 ].text == (
156- f'Hello, async { non_escalating_agent .name } !'
157- )
158- assert events [1 ].content .parts [0 ].text == (
159- f'Hello, async { escalating_agent .name } !'
160- )
161- assert (
162- events [2 ].content .parts [0 ].text == 'I have done my job after escalation!!'
163- )
207+ simplified_events = testing_utils .simplify_resumable_app_events (events )
208+
209+ if resumable :
210+ expected_events = [
211+ (
212+ loop_agent .name ,
213+ {
214+ 'current_sub_agent' : non_escalating_agent .name ,
215+ 'times_looped' : 0 ,
216+ },
217+ ),
218+ (
219+ non_escalating_agent .name ,
220+ f'Hello, async { non_escalating_agent .name } !' ,
221+ ),
222+ (
223+ loop_agent .name ,
224+ {'current_sub_agent' : escalating_agent .name , 'times_looped' : 0 },
225+ ),
226+ (
227+ escalating_agent .name ,
228+ f'Hello, async { escalating_agent .name } !' ,
229+ ),
230+ (
231+ escalating_agent .name ,
232+ 'I have done my job after escalation!!' ,
233+ ),
234+ (loop_agent .name , END_OF_AGENT ),
235+ ]
236+ else :
237+ expected_events = [
238+ (
239+ non_escalating_agent .name ,
240+ f'Hello, async { non_escalating_agent .name } !' ,
241+ ),
242+ (
243+ escalating_agent .name ,
244+ f'Hello, async { escalating_agent .name } !' ,
245+ ),
246+ (
247+ escalating_agent .name ,
248+ 'I have done my job after escalation!!' ,
249+ ),
250+ ]
251+ assert simplified_events == expected_events
0 commit comments