From e0a3f1f3907604d03ba0e7d8769e66414ff2cfdb Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Tue, 18 Nov 2025 16:12:13 +0900 Subject: [PATCH 1/2] as tool kwargs --- .../packages/core/agent_framework/_agents.py | 8 +- .../packages/core/agent_framework/_tools.py | 26 +- .../core/test_as_tool_kwargs_propagation.py | 318 ++++++++++++ python/packages/core/tests/core/test_tools.py | 19 + python/samples/README.md | 1 + .../middleware/runtime_context_delegation.py | 456 ++++++++++++++++++ 6 files changed, 823 insertions(+), 5 deletions(-) create mode 100644 python/packages/core/tests/core/test_as_tool_kwargs_propagation.py create mode 100644 python/samples/getting_started/middleware/runtime_context_delegation.py diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index e3ea1bdea6..a158ce8973 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -454,13 +454,16 @@ async def agent_wrapper(**kwargs: Any) -> str: # Extract the input from kwargs using the specified arg_name input_text = kwargs.get(arg_name, "") + # Forward all kwargs except the arg_name to support runtime context propagation + forwarded_kwargs = {k: v for k, v in kwargs.items() if k != arg_name} + if stream_callback is None: # Use non-streaming mode - return (await self.run(input_text)).text + return (await self.run(input_text, **forwarded_kwargs)).text # Use streaming mode - accumulate updates and create final response response_updates: list[AgentRunResponseUpdate] = [] - async for update in self.run_stream(input_text): + async for update in self.run_stream(input_text, **forwarded_kwargs): response_updates.append(update) if is_async_callback: await stream_callback(update) # type: ignore[misc] @@ -475,6 +478,7 @@ async def agent_wrapper(**kwargs: Any) -> str: description=tool_description, func=agent_wrapper, input_model=input_model, # type: ignore + forward_additional_kwargs=True, ) def _normalize_messages( diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 6edd258e15..b1188e051c 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -587,6 +587,7 @@ def __init__( additional_properties: dict[str, Any] | None = None, func: Callable[..., Awaitable[ReturnT] | ReturnT] | None = None, input_model: type[ArgsT] | Mapping[str, Any] | None = None, + forward_additional_kwargs: bool = False, **kwargs: Any, ) -> None: """Initialize the AIFunction. @@ -605,6 +606,8 @@ def __init__( input_model: The Pydantic model that defines the input parameters for the function. This can also be a JSON schema dictionary. If not provided, it will be inferred from the function signature. + forward_additional_kwargs: Whether to forward unknown kwargs directly to the wrapped function. + Used by agent tools created via ``as_tool`` so they can receive runtime context kwargs. **kwargs: Additional keyword arguments. """ super().__init__( @@ -626,6 +629,7 @@ def __init__( self.invocation_exception_count = 0 self._invocation_duration_histogram = _default_histogram() self.type: Literal["ai_function"] = "ai_function" + self.forward_additional_kwargs = forward_additional_kwargs @property def declaration_only(self) -> bool: @@ -691,10 +695,17 @@ async def invoke( from .observability import OBSERVABILITY_SETTINGS tool_call_id = kwargs.pop("tool_call_id", None) + provided_kwargs = dict(kwargs) if arguments is not None: if not isinstance(arguments, self.input_model): raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - kwargs = arguments.model_dump(exclude_none=True) + argument_values = arguments.model_dump(exclude_none=True) + if self.forward_additional_kwargs: + kwargs = {**provided_kwargs, **argument_values} + else: + kwargs = argument_values + else: + kwargs = provided_kwargs if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") logger.debug(f"Function arguments: {kwargs}") @@ -1228,8 +1239,15 @@ async def _auto_invoke_function( parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {}) + # Filter out internal framework kwargs before merging/passing to tools. + runtime_kwargs: dict[str, Any] = { + key: value + for key, value in (custom_args or {}).items() + if key not in {"_function_middleware_pipeline", "middleware", "chat_options", "tools"} + } + # Merge with user-supplied args; right-hand side dominates, so parsed args win on conflicts. - merged_args: dict[str, Any] = (custom_args or {}) | parsed_args + merged_args: dict[str, Any] = runtime_kwargs | parsed_args try: args = tool.input_model.model_validate(merged_args) except ValidationError as exc: @@ -1245,6 +1263,7 @@ async def _auto_invoke_function( function_result = await tool.invoke( arguments=args, tool_call_id=function_call_content.call_id, + **runtime_kwargs, ) # type: ignore[arg-type] return FunctionResultContent( call_id=function_call_content.call_id, @@ -1261,13 +1280,14 @@ async def _auto_invoke_function( middleware_context = FunctionInvocationContext( function=tool, arguments=args, - kwargs=custom_args or {}, + kwargs=runtime_kwargs.copy(), ) async def final_function_handler(context_obj: Any) -> Any: return await tool.invoke( arguments=context_obj.arguments, tool_call_id=function_call_content.call_id, + **context_obj.kwargs, ) try: diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py new file mode 100644 index 0000000000..bb07e04564 --- /dev/null +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -0,0 +1,318 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for kwargs propagation through as_tool() method.""" + +from collections.abc import Awaitable, Callable +from typing import Any + +from agent_framework import ChatAgent, ChatMessage, ChatResponse, FunctionCallContent, agent_middleware +from agent_framework._middleware import AgentRunContext + +from .conftest import MockChatClient + + +class TestAsToolKwargsPropagation: + """Test cases for kwargs propagation through as_tool() delegation.""" + + async def test_as_tool_forwards_runtime_kwargs(self, chat_client: MockChatClient) -> None: + """Test that runtime kwargs are forwarded through as_tool() to sub-agent.""" + captured_kwargs: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Capture kwargs passed to the sub-agent + print(f"Middleware captured kwargs: {context.kwargs}") + captured_kwargs.update(context.kwargs) + await next(context) + + # Setup mock response + chat_client.responses = [ + ChatResponse(messages=[ChatMessage(role="assistant", text="Response from sub-agent")]), + ] + + # Create sub-agent with middleware + sub_agent = ChatAgent( + chat_client=chat_client, + name="sub_agent", + middleware=[capture_middleware], + ) + + # Create tool from sub-agent + tool = sub_agent.as_tool(name="delegate", arg_name="task") + + # Directly invoke the tool with kwargs (simulating what happens during agent execution) + result = await tool.invoke( + arguments=tool.input_model(task="Test delegation"), + api_token="secret-xyz-123", + user_id="user-456", + session_id="session-789", + ) + print(f"Tool result: {result}") + print(f"Captured kwargs after invoke: {captured_kwargs}") + + # Verify kwargs were forwarded to sub-agent + assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}" + assert captured_kwargs["api_token"] == "secret-xyz-123" + assert "user_id" in captured_kwargs + assert captured_kwargs["user_id"] == "user-456" + assert "session_id" in captured_kwargs + assert captured_kwargs["session_id"] == "session-789" + + async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, chat_client: MockChatClient) -> None: + """Test that the arg_name parameter is not forwarded as a kwarg.""" + captured_kwargs: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + captured_kwargs.update(context.kwargs) + await next(context) + + # Setup mock response + chat_client.responses = [ + ChatResponse(messages=[ChatMessage(role="assistant", text="Response from sub-agent")]), + ] + + sub_agent = ChatAgent( + chat_client=chat_client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool(arg_name="custom_task") + + # Invoke tool with both the arg_name field and additional kwargs + await tool.invoke( + arguments=tool.input_model(custom_task="Test task"), + api_token="token-123", + custom_task="should_be_excluded", # This should be filtered out + ) + + # The arg_name ("custom_task") should NOT be in the forwarded kwargs + assert "custom_task" not in captured_kwargs + # But other kwargs should be present + assert "api_token" in captured_kwargs + assert captured_kwargs["api_token"] == "token-123" + + async def test_as_tool_nested_delegation_propagates_kwargs(self, chat_client: MockChatClient) -> None: + """Test that kwargs propagate through multiple levels of delegation (A → B → C).""" + captured_kwargs_list: list[dict[str, Any]] = [] + + @agent_middleware + async def capture_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + # Capture kwargs at each level + captured_kwargs_list.append(dict(context.kwargs)) + await next(context) + + # Setup mock responses to trigger nested tool invocation: B calls tool C, then completes. + chat_client.responses = [ + ChatResponse( + messages=[ + ChatMessage( + role="assistant", + contents=[ + FunctionCallContent( + call_id="call_c_1", + name="call_c", + arguments='{"task": "Please execute agent_c"}', + ) + ], + ) + ] + ), + ChatResponse(messages=[ChatMessage(role="assistant", text="Response from agent_c")]), + ChatResponse(messages=[ChatMessage(role="assistant", text="Response from agent_b")]), + ] + + # Create agent C (bottom level) + agent_c = ChatAgent( + chat_client=chat_client, + name="agent_c", + middleware=[capture_middleware], + ) + + # Create agent B (middle level) - delegates to C + agent_b = ChatAgent( + chat_client=chat_client, + name="agent_b", + tools=[agent_c.as_tool(name="call_c")], + middleware=[capture_middleware], + ) + + # Create tool from B for direct invocation + tool_b = agent_b.as_tool(name="call_b") + + # Invoke tool B with kwargs - should propagate to both B and C + await tool_b.invoke( + arguments=tool_b.input_model(task="Test cascade"), + trace_id="trace-abc-123", + tenant_id="tenant-xyz", + ) + + # Verify both levels received the kwargs + # We should have 2 captures: one from B, one from C + assert len(captured_kwargs_list) >= 2 + for kwargs_dict in captured_kwargs_list: + assert kwargs_dict.get("trace_id") == "trace-abc-123" + assert kwargs_dict.get("tenant_id") == "tenant-xyz" + + async def test_as_tool_streaming_mode_forwards_kwargs(self, chat_client: MockChatClient) -> None: + """Test that kwargs are forwarded in streaming mode.""" + captured_kwargs: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + captured_kwargs.update(context.kwargs) + await next(context) + + # Setup mock streaming responses + from agent_framework import ChatResponseUpdate, TextContent + + chat_client.streaming_responses = [ + [ChatResponseUpdate(text=TextContent(text="Streaming response"), role="assistant")], + ] + + sub_agent = ChatAgent( + chat_client=chat_client, + name="sub_agent", + middleware=[capture_middleware], + ) + + captured_updates: list[Any] = [] + + async def stream_callback(update: Any) -> None: + captured_updates.append(update) + + tool = sub_agent.as_tool(stream_callback=stream_callback) + + # Invoke tool with kwargs while streaming callback is active + await tool.invoke( + arguments=tool.input_model(task="Test streaming"), + api_key="streaming-key-999", + ) + + # Verify kwargs were forwarded even in streaming mode + assert "api_key" in captured_kwargs + assert captured_kwargs["api_key"] == "streaming-key-999" + assert len(captured_updates) == 1 + + async def test_as_tool_empty_kwargs_still_works(self, chat_client: MockChatClient) -> None: + """Test that as_tool works correctly when no extra kwargs are provided.""" + # Setup mock response + chat_client.responses = [ + ChatResponse(messages=[ChatMessage(role="assistant", text="Response from agent")]), + ] + + sub_agent = ChatAgent( + chat_client=chat_client, + name="sub_agent", + ) + + tool = sub_agent.as_tool() + + # Invoke without any extra kwargs - should work without errors + result = await tool.invoke(arguments=tool.input_model(task="Simple task")) + + # Verify tool executed successfully + assert result is not None + + async def test_as_tool_kwargs_with_chat_options(self, chat_client: MockChatClient) -> None: + """Test that kwargs including chat_options are properly forwarded.""" + captured_kwargs: dict[str, Any] = {} + + @agent_middleware + async def capture_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + captured_kwargs.update(context.kwargs) + await next(context) + + # Setup mock response + chat_client.responses = [ + ChatResponse(messages=[ChatMessage(role="assistant", text="Response with options")]), + ] + + sub_agent = ChatAgent( + chat_client=chat_client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool() + + # Invoke with various kwargs + await tool.invoke( + arguments=tool.input_model(task="Test with options"), + temperature=0.8, + max_tokens=500, + custom_param="custom_value", + ) + + # Verify all kwargs were forwarded + assert "temperature" in captured_kwargs + assert captured_kwargs["temperature"] == 0.8 + assert "max_tokens" in captured_kwargs + assert captured_kwargs["max_tokens"] == 500 + assert "custom_param" in captured_kwargs + assert captured_kwargs["custom_param"] == "custom_value" + + async def test_as_tool_kwargs_isolated_per_invocation(self, chat_client: MockChatClient) -> None: + """Test that kwargs are isolated per invocation and don't leak between calls.""" + first_call_kwargs: dict[str, Any] = {} + second_call_kwargs: dict[str, Any] = {} + call_count = 0 + + @agent_middleware + async def capture_middleware( + context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] + ) -> None: + nonlocal call_count + call_count += 1 + if call_count == 1: + first_call_kwargs.update(context.kwargs) + elif call_count == 2: + second_call_kwargs.update(context.kwargs) + await next(context) + + # Setup mock responses for both calls + chat_client.responses = [ + ChatResponse(messages=[ChatMessage(role="assistant", text="First response")]), + ChatResponse(messages=[ChatMessage(role="assistant", text="Second response")]), + ] + + sub_agent = ChatAgent( + chat_client=chat_client, + name="sub_agent", + middleware=[capture_middleware], + ) + + tool = sub_agent.as_tool() + + # First call with specific kwargs + await tool.invoke( + arguments=tool.input_model(task="First task"), + session_id="session-1", + api_token="token-1", + ) + + # Second call with different kwargs + await tool.invoke( + arguments=tool.input_model(task="Second task"), + session_id="session-2", + api_token="token-2", + ) + + # Verify first call had its own kwargs + assert first_call_kwargs.get("session_id") == "session-1" + assert first_call_kwargs.get("api_token") == "token-1" + + # Verify second call had its own kwargs (not leaked from first) + assert second_call_kwargs.get("session_id") == "session-2" + assert second_call_kwargs.get("api_token") == "token-2" diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index acd9157363..f3477c8163 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -191,6 +191,25 @@ def telemetry_test_tool(x: int, y: int) -> int: assert attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id" +async def test_ai_function_invoke_ignores_additional_kwargs() -> None: + """Ensure ai_function tools drop unknown kwargs when invoked with validated arguments.""" + + @ai_function + async def simple_tool(message: str) -> str: + """Echo tool.""" + return message.upper() + + args = simple_tool.input_model(message="hello world") + + # These kwargs simulate runtime context passed through function invocation. + result = await simple_tool.invoke( + arguments=args, + api_token="secret-token", + chat_options={"model_id": "dummy"}, + ) + + assert result == "HELLO WORLD" + async def test_ai_function_invoke_telemetry_with_pydantic_args(span_exporter: InMemorySpanExporter): """Test the ai_function invoke method with Pydantic model arguments.""" diff --git a/python/samples/README.md b/python/samples/README.md index 937c303c77..3c4d9172c7 100644 --- a/python/samples/README.md +++ b/python/samples/README.md @@ -205,6 +205,7 @@ This directory contains samples demonstrating the capabilities of Microsoft Agen | [`getting_started/middleware/function_based_middleware.py`](./getting_started/middleware/function_based_middleware.py) | Function-based middleware example | | [`getting_started/middleware/middleware_termination.py`](./getting_started/middleware/middleware_termination.py) | Middleware termination example | | [`getting_started/middleware/override_result_with_middleware.py`](./getting_started/middleware/override_result_with_middleware.py) | Override result with middleware example | +| [`getting_started/middleware/runtime_context_delegation.py`](./getting_started/middleware/runtime_context_delegation.py) | Runtime context delegation example demonstrating how to pass API tokens, session data, and other context through hierarchical agent delegation | | [`getting_started/middleware/shared_state_middleware.py`](./getting_started/middleware/shared_state_middleware.py) | Shared state middleware example | | [`getting_started/middleware/thread_behavior_middleware.py`](./getting_started/middleware/thread_behavior_middleware.py) | Thread behavior middleware example demonstrating how to track conversation state across multiple agent runs | diff --git a/python/samples/getting_started/middleware/runtime_context_delegation.py b/python/samples/getting_started/middleware/runtime_context_delegation.py new file mode 100644 index 0000000000..94719396c6 --- /dev/null +++ b/python/samples/getting_started/middleware/runtime_context_delegation.py @@ -0,0 +1,456 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from collections.abc import Awaitable, Callable +from typing import Annotated + +from agent_framework import FunctionInvocationContext, ai_function, function_middleware +from agent_framework.openai import OpenAIChatClient +from pydantic import Field + +""" +Runtime Context Delegation Patterns + +This sample demonstrates different patterns for passing runtime context (API tokens, +session data, etc.) to tools and sub-agents. + +Patterns Demonstrated: + +1. **Pattern 1: Single Agent with Middleware & Closure** (Lines 130-180) + - Best for: Single agent with multiple tools + - How: Middleware stores kwargs in container, tools access via closure + - Pros: Simple, explicit state management + - Cons: Requires container instance per agent + +2. **Pattern 2: Hierarchical Agents with kwargs Propagation** (Lines 190-240) + - Best for: Parent-child agent delegation with as_tool() + - How: kwargs automatically propagate through as_tool() wrapper + - Pros: Automatic, works with nested delegation, clean separation + - Cons: None - this is the recommended pattern for hierarchical agents + +3. **Pattern 3: Mixed - Hierarchical with Middleware** (Lines 250-300) + - Best for: Complex scenarios needing both delegation and state management + - How: Combines automatic kwargs propagation with middleware processing + - Pros: Maximum flexibility, can transform/validate context at each level + - Cons: More complex setup + +Key Concepts: +- Runtime Context: Session-specific data like API tokens, user IDs, tenant info +- Middleware: Intercepts function calls to access/modify kwargs +- Closure: Functions capturing variables from outer scope +- kwargs Propagation: Automatic forwarding of runtime context through delegation chains +""" + + +class SessionContextContainer: + """Container for runtime session context accessible via closure.""" + + def __init__(self) -> None: + """Initialize with None values for runtime context.""" + self.api_token: str | None = None + self.user_id: str | None = None + self.session_metadata: dict[str, str] = {} + + async def inject_context_middleware( + self, + context: FunctionInvocationContext, + next: Callable[[FunctionInvocationContext], Awaitable[None]], + ) -> None: + """Middleware that extracts runtime context from kwargs and stores in container. + + This middleware runs before tool execution and makes runtime context + available to tools via the container instance. + """ + # Extract runtime context from kwargs + self.api_token = context.kwargs.get("api_token") + self.user_id = context.kwargs.get("user_id") + self.session_metadata = context.kwargs.get("session_metadata", {}) + + # Log what we captured (for demonstration) + if self.api_token or self.user_id: + print("[Middleware] Captured runtime context:") + print(f" - API Token: {'***' + self.api_token[-4:] if self.api_token else 'None'}") + print(f" - User ID: {self.user_id}") + print(f" - Session Metadata: {self.session_metadata}") + + # Continue to tool execution + await next(context) + + +# Create a container instance that will be shared via closure +runtime_context = SessionContextContainer() + + +@ai_function +async def send_email( + to: Annotated[str, Field(description="Recipient email address")], + subject: Annotated[str, Field(description="Email subject line")], + body: Annotated[str, Field(description="Email body content")], +) -> str: + """Send an email using authenticated API (simulated). + + This function accesses runtime context (API token, user ID) via closure + from the runtime_context container. + """ + # Access runtime context via closure + token = runtime_context.api_token + user_id = runtime_context.user_id + tenant = runtime_context.session_metadata.get("tenant", "unknown") + + print("\n[send_email] Executing with runtime context:") + print(f" - Token: {'***' + token[-4:] if token else 'NOT PROVIDED'}") + print(f" - User ID: {user_id or 'NOT PROVIDED'}") + print(f" - Tenant: {tenant}") + print(f" - To: {to}") + print(f" - Subject: {subject}") + + # Simulate API call with authentication + if not token: + return "ERROR: No API token provided - cannot send email" + + # Simulate sending email + return f"Email sent to {to} from user {user_id} (tenant: {tenant}). Subject: '{subject}'" + + +@ai_function +async def send_notification( + message: Annotated[str, Field(description="Notification message to send")], + priority: Annotated[str, Field(description="Priority level: low, medium, high")] = "medium", +) -> str: + """Send a push notification using authenticated API (simulated). + + This function accesses runtime context via closure from runtime_context. + """ + token = runtime_context.api_token + user_id = runtime_context.user_id + + print("\n[send_notification] Executing with runtime context:") + print(f" - Token: {'***' + token[-4:] if token else 'NOT PROVIDED'}") + print(f" - User ID: {user_id or 'NOT PROVIDED'}") + print(f" - Message: {message}") + print(f" - Priority: {priority}") + + if not token: + return "ERROR: No API token provided - cannot send notification" + + return f"Notification sent to user {user_id} with priority {priority}: {message}" + + +async def pattern_1_single_agent_with_closure() -> None: + """Pattern 1: Single agent with middleware and closure for runtime context.""" + print("\n" + "=" * 70) + print("PATTERN 1: Single Agent with Middleware & Closure") + print("=" * 70) + print("Use case: Single agent with multiple tools sharing runtime context") + print() + + client = OpenAIChatClient(model_id="gpt-4o-mini") + + # Create agent with both tools and shared context via middleware + communication_agent = client.create_agent( + name="communication_agent", + instructions=( + "You are a communication assistant that can send emails and notifications. " + "Use send_email for email tasks and send_notification for notification tasks." + ), + tools=[send_email, send_notification], + # Both tools share the same context container via middleware + middleware=[runtime_context.inject_context_middleware], + ) + + # Test 1: Send email with runtime context + print("\n" + "=" * 70) + print("TEST 1: Email with Runtime Context") + print("=" * 70) + + user_query = ( + "Send an email to john@example.com with subject 'Meeting Tomorrow' and body 'Don't forget our 2pm meeting.'" + ) + print(f"\nUser: {user_query}") + + result1 = await communication_agent.run( + user_query, + # Runtime context passed as kwargs + api_token="sk-test-token-xyz-789", + user_id="user-12345", + session_metadata={"tenant": "acme-corp", "region": "us-west"}, + ) + + print(f"\nAgent: {result1.text}") + + # Test 2: Send notification with different runtime context + print("\n" + "=" * 70) + print("TEST 2: Notification with Different Runtime Context") + print("=" * 70) + + user_query2 = "Send a high priority notification saying 'Your order has shipped!'" + print(f"\nUser: {user_query2}") + + result2 = await communication_agent.run( + user_query2, + # Different runtime context for this request + api_token="sk-prod-token-abc-456", + user_id="user-67890", + session_metadata={"tenant": "store-inc", "region": "eu-central"}, + ) + + print(f"\nAgent: {result2.text}") + + # Test 3: Both email and notification in one request + print("\n" + "=" * 70) + print("TEST 3: Multiple Tools in One Request") + print("=" * 70) + + user_query3 = ( + "Send an email to alice@example.com about the new feature launch " + "and also send a notification to remind about the team meeting." + ) + print(f"\nUser: {user_query3}") + + result3 = await communication_agent.run( + user_query3, + api_token="sk-dev-token-def-123", + user_id="user-11111", + session_metadata={"tenant": "dev-team", "region": "us-east"}, + ) + + print(f"\nAgent: {result3.text}") + + # Test 4: Missing context - show error handling + print("\n" + "=" * 70) + print("TEST 4: Missing Runtime Context (Error Case)") + print("=" * 70) + + user_query4 = "Send an email to test@example.com with subject 'Test'" + print(f"\nUser: {user_query4}") + print("Note: Running WITHOUT api_token to demonstrate error handling") + + result4 = await communication_agent.run( + user_query4, + # Missing api_token - tools should handle gracefully + user_id="user-22222", + ) + + print(f"\nAgent: {result4.text}") + + print("\n✓ Pattern 1 complete - Middleware & closure pattern works for single agents") + + +# Pattern 2: Hierarchical agents with automatic kwargs propagation +# ================================================================ + + +# Create tools for sub-agents (these will use kwargs propagation) +@ai_function +async def send_email_v2( + to: Annotated[str, Field(description="Recipient email")], + subject: Annotated[str, Field(description="Subject")], + body: Annotated[str, Field(description="Body")], +) -> str: + """Send email - demonstrates kwargs propagation pattern.""" + # In this pattern, we can create a middleware to access kwargs + # But for simplicity, we'll just simulate the operation + return f"Email sent to {to} with subject '{subject}'" + + +@ai_function +async def send_sms( + phone: Annotated[str, Field(description="Phone number")], + message: Annotated[str, Field(description="SMS message")], +) -> str: + """Send SMS message.""" + return f"SMS sent to {phone}: {message}" + + +async def pattern_2_hierarchical_with_kwargs_propagation() -> None: + """Pattern 2: Hierarchical agents with automatic kwargs propagation through as_tool().""" + print("\n" + "=" * 70) + print("PATTERN 2: Hierarchical Agents with kwargs Propagation") + print("=" * 70) + print("Use case: Parent agent delegates to specialized sub-agents") + print("Feature: Runtime kwargs automatically propagate through as_tool()") + print() + + # Track kwargs at each level + email_agent_kwargs: dict[str, object] = {} + sms_agent_kwargs: dict[str, object] = {} + + @function_middleware + async def email_kwargs_tracker( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + email_agent_kwargs.update(context.kwargs) + print(f"[EmailAgent] Received runtime context: {list(context.kwargs.keys())}") + await next(context) + + @function_middleware + async def sms_kwargs_tracker( + context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + sms_agent_kwargs.update(context.kwargs) + print(f"[SMSAgent] Received runtime context: {list(context.kwargs.keys())}") + await next(context) + + client = OpenAIChatClient(model_id="gpt-4o-mini") + + # Create specialized sub-agents + email_agent = client.create_agent( + name="email_agent", + instructions="You send emails using the send_email_v2 tool.", + tools=[send_email_v2], + middleware=[email_kwargs_tracker], + ) + + sms_agent = client.create_agent( + name="sms_agent", + instructions="You send SMS messages using the send_sms tool.", + tools=[send_sms], + middleware=[sms_kwargs_tracker], + ) + + # Create coordinator that delegates to sub-agents + coordinator = client.create_agent( + name="coordinator", + instructions=( + "You coordinate communication tasks. " + "Use email_sender for emails and sms_sender for SMS. " + "Delegate to the appropriate specialized agent." + ), + tools=[ + email_agent.as_tool( + name="email_sender", + description="Send emails to recipients", + arg_name="task", + ), + sms_agent.as_tool( + name="sms_sender", + description="Send SMS messages", + arg_name="task", + ), + ], + ) + + # Test: Runtime context propagates automatically + print("Test: Send email with runtime context\n") + await coordinator.run( + "Send an email to john@example.com with subject 'Meeting' and body 'See you at 2pm'", + api_token="secret-token-abc", + user_id="user-999", + tenant_id="tenant-acme", + ) + + print(f"\n[Verification] EmailAgent received: {email_agent_kwargs}") + print(f" - api_token: {email_agent_kwargs.get('api_token')}") + print(f" - user_id: {email_agent_kwargs.get('user_id')}") + print(f" - tenant_id: {email_agent_kwargs.get('tenant_id')}") + + print("\n✓ Pattern 2 complete - kwargs automatically propagate through as_tool()") + + +# Pattern 3: Mixed pattern - hierarchical with middleware processing +# =================================================================== + + +class AuthContextMiddleware: + """Middleware that validates and transforms runtime context.""" + + def __init__(self) -> None: + self.validated_tokens: list[str] = [] + + async def validate_and_track( + self, context: FunctionInvocationContext, next: Callable[[FunctionInvocationContext], Awaitable[None]] + ) -> None: + """Validate API token and track usage.""" + api_token = context.kwargs.get("api_token") + + if api_token: + # Simulate token validation + if api_token.startswith("valid-"): + print(f"[AuthMiddleware] ✓ Token validated: ***{api_token[-4:]}") + self.validated_tokens.append(api_token) + else: + print(f"[AuthMiddleware] ✗ Invalid token: {api_token}") + # Could set context.terminate = True to block execution + else: + print("[AuthMiddleware] ⚠ No API token provided") + + await next(context) + + +@ai_function +async def protected_operation(operation: Annotated[str, Field(description="Operation to perform")]) -> str: + """Protected operation that requires authentication.""" + return f"Executed protected operation: {operation}" + + +async def pattern_3_hierarchical_with_middleware() -> None: + """Pattern 3: Hierarchical agents with middleware processing at each level.""" + print("\n" + "=" * 70) + print("PATTERN 3: Hierarchical with Middleware Processing") + print("=" * 70) + print("Use case: Multi-level validation/transformation of runtime context") + print() + + auth_middleware = AuthContextMiddleware() + + client = OpenAIChatClient(model_id="gpt-4o-mini") + + # Sub-agent with validation middleware + protected_agent = client.create_agent( + name="protected_agent", + instructions="You perform protected operations that require authentication.", + tools=[protected_operation], + middleware=[auth_middleware.validate_and_track], + ) + + # Coordinator delegates to protected agent + coordinator = client.create_agent( + name="coordinator", + instructions="You coordinate protected operations. Delegate to protected_executor.", + tools=[ + protected_agent.as_tool( + name="protected_executor", + description="Execute protected operations", + ) + ], + ) + + # Test with valid token + print("Test 1: Valid token\n") + await coordinator.run( + "Execute operation: backup_database", + api_token="valid-token-xyz-789", + user_id="admin-123", + ) + + # Test with invalid token + print("\nTest 2: Invalid token\n") + await coordinator.run( + "Execute operation: delete_records", + api_token="invalid-token-bad", + user_id="user-456", + ) + + print(f"\n[Validation Summary] Validated tokens: {len(auth_middleware.validated_tokens)}") + print("✓ Pattern 3 complete - Middleware can validate/transform context at each level") + + +async def main() -> None: + """Demonstrate all runtime context delegation patterns.""" + print("=" * 70) + print("Runtime Context Delegation Patterns Demo") + print("=" * 70) + print() + + # Run Pattern 1 + await pattern_1_single_agent_with_closure() + + # Run Pattern 2 + await pattern_2_hierarchical_with_kwargs_propagation() + + # Run Pattern 3 + await pattern_3_hierarchical_with_middleware() + + +if __name__ == "__main__": + asyncio.run(main()) From 90ef4261a304f2312847020569ac31f22ec4d8a7 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 19 Nov 2025 07:50:01 +0900 Subject: [PATCH 2/2] simplify --- .../packages/core/agent_framework/_agents.py | 13 ++++--- .../packages/core/agent_framework/_tools.py | 35 ++++++++----------- .../core/agent_framework/observability.py | 6 ++-- .../core/test_as_tool_kwargs_propagation.py | 5 +-- python/packages/core/tests/core/test_tools.py | 1 + python/uv.lock | 2 +- 6 files changed, 30 insertions(+), 32 deletions(-) diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index a158ce8973..3c8ae17b91 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -473,13 +473,14 @@ async def agent_wrapper(**kwargs: Any) -> str: # Create final text from accumulated updates return AgentRunResponse.from_agent_run_response_updates(response_updates).text - return AIFunction( + agent_tool: AIFunction[BaseModel, str] = AIFunction( name=tool_name, description=tool_description, func=agent_wrapper, input_model=input_model, # type: ignore - forward_additional_kwargs=True, ) + agent_tool._forward_runtime_kwargs = True # type: ignore + return agent_tool def _normalize_messages( self, @@ -872,7 +873,9 @@ async def run( user=user, **(additional_chat_options or {}), ) - response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **kwargs) + # Filter chat_options from kwargs to prevent duplicate keyword argument + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} + response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **filtered_kwargs) await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id) @@ -1004,9 +1007,11 @@ async def run_stream( **(additional_chat_options or {}), ) + # Filter chat_options from kwargs to prevent duplicate keyword argument + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} response_updates: list[ChatResponseUpdate] = [] async for update in self.chat_client.get_streaming_response( - messages=thread_messages, chat_options=co, **kwargs + messages=thread_messages, chat_options=co, **filtered_kwargs ): response_updates.append(update) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index b1188e051c..a2cf502eff 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -587,7 +587,6 @@ def __init__( additional_properties: dict[str, Any] | None = None, func: Callable[..., Awaitable[ReturnT] | ReturnT] | None = None, input_model: type[ArgsT] | Mapping[str, Any] | None = None, - forward_additional_kwargs: bool = False, **kwargs: Any, ) -> None: """Initialize the AIFunction. @@ -606,8 +605,6 @@ def __init__( input_model: The Pydantic model that defines the input parameters for the function. This can also be a JSON schema dictionary. If not provided, it will be inferred from the function signature. - forward_additional_kwargs: Whether to forward unknown kwargs directly to the wrapped function. - Used by agent tools created via ``as_tool`` so they can receive runtime context kwargs. **kwargs: Additional keyword arguments. """ super().__init__( @@ -629,7 +626,7 @@ def __init__( self.invocation_exception_count = 0 self._invocation_duration_histogram = _default_histogram() self.type: Literal["ai_function"] = "ai_function" - self.forward_additional_kwargs = forward_additional_kwargs + self._forward_runtime_kwargs: bool = False @property def declaration_only(self) -> bool: @@ -694,18 +691,16 @@ async def invoke( global OBSERVABILITY_SETTINGS from .observability import OBSERVABILITY_SETTINGS - tool_call_id = kwargs.pop("tool_call_id", None) - provided_kwargs = dict(kwargs) + original_kwargs = dict(kwargs) + tool_call_id = original_kwargs.pop("tool_call_id", None) if arguments is not None: if not isinstance(arguments, self.input_model): raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - argument_values = arguments.model_dump(exclude_none=True) - if self.forward_additional_kwargs: - kwargs = {**provided_kwargs, **argument_values} - else: - kwargs = argument_values + kwargs = arguments.model_dump(exclude_none=True) + if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: + kwargs.update(original_kwargs) else: - kwargs = provided_kwargs + kwargs = original_kwargs if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") logger.debug(f"Function arguments: {kwargs}") @@ -1239,22 +1234,20 @@ async def _auto_invoke_function( parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {}) - # Filter out internal framework kwargs before merging/passing to tools. + # Filter out internal framework kwargs before passing to tools. runtime_kwargs: dict[str, Any] = { key: value for key, value in (custom_args or {}).items() - if key not in {"_function_middleware_pipeline", "middleware", "chat_options", "tools"} + if key not in {"_function_middleware_pipeline", "middleware"} } - - # Merge with user-supplied args; right-hand side dominates, so parsed args win on conflicts. - merged_args: dict[str, Any] = runtime_kwargs | parsed_args try: - args = tool.input_model.model_validate(merged_args) + args = tool.input_model.model_validate(parsed_args) except ValidationError as exc: message = "Error: Argument parsing failed." if config.include_detailed_errors: message = f"{message} Exception: {exc}" return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc) + if not middleware_pipeline or ( not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares ): @@ -1263,8 +1256,8 @@ async def _auto_invoke_function( function_result = await tool.invoke( arguments=args, tool_call_id=function_call_content.call_id, - **runtime_kwargs, - ) # type: ignore[arg-type] + **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, + ) return FunctionResultContent( call_id=function_call_content.call_id, result=function_result, @@ -1287,7 +1280,7 @@ async def final_function_handler(context_obj: Any) -> Any: return await tool.invoke( arguments=context_obj.arguments, tool_call_id=function_call_content.call_id, - **context_obj.kwargs, + **context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) try: diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 3e44fae23c..8888f69579 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1104,6 +1104,7 @@ async def trace_run( if not OBSERVABILITY_SETTINGS.ENABLED: # If model diagnostics are not enabled, just return the completion return await run_func(self, messages=messages, thread=thread, **kwargs) + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, @@ -1112,7 +1113,7 @@ async def trace_run( agent_description=self.description, thread_id=thread.service_thread_id if thread else None, chat_options=getattr(self, "chat_options", None), - **kwargs, + **filtered_kwargs, ) with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: @@ -1173,6 +1174,7 @@ async def trace_run_streaming( all_updates: list["AgentRunResponseUpdate"] = [] + filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"} attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, provider_name=provider_name, @@ -1181,7 +1183,7 @@ async def trace_run_streaming( agent_description=self.description, thread_id=thread.service_thread_id if thread else None, chat_options=getattr(self, "chat_options", None), - **kwargs, + **filtered_kwargs, ) with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index bb07e04564..8669ecc3d6 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -23,7 +23,6 @@ async def capture_middleware( context: AgentRunContext, next: Callable[[AgentRunContext], Awaitable[None]] ) -> None: # Capture kwargs passed to the sub-agent - print(f"Middleware captured kwargs: {context.kwargs}") captured_kwargs.update(context.kwargs) await next(context) @@ -43,14 +42,12 @@ async def capture_middleware( tool = sub_agent.as_tool(name="delegate", arg_name="task") # Directly invoke the tool with kwargs (simulating what happens during agent execution) - result = await tool.invoke( + _ = await tool.invoke( arguments=tool.input_model(task="Test delegation"), api_token="secret-xyz-123", user_id="user-456", session_id="session-789", ) - print(f"Tool result: {result}") - print(f"Captured kwargs after invoke: {captured_kwargs}") # Verify kwargs were forwarded to sub-agent assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}" diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f3477c8163..49772e7fdd 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -210,6 +210,7 @@ async def simple_tool(message: str) -> str: assert result == "HELLO WORLD" + async def test_ai_function_invoke_telemetry_with_pydantic_args(span_exporter: InMemorySpanExporter): """Test the ai_function invoke method with Pydantic model arguments.""" diff --git a/python/uv.lock b/python/uv.lock index f0657d5f68..a9a0e9592d 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -190,7 +190,7 @@ requires-dist = [ [[package]] name = "agent-framework-ag-ui" -version = "1.0.0b251114" +version = "1.0.0b251117" source = { editable = "packages/ag-ui" } dependencies = [ { name = "ag-ui-protocol", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },