diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index a2f9a4eea8..b0ff5001fa 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -309,7 +309,7 @@ def _clone_chat_agent(self, agent: ChatAgent) -> ChatAgent: name=agent.name, description=agent.description, chat_message_store_factory=agent.chat_message_store_factory, - context_providers=agent.context_provider, + context_provider=agent.context_provider, middleware=middleware, default_options=cloned_options, # type: ignore[arg-type] ) diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index ab9f6e45cb..67a65007f9 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -11,6 +11,8 @@ ChatResponse, ChatResponseUpdate, Content, + Context, + ContextProvider, ResponseStream, WorkflowEvent, resolve_agent_id, @@ -303,6 +305,48 @@ async def mock_get_response(messages: Any, options: dict[str, Any] | None = None assert last_tool_choice == {"mode": "required"}, f"Expected 'required', got {last_tool_choice}" +async def test_context_provider_preserved_during_handoff(): + """Verify that context_provider is preserved when cloning agents in handoff workflows.""" + # Track whether context provider methods were called + provider_calls: list[str] = [] + + class TestContextProvider(ContextProvider): + """A test context provider that tracks its invocations.""" + + async def invoking(self, messages: Sequence[ChatMessage], **kwargs: Any) -> Context: + provider_calls.append("invoking") + return Context(instructions="Test context from provider.") + + # Create context provider + context_provider = TestContextProvider() + + # Create a mock chat client + mock_client = MockChatClient(name="test_agent") + + # Create agent with context provider using proper constructor + agent = ChatAgent( + chat_client=mock_client, + name="test_agent", + id="test_agent", + context_provider=context_provider, + ) + + # Verify the original agent has the context provider + assert agent.context_provider is context_provider, "Original agent should have context provider" + + # Build handoff workflow - this should clone the agent and preserve context_provider + workflow = HandoffBuilder(participants=[agent]).with_start_agent(agent).build() + + # Run workflow with a simple message to trigger context provider + await _drain(workflow.run("Test message", stream=True)) + + # Verify context provider was invoked during the workflow execution + assert len(provider_calls) > 0, ( + "Context provider should be called during workflow execution, " + "indicating it was properly preserved during agent cloning" + ) + + # region Participant Factory Tests