Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,14 +1024,19 @@ async def _prepare_run_context(
await self._async_exit_stack.enter_async_context(mcp_server)
final_tools.extend(mcp_server.functions)

# Merge runtime kwargs into additional_function_arguments so they're available
# in function middleware context and tool invocation.
existing_additional_args = opts.pop("additional_function_arguments", None) or {}
additional_function_arguments = {**kwargs, **existing_additional_args}

# Build options dict from run() options merged with provided options
run_opts: dict[str, Any] = {
"model_id": opts.pop("model_id", None),
"conversation_id": active_session.service_session_id
if active_session
else opts.pop("conversation_id", None),
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
"additional_function_arguments": opts.pop("additional_function_arguments", None),
"additional_function_arguments": additional_function_arguments or None,
"frequency_penalty": opts.pop("frequency_penalty", None),
"logit_bias": opts.pop("logit_bias", None),
"max_tokens": opts.pop("max_tokens", None),
Expand Down
7 changes: 1 addition & 6 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2871,12 +2871,7 @@ class _ChatOptionsBase(TypedDict, total=False):
presence_penalty: float

# Tool configuration (forward reference to avoid circular import)
tools: (
ToolTypes
| Callable[..., Any]
| Sequence[ToolTypes | Callable[..., Any]]
| None
)
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None
tool_choice: ToolMode | Literal["auto", "required", "none"]
allow_multiple_tool_calls: bool

Expand Down
173 changes: 173 additions & 0 deletions python/packages/core/tests/core/test_middleware_with_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,179 @@ async def kwargs_middleware(
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value"

async def test_run_kwargs_available_in_function_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that kwargs passed directly to agent.run() appear in FunctionInvocationContext.kwargs,
including complex nested values like dicts."""
captured_kwargs: dict[str, Any] = {}

@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()

chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]

agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])

session_metadata = {"tenant": "acme-corp", "region": "us-west"}
await agent.run(
[Message(role="user", text="Get weather")],
user_id="user-456",
session_metadata=session_metadata,
)

assert "user_id" in captured_kwargs, f"Expected 'user_id' in kwargs: {captured_kwargs}"
assert captured_kwargs["user_id"] == "user-456"
assert captured_kwargs["session_metadata"] == {"tenant": "acme-corp", "region": "us-west"}

async def test_run_kwargs_merged_with_additional_function_arguments(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that explicit additional_function_arguments in options take precedence over run kwargs."""
captured_kwargs: dict[str, Any] = {}

@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()

chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]

agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])

await agent.run(
[Message(role="user", text="Get weather")],
# This kwarg should be overridden by additional_function_arguments
user_id="from-kwargs",
tenant_id="from-kwargs",
options={
"additional_function_arguments": {
"user_id": "from-options",
"extra_key": "only-in-options",
}
},
)

# additional_function_arguments takes precedence for overlapping keys
assert captured_kwargs["user_id"] == "from-options"
# Non-overlapping kwargs from run() still come through
assert captured_kwargs["tenant_id"] == "from-kwargs"
# Keys only in additional_function_arguments are present
assert captured_kwargs["extra_key"] == "only-in-options"

async def test_run_kwargs_consistent_across_multiple_tool_calls(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that kwargs are consistent across multiple tool invocations in a single run."""
invocation_kwargs: list[dict[str, Any]] = []

@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
invocation_kwargs.append(dict(context.kwargs))
await call_next()

chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
),
Content.from_function_call(
call_id="call_2", name="sample_tool_function", arguments='{"location": "Portland"}'
),
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]

agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])

await agent.run(
[Message(role="user", text="Get weather for both cities")],
user_id="user-456",
request_id="req-001",
)

assert len(invocation_kwargs) == 2
for kw in invocation_kwargs:
assert kw["user_id"] == "user-456"
assert kw["request_id"] == "req-001"

async def test_run_without_kwargs_produces_empty_context_kwargs(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that when no kwargs are passed to run(), FunctionInvocationContext.kwargs is empty."""
captured_kwargs: dict[str, Any] = {}

@function_middleware
async def capture_middleware(
context: FunctionInvocationContext, call_next: Callable[[], Awaitable[None]]
) -> None:
captured_kwargs.update(context.kwargs)
await call_next()

chat_client_base.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="sample_tool_function", arguments='{"location": "Seattle"}'
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]

agent = Agent(client=chat_client_base, middleware=[capture_middleware], tools=[sample_tool_function])

await agent.run([Message(role="user", text="Get weather")])

# No runtime kwargs should be present
assert "user_id" not in captured_kwargs


