-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Python: Access agent context in as_tool scenarios
#3731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
764c6b3
e9fafdd
d3f7710
eda6241
d27bfac
08b1ca0
c14fa7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Iterator | ||
| from contextlib import contextmanager | ||
| from contextvars import ContextVar | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ._middleware import AgentContext | ||
|
|
||
| __all__ = [ | ||
| "get_current_agent_run_context", | ||
| ] | ||
|
|
||
| _current_agent_run_context: ContextVar[AgentContext | None] = ContextVar("agent_run_context", default=None) | ||
|
|
||
|
|
||
| def get_current_agent_run_context() -> AgentContext | None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We will also need to add an example that will demonstrate how to use this functionality and how it resolves the problem described in the issue associated with this PR. |
||
| """Get the current agent run context, if any. | ||
|
|
||
| Returns the AgentContext for the currently executing agent run, | ||
| or None if called outside of an agent run. This enables sub-agents | ||
| (invoked as tools) to access their parent agent's run context. | ||
|
|
||
| Returns: | ||
| The current AgentContext, or None if not within an agent run. | ||
|
|
||
| Examples: | ||
| .. code-block:: python | ||
|
|
||
| from agent_framework import get_current_agent_run_context | ||
|
|
||
|
|
||
| @tool | ||
| async def my_tool() -> str: | ||
| parent_ctx = get_current_agent_run_context() | ||
| if parent_ctx and parent_ctx.thread: | ||
| # Access parent's conversation_id | ||
| conv_id = parent_ctx.thread.service_thread_id | ||
| return "done" | ||
| """ | ||
| return _current_agent_run_context.get() | ||
|
|
||
|
|
||
| @contextmanager | ||
| def agent_run_scope(context: AgentContext) -> Iterator[None]: | ||
| """Context manager to set the agent run context for the duration of a block. | ||
|
|
||
| This is used internally by the agent framework to establish the current | ||
| run context. The context is automatically restored when the block exits. | ||
|
|
||
| Args: | ||
| context: The AgentContext to set as current. | ||
| """ | ||
| token = _current_agent_run_context.set(context) | ||
| try: | ||
| yield | ||
| finally: | ||
| _current_agent_run_context.reset(token) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, overload | ||
|
|
||
| from ._agent_context import agent_run_scope | ||
| from ._clients import SupportsChatGetResponse | ||
| from ._types import ( | ||
| AgentResponse, | ||
|
|
@@ -1075,6 +1076,38 @@ def _middleware_handler( | |
| ) | ||
|
|
||
|
|
||
| def _wrap_stream_with_context( | ||
| inner_stream: ResponseStream[AgentResponseUpdate, AgentResponse], | ||
| context: AgentContext, | ||
| ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: | ||
| """Wrap a ResponseStream to maintain agent run context during iteration. | ||
|
|
||
| This ensures that `get_current_agent_run_context()` returns the correct context | ||
| when tools are invoked during streaming, including sub-agents wrapped as tools. | ||
|
|
||
| Args: | ||
| inner_stream: The inner ResponseStream to wrap. | ||
| context: The AgentContext to maintain during iteration. | ||
|
|
||
| Returns: | ||
| A new ResponseStream that maintains the context during iteration. | ||
| """ | ||
|
|
||
| async def _iterate_with_context() -> AsyncIterable[AgentResponseUpdate]: | ||
| with agent_run_scope(context): | ||
| async for update in inner_stream: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't consume a stream unless we absolutely have to. |
||
| yield update | ||
|
|
||
| async def _finalize_with_context(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: | ||
| with agent_run_scope(context): | ||
| return await inner_stream.get_final_response() | ||
|
|
||
| return ResponseStream( | ||
| _iterate_with_context(), | ||
| finalizer=_finalize_with_context, | ||
| ) | ||
|
|
||
|
|
||
| class AgentMiddlewareLayer: | ||
| """Layer for agents to apply agent middleware around run execution.""" | ||
|
|
||
|
|
@@ -1155,10 +1188,7 @@ def run( | |
| combined_kwargs = dict(kwargs) | ||
| combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None | ||
|
|
||
| # Execute with middleware if available | ||
| if not pipeline.has_middlewares: | ||
| return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] | ||
|
|
||
| # Always create AgentContext for ambient access (enables sub-agents to access parent context) | ||
| context = AgentContext( | ||
| agent=self, # type: ignore[arg-type] | ||
| messages=prepare_messages(messages), # type: ignore[arg-type] | ||
|
|
@@ -1168,11 +1198,29 @@ def run( | |
| kwargs=combined_kwargs, | ||
| ) | ||
|
|
||
| # Execute without middleware pipeline if none configured | ||
| if not pipeline.has_middlewares: | ||
| if stream: | ||
| # For streaming, wrap to maintain context during iteration | ||
| inner_stream: ResponseStream[AgentResponseUpdate, AgentResponse] = super().run( # type: ignore[misc, assignment] | ||
| messages, stream=True, thread=thread, options=options, **combined_kwargs | ||
| ) | ||
| return _wrap_stream_with_context(inner_stream, context) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like |
||
|
|
||
| async def _no_middleware_run() -> AgentResponse: | ||
| with agent_run_scope(context): | ||
| return await super(AgentMiddlewareLayer, self).run( # type: ignore[misc, no-any-return] | ||
| messages, stream=False, thread=thread, options=options, **combined_kwargs | ||
| ) | ||
|
|
||
| return _no_middleware_run() | ||
|
|
||
| async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None: | ||
| return await pipeline.execute( | ||
| context=context, | ||
| final_handler=self._middleware_handler, | ||
| ) | ||
| with agent_run_scope(context): | ||
| return await pipeline.execute( | ||
| context=context, | ||
| final_handler=self._middleware_handler, | ||
| ) | ||
|
|
||
| if stream: | ||
| # For streaming, wrap execution in ResponseStream.from_awaitable | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,156 @@ | ||
| # Copyright (c) Microsoft. All rights reserved. | ||
|
|
||
| """Tests for agent run context propagation. | ||
|
|
||
| These tests verify that: | ||
| 1. Agent run context is properly set during agent execution | ||
| 2. Sub-agents can access parent context via get_current_agent_run_context() | ||
| 3. Context is isolated between concurrent agent runs | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| from collections.abc import Awaitable, Callable | ||
| from unittest.mock import AsyncMock | ||
|
|
||
| from agent_framework import ( | ||
| Agent, | ||
| AgentContext, | ||
| ChatResponse, | ||
| Message, | ||
| agent_middleware, | ||
| get_current_agent_run_context, | ||
| ) | ||
| from agent_framework._agent_context import agent_run_scope | ||
|
|
||
| from .conftest import MockChatClient | ||
|
|
||
|
|
||
| class TestAgentContext: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New unit tests don't test the case where context is retrieved within the tool, as described in PR description. |
||
| """Tests for ambient agent run context.""" | ||
|
|
||
| async def test_context_not_available_outside_run(self) -> None: | ||
| """Test that context is None outside of agent run.""" | ||
| context = get_current_agent_run_context() | ||
| assert context is None | ||
|
|
||
| async def test_context_scope_properly_restored(self) -> None: | ||
| """Test that context is properly restored after scope exits.""" | ||
| mock_agent = AsyncMock() | ||
| mock_agent.name = "test" | ||
|
|
||
| mock_context = AgentContext( | ||
| agent=mock_agent, | ||
| messages=[], | ||
| thread=None, | ||
| options=None, | ||
| stream=False, | ||
| kwargs={}, | ||
| ) | ||
|
|
||
| # Before scope | ||
| assert get_current_agent_run_context() is None | ||
|
|
||
| # Inside scope | ||
| with agent_run_scope(mock_context): | ||
| assert get_current_agent_run_context() is mock_context | ||
|
|
||
| # After scope | ||
| assert get_current_agent_run_context() is None | ||
|
|
||
| async def test_nested_context_scopes(self) -> None: | ||
| """Test that nested context scopes work correctly.""" | ||
| mock_agent1 = AsyncMock() | ||
| mock_agent1.name = "agent1" | ||
| mock_agent2 = AsyncMock() | ||
| mock_agent2.name = "agent2" | ||
|
|
||
| context1 = AgentContext(agent=mock_agent1, messages=[], thread=None, options=None, stream=False, kwargs={}) | ||
| context2 = AgentContext(agent=mock_agent2, messages=[], thread=None, options=None, stream=False, kwargs={}) | ||
|
|
||
| with agent_run_scope(context1): | ||
| assert get_current_agent_run_context() is context1 | ||
|
|
||
| with agent_run_scope(context2): | ||
| assert get_current_agent_run_context() is context2 | ||
|
|
||
| # After inner scope exits, should restore outer context | ||
| assert get_current_agent_run_context() is context1 | ||
|
|
||
| assert get_current_agent_run_context() is None | ||
|
|
||
| async def test_context_isolated_between_concurrent_tasks(self) -> None: | ||
| """Test that context is isolated between concurrent async tasks.""" | ||
| results: dict[str, AgentContext | None] = {} | ||
| mock_agent1 = AsyncMock() | ||
| mock_agent1.name = "agent1" | ||
| mock_agent2 = AsyncMock() | ||
| mock_agent2.name = "agent2" | ||
|
|
||
| context1 = AgentContext(agent=mock_agent1, messages=[], thread=None, options=None, stream=False, kwargs={}) | ||
| context2 = AgentContext(agent=mock_agent2, messages=[], thread=None, options=None, stream=False, kwargs={}) | ||
|
|
||
| async def task1() -> None: | ||
| with agent_run_scope(context1): | ||
| await asyncio.sleep(0.01) # Yield to other task | ||
| results["task1"] = get_current_agent_run_context() | ||
|
|
||
| async def task2() -> None: | ||
| with agent_run_scope(context2): | ||
| await asyncio.sleep(0.01) # Yield to other task | ||
| results["task2"] = get_current_agent_run_context() | ||
|
|
||
| await asyncio.gather(task1(), task2()) | ||
|
|
||
| # Each task should see its own context | ||
| assert results["task1"] is context1 | ||
| assert results["task2"] is context2 | ||
|
|
||
| async def test_context_available_in_middleware(self) -> None: | ||
| """Test that agent run context is available in agent middleware.""" | ||
| chat_client = MockChatClient() | ||
| captured_context: AgentContext | None = None | ||
|
|
||
| @agent_middleware | ||
| async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: | ||
| nonlocal captured_context | ||
| # Get ambient context - should be the same as the passed context | ||
| captured_context = get_current_agent_run_context() | ||
| await call_next() | ||
|
|
||
| chat_client.responses = [ | ||
| ChatResponse(messages=[Message("assistant", ["Response"])]), | ||
| ] | ||
|
|
||
| agent = Agent(client=chat_client, name="test_agent", middleware=[capture_middleware]) | ||
| await agent.run("Test message") | ||
|
|
||
| assert captured_context is not None | ||
| assert captured_context.agent is agent | ||
|
|
||
| async def test_context_available_in_streaming_middleware(self) -> None: | ||
| """Test that agent run context is available in middleware during streaming.""" | ||
| chat_client = MockChatClient() | ||
| captured_context: AgentContext | None = None | ||
|
|
||
| @agent_middleware | ||
| async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: | ||
| nonlocal captured_context | ||
| captured_context = get_current_agent_run_context() | ||
| await call_next() | ||
|
|
||
| chat_client.responses = [ | ||
| ChatResponse(messages=[Message("assistant", ["Streaming response"])]), | ||
| ] | ||
|
|
||
| agent = Agent(client=chat_client, name="test_agent", middleware=[capture_middleware]) | ||
|
|
||
| # Run with streaming and consume the response | ||
| async for _update in agent.run("Test message", stream=True): | ||
| pass | ||
|
|
||
| # Context should have been available in middleware | ||
| assert captured_context is not None | ||
| assert captured_context.agent is agent | ||
| assert captured_context.stream is True | ||
Uh oh!
There was an error while loading. Please reload this page.