diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 7ded0af8d5..b000d4d41d 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -337,6 +337,7 @@ async def _notify_thread_of_new_messages( thread: AgentThread, input_messages: ChatMessage | Sequence[ChatMessage], response_messages: ChatMessage | Sequence[ChatMessage], + **kwargs: Any, ) -> None: """Notify the thread of new messages. @@ -346,13 +347,14 @@ async def _notify_thread_of_new_messages( thread: The thread to notify of new messages. input_messages: The input messages to notify about. response_messages: The response messages to notify about. + **kwargs: Any extra arguments to pass from the agent run. """ if isinstance(input_messages, ChatMessage) or len(input_messages) > 0: await thread.on_new_messages(input_messages) if isinstance(response_messages, ChatMessage) or len(response_messages) > 0: await thread.on_new_messages(response_messages) if thread.context_provider: - await thread.context_provider.invoked(input_messages, response_messages) + await thread.context_provider.invoked(input_messages, response_messages, **kwargs) @property def display_name(self) -> str: @@ -969,7 +971,7 @@ async def run_stream( """ input_messages = self._normalize_messages(messages) thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages( - thread=thread, input_messages=input_messages + thread=thread, input_messages=input_messages, **kwargs ) agent_name = self._get_agent_name() # Resolve final tool list (runtime provided tools + local MCP server tools) @@ -1039,7 +1041,7 @@ async def run_stream( response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) - await self._notify_thread_of_new_messages(thread, input_messages, response.messages) + await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs) @override def get_new_thread( @@ -1234,6 +1236,7 @@ async def _prepare_thread_and_messages( *, thread: AgentThread | None, input_messages: list[ChatMessage] | None = None, + **kwargs: Any, ) -> tuple[AgentThread, ChatOptions, list[ChatMessage]]: """Prepare the thread and messages for agent execution. @@ -1243,6 +1246,7 @@ async def _prepare_thread_and_messages( Keyword Args: thread: The conversation thread. input_messages: Messages to process. + **kwargs: Any extra arguments to pass from the agent run. Returns: A tuple containing: @@ -1263,7 +1267,7 @@ async def _prepare_thread_and_messages( context: Context | None = None if self.context_provider: async with self.context_provider: - context = await self.context_provider.invoking(input_messages or []) + context = await self.context_provider.invoking(input_messages or [], **kwargs) if context: if context.messages: thread_messages.extend(context.messages)