Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/packages/core/AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ agent = OpenAIChatClient().as_agent(
from agent_framework import ChatAgent, AgentMiddleware, AgentContext

class LoggingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, next) -> AgentResponse:
async def process(self, context: AgentContext, call_next) -> None:
print(f"Input: {context.messages}")
response = await next(context)
print(f"Output: {response}")
return response
await call_next(context)
print(f"Output: {context.result}")

agent = ChatAgent(..., middleware=[LoggingMiddleware()])
```
Expand Down
74 changes: 37 additions & 37 deletions python/packages/core/agent_framework/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class AgentContext:
options: The options for the agent invocation as a dict.
stream: Whether this is a streaming invocation.
metadata: Metadata dictionary for sharing data between agent middleware.
result: Agent execution result. Can be observed after calling ``next()``
result: Agent execution result. Can be observed after calling ``call_next()``
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be AgentResponse.
For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse].
Expand All @@ -135,7 +135,7 @@ class AgentContext:


class LoggingMiddleware(AgentMiddleware):
async def process(self, context: AgentContext, next):
async def process(self, context: AgentContext, call_next):
print(f"Agent: {context.agent.name}")
print(f"Messages: {len(context.messages)}")
print(f"Thread: {context.thread}")
Expand All @@ -145,7 +145,7 @@ async def process(self, context: AgentContext, next):
context.metadata["start_time"] = time.time()

# Continue execution
await next(context)
await call_next(context)

# Access result after execution
print(f"Result: {context.result}")
Expand Down Expand Up @@ -208,7 +208,7 @@ class FunctionInvocationContext:
function: The function being invoked.
arguments: The validated arguments for the function.
metadata: Metadata dictionary for sharing data between function middleware.
result: Function execution result. Can be observed after calling ``next()``
result: Function execution result. Can be observed after calling ``call_next()``
to see the actual execution result or can be set to override the execution result.

kwargs: Additional keyword arguments passed to the chat method that invoked this function.
Expand All @@ -220,7 +220,7 @@ class FunctionInvocationContext:


class ValidationMiddleware(FunctionMiddleware):
async def process(self, context: FunctionInvocationContext, next):
async def process(self, context: FunctionInvocationContext, call_next):
print(f"Function: {context.function.name}")
print(f"Arguments: {context.arguments}")

Expand All @@ -229,7 +229,7 @@ async def process(self, context: FunctionInvocationContext, next):
raise MiddlewareTermination("Validation failed")

# Continue execution
await next(context)
await call_next(context)
"""

def __init__(
Expand Down Expand Up @@ -268,7 +268,7 @@ class ChatContext:
options: The options for the chat request as a dict.
stream: Whether this is a streaming invocation.
metadata: Metadata dictionary for sharing data between chat middleware.
result: Chat execution result. Can be observed after calling ``next()``
result: Chat execution result. Can be observed after calling ``call_next()``
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be ChatResponse.
For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse].
Expand All @@ -284,7 +284,7 @@ class ChatContext:


class TokenCounterMiddleware(ChatMiddleware):
async def process(self, context: ChatContext, next):
async def process(self, context: ChatContext, call_next):
print(f"Chat client: {context.chat_client.__class__.__name__}")
print(f"Messages: {len(context.messages)}")
print(f"Model: {context.options.get('model_id')}")
Expand All @@ -293,7 +293,7 @@ async def process(self, context: ChatContext, next):
context.metadata["input_tokens"] = self.count_tokens(context.messages)

# Continue execution
await next(context)
await call_next(context)

# Access result and count output tokens
if context.result:
Expand Down Expand Up @@ -363,9 +363,9 @@ class RetryMiddleware(AgentMiddleware):
def __init__(self, max_retries: int = 3):
self.max_retries = max_retries

async def process(self, context: AgentContext, next):
async def process(self, context: AgentContext, call_next):
for attempt in range(self.max_retries):
await next(context)
await call_next(context)
if context.result and not context.result.is_error:
break
print(f"Retry {attempt + 1}/{self.max_retries}")
Expand All @@ -379,24 +379,24 @@ async def process(self, context: AgentContext, next):
async def process(
self,
context: AgentContext,
next: Callable[[AgentContext], Awaitable[None]],
call_next: Callable[[AgentContext], Awaitable[None]],
) -> None:
"""Process an agent invocation.

Args:
context: Agent invocation context containing agent, messages, and metadata.
Use context.stream to determine if this is a streaming call.
MiddlewareTypes can set context.result to override execution, or observe
the actual execution result after calling next().
the actual execution result after calling call_next().
For non-streaming: AgentResponse
For streaming: AsyncIterable[AgentResponseUpdate]
next: Function to call the next middleware or final agent execution.
call_next: Function to call the next middleware or final agent execution.
Does not return anything - all data flows through the context.

Note:
MiddlewareTypes should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
or observe context.result after calling call_next() for actual results.
"""
...

Expand All @@ -422,7 +422,7 @@ class CachingMiddleware(FunctionMiddleware):
def __init__(self):
self.cache = {}

async def process(self, context: FunctionInvocationContext, next):
async def process(self, context: FunctionInvocationContext, call_next):
cache_key = f"{context.function.name}:{context.arguments}"

# Check cache
Expand All @@ -431,7 +431,7 @@ async def process(self, context: FunctionInvocationContext, next):
raise MiddlewareTermination()

# Execute function
await next(context)
await call_next(context)

# Cache result
if context.result:
Expand All @@ -446,21 +446,21 @@ async def process(self, context: FunctionInvocationContext, next):
async def process(
self,
context: FunctionInvocationContext,
next: Callable[[FunctionInvocationContext], Awaitable[None]],
call_next: Callable[[FunctionInvocationContext], Awaitable[None]],
) -> None:
"""Process a function invocation.

Args:
context: Function invocation context containing function, arguments, and metadata.
MiddlewareTypes can set context.result to override execution, or observe
the actual execution result after calling next().
next: Function to call the next middleware or final function execution.
the actual execution result after calling call_next().
call_next: Function to call the next middleware or final function execution.
Does not return anything - all data flows through the context.

Note:
MiddlewareTypes should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
or observe context.result after calling call_next() for actual results.
"""
...

Expand All @@ -486,14 +486,14 @@ class SystemPromptMiddleware(ChatMiddleware):
def __init__(self, system_prompt: str):
self.system_prompt = system_prompt

async def process(self, context: ChatContext, next):
async def process(self, context: ChatContext, call_next):
# Add system prompt to messages
from agent_framework import ChatMessage

context.messages.insert(0, ChatMessage(role="system", text=self.system_prompt))

# Continue execution
await next(context)
await call_next(context)


# Use with an agent
Expand All @@ -508,24 +508,24 @@ async def process(self, context: ChatContext, next):
async def process(
self,
context: ChatContext,
next: Callable[[ChatContext], Awaitable[None]],
call_next: Callable[[ChatContext], Awaitable[None]],
) -> None:
"""Process a chat client request.

Args:
context: Chat invocation context containing chat client, messages, options, and metadata.
Use context.stream to determine if this is a streaming call.
MiddlewareTypes can set context.result to override execution, or observe
the actual execution result after calling next().
the actual execution result after calling call_next().
For non-streaming: ChatResponse
For streaming: ResponseStream[ChatResponseUpdate, ChatResponse]
next: Function to call the next middleware or final chat execution.
call_next: Function to call the next middleware or final chat execution.
Does not return anything - all data flows through the context.

Note:
MiddlewareTypes should not return anything. All data manipulation should happen
within the context object. Set context.result to override execution,
or observe context.result after calling next() for actual results.
or observe context.result after calling call_next() for actual results.
"""
...

Expand Down Expand Up @@ -576,9 +576,9 @@ def agent_middleware(func: AgentMiddlewareCallable) -> AgentMiddlewareCallable:


@agent_middleware
async def logging_middleware(context: AgentContext, next):
async def logging_middleware(context: AgentContext, call_next):
print(f"Before: {context.agent.name}")
await next(context)
await call_next(context)
print(f"After: {context.result}")


Expand Down Expand Up @@ -609,9 +609,9 @@ def function_middleware(func: FunctionMiddlewareCallable) -> FunctionMiddlewareC


@function_middleware
async def logging_middleware(context: FunctionInvocationContext, next):
async def logging_middleware(context: FunctionInvocationContext, call_next):
print(f"Calling: {context.function.name}")
await next(context)
await call_next(context)
print(f"Result: {context.result}")


Expand Down Expand Up @@ -642,9 +642,9 @@ def chat_middleware(func: ChatMiddlewareCallable) -> ChatMiddlewareCallable:


@chat_middleware
async def logging_middleware(context: ChatContext, next):
async def logging_middleware(context: ChatContext, call_next):
print(f"Messages: {len(context.messages)}")
await next(context)
await call_next(context)
print(f"Response: {context.result}")


Expand All @@ -669,8 +669,8 @@ class MiddlewareWrapper(Generic[TContext]):
def __init__(self, func: Callable[[TContext, Callable[[TContext], Awaitable[None]]], Awaitable[None]]) -> None:
self.func = func

async def process(self, context: TContext, next: Callable[[TContext], Awaitable[None]]) -> None:
await self.func(context, next)
async def process(self, context: TContext, call_next: Callable[[TContext], Awaitable[None]]) -> None:
await self.func(context, call_next)


class BaseMiddlewarePipeline(ABC):
Expand Down Expand Up @@ -1226,7 +1226,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType:
sig = inspect.signature(middleware)
params = list(sig.parameters.values())

# Must have at least 2 parameters (context and next)
# Must have at least 2 parameters (context and call_next)
if len(params) >= 2:
first_param = params[0]
if hasattr(first_param.annotation, "__name__"):
Expand All @@ -1240,7 +1240,7 @@ def _determine_middleware_type(middleware: Any) -> MiddlewareType:
else:
# Not enough parameters - can't be valid middleware
raise MiddlewareException(
f"MiddlewareTypes function must have at least 2 parameters (context, next), "
f"Middleware function must have at least 2 parameters (context, call_next), "
f"but {middleware.__name__} has {len(params)}"
)
except Exception as e:
Expand Down
42 changes: 28 additions & 14 deletions python/packages/core/tests/core/test_as_tool_kwargs_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ async def test_as_tool_forwards_runtime_kwargs(self, chat_client: MockChatClient
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
# Capture kwargs passed to the sub-agent
captured_kwargs.update(context.kwargs)
await next(context)
await call_next(context)

# Setup mock response
chat_client.responses = [
Expand Down Expand Up @@ -60,9 +62,11 @@ async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, chat_client
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await next(context)
await call_next(context)

# Setup mock response
chat_client.responses = [
Expand Down Expand Up @@ -95,10 +99,12 @@ async def test_as_tool_nested_delegation_propagates_kwargs(self, chat_client: Mo
captured_kwargs_list: list[dict[str, Any]] = []

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
# Capture kwargs at each level
captured_kwargs_list.append(dict(context.kwargs))
await next(context)
await call_next(context)

# Setup mock responses to trigger nested tool invocation: B calls tool C, then completes.
chat_client.responses = [
Expand Down Expand Up @@ -156,9 +162,11 @@ async def test_as_tool_streaming_mode_forwards_kwargs(self, chat_client: MockCha
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await next(context)
await call_next(context)

# Setup mock streaming responses
from agent_framework import ChatResponseUpdate
Expand Down Expand Up @@ -216,9 +224,11 @@ async def test_as_tool_kwargs_with_chat_options(self, chat_client: MockChatClien
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await next(context)
await call_next(context)

# Setup mock response
chat_client.responses = [
Expand Down Expand Up @@ -256,14 +266,16 @@ async def test_as_tool_kwargs_isolated_per_invocation(self, chat_client: MockCha
call_count = 0

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
nonlocal call_count
call_count += 1
if call_count == 1:
first_call_kwargs.update(context.kwargs)
elif call_count == 2:
second_call_kwargs.update(context.kwargs)
await next(context)
await call_next(context)

# Setup mock responses for both calls
chat_client.responses = [
Expand Down Expand Up @@ -306,9 +318,11 @@ async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, chat
captured_kwargs: dict[str, Any] = {}

@agent_middleware
async def capture_middleware(context: AgentContext, next: Callable[[AgentContext], Awaitable[None]]) -> None:
async def capture_middleware(
context: AgentContext, call_next: Callable[[AgentContext], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await next(context)
await call_next(context)

# Setup mock response
chat_client.responses = [
Expand Down
Loading
Loading