Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/packages/core/agent_framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
_version = "0.0.0" # Fallback for development mode
__version__: Final[str] = _version

from ._agent_context import * # noqa: F403
from ._agents import * # noqa: F403
from ._clients import * # noqa: F403
from ._logging import * # noqa: F403
Expand Down
61 changes: 61 additions & 0 deletions python/packages/core/agent_framework/_agent_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft. All rights reserved.

from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from ._middleware import AgentContext

__all__ = [
"get_current_agent_run_context",
]

_current_agent_run_context: ContextVar[AgentContext | None] = ContextVar("agent_run_context", default=None)


def get_current_agent_run_context() -> AgentContext | None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will also need to add an example that will demonstrate how to use this functionality and how it resolves the problem described in the issue associated with this PR.

"""Get the current agent run context, if any.

Returns the AgentContext for the currently executing agent run,
or None if called outside of an agent run. This enables sub-agents
(invoked as tools) to access their parent agent's run context.

Returns:
The current AgentContext, or None if not within an agent run.

Examples:
.. code-block:: python

from agent_framework import get_current_agent_run_context


@tool
async def my_tool() -> str:
parent_ctx = get_current_agent_run_context()
if parent_ctx and parent_ctx.thread:
# Access parent's conversation_id
conv_id = parent_ctx.thread.service_thread_id
return "done"
"""
return _current_agent_run_context.get()


@contextmanager
def agent_run_scope(context: AgentContext) -> Iterator[None]:
"""Context manager to set the agent run context for the duration of a block.

This is used internally by the agent framework to establish the current
run context. The context is automatically restored when the block exits.

Args:
context: The AgentContext to set as current.
"""
token = _current_agent_run_context.set(context)
try:
yield
finally:
_current_agent_run_context.reset(token)
16 changes: 16 additions & 0 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,14 @@ async def _run_non_streaming() -> AgentResponse[Any]:
options=options,
kwargs=kwargs,
)

# Update ambient context with resolved thread for sub-agent conversation_id inheritance
from ._agent_context import get_current_agent_run_context

parent_context = get_current_agent_run_context()
if parent_context is not None and parent_context.thread is None:
parent_context.thread = ctx["thread"]

response = await self.client.get_response( # type: ignore[call-overload]
messages=ctx["thread_messages"],
stream=False,
Expand Down Expand Up @@ -948,6 +956,14 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse]:
kwargs=kwargs,
)
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it

# Update ambient context with resolved thread for sub-agent conversation_id inheritance
from ._agent_context import get_current_agent_run_context

parent_context = get_current_agent_run_context()
if parent_context is not None and parent_context.thread is None:
parent_context.thread = ctx["thread"]

