Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _clone_chat_agent(self, agent: ChatAgent) -> ChatAgent:
name=agent.name,
description=agent.description,
chat_message_store_factory=agent.chat_message_store_factory,
context_providers=agent.context_provider,
context_provider=agent.context_provider,
middleware=middleware,
default_options=cloned_options, # type: ignore[arg-type]
)
Expand Down
44 changes: 44 additions & 0 deletions python/packages/orchestrations/tests/test_handoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
ChatResponse,
ChatResponseUpdate,
Content,
Context,
ContextProvider,
ResponseStream,
WorkflowEvent,
resolve_agent_id,
Expand Down Expand Up @@ -303,6 +305,48 @@ async def mock_get_response(messages: Any, options: dict[str, Any] | None = None
assert last_tool_choice == {"mode": "required"}, f"Expected 'required', got {last_tool_choice}"


async def test_context_provider_preserved_during_handoff():
"""Verify that context_provider is preserved when cloning agents in handoff workflows."""
# Track whether context provider methods were called
provider_calls: list[str] = []

class TestContextProvider(ContextProvider):
"""A test context provider that tracks its invocations."""

async def invoking(self, messages: Sequence[ChatMessage], **kwargs: Any) -> Context:
provider_calls.append("invoking")
return Context(instructions="Test context from provider.")

# Create context provider
context_provider = TestContextProvider()

# Create a mock chat client
mock_client = MockChatClient(name="test_agent")

# Create agent with context provider using proper constructor
agent = ChatAgent(
chat_client=mock_client,
name="test_agent",
id="test_agent",
context_provider=context_provider,
)

# Verify the original agent has the context provider
assert agent.context_provider is context_provider, "Original agent should have context provider"

# Build handoff workflow - this should clone the agent and preserve context_provider
workflow = HandoffBuilder(participants=[agent]).with_start_agent(agent).build()

# Run workflow with a simple message to trigger context provider
await _drain(workflow.run("Test message", stream=True))

# Verify context provider was invoked during the workflow execution
assert len(provider_calls) > 0, (
"Context provider should be called during workflow execution, "
"indicating it was properly preserved during agent cloning"
)


# region Participant Factory Tests


Expand Down
Loading