From 44920aa35e9c0144b9deb7e3e4488e9adda345de Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:15:24 -0800 Subject: [PATCH 1/2] Renamed AgentRunContext to AgentContext --- python/packages/core/AGENTS.md | 6 +- .../core/agent_framework/_middleware.py | 44 ++--- .../core/agent_framework/_serialization.py | 10 +- .../core/test_as_tool_kwargs_propagation.py | 30 +-- .../core/tests/core/test_middleware.py | 182 +++++++----------- .../core/test_middleware_context_result.py | 52 ++--- .../tests/core/test_middleware_with_agent.py | 92 +++------ .../agent_framework_purview/_middleware.py | 6 +- .../packages/purview/tests/test_middleware.py | 50 ++--- python/samples/concepts/tools/README.md | 4 +- .../getting_started/middleware/README.md | 2 +- .../agent_and_run_level_middleware.py | 16 +- .../middleware/class_based_middleware.py | 10 +- .../middleware/decorator_middleware.py | 4 +- .../middleware/function_based_middleware.py | 6 +- .../middleware/middleware_termination.py | 10 +- .../override_result_with_middleware.py | 6 +- .../middleware/thread_behavior_middleware.py | 8 +- 18 files changed, 219 insertions(+), 319 deletions(-) diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 946d077c8b..440bf1b3d0 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -55,7 +55,7 @@ agent_framework/ - **`AgentMiddleware`** - Intercepts agent `run()` calls - **`ChatMiddleware`** - Intercepts chat client `get_response()` calls - **`FunctionMiddleware`** - Intercepts function/tool invocations -- **`AgentRunContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware +- **`AgentContext`** / **`ChatContext`** / **`FunctionInvocationContext`** - Context objects passed through middleware ### Threads (`_threads.py`) @@ -114,10 +114,10 @@ agent = OpenAIChatClient().as_agent( ### Middleware Pipeline ```python -from agent_framework import ChatAgent, AgentMiddleware, AgentRunContext +from agent_framework import ChatAgent, AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): - async def invoke(self, context: AgentRunContext, next) -> AgentResponse: + async def invoke(self, context: AgentContext, next) -> AgentResponse: print(f"Input: {context.messages}") response = await next(context) print(f"Output: {response}") diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 44a55b13b3..7f6619570e 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -43,10 +43,10 @@ TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel) __all__ = [ + "AgentContext", "AgentMiddleware", "AgentMiddlewareLayer", "AgentMiddlewareTypes", - "AgentRunContext", "ChatAndFunctionMiddlewareTypes", "ChatContext", "ChatMiddleware", @@ -109,7 +109,7 @@ class MiddlewareType(str, Enum): CHAT = "chat" -class AgentRunContext: +class AgentContext: """Context object for agent middleware invocations. This context is passed through the agent middleware pipeline and contains all information @@ -131,11 +131,11 @@ class AgentRunContext: Examples: .. code-block:: python - from agent_framework import AgentMiddleware, AgentRunContext + from agent_framework import AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): - async def process(self, context: AgentRunContext, next): + async def process(self, context: AgentContext, next): print(f"Agent: {context.agent.name}") print(f"Messages: {len(context.messages)}") print(f"Thread: {context.thread}") @@ -170,7 +170,7 @@ def __init__( | None = None, stream_cleanup_hooks: Sequence[Callable[[], Awaitable[None] | None]] | None = None, ) -> None: - """Initialize the AgentRunContext. + """Initialize the AgentContext. Args: agent: The agent being invoked. @@ -356,14 +356,14 @@ class AgentMiddleware(ABC): Examples: .. code-block:: python - from agent_framework import AgentMiddleware, AgentRunContext, ChatAgent + from agent_framework import AgentMiddleware, AgentContext, ChatAgent class RetryMiddleware(AgentMiddleware): def __init__(self, max_retries: int = 3): self.max_retries = max_retries - async def process(self, context: AgentRunContext, next): + async def process(self, context: AgentContext, next): for attempt in range(self.max_retries): await next(context) if context.result and not context.result.is_error: @@ -378,8 +378,8 @@ async def process(self, context: AgentRunContext, next): @abstractmethod async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Process an agent invocation. @@ -531,7 +531,7 @@ async def process( # Pure function type definitions for convenience -AgentMiddlewareCallable = Callable[[AgentRunContext, Callable[[AgentRunContext], Awaitable[None]]], Awaitable[None]] +AgentMiddlewareCallable = Callable[[AgentContext, Callable[[AgentContext], Awaitable[None]]], Awaitable[None]] AgentMiddlewareTypes: TypeAlias = AgentMiddleware | AgentMiddlewareCallable FunctionMiddlewareCallable = Callable[ @@ -561,7 +561,7 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: """Decorator to mark a function as agent middleware. This decorator explicitly identifies a function as agent middleware, - which processes AgentRunContext objects. + which processes AgentContext objects. Args: func: The middleware function to mark as agent middleware. @@ -572,11 +572,11 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable: Examples: .. code-block:: python - from agent_framework import agent_middleware, AgentRunContext, ChatAgent + from agent_framework import agent_middleware, AgentContext, ChatAgent @agent_middleware - async def logging_middleware(context: AgentRunContext, next): + async def logging_middleware(context: AgentContext, next): print(f"Before: {context.agent.name}") await next(context) print(f"After: {context.result}") @@ -752,9 +752,9 @@ def _register_middleware(self, middleware: AgentMiddlewareTypes) -> None: async def execute( self, - context: AgentRunContext, + context: AgentContext, final_handler: Callable[ - [AgentRunContext], Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse] + [AgentContext], Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse] ], ) -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None: """Execute the agent middleware pipeline for streaming or non-streaming. @@ -772,17 +772,17 @@ async def execute( context.result = await context.result return context.result - def create_next_handler(index: int) -> Callable[[AgentRunContext], Awaitable[None]]: + def create_next_handler(index: int) -> Callable[[AgentContext], Awaitable[None]]: if index >= len(self._middleware): - async def final_wrapper(c: AgentRunContext) -> None: + async def final_wrapper(c: AgentContext) -> None: c.result = final_handler(c) # type: ignore[assignment] if inspect.isawaitable(c.result): c.result = await c.result return final_wrapper - async def current_handler(c: AgentRunContext) -> None: + async def current_handler(c: AgentContext) -> None: # MiddlewareTermination bubbles up to execute() to skip post-processing await self._middleware[index].process(c, create_next_handler(index + 1)) @@ -1161,7 +1161,7 @@ def run( if not pipeline.has_middlewares: return super().run(messages, stream=stream, thread=thread, options=options, **combined_kwargs) # type: ignore[misc, no-any-return] - context = AgentRunContext( + context = AgentContext( agent=self, # type: ignore[arg-type] messages=prepare_messages(messages), # type: ignore[arg-type] thread=thread, @@ -1194,7 +1194,7 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse return _execute() # type: ignore[return-value] def _middleware_handler( - self, context: AgentRunContext + self, context: AgentContext ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: return super().run( # type: ignore[misc, no-any-return] context.messages, @@ -1231,7 +1231,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType: first_param = params[0] if hasattr(first_param.annotation, "__name__"): annotation_name = first_param.annotation.__name__ - if annotation_name == "AgentRunContext": + if annotation_name == "AgentContext": param_type = MiddlewareType.AGENT elif annotation_name == "FunctionInvocationContext": param_type = MiddlewareType.FUNCTION @@ -1270,7 +1270,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType: raise MiddlewareException( f"Cannot determine middleware type for function {middleware.__name__}. " f"Please either use @agent_middleware/@function_middleware/@chat_middleware decorators " - f"or specify parameter types (AgentRunContext, FunctionInvocationContext, or ChatContext)." + f"or specify parameter types (AgentContext, FunctionInvocationContext, or ChatContext)." ) diff --git a/python/packages/core/agent_framework/_serialization.py b/python/packages/core/agent_framework/_serialization.py index 0e9a34fed4..dd6b8f871f 100644 --- a/python/packages/core/agent_framework/_serialization.py +++ b/python/packages/core/agent_framework/_serialization.py @@ -477,12 +477,12 @@ async def get_current_weather(location: Annotated[str, "The city name"]) -> str: .. code-block:: python - from agent_framework._middleware import AgentRunContext + from agent_framework._middleware import AgentContext from agent_framework import BaseAgent - # AgentRunContext has INJECTABLE = {"agent", "result"} + # AgentContext has INJECTABLE = {"agent", "result"} context_data = { - "type": "agent_run_context", + "type": "agent_context", "messages": [{"role": "user", "text": "Hello"}], "stream": False, "metadata": {"session_id": "abc123"}, @@ -492,14 +492,14 @@ async def get_current_weather(location: Annotated[str, "The city name"]) -> str: # Inject agent and result during middleware processing my_agent = BaseAgent(name="test-agent") dependencies = { - "agent_run_context": { + "agent_context": { "agent": my_agent, "result": None, # Will be populated during execution } } # Reconstruct context with agent dependency for middleware chain - context = AgentRunContext.from_dict(context_data, dependencies=dependencies) + context = AgentContext.from_dict(context_data, dependencies=dependencies) # MiddlewareTypes can now access context.agent and process the execution This injection system allows the agent framework to maintain clean separation diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index 8d262a5c23..8a2c4ceb5b 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -6,7 +6,7 @@ from typing import Any from agent_framework import ChatAgent, ChatMessage, ChatResponse, Content, agent_middleware -from agent_framework._middleware import AgentRunContext +from agent_framework._middleware import AgentContext from .conftest import MockChatClient @@ -19,9 +19,7 @@ async def test_as_tool_forwards_runtime_kwargs(self, chat_client: MockChatClient captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) await next(context) @@ -62,9 +60,7 @@ async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, chat_client captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) @@ -99,9 +95,7 @@ async def test_as_tool_nested_delegation_propagates_kwargs(self, chat_client: Mo captured_kwargs_list: list[dict[str, Any]] = [] @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture kwargs at each level captured_kwargs_list.append(dict(context.kwargs)) await next(context) @@ -162,9 +156,7 @@ async def test_as_tool_streaming_mode_forwards_kwargs(self, chat_client: MockCha captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) @@ -224,9 +216,7 @@ async def test_as_tool_kwargs_with_chat_options(self, chat_client: MockChatClien captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) @@ -266,9 +256,7 @@ async def test_as_tool_kwargs_isolated_per_invocation(self, chat_client: MockCha call_count = 0 @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: nonlocal call_count call_count += 1 if call_count == 1: @@ -318,9 +306,7 @@ async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, chat captured_kwargs: dict[str, Any] = {} @agent_middleware - async def capture_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) await next(context) diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index f6a0267500..e6403fa2e2 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -18,9 +18,9 @@ ResponseStream, ) from agent_framework._middleware import ( + AgentContext, AgentMiddleware, AgentMiddlewarePipeline, - AgentRunContext, ChatContext, ChatMiddleware, ChatMiddlewarePipeline, @@ -32,13 +32,13 @@ from agent_framework._tools import FunctionTool -class TestAgentRunContext: - """Test cases for AgentRunContext.""" +class TestAgentContext: + """Test cases for AgentContext.""" def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None: - """Test AgentRunContext initialization with default values.""" + """Test AgentContext initialization with default values.""" messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) assert context.agent is mock_agent assert context.messages == messages @@ -46,10 +46,10 @@ def test_init_with_defaults(self, mock_agent: AgentProtocol) -> None: assert context.metadata == {} def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None: - """Test AgentRunContext initialization with custom values.""" + """Test AgentContext initialization with custom values.""" messages = [ChatMessage(role="user", text="test")] metadata = {"key": "value"} - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True, metadata=metadata) + context = AgentContext(agent=mock_agent, messages=messages, stream=True, metadata=metadata) assert context.agent is mock_agent assert context.messages == messages @@ -57,12 +57,12 @@ def test_init_with_custom_values(self, mock_agent: AgentProtocol) -> None: assert context.metadata == metadata def test_init_with_thread(self, mock_agent: AgentProtocol) -> None: - """Test AgentRunContext initialization with thread parameter.""" + """Test AgentContext initialization with thread parameter.""" from agent_framework import AgentThread messages = [ChatMessage(role="user", text="test")] thread = AgentThread() - context = AgentRunContext(agent=mock_agent, messages=messages, thread=thread) + context = AgentContext(agent=mock_agent, messages=messages, thread=thread) assert context.agent is mock_agent assert context.messages == messages @@ -135,11 +135,11 @@ class TestAgentMiddlewarePipeline: """Test cases for AgentMiddlewarePipeline.""" class PreNextTerminateMiddleware(AgentMiddleware): - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: raise MiddlewareTermination class PostNextTerminateMiddleware(AgentMiddleware): - async def process(self, context: AgentRunContext, next: Any) -> None: + async def process(self, context: AgentContext, next: Any) -> None: await next(context) raise MiddlewareTermination @@ -157,7 +157,7 @@ def test_init_with_class_middleware(self) -> None: def test_init_with_function_middleware(self) -> None: """Test AgentMiddlewarePipeline initialization with function-based middleware.""" - async def test_middleware(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def test_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: await next(context) pipeline = AgentMiddlewarePipeline(test_middleware) @@ -167,11 +167,11 @@ async def test_execute_no_middleware(self, mock_agent: AgentProtocol) -> None: """Test pipeline execution with no middleware.""" pipeline = AgentMiddlewarePipeline() messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return expected_response result = await pipeline.execute(context, final_handler) @@ -185,9 +185,7 @@ class OrderTrackingMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -195,11 +193,11 @@ async def process( middleware = OrderTrackingMiddleware("test") pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return expected_response @@ -211,9 +209,9 @@ async def test_execute_stream_no_middleware(self, mock_agent: AgentProtocol) -> """Test pipeline streaming execution with no middleware.""" pipeline = AgentMiddlewarePipeline() messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) yield AgentResponseUpdate(contents=[Content.from_text(text="chunk2")]) @@ -238,9 +236,7 @@ class StreamOrderTrackingMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -248,9 +244,9 @@ async def process( middleware = StreamOrderTrackingMiddleware("test") pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) @@ -274,10 +270,10 @@ async def test_execute_with_pre_next_termination(self, mock_agent: AgentProtocol middleware = self.PreNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: # Handler should not be executed when terminated before next() execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -292,10 +288,10 @@ async def test_execute_with_post_next_termination(self, mock_agent: AgentProtoco middleware = self.PostNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -310,10 +306,10 @@ async def test_execute_stream_with_pre_next_termination(self, mock_agent: AgentP middleware = self.PreNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: # Handler should not be executed when terminated before next() execution_order.append("handler_start") @@ -338,10 +334,10 @@ async def test_execute_stream_with_post_next_termination(self, mock_agent: Agent middleware = self.PostNextTerminateMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) execution_order: list[str] = [] - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: execution_order.append("handler_start") yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) @@ -367,9 +363,7 @@ async def test_execute_with_thread_in_context(self, mock_agent: AgentProtocol) - captured_thread = None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread await next(context) @@ -378,11 +372,11 @@ async def process( pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] thread = AgentThread() - context = AgentRunContext(agent=mock_agent, messages=messages, thread=thread) + context = AgentContext(agent=mock_agent, messages=messages, thread=thread) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return expected_response result = await pipeline.execute(context, final_handler) @@ -394,9 +388,7 @@ async def test_execute_with_no_thread_in_context(self, mock_agent: AgentProtocol captured_thread = "not_none" # Use string to distinguish from None class ThreadCapturingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: nonlocal captured_thread captured_thread = context.thread await next(context) @@ -404,11 +396,11 @@ async def process( middleware = ThreadCapturingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, thread=None) + context = AgentContext(agent=mock_agent, messages=messages, thread=None) expected_response = AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return expected_response result = await pipeline.execute(context, final_handler) @@ -774,9 +766,7 @@ async def test_agent_middleware_execution(self, mock_agent: AgentProtocol) -> No metadata_updates: list[str] = [] class MetadataAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: context.metadata["before"] = True metadata_updates.append("before") await next(context) @@ -786,9 +776,9 @@ async def process( middleware = MetadataAgentMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: metadata_updates.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -839,9 +829,7 @@ async def test_agent_function_middleware(self, mock_agent: AgentProtocol) -> Non """Test function-based agent middleware.""" execution_order: list[str] = [] - async def test_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def test_agent_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("function_before") context.metadata["function_middleware"] = True await next(context) @@ -849,9 +837,9 @@ async def test_agent_middleware( pipeline = AgentMiddlewarePipeline(test_agent_middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -896,25 +884,21 @@ async def test_mixed_agent_middleware(self, mock_agent: AgentProtocol) -> None: execution_order: list[str] = [] class ClassMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("class_before") await next(context) execution_order.append("class_after") - async def function_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def function_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("function_before") await next(context) execution_order.append("function_after") pipeline = AgentMiddlewarePipeline(ClassMiddleware(), function_middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -997,25 +981,19 @@ async def test_agent_middleware_execution_order(self, mock_agent: AgentProtocol) execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("first_before") await next(context) execution_order.append("first_after") class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("second_before") await next(context) execution_order.append("second_after") class ThirdMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("third_before") await next(context) execution_order.append("third_after") @@ -1023,9 +1001,9 @@ async def process( middleware = [FirstMiddleware(), SecondMiddleware(), ThirdMiddleware()] pipeline = AgentMiddlewarePipeline(*middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: execution_order.append("handler") return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -1136,9 +1114,7 @@ async def test_agent_context_validation(self, mock_agent: AgentProtocol) -> None """Test that agent context contains expected data.""" class ContextValidationMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Verify context has all expected attributes assert hasattr(context, "agent") assert hasattr(context, "messages") @@ -1161,9 +1137,9 @@ async def process( middleware = ContextValidationMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: # Verify metadata was set by middleware assert ctx.metadata.get("validated") is True return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) @@ -1260,9 +1236,7 @@ async def test_streaming_flag_validation(self, mock_agent: AgentProtocol) -> Non streaming_flags: list[bool] = [] class StreamingFlagMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: streaming_flags.append(context.stream) await next(context) @@ -1271,18 +1245,18 @@ async def process( messages = [ChatMessage(role="user", text="test")] # Test non-streaming - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: streaming_flags.append(ctx.stream) return AgentResponse(messages=[ChatMessage(role="assistant", text="response")]) await pipeline.execute(context, final_handler) # Test streaming - context_stream = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context_stream = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_stream_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_stream_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: streaming_flags.append(ctx.stream) yield AgentResponseUpdate(contents=[Content.from_text(text="chunk")]) @@ -1302,9 +1276,7 @@ async def test_streaming_middleware_behavior(self, mock_agent: AgentProtocol) -> chunks_processed: list[str] = [] class StreamProcessingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: chunks_processed.append("before_stream") await next(context) chunks_processed.append("after_stream") @@ -1312,9 +1284,9 @@ async def process( middleware = StreamProcessingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_stream_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_stream_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: chunks_processed.append("stream_start") yield AgentResponseUpdate(contents=[Content.from_text(text="chunk1")]) @@ -1436,7 +1408,7 @@ class FunctionTestArgs(BaseModel): class TestAgentMiddleware(AgentMiddleware): """Test implementation of AgentMiddleware.""" - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: await next(context) @@ -1469,20 +1441,18 @@ async def test_agent_middleware_no_next_no_execution(self, mock_agent: AgentProt """Test that when agent middleware doesn't call next(), no execution happens.""" class NoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass middleware = NoNextMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="should not execute")]) @@ -1498,20 +1468,18 @@ async def test_agent_middleware_no_next_no_streaming_execution(self, mock_agent: """Test that when agent middleware doesn't call next(), no streaming execution happens.""" class NoNextStreamingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Don't call next() - this should prevent any execution pass middleware = NoNextStreamingMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) handler_called = False - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: nonlocal handler_called handler_called = True @@ -1566,26 +1534,22 @@ async def test_multiple_middlewares_early_stop(self, mock_agent: AgentProtocol) execution_order: list[str] = [] class FirstMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("first") # Don't call next() - this should stop the pipeline class SecondMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("second") await next(context) pipeline = AgentMiddlewarePipeline(FirstMiddleware(), SecondMiddleware()) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="should not execute")]) diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index 64eec8dc3b..29bb2e3aa2 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -17,9 +17,9 @@ ResponseStream, ) from agent_framework._middleware import ( + AgentContext, AgentMiddleware, AgentMiddlewarePipeline, - AgentRunContext, FunctionInvocationContext, FunctionMiddleware, FunctionMiddlewarePipeline, @@ -43,9 +43,7 @@ async def test_agent_middleware_response_override_non_streaming(self, mock_agent override_response = AgentResponse(messages=[ChatMessage(role="assistant", text="overridden response")]) class ResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response await next(context) context.result = override_response @@ -53,11 +51,11 @@ async def process( middleware = ResponseOverrideMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages) + context = AgentContext(agent=mock_agent, messages=messages) handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="original response")]) @@ -79,9 +77,7 @@ async def override_stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text=" stream")]) class StreamResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Execute the pipeline first, then override the response stream await next(context) context.result = ResponseStream(override_stream()) @@ -89,9 +85,9 @@ async def process( middleware = StreamResponseOverrideMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=True) + context = AgentContext(agent=mock_agent, messages=messages, stream=True) - async def final_handler(ctx: AgentRunContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: + async def final_handler(ctx: AgentContext) -> ResponseStream[AgentResponseUpdate, AgentResponse]: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text="original")]) @@ -145,9 +141,7 @@ async def test_chat_agent_middleware_response_override(self) -> None: mock_chat_client = MockChatClient() class ChatAgentResponseOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Always call next() first to allow execution await next(context) # Then conditionally override based on content @@ -184,9 +178,7 @@ async def custom_stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate(contents=[Content.from_text(text=" response!")]) class ChatAgentStreamOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Check if we want to override BEFORE calling next to avoid creating unused streams if any("custom stream" in msg.text for msg in context.messages if msg.text): context.result = ResponseStream(custom_stream()) @@ -223,9 +215,7 @@ async def test_agent_middleware_conditional_no_next(self, mock_agent: AgentProto """Test that when agent middleware conditionally doesn't call next(), no execution happens.""" class ConditionalNoNextMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Only call next() if message contains "execute" if any("execute" in msg.text for msg in context.messages if msg.text): await next(context) @@ -236,14 +226,14 @@ async def process( handler_called = False - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: nonlocal handler_called handler_called = True return AgentResponse(messages=[ChatMessage(role="assistant", text="executed response")]) # Test case where next() is NOT called no_execute_messages = [ChatMessage(role="user", text="Don't run this")] - no_execute_context = AgentRunContext(agent=mock_agent, messages=no_execute_messages, stream=False) + no_execute_context = AgentContext(agent=mock_agent, messages=no_execute_messages, stream=False) no_execute_result = await pipeline.execute(no_execute_context, final_handler) # When middleware doesn't call next(), result should be empty AgentResponse @@ -255,7 +245,7 @@ async def final_handler(ctx: AgentRunContext) -> AgentResponse: # Test case where next() IS called execute_messages = [ChatMessage(role="user", text="Please execute this")] - execute_context = AgentRunContext(agent=mock_agent, messages=execute_messages, stream=False) + execute_context = AgentContext(agent=mock_agent, messages=execute_messages, stream=False) execute_result = await pipeline.execute(execute_context, final_handler) assert execute_result is not None @@ -318,9 +308,7 @@ async def test_agent_middleware_response_observability(self, mock_agent: AgentPr observed_responses: list[AgentResponse] = [] class ObservabilityMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Context should be empty before next() assert context.result is None @@ -335,9 +323,9 @@ async def process( middleware = ObservabilityMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=False) + context = AgentContext(agent=mock_agent, messages=messages, stream=False) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return AgentResponse(messages=[ChatMessage(role="assistant", text="executed response")]) result = await pipeline.execute(context, final_handler) @@ -386,9 +374,7 @@ async def test_agent_middleware_post_execution_override(self, mock_agent: AgentP """Test that middleware can override response after observing execution.""" class PostExecutionOverrideMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Call next to execute first await next(context) @@ -405,9 +391,9 @@ async def process( middleware = PostExecutionOverrideMiddleware() pipeline = AgentMiddlewarePipeline(middleware) messages = [ChatMessage(role="user", text="test")] - context = AgentRunContext(agent=mock_agent, messages=messages, stream=False) + context = AgentContext(agent=mock_agent, messages=messages, stream=False) - async def final_handler(ctx: AgentRunContext) -> AgentResponse: + async def final_handler(ctx: AgentContext) -> AgentResponse: return AgentResponse(messages=[ChatMessage(role="assistant", text="response to modify")]) result = await pipeline.execute(context, final_handler) diff --git a/python/packages/core/tests/core/test_middleware_with_agent.py b/python/packages/core/tests/core/test_middleware_with_agent.py index 50146ab008..1bb91137e7 100644 --- a/python/packages/core/tests/core/test_middleware_with_agent.py +++ b/python/packages/core/tests/core/test_middleware_with_agent.py @@ -6,9 +6,9 @@ import pytest from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponseUpdate, - AgentRunContext, ChatAgent, ChatClientProtocol, ChatContext, @@ -44,9 +44,7 @@ class TrackingAgentMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -122,9 +120,7 @@ async def test_agent_middleware_with_pre_termination(self, chat_client: "MockCha execution_order: list[str] = [] class PreTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("middleware_before") raise MiddlewareTermination # Code after raise is unreachable @@ -153,9 +149,7 @@ async def test_agent_middleware_with_post_termination(self, chat_client: "MockCh execution_order: list[str] = [] class PostTerminationMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("middleware_before") await next(context) execution_order.append("middleware_after") @@ -225,7 +219,7 @@ async def test_function_based_agent_middleware_with_chat_agent(self, chat_client execution_order: list[str] = [] async def tracking_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("agent_function_before") await next(context) @@ -290,9 +284,7 @@ async def test_agent_middleware_with_streaming(self, chat_client: "MockChatClien streaming_flags: list[bool] = [] class StreamingTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("middleware_before") streaming_flags.append(context.stream) await next(context) @@ -334,9 +326,7 @@ async def test_non_streaming_vs_streaming_flag_validation(self, chat_client: "Mo streaming_flags: list[bool] = [] class FlagTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: streaming_flags.append(context.stream) await next(context) @@ -368,9 +358,7 @@ class OrderedMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append(f"{self.name}_before") await next(context) execution_order.append(f"{self.name}_after") @@ -400,15 +388,13 @@ async def test_mixed_middleware_types_with_chat_agent(self, chat_client_base: "M execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("class_agent_before") await next(context) execution_order.append("class_agent_after") async def function_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("function_agent_before") await next(context) @@ -447,15 +433,13 @@ async def test_mixed_middleware_types_with_supported_client(self, chat_client_ba execution_order: list[str] = [] class ClassAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("class_agent_before") await next(context) execution_order.append("class_agent_after") async def function_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("function_agent_before") await next(context) @@ -646,8 +630,8 @@ async def test_mixed_agent_and_function_middleware_with_tool_calls( class TrackingAgentMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: execution_order.append("agent_middleware_before") await next(context) @@ -801,7 +785,7 @@ def __init__(self, name: str, execution_log: list[str]): self.name = name self.execution_log = execution_log - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") await next(context) self.execution_log.append(f"{self.name}_end") @@ -924,7 +908,7 @@ def __init__(self, name: str, execution_log: list[str]): self.name = name self.execution_log = execution_log - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: self.execution_log.append(f"{self.name}_start") await next(context) self.execution_log.append(f"{self.name}_end") @@ -976,9 +960,7 @@ class MetadataAgentMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Set metadata to pass information to run middleware context.metadata[f"{self.name}_key"] = f"{self.name}_value" @@ -989,9 +971,7 @@ class MetadataRunMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") # Read metadata set by agent middleware for key, value in context.metadata.items(): @@ -1049,9 +1029,7 @@ class StreamingTrackingMiddleware(AgentMiddleware): def __init__(self, name: str): self.name = name - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append(f"{self.name}_start") streaming_flags.append(context.stream) await next(context) @@ -1093,9 +1071,7 @@ async def test_agent_and_run_level_both_agent_and_function_middleware( # Agent-level middleware class AgentLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append("agent_level_agent_start") context.metadata["agent_level_agent"] = "processed" await next(context) @@ -1114,9 +1090,7 @@ async def process( # Run-level middleware class RunLevelAgentMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_log.append("run_level_agent_start") # Verify agent-level middleware metadata is available assert "agent_level_agent" in context.metadata @@ -1218,7 +1192,7 @@ async def test_decorator_and_type_match(self, chat_client_base: "MockBaseChatCli @agent_middleware async def matching_agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] ) -> None: execution_order.append("decorator_type_match_agent") await next(context) @@ -1346,7 +1320,7 @@ async def test_only_type_specified(self, chat_client_base: "MockBaseChatClient") execution_order: list[str] = [] # No decorator - async def type_only_agent(context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def type_only_agent(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("type_only_agent") await next(context) @@ -1440,16 +1414,14 @@ async def test_function_middleware(context: Any, next: Any) -> None: class TestChatAgentThreadBehavior: - """Test cases for thread behavior in AgentRunContext across multiple runs.""" + """Test cases for thread behavior in AgentContext across multiple runs.""" - async def test_agent_run_context_thread_behavior_across_multiple_runs(self, chat_client: "MockChatClient") -> None: - """Test that AgentRunContext.thread property behaves correctly across multiple agent runs.""" + async def test_agent_context_thread_behavior_across_multiple_runs(self, chat_client: "MockChatClient") -> None: + """Test that AgentContext.thread property behaves correctly across multiple agent runs.""" thread_states: list[dict[str, Any]] = [] class ThreadTrackingMiddleware(AgentMiddleware): - async def process( - self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture state before next() call thread_messages = [] if context.thread and context.thread.message_store: @@ -1804,9 +1776,7 @@ async def test_combined_middleware(self) -> None: """Test ChatAgent with combined middleware types.""" execution_order: list[str] = [] - async def agent_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def agent_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: execution_order.append("agent_middleware_before") await next(context) execution_order.append("agent_middleware_after") @@ -1844,9 +1814,7 @@ async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> N modified_kwargs: dict[str, Any] = {} @agent_middleware - async def kwargs_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] - ) -> None: + async def kwargs_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Capture the original kwargs captured_kwargs.update(context.kwargs) @@ -1897,7 +1865,7 @@ async def kwargs_middleware( # class TrackingMiddleware(AgentMiddleware): # async def process( -# self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] +# self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]] # ) -> None: # execution_order.append("before") # await next(context) diff --git a/python/packages/purview/agent_framework_purview/_middleware.py b/python/packages/purview/agent_framework_purview/_middleware.py index 2aabd5a57b..dba7a3f649 100644 --- a/python/packages/purview/agent_framework_purview/_middleware.py +++ b/python/packages/purview/agent_framework_purview/_middleware.py @@ -2,7 +2,7 @@ from collections.abc import Awaitable, Callable -from agent_framework import AgentMiddleware, AgentRunContext, ChatContext, ChatMiddleware, MiddlewareTermination +from agent_framework import AgentContext, AgentMiddleware, ChatContext, ChatMiddleware, MiddlewareTermination from agent_framework._logging import get_logger from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential @@ -47,8 +47,8 @@ def __init__( async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: # type: ignore[override] resolved_user_id: str | None = None try: diff --git a/python/packages/purview/tests/test_middleware.py b/python/packages/purview/tests/test_middleware.py index 7c9edacd1a..b0aadd8cd5 100644 --- a/python/packages/purview/tests/test_middleware.py +++ b/python/packages/purview/tests/test_middleware.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agent_framework import AgentResponse, AgentRunContext, ChatMessage, MiddlewareTermination +from agent_framework import AgentContext, AgentResponse, ChatMessage, MiddlewareTermination from azure.core.credentials import AccessToken from agent_framework_purview import PurviewPolicyMiddleware, PurviewSettings @@ -49,12 +49,12 @@ async def test_middleware_allows_clean_prompt( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test middleware allows prompt that passes policy check.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello, how are you?")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello, how are you?")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): next_called = False - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: nonlocal next_called next_called = True ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="I'm good, thanks!")]) @@ -68,12 +68,12 @@ async def test_middleware_blocks_prompt_on_policy_violation( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test middleware blocks prompt that violates policy.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Sensitive information")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Sensitive information")]) with patch.object(middleware._processor, "process_messages", return_value=(True, "user-123")): next_called = False - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: nonlocal next_called next_called = True @@ -88,7 +88,7 @@ async def mock_next(ctx: AgentRunContext) -> None: async def test_middleware_checks_response(self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock) -> None: """Test middleware checks agent response for policy violations.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) call_count = 0 @@ -100,7 +100,7 @@ async def mock_process_messages(messages, activity, user_id=None): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse( messages=[ChatMessage(role="assistant", text="Here's some sensitive information")] ) @@ -120,11 +120,11 @@ async def test_middleware_handles_result_without_messages( # Set ignore_exceptions to True so AttributeError is caught and logged middleware._settings.ignore_exceptions = True - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = "Some non-standard result" await middleware.process(context, mock_next) @@ -137,11 +137,11 @@ async def test_middleware_processor_receives_correct_activity( """Test middleware passes correct activity type to processor.""" from agent_framework_purview._models import Activity - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_process: - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -154,12 +154,12 @@ async def test_middleware_streaming_skips_post_check( self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock ) -> None: """Test that streaming results skip post-check evaluation.""" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) context.stream = True with patch.object(middleware._processor, "process_messages", return_value=(False, "user-123")) as mock_proc: - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="streaming")]) await middleware.process(context, mock_next) @@ -172,7 +172,7 @@ async def test_middleware_payment_required_in_pre_check_raises_by_default( """Test that 402 in pre-check is raised when ignore_payment_required=False.""" from agent_framework_purview._exceptions import PurviewPaymentRequiredError - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) with patch.object( middleware._processor, @@ -180,7 +180,7 @@ async def test_middleware_payment_required_in_pre_check_raises_by_default( side_effect=PurviewPaymentRequiredError("Payment required"), ): - async def mock_next(_: AgentRunContext) -> None: + async def mock_next(_: AgentContext) -> None: raise AssertionError("next should not be called") with pytest.raises(PurviewPaymentRequiredError): @@ -192,7 +192,7 @@ async def test_middleware_payment_required_in_post_check_raises_by_default( """Test that 402 in post-check is raised when ignore_payment_required=False.""" from agent_framework_purview._exceptions import PurviewPaymentRequiredError - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) call_count = 0 @@ -205,7 +205,7 @@ async def side_effect(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="OK")]) with pytest.raises(PurviewPaymentRequiredError): @@ -217,7 +217,7 @@ async def test_middleware_post_check_exception_raises_when_ignore_exceptions_fal """Test that post-check exceptions are propagated when ignore_exceptions=False.""" middleware._settings.ignore_exceptions = False - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Hello")]) call_count = 0 @@ -230,7 +230,7 @@ async def side_effect(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=side_effect): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="OK")]) with pytest.raises(ValueError, match="Post-check blew up"): @@ -243,13 +243,13 @@ async def test_middleware_handles_pre_check_exception( # Set ignore_exceptions to True middleware._settings.ignore_exceptions = True - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) with patch.object( middleware._processor, "process_messages", side_effect=Exception("Pre-check error") ) as mock_process: - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -266,7 +266,7 @@ async def test_middleware_handles_post_check_exception( # Set ignore_exceptions to True middleware._settings.ignore_exceptions = True - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) call_count = 0 @@ -279,7 +279,7 @@ async def mock_process_messages(*args, **kwargs): with patch.object(middleware._processor, "process_messages", side_effect=mock_process_messages): - async def mock_next(ctx: AgentRunContext) -> None: + async def mock_next(ctx: AgentContext) -> None: ctx.result = AgentResponse(messages=[ChatMessage(role="assistant", text="Response")]) await middleware.process(context, mock_next) @@ -297,7 +297,7 @@ async def test_middleware_with_ignore_exceptions_true(self, mock_credential: Asy mock_agent = MagicMock() mock_agent.name = "test-agent" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) # Mock processor to raise an exception async def mock_process_messages(*args, **kwargs): @@ -321,7 +321,7 @@ async def test_middleware_with_ignore_exceptions_false(self, mock_credential: As mock_agent = MagicMock() mock_agent.name = "test-agent" - context = AgentRunContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) + context = AgentContext(agent=mock_agent, messages=[ChatMessage(role="user", text="Test")]) # Mock processor to raise an exception async def mock_process_messages(*args, **kwargs): diff --git a/python/samples/concepts/tools/README.md b/python/samples/concepts/tools/README.md index 3a270b25aa..04b7c04569 100644 --- a/python/samples/concepts/tools/README.md +++ b/python/samples/concepts/tools/README.md @@ -34,7 +34,7 @@ sequenceDiagram Note over Agent,AML: Agent Middleware Layer Agent->>AML: run() with middleware param AML->>AML: categorize_middleware() → split by type - AML->>AMP: execute(AgentRunContext) + AML->>AMP: execute(AgentContext) loop Agent Middleware Chain AMP->>AMP: middleware[i].process(context, next) @@ -127,7 +127,7 @@ sequenceDiagram **Entry Point:** `Agent.run(messages, thread, options, middleware)` -**Context Object:** `AgentRunContext` +**Context Object:** `AgentContext` | Field | Type | Description | |-------|------|-------------| diff --git a/python/samples/getting_started/middleware/README.md b/python/samples/getting_started/middleware/README.md index 3d1bd61d27..659e81647a 100644 --- a/python/samples/getting_started/middleware/README.md +++ b/python/samples/getting_started/middleware/README.md @@ -13,7 +13,7 @@ This folder contains examples demonstrating various middleware patterns with the | [`exception_handling_with_middleware.py`](exception_handling_with_middleware.py) | Demonstrates how to use middleware for centralized exception handling in function calls. Shows how to catch exceptions from functions, provide graceful error responses, and override function results when errors occur to provide user-friendly messages. | | [`override_result_with_middleware.py`](override_result_with_middleware.py) | Shows how to use middleware to intercept and modify function results after execution, supporting both regular and streaming agent responses. Demonstrates result filtering, formatting, enhancement, and custom streaming response generation. | | [`shared_state_middleware.py`](shared_state_middleware.py) | Demonstrates how to implement function-based middleware within a class to share state between multiple middleware functions. Shows how middleware can work together by sharing state, including call counting and result enhancement. | -| [`thread_behavior_middleware.py`](thread_behavior_middleware.py) | Demonstrates how middleware can access and track thread state across multiple agent runs. Shows how `AgentRunContext.thread` behaves differently before and after the `next()` call, how conversation history accumulates in threads, and timing of thread message updates. Essential for understanding conversation flow in middleware. | +| [`thread_behavior_middleware.py`](thread_behavior_middleware.py) | Demonstrates how middleware can access and track thread state across multiple agent runs. Shows how `AgentContext.thread` behaves differently before and after the `next()` call, how conversation history accumulates in threads, and timing of thread message updates. Essential for understanding conversation flow in middleware. | | [`agent_and_run_level_middleware.py`](agent_and_run_level_middleware.py) | Explains the difference between agent-level middleware (applied to ALL runs of the agent) and run-level middleware (applied to specific runs only). Shows security validation, performance monitoring, and context-specific middleware patterns. | | [`chat_middleware.py`](chat_middleware.py) | Demonstrates how to use chat middleware to observe and override inputs sent to AI models. Shows how to intercept chat requests, log and modify input messages, and override entire responses before they reach the underlying AI service. | diff --git a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py index 32fd7a2e52..c90dd1936b 100644 --- a/python/samples/getting_started/middleware/agent_and_run_level_middleware.py +++ b/python/samples/getting_started/middleware/agent_and_run_level_middleware.py @@ -7,9 +7,9 @@ from typing import Annotated from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponse, - AgentRunContext, FunctionInvocationContext, tool, ) @@ -49,7 +49,7 @@ def get_weather( class SecurityAgentMiddleware(AgentMiddleware): """Agent-level security middleware that validates all requests.""" - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: print("[SecurityMiddleware] Checking security for all requests...") # Check for security violations in the last user message @@ -66,8 +66,8 @@ async def process(self, context: AgentRunContext, next: Callable[[AgentRunContex async def performance_monitor_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Agent-level performance monitoring for all runs.""" print("[PerformanceMonitor] Starting performance monitoring...") @@ -85,7 +85,7 @@ async def performance_monitor_middleware( class HighPriorityMiddleware(AgentMiddleware): """Run-level middleware for high priority requests.""" - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: print("[HighPriority] Processing high priority request with expedited handling...") # Read metadata set by agent-level middleware @@ -101,8 +101,8 @@ async def process(self, context: AgentRunContext, next: Callable[[AgentRunContex async def debugging_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Run-level debugging middleware for troubleshooting specific runs.""" print("[Debug] Debug mode enabled for this run") @@ -126,7 +126,7 @@ class CachingMiddleware(AgentMiddleware): def __init__(self) -> None: self.cache: dict[str, AgentResponse] = {} - async def process(self, context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]]) -> None: + async def process(self, context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: # Create a simple cache key from the last message last_message = context.messages[-1] if context.messages else None cache_key: str = last_message.text if last_message and last_message.text else "no_message" diff --git a/python/samples/getting_started/middleware/class_based_middleware.py b/python/samples/getting_started/middleware/class_based_middleware.py index 65fa279f19..727c0a2821 100644 --- a/python/samples/getting_started/middleware/class_based_middleware.py +++ b/python/samples/getting_started/middleware/class_based_middleware.py @@ -7,9 +7,9 @@ from typing import Annotated from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponse, - AgentRunContext, ChatMessage, FunctionInvocationContext, FunctionMiddleware, @@ -49,8 +49,8 @@ class SecurityAgentMiddleware(AgentMiddleware): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: # Check for potential security violations in the query # Look at the last user message @@ -61,9 +61,7 @@ async def process( print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.") # Override the result with warning message context.result = AgentResponse( - messages=[ - ChatMessage("assistant", ["Detected sensitive information, the request is blocked."]) - ] + messages=[ChatMessage("assistant", ["Detected sensitive information, the request is blocked."])] ) # Simply don't call next() to prevent execution return diff --git a/python/samples/getting_started/middleware/decorator_middleware.py b/python/samples/getting_started/middleware/decorator_middleware.py index f16407918c..2ea1196bc3 100644 --- a/python/samples/getting_started/middleware/decorator_middleware.py +++ b/python/samples/getting_started/middleware/decorator_middleware.py @@ -20,7 +20,7 @@ The framework supports the following middleware detection scenarios: 1. Both decorator and parameter type specified: - - Validates that they match (e.g., @agent_middleware with AgentRunContext) + - Validates that they match (e.g., @agent_middleware with AgentContext) - Throws exception if they don't match for safety 2. Only decorator specified: @@ -28,7 +28,7 @@ - No type annotations needed - framework handles context types automatically 3. Only parameter type specified: - - Uses type annotations (AgentRunContext, FunctionInvocationContext) for detection + - Uses type annotations (AgentContext, FunctionInvocationContext) for detection 4. Neither decorator nor parameter type specified: - Throws exception requiring either decorator or type annotation diff --git a/python/samples/getting_started/middleware/function_based_middleware.py b/python/samples/getting_started/middleware/function_based_middleware.py index 21defef491..1616aa5fc3 100644 --- a/python/samples/getting_started/middleware/function_based_middleware.py +++ b/python/samples/getting_started/middleware/function_based_middleware.py @@ -7,7 +7,7 @@ from typing import Annotated from agent_framework import ( - AgentRunContext, + AgentContext, FunctionInvocationContext, tool, ) @@ -42,8 +42,8 @@ def get_weather( async def security_agent_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """Agent middleware that checks for security violations.""" # Check for potential security violations in the query diff --git a/python/samples/getting_started/middleware/middleware_termination.py b/python/samples/getting_started/middleware/middleware_termination.py index ea32bc606b..69fa5766d9 100644 --- a/python/samples/getting_started/middleware/middleware_termination.py +++ b/python/samples/getting_started/middleware/middleware_termination.py @@ -6,9 +6,9 @@ from typing import Annotated from agent_framework import ( + AgentContext, AgentMiddleware, AgentResponse, - AgentRunContext, ChatMessage, tool, ) @@ -47,8 +47,8 @@ def __init__(self, blocked_words: list[str]): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: # Check if the user message contains any blocked words last_message = context.messages[-1] if context.messages else None @@ -87,8 +87,8 @@ def __init__(self, max_responses: int = 1): async def process( self, - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: print(f"[PostTerminationMiddleware] Processing request (response count: {self.response_count})") diff --git a/python/samples/getting_started/middleware/override_result_with_middleware.py b/python/samples/getting_started/middleware/override_result_with_middleware.py index 06351d1803..8aef8f8e3b 100644 --- a/python/samples/getting_started/middleware/override_result_with_middleware.py +++ b/python/samples/getting_started/middleware/override_result_with_middleware.py @@ -7,9 +7,9 @@ from typing import Annotated from agent_framework import ( + AgentContext, AgentResponse, AgentResponseUpdate, - AgentRunContext, ChatContext, ChatMessage, ChatResponse, @@ -104,9 +104,7 @@ def _append_validation_note(response: ChatResponse) -> ChatResponse: context.result.messages.append(ChatMessage(role=Role.ASSISTANT, text=validation_note)) -async def agent_cleanup_middleware( - context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] -) -> None: +async def agent_cleanup_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None: """Agent middleware that validates chat middleware effects and cleans the result.""" await next(context) diff --git a/python/samples/getting_started/middleware/thread_behavior_middleware.py b/python/samples/getting_started/middleware/thread_behavior_middleware.py index 93f72d567a..0665d23720 100644 --- a/python/samples/getting_started/middleware/thread_behavior_middleware.py +++ b/python/samples/getting_started/middleware/thread_behavior_middleware.py @@ -5,7 +5,7 @@ from typing import Annotated from agent_framework import ( - AgentRunContext, + AgentContext, ChatMessageStore, tool, ) @@ -19,7 +19,7 @@ This sample demonstrates how middleware can access and track thread state across multiple agent runs. The example shows: -- How AgentRunContext.thread property behaves across multiple runs +- How AgentContext.thread property behaves across multiple runs - How middleware can access conversation history through the thread - The timing of when thread messages are populated (before vs after next() call) - How to track thread state changes across runs @@ -45,8 +45,8 @@ def get_weather( async def thread_tracking_middleware( - context: AgentRunContext, - next: Callable[[AgentRunContext], Awaitable[None]], + context: AgentContext, + next: Callable[[AgentContext], Awaitable[None]], ) -> None: """MiddlewareTypes that tracks and logs thread behavior across runs.""" thread_messages = [] From 23dbe1ff476ebafc9b365868b76f05197881e135 Mon Sep 17 00:00:00 2001 From: Dmytro Struk <13853051+dmytrostruk@users.noreply.github.com> Date: Thu, 5 Feb 2026 16:24:34 -0800 Subject: [PATCH 2/2] Update python/packages/core/AGENTS.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- python/packages/core/AGENTS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/packages/core/AGENTS.md b/python/packages/core/AGENTS.md index 440bf1b3d0..a41f5ed42f 100644 --- a/python/packages/core/AGENTS.md +++ b/python/packages/core/AGENTS.md @@ -117,7 +117,7 @@ agent = OpenAIChatClient().as_agent( from agent_framework import ChatAgent, AgentMiddleware, AgentContext class LoggingMiddleware(AgentMiddleware): - async def invoke(self, context: AgentContext, next) -> AgentResponse: + async def process(self, context: AgentContext, next) -> AgentResponse: print(f"Input: {context.messages}") response = await next(context) print(f"Output: {response}")