@@ -306,30 +306,12 @@ async def _run_with_trace(
306306 if not session :
307307 raise ValueError (f'Session not found: { session_id } ' )
308308
309- invocation_context = self ._new_invocation_context (
310- session ,
309+ invocation_context = await self ._setup_context_for_new_invocation (
310+ session = session ,
311311 new_message = new_message ,
312312 run_config = run_config ,
313+ state_delta = state_delta ,
313314 )
314- root_agent = self .agent
315-
316- # Modify user message before execution.
317- modified_user_message = await invocation_context .plugin_manager .run_on_user_message_callback (
318- invocation_context = invocation_context , user_message = new_message
319- )
320- if modified_user_message is not None :
321- new_message = modified_user_message
322-
323- if new_message :
324- await self ._append_new_message_to_session (
325- session ,
326- new_message ,
327- invocation_context ,
328- run_config .save_input_blobs_as_artifacts ,
329- state_delta ,
330- )
331-
332- invocation_context .agent = self ._find_agent_to_run (session , root_agent )
333315
334316 async def execute (ctx : InvocationContext ) -> AsyncGenerator [Event ]:
335317 async with Aclosing (ctx .agent .run_async (ctx )) as agen :
@@ -420,6 +402,7 @@ async def _exec_with_plugin(
420402
421403 async def _append_new_message_to_session (
422404 self ,
405+ * ,
423406 session : Session ,
424407 new_message : types .Content ,
425408 invocation_context : InvocationContext ,
@@ -433,6 +416,7 @@ async def _append_new_message_to_session(
433416 new_message: The new message to append.
434417 invocation_context: The invocation context for the message.
435418 save_input_blobs_as_artifacts: Whether to save input blobs as artifacts.
419+ state_delta: Optional state changes to apply to the session.
436420 """
437421 if not new_message .parts :
438422 raise ValueError ('No parts in the new_message.' )
@@ -661,6 +645,44 @@ def _is_transferable_across_agent_tree(self, agent_to_run: BaseAgent) -> bool:
661645 agent = agent .parent_agent
662646 return True
663647
648+ async def _setup_context_for_new_invocation (
649+ self ,
650+ * ,
651+ session : Session ,
652+ new_message : types .Content ,
653+ run_config : RunConfig ,
654+ state_delta : Optional [dict [str , Any ]],
655+ ) -> InvocationContext :
656+ """Sets up the context for a new invocation.
657+
658+ Args:
659+ session: The session to setup the invocation context for.
660+ new_message: The new message to process and append to the session.
661+ run_config: The run config of the agent.
662+ state_delta: Optional state changes to apply to the session.
663+
664+ Returns:
665+ The invocation context for the new invocation.
666+ """
667+ # Step 1: Create invocation context in memory.
668+ invocation_context = self ._new_invocation_context (
669+ session ,
670+ new_message = new_message ,
671+ run_config = run_config ,
672+ )
673+ # Step 2: Handle new message, by running callbacks and appending to
674+ # session.
675+ await self ._handle_new_message (
676+ session = session ,
677+ new_message = new_message ,
678+ invocation_context = invocation_context ,
679+ run_config = run_config ,
680+ state_delta = state_delta ,
681+ )
682+ # Step 3: Set agent to run for the invocation.
683+ invocation_context .agent = self ._find_agent_to_run (session , self .agent )
684+ return invocation_context
685+
664686 def _new_invocation_context (
665687 self ,
666688 session : Session ,
@@ -743,6 +765,42 @@ def _new_invocation_context_for_live(
743765 run_config = run_config ,
744766 )
745767
768+ async def _handle_new_message (
769+ self ,
770+ * ,
771+ session : Session ,
772+ new_message : types .Content ,
773+ invocation_context : InvocationContext ,
774+ run_config : RunConfig ,
775+ state_delta : Optional [dict [str , Any ]],
776+ ) -> None :
777+ """Handles a new message by running callbacks and appending to session.
778+
779+ Args:
780+ session: The session of the new message.
781+ new_message: The new message to process and append to the session.
782+ invocation_context: The invocation context to use for the message
783+ handling.
784+ run_config: The run config of the agent.
785+ state_delta: Optional state changes to apply to the session.
786+ """
787+ modified_user_message = (
788+ await invocation_context .plugin_manager .run_on_user_message_callback (
789+ invocation_context = invocation_context , user_message = new_message
790+ )
791+ )
792+ if modified_user_message is not None :
793+ new_message = modified_user_message
794+
795+ if new_message :
796+ await self ._append_new_message_to_session (
797+ session = session ,
798+ new_message = new_message ,
799+ invocation_context = invocation_context ,
800+ save_input_blobs_as_artifacts = run_config .save_input_blobs_as_artifacts ,
801+ state_delta = state_delta ,
802+ )
803+
746804 def _collect_toolset (self , agent : BaseAgent ) -> set [BaseToolset ]:
747805 toolsets = set ()
748806 if isinstance (agent , LlmAgent ):
0 commit comments