class TestMiddlewareDynamicRebuild:
"""Test cases for dynamic middleware pipeline rebuilding with Agent."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ async def exception_handling_middleware(
print(f"[ExceptionHandlingMiddleware] Caught TimeoutError: {e}")
# Override function result to provide custom message in response.
context.result = (
"Request Timeout: The data service is taking longer than expected to respond.",
"Respond with message - 'Sorry for the inconvenience, please try again later.'",
"Request Timeout: The data service is taking longer than expected to respond."
"Respond with message - 'Sorry for the inconvenience, please try again later.'"
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import re
from collections.abc import Awaitable, Callable
from collections.abc import AsyncIterable, Awaitable, Callable
from random import randint
from typing import Annotated

Expand All @@ -13,9 +13,9 @@
ChatContext,
ChatResponse,
ChatResponseUpdate,
Content,
Message,
ResponseStream,
Role,
tool,
)
from agent_framework.openai import OpenAIResponsesClient
Expand Down Expand Up @@ -66,22 +66,20 @@ async def weather_override_middleware(context: ChatContext, call_next: Callable[
]

if context.stream and isinstance(context.result, ResponseStream):
index = {"value": 0}

def _update_hook(update: ChatResponseUpdate) -> ChatResponseUpdate:
for content in update.contents or []:
if not content.text:
continue
content.text = f"Weather Advisory: [{index['value']}] {content.text}"
index["value"] += 1
return update
async def _override_stream() -> AsyncIterable[ChatResponseUpdate]:
for i, chunk_text in enumerate(chunks):
yield ChatResponseUpdate(
contents=[Content.from_text(text=f"Weather Advisory: [{i}] {chunk_text}")],
role="assistant",
)

context.result.with_transform_hook(_update_hook)
context.result = ResponseStream(_override_stream())
else:
# For non-streaming: just replace with a new message
current_text = context.result.text if isinstance(context.result, ChatResponse) else ""
custom_message = f"Weather Advisory: [0] {''.join(chunks)} Original message was: {current_text}"
context.result = ChatResponse(messages=[Message(role=Role.ASSISTANT, text=custom_message)])
context.result = ChatResponse(messages=[Message(role="assistant", text=custom_message)])


async def validate_weather_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
Expand All @@ -96,12 +94,12 @@ async def validate_weather_middleware(context: ChatContext, call_next: Callable[
if context.stream and isinstance(context.result, ResponseStream):

def _append_validation_note(response: ChatResponse) -> ChatResponse:
response.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
response.messages.append(Message(role="assistant", text=validation_note))
return response

context.result.with_finalizer(_append_validation_note)
elif isinstance(context.result, ChatResponse):
context.result.messages.append(Message(role=Role.ASSISTANT, text=validation_note))
context.result.messages.append(Message(role="assistant", text=validation_note))


async def agent_cleanup_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
Expand Down Expand Up @@ -154,7 +152,7 @@ def _sanitize(response: AgentResponse) -> AgentResponse:
if not found_validation:
raise RuntimeError("Expected validation note not found in agent response.")

cleaned_messages.append(Message(role=Role.ASSISTANT, text=" Agent: OK"))
cleaned_messages.append(Message(role="assistant", text=" Agent: OK"))
response.messages = cleaned_messages
return response

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

def create_sample_image() -> str:
"""Create a simple 1x1 pixel PNG image for testing."""
# This is a tiny red pixel in PNG format
# This is a tiny yellow pixel in PNG format
png_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
return f"data:image/png;base64,{png_data}"

Expand All @@ -32,7 +32,7 @@ async def test_image() -> None:
],
)

response = await client.get_response(message)
response = await client.get_response([message])
print(f"Image Response: {response}")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def load_sample_pdf() -> bytes:

def create_sample_image() -> str:
"""Create a simple 1x1 pixel PNG image for testing."""
# This is a tiny red pixel in PNG format
# This is a tiny yellow pixel in PNG format
png_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
return f"data:image/png;base64,{png_data}"

Expand All @@ -41,7 +41,7 @@ async def test_image() -> None:
],
)

response = await client.get_response(message)
response = await client.get_response([message])
print(f"Image Response: {response}")


Expand All @@ -62,7 +62,7 @@ async def test_pdf() -> None:
],
)

response = await client.get_response(message)
response = await client.get_response([message])
print(f"PDF Response: {response}")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def load_sample_pdf() -> bytes:

def create_sample_image() -> str:
"""Create a simple 1x1 pixel PNG image for testing."""
# This is a tiny red pixel in PNG format
# This is a tiny yellow pixel in PNG format
png_data = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8/5+hHgAHggJ/PchI7wAAAABJRU5ErkJggg=="
return f"data:image/png;base64,{png_data}"

Expand Down Expand Up @@ -53,7 +53,7 @@ async def test_image() -> None:
],
)

response = await client.get_response(message)
response = await client.get_response([message])
print(f"Image Response: {response}")


Expand All @@ -70,7 +70,7 @@ async def test_audio() -> None:
],
)

response = await client.get_response(message)
response = await client.get_response([message])
print(f"Audio Response: {response}")


Expand All @@ -89,7 +89,7 @@ async def test_pdf() -> None:
],
)

response = await client.get_response(message)
response = await client.get_response([message])
print(f"PDF Response: {response}")


Expand Down