Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions python/packages/core/agent_framework/_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,16 @@ async def agent_wrapper(**kwargs: Any) -> str:
# Extract the input from kwargs using the specified arg_name
input_text = kwargs.get(arg_name, "")

# Forward all kwargs except the arg_name to support runtime context propagation
forwarded_kwargs = {k: v for k, v in kwargs.items() if k != arg_name}

if stream_callback is None:
# Use non-streaming mode
return (await self.run(input_text)).text
return (await self.run(input_text, **forwarded_kwargs)).text

# Use streaming mode - accumulate updates and create final response
response_updates: list[AgentRunResponseUpdate] = []
async for update in self.run_stream(input_text):
async for update in self.run_stream(input_text, **forwarded_kwargs):
response_updates.append(update)
if is_async_callback:
await stream_callback(update) # type: ignore[misc]
Expand All @@ -470,12 +473,14 @@ async def agent_wrapper(**kwargs: Any) -> str:
# Create final text from accumulated updates
return AgentRunResponse.from_agent_run_response_updates(response_updates).text

return AIFunction(
agent_tool: AIFunction[BaseModel, str] = AIFunction(
name=tool_name,
description=tool_description,
func=agent_wrapper,
input_model=input_model, # type: ignore
)
agent_tool._forward_runtime_kwargs = True # type: ignore
return agent_tool

def _normalize_messages(
self,
Expand Down Expand Up @@ -868,7 +873,9 @@ async def run(
user=user,
**(additional_chat_options or {}),
)
response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **kwargs)
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response = await self.chat_client.get_response(messages=thread_messages, chat_options=co, **filtered_kwargs)

await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)

Expand Down Expand Up @@ -1000,9 +1007,11 @@ async def run_stream(
**(additional_chat_options or {}),
)

# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response_updates: list[ChatResponseUpdate] = []
async for update in self.chat_client.get_streaming_response(
messages=thread_messages, chat_options=co, **kwargs
messages=thread_messages, chat_options=co, **filtered_kwargs
):
response_updates.append(update)

Expand Down
25 changes: 19 additions & 6 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ def __init__(
self.invocation_exception_count = 0
self._invocation_duration_histogram = _default_histogram()
self.type: Literal["ai_function"] = "ai_function"
self._forward_runtime_kwargs: bool = False

@property
def declaration_only(self) -> bool:
Expand Down Expand Up @@ -690,11 +691,16 @@ async def invoke(
global OBSERVABILITY_SETTINGS
from .observability import OBSERVABILITY_SETTINGS

tool_call_id = kwargs.pop("tool_call_id", None)
original_kwargs = dict(kwargs)
tool_call_id = original_kwargs.pop("tool_call_id", None)
if arguments is not None:
if not isinstance(arguments, self.input_model):
raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}")
kwargs = arguments.model_dump(exclude_none=True)
if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs:
kwargs.update(original_kwargs)
else:
kwargs = original_kwargs
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
logger.info(f"Function name: {self.name}")
logger.debug(f"Function arguments: {kwargs}")
Expand Down Expand Up @@ -1228,15 +1234,20 @@ async def _auto_invoke_function(

parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})

# Merge with user-supplied args; right-hand side dominates, so parsed args win on conflicts.
merged_args: dict[str, Any] = (custom_args or {}) | parsed_args
# Filter out internal framework kwargs before passing to tools.
runtime_kwargs: dict[str, Any] = {
key: value
for key, value in (custom_args or {}).items()
if key not in {"_function_middleware_pipeline", "middleware"}
}
try:
args = tool.input_model.model_validate(merged_args)
args = tool.input_model.model_validate(parsed_args)
except ValidationError as exc:
message = "Error: Argument parsing failed."
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)

if not middleware_pipeline or (
not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares
):
Expand All @@ -1245,7 +1256,8 @@ async def _auto_invoke_function(
function_result = await tool.invoke(
arguments=args,
tool_call_id=function_call_content.call_id,
) # type: ignore[arg-type]
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
)
return FunctionResultContent(
call_id=function_call_content.call_id,
result=function_result,
Expand All @@ -1261,13 +1273,14 @@ async def _auto_invoke_function(
middleware_context = FunctionInvocationContext(
function=tool,
arguments=args,
kwargs=custom_args or {},
kwargs=runtime_kwargs.copy(),
)

async def final_function_handler(context_obj: Any) -> Any:
return await tool.invoke(
arguments=context_obj.arguments,
tool_call_id=function_call_content.call_id,
**context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
)

try:
Expand Down
6 changes: 4 additions & 2 deletions python/packages/core/agent_framework/observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ async def trace_run(
if not OBSERVABILITY_SETTINGS.ENABLED:
# If model diagnostics are not enabled, just return the completion
return await run_func(self, messages=messages, thread=thread, **kwargs)
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
attributes = _get_span_attributes(
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
provider_name=provider_name,
Expand All @@ -1112,7 +1113,7 @@ async def trace_run(
agent_description=self.description,
thread_id=thread.service_thread_id if thread else None,
chat_options=getattr(self, "chat_options", None),
**kwargs,
**filtered_kwargs,
)
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
Expand Down Expand Up @@ -1173,6 +1174,7 @@ async def trace_run_streaming(

all_updates: list["AgentRunResponseUpdate"] = []

filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
attributes = _get_span_attributes(
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
provider_name=provider_name,
Expand All @@ -1181,7 +1183,7 @@ async def trace_run_streaming(
agent_description=self.description,
thread_id=thread.service_thread_id if thread else None,
chat_options=getattr(self, "chat_options", None),
**kwargs,
**filtered_kwargs,
)
with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span:
if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages:
Expand Down
Loading
Loading