return self.client.get_response( # type: ignore[call-overload, no-any-return]
messages=ctx["thread_messages"],
stream=True,
Expand Down
64 changes: 56 additions & 8 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from enum import Enum
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, overload

from ._agent_context import agent_run_scope
from ._clients import SupportsChatGetResponse
from ._types import (
AgentResponse,
Expand Down Expand Up @@ -1075,6 +1076,38 @@ def _middleware_handler(
)


def _wrap_stream_with_context(
inner_stream: ResponseStream[AgentResponseUpdate, AgentResponse],
context: AgentContext,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
"""Wrap a ResponseStream to maintain agent run context during iteration.

This ensures that `get_current_agent_run_context()` returns the correct context
when tools are invoked during streaming, including sub-agents wrapped as tools.

Args:
inner_stream: The inner ResponseStream to wrap.
context: The AgentContext to maintain during iteration.

Returns:
A new ResponseStream that maintains the context during iteration.
"""

async def _iterate_with_context() -> AsyncIterable[AgentResponseUpdate]:
with agent_run_scope(context):
async for update in inner_stream:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't consume a stream unless we absolutely have to.

yield update

async def _finalize_with_context(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
with agent_run_scope(context):
return await inner_stream.get_final_response()

return ResponseStream(
_iterate_with_context(),
finalizer=_finalize_with_context,
)


class AgentMiddlewareLayer:
"""Layer for agents to apply agent middleware around run execution."""

Expand Down Expand Up @@ -1155,10 +1188,7 @@ def run(
combined_kwargs = dict(kwargs)
combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None

# Execute with middleware if available
if not pipeline.has_middlewares:
return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return]

# Always create AgentContext for ambient access (enables sub-agents to access parent context)
context = AgentContext(
agent=self, # type: ignore[arg-type]
messages=prepare_messages(messages), # type: ignore[arg-type]
Expand All @@ -1168,11 +1198,29 @@ def run(
kwargs=combined_kwargs,
)

# Execute without middleware pipeline if none configured
if not pipeline.has_middlewares:
if stream:
# For streaming, wrap to maintain context during iteration
inner_stream: ResponseStream[AgentResponseUpdate, AgentResponse] = super().run( # type: ignore[misc, assignment]
messages, stream=True, thread=thread, options=options, **combined_kwargs
)
return _wrap_stream_with_context(inner_stream, context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like _wrap_stream_with_context function is used only when the pipeline doesn't have middleware registered, but how about the case where middleware exists, should we add wrapping there as well?


async def _no_middleware_run() -> AgentResponse:
with agent_run_scope(context):
return await super(AgentMiddlewareLayer, self).run( # type: ignore[misc, no-any-return]
messages, stream=False, thread=thread, options=options, **combined_kwargs
)

return _no_middleware_run()

async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None:
return await pipeline.execute(
context=context,
final_handler=self._middleware_handler,
)
with agent_run_scope(context):
return await pipeline.execute(
context=context,
final_handler=self._middleware_handler,
)

if stream:
# For streaming, wrap execution in ResponseStream.from_awaitable
Expand Down
10 changes: 10 additions & 0 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,6 +1509,9 @@ def _update_conversation_id(
) -> None:
"""Update kwargs and options with conversation id.

Also updates the ambient agent run context's thread if available,
enabling sub-agents (invoked as tools) to inherit the conversation_id.

Args:
kwargs: The keyword arguments dictionary to update.
conversation_id: The conversation ID to set, or None to skip.
Expand All @@ -1525,6 +1528,13 @@ def _update_conversation_id(
if options is not None:
options["conversation_id"] = conversation_id

# Update the ambient context's thread so sub-agents can inherit the conversation_id
from ._agent_context import get_current_agent_run_context

parent_context = get_current_agent_run_context()
if parent_context and parent_context.thread:
parent_context.thread.service_thread_id = conversation_id


async def _ensure_response_stream(
stream_like: ResponseStream[Any, Any] | Awaitable[ResponseStream[Any, Any]],
Expand Down
156 changes: 156 additions & 0 deletions python/packages/core/tests/core/test_agent_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Microsoft. All rights reserved.

"""Tests for agent run context propagation.

These tests verify that:
1. Agent run context is properly set during agent execution
2. Sub-agents can access parent context via get_current_agent_run_context()
3. Context is isolated between concurrent agent runs
"""

from __future__ import annotations

import asyncio
from collections.abc import Awaitable, Callable
from unittest.mock import AsyncMock

from agent_framework import (
Agent,
AgentContext,
ChatResponse,
Message,
agent_middleware,
get_current_agent_run_context,
)
from agent_framework._agent_context import agent_run_scope

from .conftest import MockChatClient


class TestAgentContext:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New unit tests don't test the case where context is retrieved within the tool, as described in PR description.

"""Tests for ambient agent run context."""

async def test_context_not_available_outside_run(self) -> None:
"""Test that context is None outside of agent run."""
context = get_current_agent_run_context()
assert context is None

async def test_context_scope_properly_restored(self) -> None:
"""Test that context is properly restored after scope exits."""
mock_agent = AsyncMock()
mock_agent.name = "test"

mock_context = AgentContext(
agent=mock_agent,
messages=[],
thread=None,
options=None,
stream=False,
kwargs={},
)

# Before scope
assert get_current_agent_run_context() is None

# Inside scope
with agent_run_scope(mock_context):
assert get_current_agent_run_context() is mock_context

# After scope
assert get_current_agent_run_context() is None

async def test_nested_context_scopes(self) -> None:
"""Test that nested context scopes work correctly."""
mock_agent1 = AsyncMock()
mock_agent1.name = "agent1"
mock_agent2 = AsyncMock()
mock_agent2.name = "agent2"

context1 = AgentContext(agent=mock_agent1, messages=[], thread=None, options=None, stream=False, kwargs={})
context2 = AgentContext(agent=mock_agent2, messages=[], thread=None, options=None, stream=False, kwargs={})

with agent_run_scope(context1):
assert get_current_agent_run_context() is context1

with agent_run_scope(context2):
assert get_current_agent_run_context() is context2

# After inner scope exits, should restore outer context
assert get_current_agent_run_context() is context1

assert get_current_agent_run_context() is None

async def test_context_isolated_between_concurrent_tasks(self) -> None:
"""Test that context is isolated between concurrent async tasks."""
results: dict[str, AgentContext | None] = {}
mock_agent1 = AsyncMock()
mock_agent1.name = "agent1"
mock_agent2 = AsyncMock()
mock_agent2.name = "agent2"

context1 = AgentContext(agent=mock_agent1, messages=[], thread=None, options=None, stream=False, kwargs={})
context2 = AgentContext(agent=mock_agent2, messages=[], thread=None, options=None, stream=False, kwargs={})

async def task1() -> None:
with agent_run_scope(context1):
await asyncio.sleep(0.01) # Yield to other task
results["task1"] = get_current_agent_run_context()

async def task2() -> None:
with agent_run_scope(context2):
await asyncio.sleep(0.01) # Yield to other task
results["task2"] = get_current_agent_run_context()

await asyncio.gather(task1(), task2())

# Each task should see its own context
assert results["task1"] is context1
assert results["task2"] is context2

async def test_context_available_in_middleware(self) -> None:
"""Test that agent run context is available in agent middleware."""
chat_client = MockChatClient()
captured_context: AgentContext | None = None

@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
nonlocal captured_context
# Get ambient context - should be the same as the passed context
captured_context = get_current_agent_run_context()
await call_next()

chat_client.responses = [
ChatResponse(messages=[Message("assistant", ["Response"])]),
]

agent = Agent(client=chat_client, name="test_agent", middleware=[capture_middleware])
await agent.run("Test message")

assert captured_context is not None
assert captured_context.agent is agent

async def test_context_available_in_streaming_middleware(self) -> None:
"""Test that agent run context is available in middleware during streaming."""
chat_client = MockChatClient()
captured_context: AgentContext | None = None

@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
nonlocal captured_context
captured_context = get_current_agent_run_context()
await call_next()

chat_client.responses = [
ChatResponse(messages=[Message("assistant", ["Streaming response"])]),
]

agent = Agent(client=chat_client, name="test_agent", middleware=[capture_middleware])

# Run with streaming and consume the response
async for _update in agent.run("Test message", stream=True):
pass

# Context should have been available in middleware
assert captured_context is not None
assert captured_context.agent is agent
assert captured_context.stream is True
Loading