From 8733b1159e8c95b590b816964ddc0f6486fc97d5 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Mon, 5 Jan 2026 16:31:06 +0900 Subject: [PATCH 1/8] fix(ag-ui): execute tools after approval in human-in-the-loop flow --- .../_message_adapters.py | 63 ++++++- .../_orchestration/_message_hygiene.py | 24 ++- .../tests/test_agent_wrapper_comprehensive.py | 178 ++++++++++++++++++ 3 files changed, 255 insertions(+), 10 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 11d2977f90..b9bc512c2c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -2,6 +2,7 @@ """Message format conversion between AG-UI and Agent Framework.""" +import json from typing import Any, cast from agent_framework import ( @@ -59,21 +60,65 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha # Distinguish approval payloads from actual tool results is_approval = False if isinstance(result_content, str) and result_content: - import json as _json - try: - parsed = _json.loads(result_content) + parsed = json.loads(result_content) is_approval = isinstance(parsed, dict) and "accepted" in parsed except Exception: is_approval = False if is_approval: - # Approval responses should be treated as user messages to trigger human-in-the-loop flow - chat_msg = ChatMessage( - role=Role.USER, - contents=[TextContent(text=str(result_content))], - additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")}, - ) + # Look for the matching function call in previous messages to create + # a proper FunctionApprovalResponseContent. This enables the agent framework + # to execute the approved tool (fix for GitHub issue #3034). + parsed_approval = json.loads(result_content) + accepted = parsed_approval.get("accepted", False) + + # Find the function call that matches this tool_call_id + matching_func_call = None + for prev_msg in result: + role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) + if role_val != "assistant": + continue + for content in prev_msg.contents or []: + if isinstance(content, FunctionCallContent): + if content.call_id == tool_call_id and content.name != "confirm_changes": + matching_func_call = content + break + + if matching_func_call: + # Remove any existing tool result for this call_id since the framework + # will re-execute the tool after approval. Keeping old results causes + # OpenAI API errors ("tool message must follow assistant with tool_calls"). + result = [ + m + for m in result + if not ( + (m.role.value if hasattr(m.role, "value") else str(m.role)) == "tool" + and any( + isinstance(c, FunctionResultContent) and c.call_id == tool_call_id + for c in (m.contents or []) + ) + ) + ] + + # Create FunctionApprovalResponseContent for the agent framework + approval_response = FunctionApprovalResponseContent( + approved=accepted, + id=str(tool_call_id), + function_call=matching_func_call, + ) + chat_msg = ChatMessage( + role=Role.USER, + contents=[approval_response], + ) + else: + # No matching function call found - this is likely a confirm_changes approval + # Keep the old behavior for backwards compatibility + chat_msg = ChatMessage( + role=Role.USER, + contents=[TextContent(text=str(result_content))], + additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")}, + ) if "id" in msg: chat_msg.message_id = msg["id"] result.append(chat_msg) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py index 97c990781b..15cb7bd43c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py @@ -6,7 +6,13 @@ import logging from typing import Any -from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent +from agent_framework import ( + ChatMessage, + FunctionApprovalResponseContent, + FunctionCallContent, + FunctionResultContent, + TextContent, +) logger = logging.getLogger(__name__) @@ -40,6 +46,22 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: continue if role_value == "user": + # Check if this message contains FunctionApprovalResponseContent + # If so, the framework will handle tool execution - don't inject synthetic results + approval_call_ids: set[str] = set() + for content in msg.contents or []: + if isinstance(content, FunctionApprovalResponseContent): + if content.function_call and content.function_call.call_id: + approval_call_ids.add(str(content.function_call.call_id)) + + if approval_call_ids and pending_tool_call_ids: + # Remove approved call_ids from pending - the framework will execute them + pending_tool_call_ids -= approval_call_ids + logger.info( + f"FunctionApprovalResponseContent found for call_ids={approval_call_ids} - " + "framework will handle execution" + ) + if pending_confirm_changes_id: user_text = "" for content in msg.contents or []: diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index beb6f8af2c..c1456c3d63 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -630,3 +630,181 @@ async def stream_fn( # Should contain some reference to the document full_text = "".join(e.delta for e in text_events) assert "written" in full_text.lower() or "document" in full_text.lower() + + +async def test_function_approval_mode_executes_tool(): + """Test that function approval with approval_mode='always_require' sends the correct messages.""" + from agent_framework import FunctionApprovalResponseContent, ai_function + from agent_framework.ag_ui import AgentFrameworkAgent + + messages_received: list[Any] = [] + + @ai_function( + name="get_datetime", + description="Get the current date and time", + approval_mode="always_require", + ) + def get_datetime() -> str: + return "2025/12/01 12:00:00" + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + # Capture the messages received by the chat client + messages_received.clear() + messages_received.extend(messages) + yield ChatResponseUpdate(contents=[TextContent(text="Processing completed")]) + + agent = ChatAgent( + name="test_agent", + instructions="Test", + chat_client=StreamingChatClientStub(stream_fn), + tools=[get_datetime], + ) + wrapper = AgentFrameworkAgent(agent=agent) + + # Simulate the conversation history with: + # 1. User message asking for time + # 2. Assistant message with the function call that needs approval + # 3. Tool approval message from user + tool_result: dict[str, Any] = {"accepted": True} + input_data: dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": "What time is it?", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_get_datetime_123", + "type": "function", + "function": { + "name": "get_datetime", + "arguments": "{}", + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps(tool_result), + "toolCallId": "call_get_datetime_123", + }, + ], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + # Verify the run completed successfully + run_started = [e for e in events if e.type == "RUN_STARTED"] + run_finished = [e for e in events if e.type == "RUN_FINISHED"] + assert len(run_started) == 1 + assert len(run_finished) == 1 + + # Verify that a FunctionApprovalResponseContent was created and sent to the agent + # This is the key fix - the orchestrator should create an approval response + approval_responses_found = False + for msg in messages_received: + for content in msg.contents: + if isinstance(content, FunctionApprovalResponseContent): + approval_responses_found = True + assert content.approved is True + assert content.function_call.name == "get_datetime" + assert content.function_call.call_id == "call_get_datetime_123" + break + + assert approval_responses_found, ( + "FunctionApprovalResponseContent should be included in messages sent to agent. " + "This is required for the agent framework to execute the approved function." + ) + + +async def test_function_approval_mode_rejection(): + """Test that function approval rejection creates a rejection response.""" + from agent_framework import FunctionApprovalResponseContent, ai_function + from agent_framework.ag_ui import AgentFrameworkAgent + + messages_received: list[Any] = [] + + @ai_function( + name="delete_all_data", + description="Delete all user data", + approval_mode="always_require", + ) + def delete_all_data() -> str: + return "All data deleted" + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + # Capture the messages received by the chat client + messages_received.clear() + messages_received.extend(messages) + yield ChatResponseUpdate(contents=[TextContent(text="Operation cancelled")]) + + agent = ChatAgent( + name="test_agent", + instructions="Test", + chat_client=StreamingChatClientStub(stream_fn), + tools=[delete_all_data], + ) + wrapper = AgentFrameworkAgent(agent=agent) + + # Simulate rejection + tool_result: dict[str, Any] = {"accepted": False} + input_data: dict[str, Any] = { + "messages": [ + { + "role": "user", + "content": "Delete all my data", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_delete_123", + "type": "function", + "function": { + "name": "delete_all_data", + "arguments": "{}", + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps(tool_result), + "toolCallId": "call_delete_123", + }, + ], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + # Verify the run completed + run_finished = [e for e in events if e.type == "RUN_FINISHED"] + assert len(run_finished) == 1 + + # Verify that a FunctionApprovalResponseContent with approved=False was created + rejection_found = False + for msg in messages_received: + for content in msg.contents: + if isinstance(content, FunctionApprovalResponseContent): + rejection_found = True + assert content.approved is False + assert content.function_call.name == "delete_all_data" + assert content.function_call.call_id == "call_delete_123" + break + + assert rejection_found, ( + "FunctionApprovalResponseContent with approved=False should be included in messages sent to agent. " + "This tells the agent framework that the tool was rejected." + ) From ba288d97241a68e750a58dbabe4ad38b9d06b367 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 7 Jan 2026 16:27:36 +0900 Subject: [PATCH 2/8] Fix shared state bug --- .../ag-ui/agent_framework_ag_ui/_events.py | 3 + .../_message_adapters.py | 125 +++++++++- .../_orchestration/_message_hygiene.py | 4 +- .../_orchestration/_state_manager.py | 6 +- .../agent_framework_ag_ui/_orchestrators.py | 217 ++++++++++++++++-- .../agents/human_in_the_loop_agent.py | 6 +- .../tests/test_agent_wrapper_comprehensive.py | 38 ++- .../ag-ui/tests/test_message_adapters.py | 84 +++++++ .../ag-ui/tests/test_orchestrators.py | 124 ++++++++++ .../tests/test_orchestrators_coverage.py | 8 +- 10 files changed, 565 insertions(+), 50 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 184da0239e..1bdcb60297 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -88,6 +88,7 @@ def __init__( self.tool_results: list[dict[str, Any]] = [] # Track tool results self.tool_calls_ended: set[str] = set() # Track which tool calls have had ToolCallEndEvent emitted self.accumulated_text_content: str = "" # Track accumulated text for final MessagesSnapshotEvent + self.messages_snapshot_emitted: bool = False # Track snapshot emission to avoid duplicates async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[BaseEvent]: """ @@ -452,6 +453,7 @@ def _emit_snapshot_for_tool_result(self) -> list[BaseEvent]: ) logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages") events.append(messages_snapshot_event) + self.messages_snapshot_emitted = True return events def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: @@ -552,6 +554,7 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: ) logger.info(f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages") events.append(messages_snapshot_event) + self.messages_snapshot_emitted = True self.should_stop_after_confirm = True logger.info("Set flag to stop run after confirm_changes") diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index b9bc512c2c..51cdaf8cc3 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -37,6 +37,31 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha Returns: List of Agent Framework ChatMessage objects """ + + def _update_tool_call_arguments( + raw_messages: list[dict[str, Any]], + tool_call_id: str, + modified_args: dict[str, Any], + ) -> None: + for raw_msg in raw_messages: + tool_calls = raw_msg.get("tool_calls") or raw_msg.get("toolCalls") + if not isinstance(tool_calls, list): + continue + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + if str(tool_call.get("id", "")) != tool_call_id: + continue + function_payload = tool_call.get("function") + if not isinstance(function_payload, dict): + return + existing_args = function_payload.get("arguments") + if isinstance(existing_args, str): + function_payload["arguments"] = json.dumps(modified_args) + else: + function_payload["arguments"] = modified_args + return + result: list[ChatMessage] = [] for msg in messages: # Handle standard tool result messages early (role="tool") to preserve provider invariants @@ -58,20 +83,31 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha result_content = msg.get("result", "") # Distinguish approval payloads from actual tool results - is_approval = False + parsed: dict[str, Any] | None = None if isinstance(result_content, str) and result_content: try: - parsed = json.loads(result_content) - is_approval = isinstance(parsed, dict) and "accepted" in parsed + parsed_candidate = json.loads(result_content) except Exception: - is_approval = False + parsed_candidate = None + if isinstance(parsed_candidate, dict): + parsed = parsed_candidate + elif isinstance(result_content, dict): + parsed = result_content + + is_approval = parsed is not None and "accepted" in parsed if is_approval: # Look for the matching function call in previous messages to create # a proper FunctionApprovalResponseContent. This enables the agent framework # to execute the approved tool (fix for GitHub issue #3034). - parsed_approval = json.loads(result_content) - accepted = parsed_approval.get("accepted", False) + accepted = parsed.get("accepted", False) if parsed is not None else False + approval_payload_text = result_content if isinstance(result_content, str) else json.dumps(parsed) + + # Log the full approval payload to debug modified arguments + import logging + + logger = logging.getLogger(__name__) + logger.info(f"Approval payload received: {parsed}") # Find the function call that matches this tool_call_id matching_func_call = None @@ -101,11 +137,68 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha ) ] + # Check if the approval payload contains modified arguments + # The UI sends back the modified state (e.g., deselected steps) in the approval payload + modified_args = {k: v for k, v in parsed.items() if k != "accepted"} if parsed else {} + state_args: dict[str, Any] | None = None + if modified_args: + original_args = matching_func_call.parse_arguments() or {} + merged_args: dict[str, Any] + if isinstance(original_args, dict) and original_args: + merged_args = {**original_args, **modified_args} + else: + merged_args = dict(modified_args) + + if isinstance(modified_args.get("steps"), list): + original_steps = original_args.get("steps") if isinstance(original_args, dict) else None + if isinstance(original_steps, list): + approved_steps = modified_args.get("steps") or [] + approved_by_description = { + step.get("description"): step + for step in approved_steps + if isinstance(step, dict) and step.get("description") + } + merged_steps: list[Any] = [] + for step in original_steps: + if not isinstance(step, dict): + merged_steps.append(step) + continue + description = step.get("description") + approved_step = approved_by_description.get(description) + status = ( + approved_step.get("status") + if isinstance(approved_step, dict) and approved_step.get("status") + else "disabled" + ) + updated_step = step.copy() + updated_step["status"] = status + merged_steps.append(updated_step) + merged_args["steps"] = merged_steps + state_args = merged_args + + # Keep the original tool call and AG-UI snapshot in sync with approved args. + updated_args = ( + json.dumps(merged_args) if isinstance(matching_func_call.arguments, str) else merged_args + ) + matching_func_call.arguments = updated_args + _update_tool_call_arguments(messages, str(tool_call_id), merged_args) + # Create a new FunctionCallContent with the modified arguments + func_call_for_approval = FunctionCallContent( + call_id=matching_func_call.call_id, + name=matching_func_call.name, + arguments=json.dumps(modified_args), + ) + logger.info(f"Using modified arguments from approval: {modified_args}") + else: + # No modified arguments - use the original function call + func_call_for_approval = matching_func_call + # Create FunctionApprovalResponseContent for the agent framework approval_response = FunctionApprovalResponseContent( approved=accepted, id=str(tool_call_id), - function_call=matching_func_call, + function_call=func_call_for_approval, + additional_properties={"ag_ui_state_args": state_args} if state_args else None, ) chat_msg = ChatMessage( role=Role.USER, @@ -116,7 +209,7 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Cha # Keep the old behavior for backwards compatibility chat_msg = ChatMessage( role=Role.USER, - contents=[TextContent(text=str(result_content))], + contents=[TextContent(text=approval_payload_text)], additional_properties={"is_tool_result": True, "tool_call_id": str(tool_call_id or "")}, ) if "id" in msg: @@ -368,6 +461,22 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic elif content is None: normalized_msg["content"] = "" + tool_calls = normalized_msg.get("tool_calls") or normalized_msg.get("toolCalls") + if isinstance(tool_calls, list): + for tool_call in tool_calls: + if not isinstance(tool_call, dict): + continue + function_payload = tool_call.get("function") + if not isinstance(function_payload, dict): + continue + if "arguments" not in function_payload: + continue + arguments = function_payload.get("arguments") + if arguments is None: + function_payload["arguments"] = "" + elif not isinstance(arguments, str): + function_payload["arguments"] = json.dumps(arguments) + # Normalize tool_call_id to toolCallId for tool messages if normalized_msg.get("role") == "tool": if "tool_call_id" in normalized_msg: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py index 15cb7bd43c..6be50a9f54 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py @@ -50,7 +50,7 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: # If so, the framework will handle tool execution - don't inject synthetic results approval_call_ids: set[str] = set() for content in msg.contents or []: - if isinstance(content, FunctionApprovalResponseContent): + if type(content) is FunctionApprovalResponseContent: if content.function_call and content.function_call.call_id: approval_call_ids.add(str(content.function_call.call_id)) @@ -58,7 +58,7 @@ def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: # Remove approved call_ids from pending - the framework will execute them pending_tool_call_ids -= approval_call_ids logger.info( - f"FunctionApprovalResponseContent found for call_ids={approval_call_ids} - " + f"FunctionApprovalResponseContent found for call_ids={sorted(approval_call_ids)} - " "framework will handle execution" ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py index 45c16afef4..7d8a23d84c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py @@ -22,9 +22,11 @@ def __init__( self.predict_state_config = predict_state_config or {} self.require_confirmation = require_confirmation self.current_state: dict[str, Any] = {} + self._state_from_input: bool = False def initialize(self, initial_state: dict[str, Any] | None) -> dict[str, Any]: """Initialize state with schema defaults.""" + self._state_from_input = initial_state is not None self.current_state = (initial_state or {}).copy() self._apply_schema_defaults() return self.current_state @@ -60,7 +62,9 @@ def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_ca """Inject state context only when starting a new user turn.""" if not self.current_state or not self.state_schema: return None - if not is_new_user_turn or conversation_has_tool_calls: + if not is_new_user_turn: + return None + if conversation_has_tool_calls and not self._state_from_input: return None state_json = json.dumps(self.current_state, indent=2) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6bdff552b6..4dae942ca2 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -13,6 +13,7 @@ BaseEvent, MessagesSnapshotEvent, RunErrorEvent, + StateSnapshotEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, @@ -21,10 +22,18 @@ AgentProtocol, AgentThread, ChatAgent, + FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, TextContent, ) +from agent_framework._middleware import extract_and_merge_function_middleware +from agent_framework._tools import ( + FunctionInvocationConfiguration, + _collect_approval_responses, + _replace_approval_contents_with_results, + _try_execute_function_calls, +) from ._utils import convert_agui_tools_to_agent_framework, generate_event_id @@ -291,7 +300,7 @@ async def run( predict_state_config=context.config.predict_state_config, require_confirmation=context.config.require_confirmation, ) - current_state = state_manager.initialize(context.input_data.get("state", {})) + current_state = state_manager.initialize(context.input_data.get("state")) event_bridge = AgentFrameworkEventBridge( run_id=context.run_id, @@ -386,6 +395,41 @@ async def run( else: logger.info(f" Content {j}: {content_type}") + # Check for FunctionApprovalResponseContent and emit updated state snapshot + # This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps) + for msg in provider_messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value != "user": + continue + for content in msg.contents or []: + if type(content) is FunctionApprovalResponseContent: + if content.function_call and content.approved: + parsed_args = content.function_call.parse_arguments() + state_args = None + if content.additional_properties: + state_args = content.additional_properties.get("ag_ui_state_args") + if not isinstance(state_args, dict): + state_args = parsed_args + if state_args and context.config.predict_state_config: + for state_key, config in context.config.predict_state_config.items(): + if config["tool"] == content.function_call.name: + tool_arg_name = config["tool_argument"] + if tool_arg_name == "*": + state_value = state_args + elif isinstance(state_args, dict) and tool_arg_name in state_args: + state_value = state_args[tool_arg_name] + else: + continue + # Update current_state and emit snapshot + current_state[state_key] = state_value + event_bridge.current_state[state_key] = state_value + logger.info( + f"Emitting StateSnapshotEvent for approved state key '{state_key}' " + f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" + ) + yield StateSnapshotEvent(snapshot=current_state) + break + messages_to_run: list[Any] = [] is_new_user_turn = False if provider_messages: @@ -393,21 +437,88 @@ async def run( role_value = last_msg.role.value if hasattr(last_msg.role, "value") else str(last_msg.role) is_new_user_turn = role_value == "user" - conversation_has_tool_calls = False - for msg in provider_messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - if role_value == "assistant" and hasattr(msg, "contents") and msg.contents: - if any(isinstance(content, FunctionCallContent) for content in msg.contents): - conversation_has_tool_calls = True - break + def _tool_calls_match_state() -> bool: + if not state_manager.predict_state_config or not state_manager.current_state: + return False + + def _parse_args(arguments: Any) -> dict[str, Any] | None: + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + except json.JSONDecodeError: + return None + if isinstance(parsed, dict): + return parsed + return None + + for state_key, config in state_manager.predict_state_config.items(): + tool_name = config["tool"] + tool_arg_name = config["tool_argument"] + tool_args: dict[str, Any] | None = None + + for msg in reversed(provider_messages): + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value != "assistant": + continue + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.name == tool_name: + tool_args = _parse_args(content.arguments) + break + if tool_args is not None: + break + + if not tool_args: + return False + + if tool_arg_name == "*": + state_value = tool_args + elif tool_arg_name in tool_args: + state_value = tool_args[tool_arg_name] + else: + return False + + if state_manager.current_state.get(state_key) != state_value: + return False + + return True + + conversation_has_tool_calls = _tool_calls_match_state() state_context_msg = state_manager.state_context_message( is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls ) - if state_context_msg: - messages_to_run.append(state_context_msg) - messages_to_run.extend(provider_messages) + def _is_state_context_message(message: Any) -> bool: + role_value = message.role.value if hasattr(message.role, "value") else str(message.role) + if role_value != "system": + return False + for content in message.contents or []: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + return True + return False + + def _pending_tool_call_ids(messages: list[Any]) -> set[str]: + pending_ids: set[str] = set() + resolved_ids: set[str] = set() + for msg in messages: + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.call_id: + pending_ids.add(str(content.call_id)) + elif isinstance(content, FunctionResultContent) and content.call_id: + resolved_ids.add(str(content.call_id)) + return pending_ids - resolved_ids + + if state_context_msg: + messages_to_run = [msg for msg in provider_messages if not _is_state_context_message(msg)] + if not _pending_tool_call_ids(messages_to_run): + insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) + if insert_index < 0: + insert_index = 0 + messages_to_run.insert(insert_index, state_context_msg) + else: + messages_to_run.extend(provider_messages) client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") @@ -441,10 +552,67 @@ async def run( if safe_metadata: run_kwargs["store"] = True + async def _resolve_approval_responses( + messages: list[Any], + tools_for_execution: list[Any], + ) -> None: + fcc_todo = _collect_approval_responses(messages) + if not fcc_todo: + return + + approved_responses = [resp for resp in fcc_todo.values() if resp.approved] + approved_function_results: list[Any] = [] + if approved_responses and tools_for_execution: + chat_client = getattr(context.agent, "chat_client", None) + config = ( + getattr(chat_client, "function_invocation_configuration", None) or FunctionInvocationConfiguration() + ) + middleware_pipeline = extract_and_merge_function_middleware(chat_client, run_kwargs) + try: + results, _ = await _try_execute_function_calls( + custom_args=run_kwargs, + attempt_idx=0, + function_calls=approved_responses, + tools=tools_for_execution, + middleware_pipeline=middleware_pipeline, + config=config, + ) + approved_function_results = list(results) + except Exception: + logger.error("Failed to execute approved tool calls; injecting error results.") + approved_function_results = [] + + normalized_results: list[FunctionResultContent] = [] + for idx, approval in enumerate(approved_responses): + if idx < len(approved_function_results) and isinstance( + approved_function_results[idx], FunctionResultContent + ): + normalized_results.append(approved_function_results[idx]) + continue + call_id = approval.function_call.call_id or approval.id + normalized_results.append( + FunctionResultContent(call_id=call_id, result="Error: Tool call invocation failed.") + ) + + _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore + + await _resolve_approval_responses(messages_to_run, server_tools) + async for update in context.agent.run_stream(messages_to_run, **run_kwargs): update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") all_updates.append(update) + if event_bridge.current_message_id is None and update.contents: + has_tool_call = any(isinstance(content, FunctionCallContent) for content in update.contents) + has_text = any(isinstance(content, TextContent) for content in update.contents) + if has_tool_call and not has_text: + tool_message_id = generate_event_id() + event_bridge.current_message_id = tool_message_id + logger.info( + "[STREAM] Emitting TextMessageStartEvent for tool-only response message_id=%s", + tool_message_id, + ) + yield TextMessageStartEvent(message_id=tool_message_id, role="assistant") events = await event_bridge.from_agent_run_update(update) logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") for event in events: @@ -508,6 +676,7 @@ async def run( logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) + has_text_content = bool(event_bridge.accumulated_text_content) assistant_text_message = { "id": event_bridge.current_message_id, "role": "assistant", @@ -519,14 +688,15 @@ async def run( if event_bridge.pending_tool_calls: tool_call_message = { - "id": generate_event_id(), + "id": event_bridge.current_message_id if not has_text_content else generate_event_id(), "role": "assistant", "tool_calls": event_bridge.pending_tool_calls.copy(), } all_messages.append(tool_call_message) all_messages.extend(event_bridge.tool_results.copy()) - all_messages.append(assistant_text_message) + if has_text_content: + all_messages.append(assistant_text_message) messages_snapshot = MessagesSnapshotEvent( messages=all_messages, # type: ignore[arg-type] @@ -539,6 +709,27 @@ async def run( yield messages_snapshot else: logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") + if not event_bridge.messages_snapshot_emitted and ( + event_bridge.pending_tool_calls or event_bridge.tool_results + ): + converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) + all_messages = converted_input_messages.copy() + + if event_bridge.pending_tool_calls: + tool_call_message = { + "id": generate_event_id(), + "role": "assistant", + "tool_calls": event_bridge.pending_tool_calls.copy(), + } + all_messages.append(tool_call_message) + + all_messages.extend(event_bridge.tool_results.copy()) + + messages_snapshot = MessagesSnapshotEvent( + messages=all_messages, # type: ignore[arg-type] + ) + logger.info(f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages (tool-only)") + yield messages_snapshot logger.info("[FINALIZE] Emitting RUN_FINISHED event") yield event_bridge.create_run_finished_event() diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py index abbd113418..ab7a3533cd 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/human_in_the_loop_agent.py @@ -75,8 +75,10 @@ def human_in_the_loop_agent(chat_client: ChatClientProtocol) -> ChatAgent: 9. "Calibrate systems" 10. "Final testing" - After calling the function, provide a brief acknowledgment like: - "I've created a plan with 10 steps. You can customize which steps to enable before I proceed." + IMPORTANT: When you call generate_task_steps, the user will be shown the steps and asked to approve. + Do NOT output any text along with the function call - just call the function. + After the user approves and the function executes, THEN provide a brief acknowledgment like: + "The plan has been created with X steps selected." """, chat_client=chat_client, tools=[generate_task_steps], diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index c1456c3d63..281b81c968 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -634,7 +634,7 @@ async def stream_fn( async def test_function_approval_mode_executes_tool(): """Test that function approval with approval_mode='always_require' sends the correct messages.""" - from agent_framework import FunctionApprovalResponseContent, ai_function + from agent_framework import FunctionResultContent, ai_function from agent_framework.ag_ui import AgentFrameworkAgent messages_received: list[Any] = [] @@ -706,27 +706,26 @@ async def stream_fn( assert len(run_started) == 1 assert len(run_finished) == 1 - # Verify that a FunctionApprovalResponseContent was created and sent to the agent - # This is the key fix - the orchestrator should create an approval response - approval_responses_found = False + # Verify that a FunctionResultContent was created and sent to the agent + # Approved tool calls are resolved before the model run. + tool_result_found = False for msg in messages_received: for content in msg.contents: - if isinstance(content, FunctionApprovalResponseContent): - approval_responses_found = True - assert content.approved is True - assert content.function_call.name == "get_datetime" - assert content.function_call.call_id == "call_get_datetime_123" + if isinstance(content, FunctionResultContent): + tool_result_found = True + assert content.call_id == "call_get_datetime_123" + assert content.result == "2025/12/01 12:00:00" break - assert approval_responses_found, ( - "FunctionApprovalResponseContent should be included in messages sent to agent. " - "This is required for the agent framework to execute the approved function." + assert tool_result_found, ( + "FunctionResultContent should be included in messages sent to agent. " + "This is required for the model to see the approved tool execution result." ) async def test_function_approval_mode_rejection(): """Test that function approval rejection creates a rejection response.""" - from agent_framework import FunctionApprovalResponseContent, ai_function + from agent_framework import FunctionResultContent, ai_function from agent_framework.ag_ui import AgentFrameworkAgent messages_received: list[Any] = [] @@ -793,18 +792,17 @@ async def stream_fn( run_finished = [e for e in events if e.type == "RUN_FINISHED"] assert len(run_finished) == 1 - # Verify that a FunctionApprovalResponseContent with approved=False was created + # Verify that a FunctionResultContent with rejection payload was created rejection_found = False for msg in messages_received: for content in msg.contents: - if isinstance(content, FunctionApprovalResponseContent): + if isinstance(content, FunctionResultContent): rejection_found = True - assert content.approved is False - assert content.function_call.name == "delete_all_data" - assert content.function_call.call_id == "call_delete_123" + assert content.call_id == "call_delete_123" + assert content.result == "Error: Tool call invocation was rejected by user." break assert rejection_found, ( - "FunctionApprovalResponseContent with approved=False should be included in messages sent to agent. " - "This tells the agent framework that the tool was rejected." + "FunctionResultContent with rejection details should be included in messages sent to agent. " + "This tells the model that the tool was rejected." ) diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index a21375b87b..6caf106ef8 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -2,6 +2,8 @@ """Tests for message adapters.""" +import json + import pytest from agent_framework import ChatMessage, FunctionCallContent, Role, TextContent @@ -68,6 +70,88 @@ def test_agui_tool_result_to_agent_framework(): assert message.additional_properties.get("tool_call_id") == "call_123" +def test_agui_tool_approval_updates_tool_call_arguments(): + """Tool approval updates matching tool call arguments for snapshots and agent context.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "generate_task_steps", + "arguments": { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "enabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + }, + }, + } + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps( + { + "accepted": True, + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ], + } + ), + "toolCallId": "call_123", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + assert len(messages) == 2 + assistant_msg = messages[0] + func_call = next(content for content in assistant_msg.contents if isinstance(content, FunctionCallContent)) + assert func_call.arguments == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "disabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "disabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + assert approval_content.function_call.parse_arguments() == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + assert approval_content.additional_properties is not None + assert approval_content.additional_properties.get("ag_ui_state_args") == { + "steps": [ + {"description": "Boil water", "status": "enabled"}, + {"description": "Brew coffee", "status": "disabled"}, + {"description": "Serve coffee", "status": "enabled"}, + ] + } + + def test_agui_multiple_messages_to_agent_framework(): """Test converting multiple AG-UI messages.""" messages_input = [ diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index af90ea2e88..8c00602538 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -42,6 +42,29 @@ async def run_stream( yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") +class RecordingAgent: + """Agent stub that captures messages passed to run_stream.""" + + def __init__(self) -> None: + self.chat_options = SimpleNamespace(tools=[], response_format=None) + self.tools: list[Any] = [] + self.chat_client = SimpleNamespace( + function_invocation_configuration=FunctionInvocationConfiguration(), + ) + self.seen_messages: list[Any] | None = None + + async def run_stream( + self, + messages: list[Any], + *, + thread: Any, + tools: list[Any] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[AgentRunResponseUpdate, None]: + self.seen_messages = messages + yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + + async def test_default_orchestrator_merges_client_tools() -> None: """Client tool declarations are merged with server tools before running agent.""" @@ -151,3 +174,104 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: last_event = events[-1] assert last_event.run_id == "test-snakecase-runid" assert last_event.thread_id == "test-snakecase-threadid" + + +async def test_state_context_injected_when_tool_call_state_mismatch() -> None: + """State context should be injected when current state differs from tool call args.""" + + agent = RecordingAgent() + orchestrator = DefaultOrchestrator() + + tool_recipe = {"title": "Salad", "special_preferences": []} + current_recipe = {"title": "Salad", "special_preferences": ["Vegetarian"]} + + input_data = { + "state": {"recipe": current_recipe}, + "messages": [ + {"role": "system", "content": "Instructions"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "update_recipe", "arguments": {"recipe": tool_recipe}}, + } + ], + }, + {"role": "user", "content": "What are the dietary preferences?"}, + ], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig( + state_schema={"recipe": {"type": "object"}}, + predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}}, + require_confirmation=False, + ), + ) + + async for _event in orchestrator.run(context): + pass + + assert agent.seen_messages is not None + state_messages = [] + for msg in agent.seen_messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value != "system": + continue + for content in msg.contents or []: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + state_messages.append(content.text) + assert state_messages + assert "Vegetarian" in state_messages[0] + + +async def test_state_context_not_injected_when_tool_call_matches_state() -> None: + """State context should be skipped when tool call args match current state.""" + + agent = RecordingAgent() + orchestrator = DefaultOrchestrator() + + input_data = { + "messages": [ + {"role": "system", "content": "Instructions"}, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "update_recipe", "arguments": {"recipe": {}}}, + } + ], + }, + {"role": "user", "content": "What are the dietary preferences?"}, + ], + } + + context = ExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig( + state_schema={"recipe": {"type": "object"}}, + predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}}, + require_confirmation=False, + ), + ) + + async for _event in orchestrator.run(context): + pass + + assert agent.seen_messages is not None + state_messages = [] + for msg in agent.seen_messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + if role_value != "system": + continue + for content in msg.contents or []: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + state_messages.append(content.text) + assert not state_messages diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 1da11bffbc..279cdb26b4 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -385,8 +385,8 @@ async def test_state_context_injection() -> None: assert "banana" in system_messages[0].contents[0].text -async def test_no_state_context_injection_with_tool_calls() -> None: - """Test state context is NOT injected if conversation has tool calls.""" +async def test_state_context_injection_with_tool_calls_and_input_state() -> None: + """Test state context is injected when state is provided, even with tool calls.""" from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent messages = [ @@ -420,13 +420,13 @@ async def test_no_state_context_injection_with_tool_calls() -> None: async for event in orchestrator.run(context): events.append(event) - # Should NOT inject state context system message since conversation has tool calls + # Should inject state context system message because input state is provided system_messages = [ msg for msg in agent.messages_received if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "system" ] - assert len(system_messages) == 0 + assert len(system_messages) == 1 async def test_structured_output_processing() -> None: From 4df06ed17b5d6dde7e82c0f170bb5937dec6cb2e Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 8 Jan 2026 12:37:57 +0900 Subject: [PATCH 3/8] Bug fix finalized --- .../ag-ui/agent_framework_ag_ui/_events.py | 8 ++-- .../_message_adapters.py | 17 ++++++- .../agent_framework_ag_ui/_orchestrators.py | 4 ++ .../ag-ui/tests/test_message_adapters.py | 27 +++++++++++ .../tests/test_orchestrators_coverage.py | 48 +++++++++++++++++++ 5 files changed, 99 insertions(+), 5 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 84c4579745..f32a755b74 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -433,14 +433,14 @@ def _emit_snapshot_for_tool_result(self) -> list[BaseEvent]: break if should_emit_snapshot and not is_predictive_without_confirmation: - from ._message_adapters import agent_framework_messages_to_agui + from ._message_adapters import agui_messages_to_snapshot_format assistant_message = { "id": generate_event_id(), "role": "assistant", "tool_calls": self.pending_tool_calls.copy(), } - converted_input_messages = agent_framework_messages_to_agui(self.input_messages) + converted_input_messages = agui_messages_to_snapshot_format(self.input_messages) all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() messages_snapshot_event = MessagesSnapshotEvent( @@ -533,7 +533,7 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: ) events.append(confirm_end) - from ._message_adapters import agent_framework_messages_to_agui + from ._message_adapters import agui_messages_to_snapshot_format assistant_message = { "id": generate_event_id(), @@ -541,7 +541,7 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: "tool_calls": self.pending_tool_calls.copy(), } - converted_input_messages = agent_framework_messages_to_agui(self.input_messages) + converted_input_messages = agui_messages_to_snapshot_format(self.input_messages) all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() messages_snapshot_event = MessagesSnapshotEvent( diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 079c23a6df..3e0a865be7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -28,6 +28,19 @@ Role.SYSTEM: "system", } +_ALLOWED_AGUI_ROLES = {"user", "assistant", "system", "tool"} + + +def _normalize_agui_role(raw_role: Any) -> str: + if not isinstance(raw_role, str): + return "user" + role = raw_role.lower() + if role == "developer": + return "system" + if role in _ALLOWED_AGUI_ROLES: + return role + return "user" + def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[ChatMessage]: """Convert AG-UI messages to Agent Framework format. @@ -70,7 +83,7 @@ def _update_tool_call_arguments( for msg in messages: # Handle standard tool result messages early (role="tool") to preserve provider invariants # This path maps AG‑UI tool messages to FunctionResultContent with the correct tool_call_id - role_str = msg.get("role", "user") + role_str = _normalize_agui_role(msg.get("role", "user")) if role_str == "tool": # Prefer explicit tool_call_id fields; fall back to backend fields only if necessary tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId") @@ -354,6 +367,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str if isinstance(msg, dict): # Always work on a copy to avoid mutating input normalized_msg = msg.copy() + normalized_msg["role"] = _normalize_agui_role(normalized_msg.get("role")) # Ensure ID exists if "id" not in normalized_msg: normalized_msg["id"] = generate_event_id() @@ -496,6 +510,7 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic function_payload_dict["arguments"] = json.dumps(arguments) # Normalize tool_call_id to toolCallId for tool messages + normalized_msg["role"] = _normalize_agui_role(normalized_msg.get("role")) if normalized_msg.get("role") == "tool": if "tool_call_id" in normalized_msg: normalized_msg["toolCallId"] = normalized_msg["tool_call_id"] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 4dae942ca2..3cdb9912b3 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -623,6 +623,10 @@ async def _resolve_approval_responses( if event_bridge.should_stop_after_confirm: logger.info("Stopping run after confirm_changes - waiting for user response") + if event_bridge.current_message_id: + logger.info(f"[CONFIRM] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") + yield event_bridge.create_message_end_event(event_bridge.current_message_id) + event_bridge.current_message_id = None yield event_bridge.create_run_finished_event() return diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index f085631e56..27af3aae93 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -10,6 +10,7 @@ from agent_framework_ag_ui._message_adapters import ( agent_framework_messages_to_agui, agui_messages_to_agent_framework, + agui_messages_to_snapshot_format, extract_text_from_contents, ) @@ -45,6 +46,32 @@ def test_agent_framework_to_agui_basic(sample_agent_framework_message): assert messages[0]["id"] == "msg-123" +def test_agent_framework_to_agui_normalizes_dict_roles(): + """Dict inputs normalize unknown roles for UI compatibility.""" + messages = [ + {"role": "developer", "content": "policy"}, + {"role": "weird_role", "content": "payload"}, + ] + + converted = agent_framework_messages_to_agui(messages) + + assert converted[0]["role"] == "system" + assert converted[1]["role"] == "user" + + +def test_agui_snapshot_format_normalizes_roles(): + """Snapshot normalization coerces roles into supported AG-UI values.""" + messages = [ + {"role": "Developer", "content": "policy"}, + {"role": "unknown", "content": "payload"}, + ] + + normalized = agui_messages_to_snapshot_format(messages) + + assert normalized[0]["role"] == "system" + assert normalized[1]["role"] == "user" + + def test_agui_tool_result_to_agent_framework(): """Test converting AG-UI tool result message to Agent Framework.""" tool_result_message = { diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index 279cdb26b4..c2ded90eb5 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -685,6 +685,54 @@ async def test_confirm_changes_with_invalid_json_fallback() -> None: assert len(user_messages) == 1 +async def test_confirm_changes_closes_active_message_before_finish() -> None: + """Confirm-changes flow closes any active text message before run finishes.""" + from ag_ui.core import TextMessageEndEvent, TextMessageStartEvent + from agent_framework import FunctionCallContent, FunctionResultContent + + updates = [ + AgentRunResponseUpdate( + contents=[ + FunctionCallContent( + name="write_document_local", + call_id="call_1", + arguments='{"document": "Draft"}', + ) + ] + ), + AgentRunResponseUpdate(contents=[FunctionResultContent(call_id="call_1", result="Done")]), + ] + + orchestrator = DefaultOrchestrator() + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Start"}]} + agent = StubAgent( + chat_options=DEFAULT_CHAT_OPTIONS, + updates=updates, + ) + context = TestExecutionContext( + input_data=input_data, + agent=agent, + config=AgentConfig( + predict_state_config={"document": {"tool": "write_document_local", "tool_argument": "document"}}, + require_confirmation=True, + ), + ) + + events: list[Any] = [] + async for event in orchestrator.run(context): + events.append(event) + + start_events = [e for e in events if isinstance(e, TextMessageStartEvent)] + end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] + assert len(start_events) == 1 + assert len(end_events) == 1 + assert end_events[0].message_id == start_events[0].message_id + + end_index = events.index(end_events[0]) + finished_index = events.index([e for e in events if e.type == "RUN_FINISHED"][0]) + assert end_index < finished_index + + async def test_tool_result_kept_when_call_id_matches() -> None: """Test tool result is kept when call_id matches pending tool calls.""" from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent From dd339d26bd4ad937d5eafe7e25d12f3eeb457676 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 8 Jan 2026 13:13:54 +0900 Subject: [PATCH 4/8] Refactoring to clean up code --- .../ag-ui/agent_framework_ag_ui/_events.py | 161 ------------- .../_message_adapters.py | 193 ++++++++++++++++ .../_orchestration/_message_hygiene.py | 198 ---------------- .../agent_framework_ag_ui/_orchestrators.py | 215 ++++++++++++++---- .../tests/test_backend_tool_rendering.py | 4 +- .../ag-ui/tests/test_events_comprehensive.py | 4 +- .../ag-ui/tests/test_helpers_ag_ui.py | 9 +- .../ag-ui/tests/test_message_hygiene.py | 9 +- .../tests/test_orchestrators_coverage.py | 2 +- 9 files changed, 374 insertions(+), 421 deletions(-) delete mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index f32a755b74..19cb2982b9 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -11,8 +11,6 @@ from ag_ui.core import ( BaseEvent, CustomEvent, - EventType, - MessagesSnapshotEvent, RunFinishedEvent, RunStartedEvent, StateDeltaEvent, @@ -49,7 +47,6 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, current_state: dict[str, Any] | None = None, skip_text_content: bool = False, - input_messages: list[Any] | None = None, require_confirmation: bool = True, ) -> None: """ @@ -62,7 +59,6 @@ def __init__( Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} current_state: Reference to the current state dict for tracking updates. skip_text_content: If True, skip emitting TextMessageContentEvents (for structured outputs). - input_messages: The input messages from the conversation history. require_confirmation: Whether predictive state updates require user confirmation. """ self.run_id = run_id @@ -83,14 +79,6 @@ def __init__( self.should_stop_after_confirm: bool = False # Flag to stop run after confirm_changes self.suppressed_summary: str = "" # Store LLM summary to show after confirmation - # For MessagesSnapshotEvent: track tool calls and results - self.input_messages = input_messages or [] - self.pending_tool_calls: list[dict[str, Any]] = [] # Track tool calls for assistant message - self.tool_results: list[dict[str, Any]] = [] # Track tool results - self.tool_calls_ended: set[str] = set() # Track which tool calls have had ToolCallEndEvent emitted - self.accumulated_text_content: str = "" # Track accumulated text for final MessagesSnapshotEvent - self.messages_snapshot_emitted: bool = False # Track snapshot emission to avoid duplicates - async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[BaseEvent]: """ Convert an AgentRunResponseUpdate to AG-UI events. @@ -156,7 +144,6 @@ def _handle_text_content(self, content: TextContent) -> list[BaseEvent]: message_id=self.current_message_id, delta=content.text, ) - self.accumulated_text_content += content.text logger.info(f" EMITTING TextMessageContentEvent with text_len={len(content.text)}") events.append(event) return events @@ -185,17 +172,6 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba ) logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'") events.append(tool_start_event) - - self.pending_tool_calls.append( - { - "id": tool_call_id, - "type": "function", - "function": { - "name": content.name, - "arguments": "", - }, - } - ) elif tool_call_id: self.current_tool_call_id = tool_call_id @@ -208,13 +184,7 @@ def _handle_function_call_content(self, content: FunctionCallContent) -> list[Ba ) events.append(args_event) - for tool_call in self.pending_tool_calls: - if tool_call["id"] == tool_call_id: - tool_call["function"]["arguments"] += delta_str - break - events.extend(self._emit_predictive_state_deltas(delta_str)) - events.extend(self._legacy_predictive_state(content)) return events @@ -319,59 +289,6 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.pending_state_updates[state_key] = state_value return events - def _legacy_predictive_state(self, content: FunctionCallContent) -> list[BaseEvent]: - events: list[BaseEvent] = [] - if not (content.name and content.arguments): - return events - parsed_args = content.parse_arguments() - if not parsed_args: - return events - - logger.info( - "Checking predict_state_config keys: %s", - list(self.predict_state_config.keys()) if self.predict_state_config else "None", - ) - for state_key, config in self.predict_state_config.items(): - logger.info(f"Checking state_key='{state_key}'") - if config["tool"] != content.name: - continue - tool_arg_name = config["tool_argument"] - logger.info(f"MATCHED tool '{content.name}' for state key '{state_key}', arg='{tool_arg_name}'") - - state_value: Any - if tool_arg_name == "*": - state_value = parsed_args - logger.info(f"Using all args as state value, keys: {list(state_value.keys())}") - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - logger.info(f"Using specific arg '{tool_arg_name}' as state value") - else: - logger.warning(f"Tool argument '{tool_arg_name}' not found in parsed args") - continue - - previous_value = self.last_emitted_state.get(state_key, object()) - if previous_value == state_value: - logger.info( - "Skipping duplicate StateDeltaEvent for key '%s' - value unchanged", - state_key, - ) - continue - - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", - "path": f"/{state_key}", - "value": state_value, - } - ], - ) - logger.info(f"Emitting StateDeltaEvent for key '{state_key}', value type: {type(state_value)}") # type: ignore - events.append(state_delta_event) - self.pending_state_updates[state_key] = state_value - self.last_emitted_state[state_key] = state_value - return events - def _handle_function_result_content(self, content: FunctionResultContent) -> list[BaseEvent]: events: list[BaseEvent] = [] if content.call_id: @@ -380,7 +297,6 @@ def _handle_function_result_content(self, content: FunctionResultContent) -> lis ) logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'") events.append(end_event) - self.tool_calls_ended.add(content.call_id) if self.state_delta_count > 0: logger.info( @@ -402,56 +318,10 @@ def _handle_function_result_content(self, content: FunctionResultContent) -> lis role="tool", ) events.append(result_event) - - self.tool_results.append( - { - "id": result_message_id, - "role": "tool", - "toolCallId": content.call_id, - "content": result_content, - } - ) - - events.extend(self._emit_snapshot_for_tool_result()) events.extend(self._emit_state_snapshot_and_confirmation()) return events - def _emit_snapshot_for_tool_result(self) -> list[BaseEvent]: - events: list[BaseEvent] = [] - should_emit_snapshot = self.pending_tool_calls and self.tool_results - - is_predictive_without_confirmation = False - if should_emit_snapshot and self.current_tool_call_name and self.predict_state_config: - for _, config in self.predict_state_config.items(): - if config["tool"] == self.current_tool_call_name and not self.require_confirmation: - is_predictive_without_confirmation = True - logger.info( - "Skipping intermediate MessagesSnapshotEvent for predictive tool '%s' - delaying until summary", - self.current_tool_call_name, - ) - break - - if should_emit_snapshot and not is_predictive_without_confirmation: - from ._message_adapters import agui_messages_to_snapshot_format - - assistant_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": self.pending_tool_calls.copy(), - } - converted_input_messages = agui_messages_to_snapshot_format(self.input_messages) - all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() - - messages_snapshot_event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, - messages=all_messages, # type: ignore[arg-type] - ) - logger.info(f"Emitting MessagesSnapshotEvent with {len(all_messages)} messages") - events.append(messages_snapshot_event) - self.messages_snapshot_emitted = True - return events - def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: events: list[BaseEvent] = [] if self.pending_state_updates: @@ -505,17 +375,6 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: confirm_call_id = generate_event_id() logger.info("Emitting confirm_changes tool call for predictive update") - self.pending_tool_calls.append( - { - "id": confirm_call_id, - "type": "function", - "function": { - "name": "confirm_changes", - "arguments": "{}", - }, - } - ) - confirm_start = ToolCallStartEvent( tool_call_id=confirm_call_id, tool_call_name="confirm_changes", @@ -533,25 +392,6 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: ) events.append(confirm_end) - from ._message_adapters import agui_messages_to_snapshot_format - - assistant_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": self.pending_tool_calls.copy(), - } - - converted_input_messages = agui_messages_to_snapshot_format(self.input_messages) - all_messages = converted_input_messages + [assistant_message] + self.tool_results.copy() - - messages_snapshot_event = MessagesSnapshotEvent( - type=EventType.MESSAGES_SNAPSHOT, - messages=all_messages, # type: ignore[arg-type] - ) - logger.info(f"Emitting MessagesSnapshotEvent for confirm_changes with {len(all_messages)} messages") - events.append(messages_snapshot_event) - self.messages_snapshot_emitted = True - self.should_stop_after_confirm = True logger.info("Set flag to stop run after confirm_changes") return events @@ -604,7 +444,6 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq ) logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") events.append(end_event) - self.tool_calls_ended.add(content.function_call.call_id) approval_event = CustomEvent( name="function_approval_request", diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 3e0a865be7..2da519bac7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -3,6 +3,7 @@ """Message format conversion between AG-UI and Agent Framework.""" import json +import logging from typing import Any, cast from agent_framework import ( @@ -30,6 +31,8 @@ _ALLOWED_AGUI_ROLES = {"user", "assistant", "system", "tool"} +logger = logging.getLogger(__name__) + def _normalize_agui_role(raw_role: Any) -> str: if not isinstance(raw_role, str): @@ -42,6 +45,196 @@ def _normalize_agui_role(raw_role: Any) -> str: return "user" +def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: + """Normalize tool ordering and inject synthetic results for AG-UI edge cases.""" + sanitized: list[ChatMessage] = [] + pending_tool_call_ids: set[str] | None = None + pending_confirm_changes_id: str | None = None + + for msg in messages: + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + + if role_value == "assistant": + tool_ids = { + str(content.call_id) + for content in msg.contents or [] + if isinstance(content, FunctionCallContent) and content.call_id + } + confirm_changes_call = None + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": + confirm_changes_call = content + break + + sanitized.append(msg) + pending_tool_call_ids = tool_ids if tool_ids else None + pending_confirm_changes_id = ( + str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None + ) + continue + + if role_value == "user": + approval_call_ids: set[str] = set() + for content in msg.contents or []: + if type(content) is FunctionApprovalResponseContent: + if content.function_call and content.function_call.call_id: + approval_call_ids.add(str(content.function_call.call_id)) + + if approval_call_ids and pending_tool_call_ids: + pending_tool_call_ids -= approval_call_ids + logger.info( + f"FunctionApprovalResponseContent found for call_ids={sorted(approval_call_ids)} - " + "framework will handle execution" + ) + + if pending_confirm_changes_id: + user_text = "" + for content in msg.contents or []: + if isinstance(content, TextContent): + user_text = content.text + break + + try: + parsed = json.loads(user_text) + if "accepted" in parsed: + logger.info( + f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" + ) + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_confirm_changes_id, + result="Confirmed" if parsed.get("accepted") else "Rejected", + ) + ], + ) + sanitized.append(synthetic_result) + if pending_tool_call_ids: + pending_tool_call_ids.discard(pending_confirm_changes_id) + pending_confirm_changes_id = None + continue + except (json.JSONDecodeError, KeyError) as exc: + logger.debug(f"Could not parse user message as confirm_changes response: {type(exc).__name__}") + + if pending_tool_call_ids: + logger.info( + f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - " + "injecting synthetic results" + ) + for pending_call_id in pending_tool_call_ids: + logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}") + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_call_id, + result="Tool execution skipped - user provided follow-up message", + ) + ], + ) + sanitized.append(synthetic_result) + pending_tool_call_ids = None + pending_confirm_changes_id = None + + sanitized.append(msg) + pending_confirm_changes_id = None + continue + + if role_value == "tool": + if not pending_tool_call_ids: + continue + keep = False + for content in msg.contents or []: + if isinstance(content, FunctionResultContent): + call_id = str(content.call_id) + if call_id in pending_tool_call_ids: + keep = True + if call_id == pending_confirm_changes_id: + pending_confirm_changes_id = None + break + if keep: + sanitized.append(msg) + continue + + sanitized.append(msg) + pending_tool_call_ids = None + pending_confirm_changes_id = None + + return sanitized + + +def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: + """Remove duplicate messages while preserving order.""" + seen_keys: dict[Any, int] = {} + unique_messages: list[ChatMessage] = [] + + for idx, msg in enumerate(messages): + role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + + if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): + call_id = str(msg.contents[0].call_id) + key: Any = (role_value, call_id) + + if key in seen_keys: + existing_idx = seen_keys[key] + existing_msg = unique_messages[existing_idx] + + existing_result = None + if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): + existing_result = existing_msg.contents[0].result + new_result = msg.contents[0].result + + if (not existing_result or existing_result == "") and new_result: + logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}") + unique_messages[existing_idx] = msg + else: + logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}") + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + elif ( + role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents) + ): + tool_call_ids = tuple( + sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) + ) + key = (role_value, tool_call_ids) + + if key in seen_keys: + logger.info(f"Skipping duplicate assistant tool call at index {idx}") + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + else: + content_str = str([str(c) for c in msg.contents]) if msg.contents else "" + key = (role_value, hash(content_str)) + + if key in seen_keys: + logger.info(f"Skipping duplicate message at index {idx}: role={role_value}") + continue + + seen_keys[key] = len(unique_messages) + unique_messages.append(msg) + + return unique_messages + + +def normalize_agui_input_messages( + messages: list[dict[str, Any]], +) -> tuple[list[ChatMessage], list[dict[str, Any]]]: + """Normalize raw AG-UI messages into provider and snapshot formats.""" + provider_messages = agui_messages_to_agent_framework(messages) + provider_messages = _sanitize_tool_history(provider_messages) + provider_messages = _deduplicate_messages(provider_messages) + snapshot_messages = agui_messages_to_snapshot_format(messages) + return provider_messages, snapshot_messages + + def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[ChatMessage]: """Convert AG-UI messages to Agent Framework format. diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py deleted file mode 100644 index 6be50a9f54..0000000000 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_message_hygiene.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Message hygiene utilities for orchestrators.""" - -import json -import logging -from typing import Any - -from agent_framework import ( - ChatMessage, - FunctionApprovalResponseContent, - FunctionCallContent, - FunctionResultContent, - TextContent, -) - -logger = logging.getLogger(__name__) - - -def sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: - """Normalize tool ordering and inject synthetic results for AG-UI edge cases.""" - sanitized: list[ChatMessage] = [] - pending_tool_call_ids: set[str] | None = None - pending_confirm_changes_id: str | None = None - - for msg in messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - if role_value == "assistant": - tool_ids = { - str(content.call_id) - for content in msg.contents or [] - if isinstance(content, FunctionCallContent) and content.call_id - } - confirm_changes_call = None - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.name == "confirm_changes": - confirm_changes_call = content - break - - sanitized.append(msg) - pending_tool_call_ids = tool_ids if tool_ids else None - pending_confirm_changes_id = ( - str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None - ) - continue - - if role_value == "user": - # Check if this message contains FunctionApprovalResponseContent - # If so, the framework will handle tool execution - don't inject synthetic results - approval_call_ids: set[str] = set() - for content in msg.contents or []: - if type(content) is FunctionApprovalResponseContent: - if content.function_call and content.function_call.call_id: - approval_call_ids.add(str(content.function_call.call_id)) - - if approval_call_ids and pending_tool_call_ids: - # Remove approved call_ids from pending - the framework will execute them - pending_tool_call_ids -= approval_call_ids - logger.info( - f"FunctionApprovalResponseContent found for call_ids={sorted(approval_call_ids)} - " - "framework will handle execution" - ) - - if pending_confirm_changes_id: - user_text = "" - for content in msg.contents or []: - if isinstance(content, TextContent): - user_text = content.text - break - - try: - parsed = json.loads(user_text) - if "accepted" in parsed: - logger.info( - f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}" - ) - synthetic_result = ChatMessage( - role="tool", - contents=[ - FunctionResultContent( - call_id=pending_confirm_changes_id, - result="Confirmed" if parsed.get("accepted") else "Rejected", - ) - ], - ) - sanitized.append(synthetic_result) - if pending_tool_call_ids: - pending_tool_call_ids.discard(pending_confirm_changes_id) - pending_confirm_changes_id = None - continue - except (json.JSONDecodeError, KeyError) as exc: - logger.debug("Could not parse user message as confirm_changes response: %s", type(exc).__name__) - - if pending_tool_call_ids: - logger.info( - f"User message arrived with {len(pending_tool_call_ids)} pending tool calls - injecting synthetic results" - ) - for pending_call_id in pending_tool_call_ids: - logger.info(f"Injecting synthetic tool result for pending call_id={pending_call_id}") - synthetic_result = ChatMessage( - role="tool", - contents=[ - FunctionResultContent( - call_id=pending_call_id, - result="Tool execution skipped - user provided follow-up message", - ) - ], - ) - sanitized.append(synthetic_result) - pending_tool_call_ids = None - pending_confirm_changes_id = None - - sanitized.append(msg) - pending_confirm_changes_id = None - continue - - if role_value == "tool": - if not pending_tool_call_ids: - continue - keep = False - for content in msg.contents or []: - if isinstance(content, FunctionResultContent): - call_id = str(content.call_id) - if call_id in pending_tool_call_ids: - keep = True - if call_id == pending_confirm_changes_id: - pending_confirm_changes_id = None - break - if keep: - sanitized.append(msg) - continue - - sanitized.append(msg) - pending_tool_call_ids = None - pending_confirm_changes_id = None - - return sanitized - - -def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: - """Remove duplicate messages while preserving order.""" - seen_keys: dict[Any, int] = {} - unique_messages: list[ChatMessage] = [] - - for idx, msg in enumerate(messages): - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - - if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): - call_id = str(msg.contents[0].call_id) - key: Any = (role_value, call_id) - - if key in seen_keys: - existing_idx = seen_keys[key] - existing_msg = unique_messages[existing_idx] - - existing_result = None - if existing_msg.contents and isinstance(existing_msg.contents[0], FunctionResultContent): - existing_result = existing_msg.contents[0].result - new_result = msg.contents[0].result - - if (not existing_result or existing_result == "") and new_result: - logger.info(f"Replacing empty tool result at index {existing_idx} with data from index {idx}") - unique_messages[existing_idx] = msg - else: - logger.info(f"Skipping duplicate tool result at index {idx}: call_id={call_id}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - elif ( - role_value == "assistant" and msg.contents and any(isinstance(c, FunctionCallContent) for c in msg.contents) - ): - tool_call_ids = tuple( - sorted(str(c.call_id) for c in msg.contents if isinstance(c, FunctionCallContent) and c.call_id) - ) - key = (role_value, tool_call_ids) - - if key in seen_keys: - logger.info(f"Skipping duplicate assistant tool call at index {idx}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - else: - content_str = str([str(c) for c in msg.contents]) if msg.contents else "" - key = (role_value, hash(content_str)) - - if key in seen_keys: - logger.info(f"Skipping duplicate message at index {idx}: role={role_value}") - continue - - seen_keys[key] = len(unique_messages) - unique_messages.append(msg) - - return unique_messages diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 3cdb9912b3..f2c15a4516 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -17,6 +17,10 @@ TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, ) from agent_framework import ( AgentProtocol, @@ -30,9 +34,9 @@ from agent_framework._middleware import extract_and_merge_function_middleware from agent_framework._tools import ( FunctionInvocationConfiguration, - _collect_approval_responses, - _replace_approval_contents_with_results, - _try_execute_function_calls, + _collect_approval_responses, # type: ignore + _replace_approval_contents_with_results, # type: ignore + _try_execute_function_calls, # type: ignore ) from ._utils import convert_agui_tools_to_agent_framework, generate_event_id @@ -70,6 +74,7 @@ def __init__( # Lazy-loaded properties self._messages = None + self._snapshot_messages = None self._last_message = None self._run_id: str | None = None self._thread_id: str | None = None @@ -78,12 +83,27 @@ def __init__( def messages(self): """Get converted Agent Framework messages (lazy loaded).""" if self._messages is None: - from ._message_adapters import agui_messages_to_agent_framework + from ._message_adapters import normalize_agui_input_messages raw = self.input_data.get("messages", []) - self._messages = agui_messages_to_agent_framework(raw) + if not isinstance(raw, list): + raw = [] + self._messages, self._snapshot_messages = normalize_agui_input_messages(raw) return self._messages + @property + def snapshot_messages(self) -> list[dict[str, Any]]: + """Get normalized AG-UI snapshot messages (lazy loaded).""" + if self._snapshot_messages is None: + if self._messages is None: + _ = self.messages + else: + from ._message_adapters import agent_framework_messages_to_agui, agui_messages_to_snapshot_format + + raw_snapshot = agent_framework_messages_to_agui(self._messages) + self._snapshot_messages = agui_messages_to_snapshot_format(raw_snapshot) + return self._snapshot_messages or [] + @property def last_message(self): """Get the last message in the conversation (lazy loaded).""" @@ -279,8 +299,6 @@ async def run( AG-UI events """ from ._events import AgentFrameworkEventBridge - from ._message_adapters import agui_messages_to_snapshot_format - from ._orchestration._message_hygiene import deduplicate_messages, sanitize_tool_history from ._orchestration._state_manager import StateManager from ._orchestration._tooling import ( collect_server_tools, @@ -308,7 +326,6 @@ async def run( predict_state_config=context.config.predict_state_config, current_state=current_state, skip_text_content=skip_text_content, - input_messages=context.input_data.get("messages", []), require_confirmation=context.config.require_confirmation, ) @@ -330,17 +347,18 @@ async def run( if current_state: thread.metadata["current_state"] = current_state # type: ignore[attr-defined] - raw_messages = context.messages or [] - if not raw_messages: + provider_messages = context.messages or [] + snapshot_messages = context.snapshot_messages + if not provider_messages: logger.warning("No messages provided in AG-UI input") yield event_bridge.create_run_finished_event() return - logger.info(f"Received {len(raw_messages)} raw messages from client") - for i, msg in enumerate(raw_messages): + logger.info(f"Received {len(provider_messages)} provider messages from client") + for i, msg in enumerate(provider_messages): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) msg_id = getattr(msg, "message_id", None) - logger.info(f" Raw message {i}: role={role}, id={msg_id}") + logger.info(f" Message {i}: role={role}, id={msg_id}") if hasattr(msg, "contents") and msg.contents: for j, content in enumerate(msg.contents): content_type = type(content).__name__ @@ -363,15 +381,12 @@ async def run( else: logger.debug(f" Content {j}: {content_type}") - sanitized_messages = sanitize_tool_history(raw_messages) - provider_messages = deduplicate_messages(sanitized_messages) - if not provider_messages: logger.info("No provider-eligible messages after filtering; finishing run without invoking agent.") yield event_bridge.create_run_finished_event() return - logger.info(f"Processing {len(provider_messages)} provider messages after sanitization/deduplication") + logger.info(f"Processing {len(provider_messages)} provider messages after normalization") for i, msg in enumerate(provider_messages): role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) logger.info(f" Message {i}: role={role}") @@ -395,6 +410,14 @@ async def run( else: logger.info(f" Content {j}: {content_type}") + pending_tool_calls: list[dict[str, Any]] = [] + tool_calls_by_id: dict[str, dict[str, Any]] = {} + tool_results: list[dict[str, Any]] = [] + tool_calls_ended: set[str] = set() + messages_snapshot_emitted = False + accumulated_text_content = "" + active_message_id: str | None = None + # Check for FunctionApprovalResponseContent and emit updated state snapshot # This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps) for msg in provider_messages: @@ -596,6 +619,48 @@ async def _resolve_approval_responses( _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore + def _should_emit_tool_snapshot(tool_name: str | None) -> bool: + if not pending_tool_calls or not tool_results: + return False + if tool_name and context.config.predict_state_config and not context.config.require_confirmation: + for config in context.config.predict_state_config.values(): + if config["tool"] == tool_name: + logger.info( + f"Skipping intermediate MessagesSnapshotEvent for predictive tool '{tool_name}' " + " - delaying until summary" + ) + return False + return True + + def _build_messages_snapshot() -> MessagesSnapshotEvent: + has_text_content = bool(accumulated_text_content) + all_messages = snapshot_messages.copy() + + if pending_tool_calls: + tool_call_message_id = ( + active_message_id if not has_text_content and active_message_id else generate_event_id() + ) + tool_call_message = { + "id": tool_call_message_id, + "role": "assistant", + "tool_calls": pending_tool_calls.copy(), + } + all_messages.append(tool_call_message) + + all_messages.extend(tool_results) + + if has_text_content and active_message_id: + assistant_text_message = { + "id": active_message_id, + "role": "assistant", + "content": accumulated_text_content, + } + all_messages.append(assistant_text_message) + + return MessagesSnapshotEvent( + messages=all_messages, # type: ignore[arg-type] + ) + await _resolve_approval_responses(messages_to_run, server_tools) async for update in context.agent.run_stream(messages_to_run, **run_kwargs): @@ -608,6 +673,8 @@ async def _resolve_approval_responses( if has_tool_call and not has_text: tool_message_id = generate_event_id() event_bridge.current_message_id = tool_message_id + active_message_id = tool_message_id + accumulated_text_content = "" logger.info( "[STREAM] Emitting TextMessageStartEvent for tool-only response message_id=%s", tool_message_id, @@ -616,8 +683,67 @@ async def _resolve_approval_responses( events = await event_bridge.from_agent_run_update(update) logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") for event in events: + if isinstance(event, TextMessageStartEvent): + active_message_id = event.message_id + accumulated_text_content = "" + elif isinstance(event, TextMessageContentEvent): + accumulated_text_content += event.delta + elif isinstance(event, ToolCallStartEvent): + tool_call_entry = tool_calls_by_id.get(event.tool_call_id) + if tool_call_entry is None: + tool_call_entry = { + "id": event.tool_call_id, + "type": "function", + "function": { + "name": event.tool_call_name, + "arguments": "", + }, + } + pending_tool_calls.append(tool_call_entry) + tool_calls_by_id[event.tool_call_id] = tool_call_entry + else: + tool_call_entry["function"]["name"] = event.tool_call_name + elif isinstance(event, ToolCallArgsEvent): + tool_call_entry = tool_calls_by_id.get(event.tool_call_id) + if tool_call_entry is None: + tool_call_entry = { + "id": event.tool_call_id, + "type": "function", + "function": { + "name": "", + "arguments": "", + }, + } + pending_tool_calls.append(tool_call_entry) + tool_calls_by_id[event.tool_call_id] = tool_call_entry + tool_call_entry["function"]["arguments"] += event.delta + elif isinstance(event, ToolCallEndEvent): + tool_calls_ended.add(event.tool_call_id) + elif isinstance(event, ToolCallResultEvent): + tool_results.append( + { + "id": event.message_id, + "role": "tool", + "toolCallId": event.tool_call_id, + "content": event.content, + } + ) logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event + if isinstance(event, ToolCallResultEvent): + tool_name = tool_calls_by_id.get(event.tool_call_id, {}).get("function", {}).get("name") + if _should_emit_tool_snapshot(tool_name): + messages_snapshot_emitted = True + messages_snapshot = _build_messages_snapshot() + logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") + yield messages_snapshot + elif isinstance(event, ToolCallEndEvent): + tool_name = tool_calls_by_id.get(event.tool_call_id, {}).get("function", {}).get("name") + if tool_name == "confirm_changes": + messages_snapshot_emitted = True + messages_snapshot = _build_messages_snapshot() + logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") + yield messages_snapshot logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") @@ -630,10 +756,8 @@ async def _resolve_approval_responses( yield event_bridge.create_run_finished_event() return - if event_bridge.pending_tool_calls: - pending_without_end = [ - tc for tc in event_bridge.pending_tool_calls if tc.get("id") not in event_bridge.tool_calls_ended - ] + if pending_tool_calls: + pending_without_end = [tc for tc in pending_tool_calls if tc.get("id") not in tool_calls_ended] if pending_without_end: logger.info( "Found %s pending tool calls without end event - emitting ToolCallEndEvent", @@ -642,8 +766,6 @@ async def _resolve_approval_responses( for tool_call in pending_without_end: tool_call_id = tool_call.get("id") if tool_call_id: - from ag_ui.core import ToolCallEndEvent - end_event = ToolCallEndEvent(tool_call_id=tool_call_id) logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'") yield end_event @@ -680,58 +802,55 @@ async def _resolve_approval_responses( logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) - has_text_content = bool(event_bridge.accumulated_text_content) - assistant_text_message = { - "id": event_bridge.current_message_id, - "role": "assistant", - "content": event_bridge.accumulated_text_content, - } - - converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) - all_messages = converted_input_messages.copy() + has_text_content = bool(accumulated_text_content) + all_messages = snapshot_messages.copy() - if event_bridge.pending_tool_calls: + if pending_tool_calls: tool_call_message = { "id": event_bridge.current_message_id if not has_text_content else generate_event_id(), "role": "assistant", - "tool_calls": event_bridge.pending_tool_calls.copy(), + "tool_calls": pending_tool_calls.copy(), } all_messages.append(tool_call_message) - all_messages.extend(event_bridge.tool_results.copy()) - if has_text_content: - all_messages.append(assistant_text_message) + all_messages.extend(tool_results) + if has_text_content and active_message_id: + all_messages.append( + { + "id": active_message_id, + "role": "assistant", + "content": accumulated_text_content, + } + ) messages_snapshot = MessagesSnapshotEvent( messages=all_messages, # type: ignore[arg-type] ) + messages_snapshot_emitted = True logger.info( - "[FINALIZE] Emitting MessagesSnapshotEvent with %s messages (text content length: %s)", - len(all_messages), - len(event_bridge.accumulated_text_content), + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages " + f"(text content length: {len(accumulated_text_content)})" ) yield messages_snapshot else: logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") - if not event_bridge.messages_snapshot_emitted and ( - event_bridge.pending_tool_calls or event_bridge.tool_results - ): - converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) - all_messages = converted_input_messages.copy() + if not messages_snapshot_emitted and (pending_tool_calls or tool_results): + all_messages = snapshot_messages.copy() - if event_bridge.pending_tool_calls: + if pending_tool_calls: tool_call_message = { "id": generate_event_id(), "role": "assistant", - "tool_calls": event_bridge.pending_tool_calls.copy(), + "tool_calls": pending_tool_calls.copy(), } all_messages.append(tool_call_message) - all_messages.extend(event_bridge.tool_results.copy()) + all_messages.extend(tool_results) messages_snapshot = MessagesSnapshotEvent( messages=all_messages, # type: ignore[arg-type] ) + messages_snapshot_emitted = True logger.info(f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages (tool-only)") yield messages_snapshot diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py index 6fefc14665..97654182cf 100644 --- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py +++ b/python/packages/ag-ui/tests/test_backend_tool_rendering.py @@ -52,8 +52,8 @@ async def test_tool_call_flow(): update2 = AgentRunResponseUpdate(contents=[tool_result]) events2 = await bridge.from_agent_run_update(update2) - # Should have: ToolCallEndEvent, ToolCallResultEvent, MessagesSnapshotEvent - assert len(events2) == 3 + # Should have: ToolCallEndEvent, ToolCallResultEvent + assert len(events2) == 2 assert isinstance(events2[0], ToolCallEndEvent) assert isinstance(events2[1], ToolCallResultEvent) diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py index 20b53cc18f..5426796654 100644 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ b/python/packages/ag-ui/tests/test_events_comprehensive.py @@ -284,14 +284,12 @@ async def test_empty_predict_state_config(): assert "STATE_DELTA" not in event_types assert "STATE_SNAPSHOT" not in event_types - # Should have: ToolCallStart, ToolCallArgs, ToolCallEnd, ToolCallResult, MessagesSnapshot - # MessagesSnapshotEvent is emitted after tool results to track the conversation + # Should have: ToolCallStart, ToolCallArgs, ToolCallEnd, ToolCallResult assert event_types == [ "TOOL_CALL_START", "TOOL_CALL_ARGS", "TOOL_CALL_END", "TOOL_CALL_RESULT", - "MESSAGES_SNAPSHOT", ] diff --git a/python/packages/ag-ui/tests/test_helpers_ag_ui.py b/python/packages/ag-ui/tests/test_helpers_ag_ui.py index bfb528511e..fc82b11510 100644 --- a/python/packages/ag-ui/tests/test_helpers_ag_ui.py +++ b/python/packages/ag-ui/tests/test_helpers_ag_ui.py @@ -18,6 +18,7 @@ from agent_framework._clients import BaseChatClient from agent_framework._types import ChatResponse, ChatResponseUpdate +from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history from agent_framework_ag_ui._orchestrators import ExecutionContext StreamFn = Callable[..., AsyncIterator[ChatResponseUpdate]] @@ -134,5 +135,9 @@ def get_new_thread(self, **kwargs: Any) -> AgentThread: class TestExecutionContext(ExecutionContext): """ExecutionContext helper that allows setting messages for tests.""" - def set_messages(self, messages: list[ChatMessage]) -> None: - self._messages = messages + def set_messages(self, messages: list[ChatMessage], *, normalize: bool = True) -> None: + if normalize: + self._messages = _deduplicate_messages(_sanitize_tool_history(messages)) + else: + self._messages = messages + self._snapshot_messages = None diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py index ba775fa7d9..380ff438bd 100644 --- a/python/packages/ag-ui/tests/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -2,10 +2,7 @@ from agent_framework import ChatMessage, FunctionCallContent, FunctionResultContent, TextContent -from agent_framework_ag_ui._orchestration._message_hygiene import ( - deduplicate_messages, - sanitize_tool_history, -) +from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history def test_sanitize_tool_history_injects_confirm_changes_result() -> None: @@ -26,7 +23,7 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None: ), ] - sanitized = sanitize_tool_history(messages) + sanitized = _sanitize_tool_history(messages) tool_messages = [ msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" @@ -48,6 +45,6 @@ def test_deduplicate_messages_prefers_non_empty_tool_results() -> None: ), ] - deduped = deduplicate_messages(messages) + deduped = _deduplicate_messages(messages) assert len(deduped) == 1 assert deduped[0].contents[0].result == "result data" diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py index c2ded90eb5..041e25c3d2 100644 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ b/python/packages/ag-ui/tests/test_orchestrators_coverage.py @@ -62,7 +62,7 @@ async def test_human_in_the_loop_json_decode_error() -> None: agent=agent, config=AgentConfig(), ) - context.set_messages(messages) + context.set_messages(messages, normalize=False) assert orchestrator.can_handle(context) From 73ce055bdf65ca8bea5d24dc1619c13c7f2dd924 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 8 Jan 2026 14:45:27 +0900 Subject: [PATCH 5/8] Code cleanup --- .../agent_framework_ag_ui/_orchestrators.py | 459 +++++++++--------- 1 file changed, 224 insertions(+), 235 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index f2c15a4516..a7e934bf8e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -44,11 +44,202 @@ if TYPE_CHECKING: from ._agent import AgentConfig from ._confirmation_strategies import ConfirmationStrategy + from ._events import AgentFrameworkEventBridge + from ._orchestration._state_manager import StateManager logger = logging.getLogger(__name__) +def _role_value(message: Any) -> str: + role = getattr(message, "role", None) + if role is None: + return "" + if hasattr(role, "value"): + return str(role.value) + return str(role) + + +def _pending_tool_call_ids(messages: list[Any]) -> set[str]: + pending_ids: set[str] = set() + resolved_ids: set[str] = set() + for msg in messages: + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.call_id: + pending_ids.add(str(content.call_id)) + elif isinstance(content, FunctionResultContent) and content.call_id: + resolved_ids.add(str(content.call_id)) + return pending_ids - resolved_ids + + +def _is_state_context_message(message: Any) -> bool: + if _role_value(message) != "system": + return False + for content in message.contents or []: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + return True + return False + + +def _ensure_tool_call_entry( + tool_call_id: str, + tool_calls_by_id: dict[str, dict[str, Any]], + pending_tool_calls: list[dict[str, Any]], +) -> dict[str, Any]: + entry = tool_calls_by_id.get(tool_call_id) + if entry is None: + entry = { + "id": tool_call_id, + "type": "function", + "function": { + "name": "", + "arguments": "", + }, + } + tool_calls_by_id[tool_call_id] = entry + pending_tool_calls.append(entry) + return entry + + +def _tool_name_for_call_id(tool_calls_by_id: dict[str, dict[str, Any]], tool_call_id: str) -> str | None: + entry = tool_calls_by_id.get(tool_call_id) + if not entry: + return None + function = entry.get("function") + if not isinstance(function, dict): + return None + name = function.get("name") + return str(name) if name else None + + +def _tool_calls_match_state(provider_messages: list[Any], state_manager: "StateManager") -> bool: + if not state_manager.predict_state_config or not state_manager.current_state: + return False + + def _parse_args(arguments: Any) -> dict[str, Any] | None: + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + except json.JSONDecodeError: + return None + if isinstance(parsed, dict): + return parsed + return None + + for state_key, config in state_manager.predict_state_config.items(): + tool_name = config["tool"] + tool_arg_name = config["tool_argument"] + tool_args: dict[str, Any] | None = None + + for msg in reversed(provider_messages): + if _role_value(msg) != "assistant": + continue + for content in msg.contents or []: + if isinstance(content, FunctionCallContent) and content.name == tool_name: + tool_args = _parse_args(content.arguments) + break + if tool_args is not None: + break + + if not tool_args: + return False + + if tool_arg_name == "*": + state_value = tool_args + elif tool_arg_name in tool_args: + state_value = tool_args[tool_arg_name] + else: + return False + + if state_manager.current_state.get(state_key) != state_value: + return False + + return True + + +def _select_messages_to_run(provider_messages: list[Any], state_manager: "StateManager") -> list[Any]: + if not provider_messages: + return [] + + is_new_user_turn = _role_value(provider_messages[-1]) == "user" + conversation_has_tool_calls = _tool_calls_match_state(provider_messages, state_manager) + state_context_msg = state_manager.state_context_message( + is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls + ) + if not state_context_msg: + return list(provider_messages) + + messages_to_run = [msg for msg in provider_messages if not _is_state_context_message(msg)] + if _pending_tool_call_ids(messages_to_run): + return messages_to_run + + insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) + if insert_index < 0: + insert_index = 0 + messages_to_run.insert(insert_index, state_context_msg) + return messages_to_run + + +def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: + if not thread_metadata: + return {} + safe_metadata: dict[str, Any] = {} + for key, value in thread_metadata.items(): + value_str = value if isinstance(value, str) else json.dumps(value) + if len(value_str) > 512: + value_str = value_str[:512] + safe_metadata[key] = value_str + return safe_metadata + + +def _collect_approved_state_snapshots( + provider_messages: list[Any], + predict_state_config: dict[str, dict[str, str]], + current_state: dict[str, Any], + event_bridge: "AgentFrameworkEventBridge", +) -> list[StateSnapshotEvent]: + if not predict_state_config: + return [] + + events: list[StateSnapshotEvent] = [] + for msg in provider_messages: + if _role_value(msg) != "user": + continue + for content in msg.contents or []: + if type(content) is FunctionApprovalResponseContent: + if not content.function_call or not content.approved: + continue + parsed_args = content.function_call.parse_arguments() + state_args = None + if content.additional_properties: + state_args = content.additional_properties.get("ag_ui_state_args") + if not isinstance(state_args, dict): + state_args = parsed_args + if not state_args: + continue + for state_key, config in predict_state_config.items(): + if config["tool"] != content.function_call.name: + continue + tool_arg_name = config["tool_argument"] + if tool_arg_name == "*": + state_value = state_args + elif isinstance(state_args, dict) and tool_arg_name in state_args: + state_value = state_args[tool_arg_name] + else: + continue + current_state[state_key] = state_value + event_bridge.current_state[state_key] = state_value + logger.info( + f"Emitting StateSnapshotEvent for approved state key '{state_key}' " + f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" + ) + events.append(StateSnapshotEvent(snapshot=current_state)) + break + return events + + class ExecutionContext: """Shared context for orchestrators.""" @@ -356,7 +547,7 @@ async def run( logger.info(f"Received {len(provider_messages)} provider messages from client") for i, msg in enumerate(provider_messages): - role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + role = _role_value(msg) msg_id = getattr(msg, "message_id", None) logger.info(f" Message {i}: role={role}, id={msg_id}") if hasattr(msg, "contents") and msg.contents: @@ -381,35 +572,6 @@ async def run( else: logger.debug(f" Content {j}: {content_type}") - if not provider_messages: - logger.info("No provider-eligible messages after filtering; finishing run without invoking agent.") - yield event_bridge.create_run_finished_event() - return - - logger.info(f"Processing {len(provider_messages)} provider messages after normalization") - for i, msg in enumerate(provider_messages): - role = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - logger.info(f" Message {i}: role={role}") - if hasattr(msg, "contents") and msg.contents: - for j, content in enumerate(msg.contents): - content_type = type(content).__name__ - if isinstance(content, TextContent): - logger.info(f" Content {j}: {content_type} - text_length={len(content.text)}") - elif isinstance(content, FunctionCallContent): - arg_length = len(str(content.arguments)) if content.arguments else 0 - logger.info(" Content %s: %s - %s args_length=%s", j, content_type, content.name, arg_length) - elif isinstance(content, FunctionResultContent): - result_preview = type(content.result).__name__ if content.result is not None else "None" - logger.info( - " Content %s: %s - call_id=%s, result_type=%s", - j, - content_type, - content.call_id, - result_preview, - ) - else: - logger.info(f" Content {j}: {content_type}") - pending_tool_calls: list[dict[str, Any]] = [] tool_calls_by_id: dict[str, dict[str, Any]] = {} tool_results: list[dict[str, Any]] = [] @@ -420,128 +582,15 @@ async def run( # Check for FunctionApprovalResponseContent and emit updated state snapshot # This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps) - for msg in provider_messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - if role_value != "user": - continue - for content in msg.contents or []: - if type(content) is FunctionApprovalResponseContent: - if content.function_call and content.approved: - parsed_args = content.function_call.parse_arguments() - state_args = None - if content.additional_properties: - state_args = content.additional_properties.get("ag_ui_state_args") - if not isinstance(state_args, dict): - state_args = parsed_args - if state_args and context.config.predict_state_config: - for state_key, config in context.config.predict_state_config.items(): - if config["tool"] == content.function_call.name: - tool_arg_name = config["tool_argument"] - if tool_arg_name == "*": - state_value = state_args - elif isinstance(state_args, dict) and tool_arg_name in state_args: - state_value = state_args[tool_arg_name] - else: - continue - # Update current_state and emit snapshot - current_state[state_key] = state_value - event_bridge.current_state[state_key] = state_value - logger.info( - f"Emitting StateSnapshotEvent for approved state key '{state_key}' " - f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" - ) - yield StateSnapshotEvent(snapshot=current_state) - break - - messages_to_run: list[Any] = [] - is_new_user_turn = False - if provider_messages: - last_msg = provider_messages[-1] - role_value = last_msg.role.value if hasattr(last_msg.role, "value") else str(last_msg.role) - is_new_user_turn = role_value == "user" - - def _tool_calls_match_state() -> bool: - if not state_manager.predict_state_config or not state_manager.current_state: - return False - - def _parse_args(arguments: Any) -> dict[str, Any] | None: - if isinstance(arguments, dict): - return arguments - if isinstance(arguments, str): - try: - parsed = json.loads(arguments) - except json.JSONDecodeError: - return None - if isinstance(parsed, dict): - return parsed - return None - - for state_key, config in state_manager.predict_state_config.items(): - tool_name = config["tool"] - tool_arg_name = config["tool_argument"] - tool_args: dict[str, Any] | None = None - - for msg in reversed(provider_messages): - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - if role_value != "assistant": - continue - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.name == tool_name: - tool_args = _parse_args(content.arguments) - break - if tool_args is not None: - break - - if not tool_args: - return False - - if tool_arg_name == "*": - state_value = tool_args - elif tool_arg_name in tool_args: - state_value = tool_args[tool_arg_name] - else: - return False - - if state_manager.current_state.get(state_key) != state_value: - return False - - return True + for snapshot_evt in _collect_approved_state_snapshots( + provider_messages, + context.config.predict_state_config, + current_state, + event_bridge, + ): + yield snapshot_evt - conversation_has_tool_calls = _tool_calls_match_state() - - state_context_msg = state_manager.state_context_message( - is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls - ) - - def _is_state_context_message(message: Any) -> bool: - role_value = message.role.value if hasattr(message.role, "value") else str(message.role) - if role_value != "system": - return False - for content in message.contents or []: - if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): - return True - return False - - def _pending_tool_call_ids(messages: list[Any]) -> set[str]: - pending_ids: set[str] = set() - resolved_ids: set[str] = set() - for msg in messages: - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.call_id: - pending_ids.add(str(content.call_id)) - elif isinstance(content, FunctionResultContent) and content.call_id: - resolved_ids.add(str(content.call_id)) - return pending_ids - resolved_ids - - if state_context_msg: - messages_to_run = [msg for msg in provider_messages if not _is_state_context_message(msg)] - if not _pending_tool_call_ids(messages_to_run): - insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) - if insert_index < 0: - insert_index = 0 - messages_to_run.insert(insert_index, state_context_msg) - else: - messages_to_run.extend(provider_messages) + messages_to_run = _select_messages_to_run(provider_messages, state_manager) client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") @@ -555,17 +604,11 @@ def _pending_tool_call_ids(messages: list[Any]) -> set[str]: register_additional_client_tools(context.agent, client_tools) tools_param = merge_tools(server_tools, client_tools) - all_updates: list[Any] = [] + collect_updates = response_format is not None + all_updates: list[Any] | None = [] if collect_updates else None update_count = 0 # Prepare metadata for chat client (Azure requires string values) - safe_metadata: dict[str, Any] = {} - thread_metadata = getattr(thread, "metadata", None) - if thread_metadata: - for key, value in thread_metadata.items(): - value_str = value if isinstance(value, str) else json.dumps(value) - if len(value_str) > 512: - value_str = value_str[:512] - safe_metadata[key] = value_str + safe_metadata = _build_safe_metadata(getattr(thread, "metadata", None)) run_kwargs: dict[str, Any] = { "thread": thread, @@ -632,14 +675,17 @@ def _should_emit_tool_snapshot(tool_name: str | None) -> bool: return False return True - def _build_messages_snapshot() -> MessagesSnapshotEvent: + def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnapshotEvent: has_text_content = bool(accumulated_text_content) all_messages = snapshot_messages.copy() if pending_tool_calls: - tool_call_message_id = ( - active_message_id if not has_text_content and active_message_id else generate_event_id() - ) + if tool_message_id and not has_text_content: + tool_call_message_id = tool_message_id + else: + tool_call_message_id = ( + active_message_id if not has_text_content and active_message_id else generate_event_id() + ) tool_call_message = { "id": tool_call_message_id, "role": "assistant", @@ -666,7 +712,8 @@ def _build_messages_snapshot() -> MessagesSnapshotEvent: async for update in context.agent.run_stream(messages_to_run, **run_kwargs): update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") - all_updates.append(update) + if all_updates is not None: + all_updates.append(update) if event_bridge.current_message_id is None and update.contents: has_tool_call = any(isinstance(content, FunctionCallContent) for content in update.contents) has_text = any(isinstance(content, TextContent) for content in update.contents) @@ -689,33 +736,10 @@ def _build_messages_snapshot() -> MessagesSnapshotEvent: elif isinstance(event, TextMessageContentEvent): accumulated_text_content += event.delta elif isinstance(event, ToolCallStartEvent): - tool_call_entry = tool_calls_by_id.get(event.tool_call_id) - if tool_call_entry is None: - tool_call_entry = { - "id": event.tool_call_id, - "type": "function", - "function": { - "name": event.tool_call_name, - "arguments": "", - }, - } - pending_tool_calls.append(tool_call_entry) - tool_calls_by_id[event.tool_call_id] = tool_call_entry - else: - tool_call_entry["function"]["name"] = event.tool_call_name + tool_call_entry = _ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) + tool_call_entry["function"]["name"] = event.tool_call_name elif isinstance(event, ToolCallArgsEvent): - tool_call_entry = tool_calls_by_id.get(event.tool_call_id) - if tool_call_entry is None: - tool_call_entry = { - "id": event.tool_call_id, - "type": "function", - "function": { - "name": "", - "arguments": "", - }, - } - pending_tool_calls.append(tool_call_entry) - tool_calls_by_id[event.tool_call_id] = tool_call_entry + tool_call_entry = _ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) tool_call_entry["function"]["arguments"] += event.delta elif isinstance(event, ToolCallEndEvent): tool_calls_ended.add(event.tool_call_id) @@ -731,14 +755,14 @@ def _build_messages_snapshot() -> MessagesSnapshotEvent: logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event if isinstance(event, ToolCallResultEvent): - tool_name = tool_calls_by_id.get(event.tool_call_id, {}).get("function", {}).get("name") + tool_name = _tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) if _should_emit_tool_snapshot(tool_name): messages_snapshot_emitted = True messages_snapshot = _build_messages_snapshot() logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") yield messages_snapshot elif isinstance(event, ToolCallEndEvent): - tool_name = tool_calls_by_id.get(event.tool_call_id, {}).get("function", {}).get("name") + tool_name = _tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) if tool_name == "confirm_changes": messages_snapshot_emitted = True messages_snapshot = _build_messages_snapshot() @@ -770,7 +794,7 @@ def _build_messages_snapshot() -> MessagesSnapshotEvent: logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'") yield end_event - if all_updates and response_format: + if response_format and all_updates: from agent_framework import AgentRunResponse from pydantic import BaseModel @@ -802,56 +826,21 @@ def _build_messages_snapshot() -> MessagesSnapshotEvent: logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) - has_text_content = bool(accumulated_text_content) - all_messages = snapshot_messages.copy() - - if pending_tool_calls: - tool_call_message = { - "id": event_bridge.current_message_id if not has_text_content else generate_event_id(), - "role": "assistant", - "tool_calls": pending_tool_calls.copy(), - } - all_messages.append(tool_call_message) - - all_messages.extend(tool_results) - if has_text_content and active_message_id: - all_messages.append( - { - "id": active_message_id, - "role": "assistant", - "content": accumulated_text_content, - } - ) - - messages_snapshot = MessagesSnapshotEvent( - messages=all_messages, # type: ignore[arg-type] - ) + messages_snapshot = _build_messages_snapshot(tool_message_id=event_bridge.current_message_id) messages_snapshot_emitted = True logger.info( - f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages " + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(messages_snapshot.messages)} messages " f"(text content length: {len(accumulated_text_content)})" ) yield messages_snapshot else: logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") if not messages_snapshot_emitted and (pending_tool_calls or tool_results): - all_messages = snapshot_messages.copy() - - if pending_tool_calls: - tool_call_message = { - "id": generate_event_id(), - "role": "assistant", - "tool_calls": pending_tool_calls.copy(), - } - all_messages.append(tool_call_message) - - all_messages.extend(tool_results) - - messages_snapshot = MessagesSnapshotEvent( - messages=all_messages, # type: ignore[arg-type] - ) + messages_snapshot = _build_messages_snapshot() messages_snapshot_emitted = True - logger.info(f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages (tool-only)") + logger.info( + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(messages_snapshot.messages)} messages" + ) yield messages_snapshot logger.info("[FINALIZE] Emitting RUN_FINISHED event") From ad2dc753e383648432ff5b8d7cbb0dab99ba416b Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 9 Jan 2026 08:20:23 +0900 Subject: [PATCH 6/8] More fixes --- .../ag-ui/agent_framework_ag_ui/_events.py | 84 +++++++++- .../_message_adapters.py | 139 +++++++++++++--- .../agent_framework_ag_ui/_orchestrators.py | 114 +++++++++++++- .../ag-ui/tests/test_events_comprehensive.py | 7 +- .../ag-ui/tests/test_human_in_the_loop.py | 85 ++++++++++ .../ag-ui/tests/test_message_adapters.py | 149 ++++++++++++++++++ 6 files changed, 550 insertions(+), 28 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index 19cb2982b9..a31d4cf024 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -48,6 +48,7 @@ def __init__( current_state: dict[str, Any] | None = None, skip_text_content: bool = False, require_confirmation: bool = True, + approval_tool_name: str | None = None, ) -> None: """ Initialize the event bridge. @@ -71,6 +72,7 @@ def __init__( self.pending_state_updates: dict[str, Any] = {} # Track updates from tool calls self.skip_text_content = skip_text_content self.require_confirmation = require_confirmation + self.approval_tool_name = approval_tool_name # For predictive state updates: accumulate streaming arguments self.streaming_tool_args: str = "" # Accumulated JSON string @@ -370,7 +372,14 @@ def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: self.current_tool_call_name = None return events - def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: + def _emit_confirm_changes_tool_call(self, function_call: FunctionCallContent | None = None) -> list[BaseEvent]: + """Emit a confirm_changes tool call for Dojo UI compatibility. + + Args: + function_call: Optional function call that needs confirmation. + If provided, includes function info in the confirm_changes args + so Dojo UI can display what's being confirmed. + """ events: list[BaseEvent] = [] confirm_call_id = generate_event_id() logger.info("Emitting confirm_changes tool call for predictive update") @@ -378,12 +387,31 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: confirm_start = ToolCallStartEvent( tool_call_id=confirm_call_id, tool_call_name="confirm_changes", + parent_message_id=self.current_message_id, ) events.append(confirm_start) + # Include function info if this is for a function approval + # This helps Dojo UI display meaningful confirmation info + if function_call: + args_dict = { + "function_name": function_call.name, + "function_call_id": function_call.call_id, + "function_arguments": function_call.parse_arguments() or {}, + "steps": [ + { + "description": f"Execute {function_call.name}", + "status": "enabled", + } + ], + } + args_json = json.dumps(args_dict) + else: + args_json = "{}" + confirm_args = ToolCallArgsEvent( tool_call_id=confirm_call_id, - delta="{}", + delta=args_json, ) events.append(confirm_args) @@ -396,6 +424,49 @@ def _emit_confirm_changes_tool_call(self) -> list[BaseEvent]: logger.info("Set flag to stop run after confirm_changes") return events + def _emit_function_approval_tool_call(self, function_call: FunctionCallContent) -> list[BaseEvent]: + """Emit a tool call that can drive UI approval for function requests.""" + tool_call_name = "confirm_changes" + if self.approval_tool_name and self.approval_tool_name != function_call.name: + tool_call_name = self.approval_tool_name + + tool_call_id = generate_event_id() + tool_start = ToolCallStartEvent( + tool_call_id=tool_call_id, + tool_call_name=tool_call_name, + parent_message_id=self.current_message_id, + ) + events: list[BaseEvent] = [tool_start] + + args_dict = { + "function_name": function_call.name, + "function_call_id": function_call.call_id, + "function_arguments": function_call.parse_arguments() or {}, + "steps": [ + { + "description": f"Execute {function_call.name}", + "status": "enabled", + } + ], + } + args_json = json.dumps(args_dict) + + events.append( + ToolCallArgsEvent( + tool_call_id=tool_call_id, + delta=args_json, + ) + ) + events.append( + ToolCallEndEvent( + tool_call_id=tool_call_id, + ) + ) + + self.should_stop_after_confirm = True + logger.info("Set flag to stop run after confirm_changes") + return events + def _handle_function_approval_request_content(self, content: FunctionApprovalRequestContent) -> list[BaseEvent]: events: list[BaseEvent] = [] logger.info("=== FUNCTION APPROVAL REQUEST ===") @@ -445,6 +516,7 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") events.append(end_event) + # Emit the function_approval_request custom event for UI implementations that support it approval_event = CustomEvent( name="function_approval_request", value={ @@ -458,6 +530,14 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq ) logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") events.append(approval_event) + + # Emit a UI-friendly approval tool call for function approvals. + if self.require_confirmation: + events.extend(self._emit_function_approval_tool_call(content.function_call)) + + # Signal orchestrator to stop the run and wait for user approval response + self.should_stop_after_confirm = True + logger.info("Set flag to stop run - waiting for function approval response") return events def create_run_started_event(self) -> RunStartedEvent: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 2da519bac7..0ff88a8e36 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -75,10 +75,15 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: if role_value == "user": approval_call_ids: set[str] = set() + approval_accepted: bool | None = None for content in msg.contents or []: if type(content) is FunctionApprovalResponseContent: if content.function_call and content.function_call.call_id: approval_call_ids.add(str(content.function_call.call_id)) + if approval_accepted is None: + approval_accepted = bool(content.approved) + else: + approval_accepted = approval_accepted and bool(content.approved) if approval_call_ids and pending_tool_call_ids: pending_tool_call_ids -= approval_call_ids @@ -87,6 +92,22 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: "framework will handle execution" ) + if pending_confirm_changes_id and approval_accepted is not None: + logger.info(f"Injecting synthetic tool result for confirm_changes call_id={pending_confirm_changes_id}") + synthetic_result = ChatMessage( + role="tool", + contents=[ + FunctionResultContent( + call_id=pending_confirm_changes_id, + result="Confirmed" if approval_accepted else "Rejected", + ) + ], + ) + sanitized.append(synthetic_result) + if pending_tool_call_ids: + pending_tool_call_ids.discard(pending_confirm_changes_id) + pending_confirm_changes_id = None + if pending_confirm_changes_id: user_text = "" for content in msg.contents or []: @@ -272,6 +293,84 @@ def _update_tool_call_arguments( function_payload_dict["arguments"] = modified_args return + def _find_matching_func_call(call_id: str) -> FunctionCallContent | None: + for prev_msg in result: + role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) + if role_val != "assistant": + continue + for content in prev_msg.contents or []: + if isinstance(content, FunctionCallContent): + if content.call_id == call_id and content.name != "confirm_changes": + return content + return None + + def _parse_arguments(arguments: Any) -> dict[str, Any] | None: + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + except json.JSONDecodeError: + return None + if isinstance(parsed, dict): + return parsed + return None + + def _resolve_approval_call_id(tool_call_id: str, parsed_payload: dict[str, Any] | None) -> str | None: + if parsed_payload: + explicit_call_id = parsed_payload.get("function_call_id") + if explicit_call_id: + return str(explicit_call_id) + + for prev_msg in result: + role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) + if role_val != "assistant": + continue + direct_call = None + confirm_call = None + sibling_calls: list[FunctionCallContent] = [] + for content in prev_msg.contents or []: + if not isinstance(content, FunctionCallContent): + continue + if content.call_id == tool_call_id: + direct_call = content + if content.name == "confirm_changes" and content.call_id == tool_call_id: + confirm_call = content + elif content.name != "confirm_changes": + sibling_calls.append(content) + + if direct_call: + direct_args = direct_call.parse_arguments() or {} + if isinstance(direct_args, dict): + explicit_call_id = direct_args.get("function_call_id") + if explicit_call_id: + return str(explicit_call_id) + + if not confirm_call: + continue + + confirm_args = confirm_call.parse_arguments() or {} + if isinstance(confirm_args, dict): + explicit_call_id = confirm_args.get("function_call_id") + if explicit_call_id: + return str(explicit_call_id) + + if len(sibling_calls) == 1 and sibling_calls[0].call_id: + return str(sibling_calls[0].call_id) + + return None + + def _filter_modified_args( + modified_args: dict[str, Any], + original_args: dict[str, Any] | None, + ) -> dict[str, Any]: + if not modified_args: + return {} + if not isinstance(original_args, dict) or not original_args: + return {} + allowed_keys = set(original_args.keys()) + return {key: value for key, value in modified_args.items() if key in allowed_keys} + result: list[ChatMessage] = [] for msg in messages: # Handle standard tool result messages early (role="tool") to preserve provider invariants @@ -319,17 +418,11 @@ def _update_tool_call_arguments( logger = logging.getLogger(__name__) logger.info(f"Approval payload received: {parsed}") - # Find the function call that matches this tool_call_id - matching_func_call = None - for prev_msg in result: - role_val = prev_msg.role.value if hasattr(prev_msg.role, "value") else str(prev_msg.role) - if role_val != "assistant": - continue - for content in prev_msg.contents or []: - if isinstance(content, FunctionCallContent): - if content.call_id == tool_call_id and content.name != "confirm_changes": - matching_func_call = content - break + approval_call_id = tool_call_id + resolved_call_id = _resolve_approval_call_id(tool_call_id, parsed) + if resolved_call_id: + approval_call_id = resolved_call_id + matching_func_call = _find_matching_func_call(approval_call_id) if matching_func_call: # Remove any existing tool result for this call_id since the framework @@ -341,7 +434,7 @@ def _update_tool_call_arguments( if not ( (m.role.value if hasattr(m.role, "value") else str(m.role)) == "tool" and any( - isinstance(c, FunctionResultContent) and c.call_id == tool_call_id + isinstance(c, FunctionResultContent) and c.call_id == approval_call_id for c in (m.contents or []) ) ) @@ -350,19 +443,21 @@ def _update_tool_call_arguments( # Check if the approval payload contains modified arguments # The UI sends back the modified state (e.g., deselected steps) in the approval payload modified_args = {k: v for k, v in parsed.items() if k != "accepted"} if parsed else {} + original_args = matching_func_call.parse_arguments() + filtered_args = _filter_modified_args(modified_args, original_args) state_args: dict[str, Any] | None = None - if modified_args: - original_args = matching_func_call.parse_arguments() or {} + if filtered_args: + original_args = original_args or {} merged_args: dict[str, Any] if isinstance(original_args, dict) and original_args: - merged_args = {**original_args, **modified_args} + merged_args = {**original_args, **filtered_args} else: - merged_args = dict(modified_args) + merged_args = dict(filtered_args) - if isinstance(modified_args.get("steps"), list): + if isinstance(filtered_args.get("steps"), list): original_steps = original_args.get("steps") if isinstance(original_args, dict) else None if isinstance(original_steps, list): - approved_steps_list: list[Any] = list(modified_args.get("steps") or []) + approved_steps_list = list(filtered_args.get("steps") or []) approved_by_description: dict[str, dict[str, Any]] = {} for step_item in approved_steps_list: if isinstance(step_item, dict): @@ -395,14 +490,14 @@ def _update_tool_call_arguments( json.dumps(merged_args) if isinstance(matching_func_call.arguments, str) else merged_args ) matching_func_call.arguments = updated_args - _update_tool_call_arguments(messages, str(tool_call_id), merged_args) + _update_tool_call_arguments(messages, str(approval_call_id), merged_args) # Create a new FunctionCallContent with the modified arguments func_call_for_approval = FunctionCallContent( call_id=matching_func_call.call_id, name=matching_func_call.name, - arguments=json.dumps(modified_args), + arguments=json.dumps(filtered_args), ) - logger.info(f"Using modified arguments from approval: {modified_args}") + logger.info(f"Using modified arguments from approval: {filtered_args}") else: # No modified arguments - use the original function call func_call_for_approval = matching_func_call @@ -410,7 +505,7 @@ def _update_tool_call_arguments( # Create FunctionApprovalResponseContent for the agent framework approval_response = FunctionApprovalResponseContent( approved=accepted, - id=str(tool_call_id), + id=str(approval_call_id), function_call=func_call_for_approval, additional_properties={"ag_ui_state_args": state_args} if state_args else None, ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index a7e934bf8e..6ade68cf51 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -159,6 +159,34 @@ def _parse_args(arguments: Any) -> dict[str, Any] | None: return True +def _schema_has_steps(schema: Any) -> bool: + if not isinstance(schema, dict): + return False + properties = schema.get("properties") + if not isinstance(properties, dict): + return False + steps_schema = properties.get("steps") + if not isinstance(steps_schema, dict): + return False + return steps_schema.get("type") == "array" + + +def _select_approval_tool_name(client_tools: list[Any] | None) -> str | None: + if not client_tools: + return None + for tool in client_tools: + tool_name = getattr(tool, "name", None) + if not tool_name: + continue + params_fn = getattr(tool, "parameters", None) + if not callable(params_fn): + continue + schema = params_fn() + if _schema_has_steps(schema): + return str(tool_name) + return None + + def _select_messages_to_run(provider_messages: list[Any], state_manager: "StateManager") -> list[Any]: if not provider_messages: return [] @@ -240,6 +268,53 @@ def _collect_approved_state_snapshots( return events +def _latest_approval_response(messages: list[Any]) -> FunctionApprovalResponseContent | None: + if not messages: + return None + last_message = messages[-1] + for content in last_message.contents or []: + if type(content) is FunctionApprovalResponseContent: + return content + return None + + +def _approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: + state_args: Any | None = None + if approval.additional_properties: + state_args = approval.additional_properties.get("ag_ui_state_args") + if isinstance(state_args, dict): + steps = state_args.get("steps") + if isinstance(steps, list): + return steps + + if approval.function_call: + parsed_args = approval.function_call.parse_arguments() + if isinstance(parsed_args, dict): + steps = parsed_args.get("steps") + if isinstance(steps, list): + return steps + + return [] + + +def _is_step_based_approval( + approval: FunctionApprovalResponseContent, + predict_state_config: dict[str, dict[str, str]], +) -> bool: + steps = _approval_steps(approval) + if steps: + return True + if not approval.function_call: + return False + if not predict_state_config: + return False + tool_name = approval.function_call.name + for config in predict_state_config.values(): + if config.get("tool") == tool_name and config.get("tool_argument") == "steps": + return True + return False + + class ExecutionContext: """Shared context for orchestrators.""" @@ -504,6 +579,9 @@ async def run( response_format = context.agent.chat_options.response_format skip_text_content = response_format is not None + client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) + approval_tool_name = _select_approval_tool_name(client_tools) + state_manager = StateManager( state_schema=context.config.state_schema, predict_state_config=context.config.predict_state_config, @@ -518,6 +596,7 @@ async def run( current_state=current_state, skip_text_content=skip_text_content, require_confirmation=context.config.require_confirmation, + approval_tool_name=approval_tool_name, ) yield event_bridge.create_run_started_event() @@ -592,7 +671,6 @@ async def run( messages_to_run = _select_messages_to_run(provider_messages, state_manager) - client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") if client_tools: for tool in client_tools: @@ -707,7 +785,37 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap messages=all_messages, # type: ignore[arg-type] ) - await _resolve_approval_responses(messages_to_run, server_tools) + # Use tools_param if available (includes client tools), otherwise fall back to server_tools + # This ensures both server tools AND client tools can be executed after approval + tools_for_approval = tools_param if tools_param is not None else server_tools + latest_approval = _latest_approval_response(messages_to_run) + await _resolve_approval_responses(messages_to_run, tools_for_approval) + + if latest_approval and _is_step_based_approval(latest_approval, context.config.predict_state_config): + from ._confirmation_strategies import DefaultConfirmationStrategy + + strategy = context.confirmation_strategy + if strategy is None: + strategy = DefaultConfirmationStrategy() + + steps = _approval_steps(latest_approval) + if steps: + if latest_approval.approved: + confirmation_message = strategy.on_approval_accepted(steps) + else: + confirmation_message = strategy.on_approval_rejected(steps) + else: + if latest_approval.approved: + confirmation_message = strategy.on_state_confirmed() + else: + confirmation_message = strategy.on_state_rejected() + + message_id = generate_event_id() + yield TextMessageStartEvent(message_id=message_id, role="assistant") + yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message) + yield TextMessageEndEvent(message_id=message_id) + yield event_bridge.create_run_finished_event() + return async for update in context.agent.run_stream(messages_to_run, **run_kwargs): update_count += 1 @@ -772,7 +880,7 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") if event_bridge.should_stop_after_confirm: - logger.info("Stopping run after confirm_changes - waiting for user response") + logger.info("Stopping run - waiting for user approval/confirmation response") if event_bridge.current_message_id: logger.info(f"[CONFIRM] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py index 5426796654..cfd45ea5c8 100644 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ b/python/packages/ag-ui/tests/test_events_comprehensive.py @@ -231,7 +231,12 @@ async def test_function_approval_request_basic(): """Test FunctionApprovalRequestContent conversion.""" from agent_framework_ag_ui._events import AgentFrameworkEventBridge - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") + # Set require_confirmation=False to test just the function_approval_request event + bridge = AgentFrameworkEventBridge( + run_id="test_run", + thread_id="test_thread", + require_confirmation=False, + ) func_call = FunctionCallContent( call_id="call_123", diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py index 92f6d69926..55a2869c91 100644 --- a/python/packages/ag-ui/tests/test_human_in_the_loop.py +++ b/python/packages/ag-ui/tests/test_human_in_the_loop.py @@ -10,9 +10,11 @@ async def test_function_approval_request_emission(): """Test that CustomEvent is emitted for FunctionApprovalRequestContent.""" + # Set require_confirmation=False to test just the function_approval_request event bridge = AgentFrameworkEventBridge( run_id="test_run", thread_id="test_thread", + require_confirmation=False, ) # Create approval request @@ -47,11 +49,65 @@ async def test_function_approval_request_emission(): assert event.value["function_call"]["arguments"]["subject"] == "Test" +async def test_function_approval_request_with_confirm_changes(): + """Test that confirm_changes is also emitted when require_confirmation=True.""" + bridge = AgentFrameworkEventBridge( + run_id="test_run", + thread_id="test_thread", + require_confirmation=True, + ) + + func_call = FunctionCallContent( + call_id="call_456", + name="delete_file", + arguments={"path": "/tmp/test.txt"}, + ) + approval_request = FunctionApprovalRequestContent( + id="approval_002", + function_call=func_call, + ) + + update = AgentRunResponseUpdate(contents=[approval_request]) + events = await bridge.from_agent_run_update(update) + + # Should emit: ToolCallEndEvent, CustomEvent, and confirm_changes (Start, Args, End) = 5 events + assert len(events) == 5 + + # Check ToolCallEndEvent + assert events[0].type == "TOOL_CALL_END" + assert events[0].tool_call_id == "call_456" + + # Check function_approval_request CustomEvent + assert events[1].type == "CUSTOM" + assert events[1].name == "function_approval_request" + + # Check confirm_changes tool call events + assert events[2].type == "TOOL_CALL_START" + assert events[2].tool_call_name == "confirm_changes" + assert events[3].type == "TOOL_CALL_ARGS" + # Verify confirm_changes includes function info for Dojo UI + import json + + args = json.loads(events[3].delta) + assert args["function_name"] == "delete_file" + assert args["function_call_id"] == "call_456" + assert args["function_arguments"] == {"path": "/tmp/test.txt"} + assert args["steps"] == [ + { + "description": "Execute delete_file", + "status": "enabled", + } + ] + assert events[4].type == "TOOL_CALL_END" + + async def test_multiple_approval_requests(): """Test handling multiple approval requests in one update.""" + # Set require_confirmation=False to simplify the test bridge = AgentFrameworkEventBridge( run_id="test_run", thread_id="test_thread", + require_confirmation=False, ) func_call_1 = FunctionCallContent( @@ -94,3 +150,32 @@ async def test_multiple_approval_requests(): assert events[3].type == "CUSTOM" assert events[3].name == "function_approval_request" assert events[3].value["id"] == "approval_2" + + +async def test_function_approval_request_sets_stop_flag(): + """Test that function approval request sets should_stop_after_confirm flag. + + This ensures the orchestrator stops the run after emitting the approval request, + allowing the UI to send back an approval response. + """ + bridge = AgentFrameworkEventBridge( + run_id="test_run", + thread_id="test_thread", + ) + + assert bridge.should_stop_after_confirm is False + + func_call = FunctionCallContent( + call_id="call_stop_test", + name="get_datetime", + arguments={}, + ) + approval_request = FunctionApprovalRequestContent( + id="approval_stop_test", + function_call=func_call, + ) + + update = AgentRunResponseUpdate(contents=[approval_request]) + await bridge.from_agent_run_update(update) + + assert bridge.should_stop_after_confirm is True diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index 27af3aae93..9173314a28 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -179,6 +179,155 @@ def test_agui_tool_approval_updates_tool_call_arguments(): } +def test_agui_tool_approval_from_confirm_changes_maps_to_function_call(): + """Confirm_changes approvals map back to the original tool call when metadata is present.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_tool", + "type": "function", + "function": {"name": "get_datetime", "arguments": {}}, + }, + { + "id": "call_confirm", + "type": "function", + "function": { + "name": "confirm_changes", + "arguments": {"function_call_id": "call_tool"}, + }, + }, + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps({"accepted": True, "function_call_id": "call_tool"}), + "toolCallId": "call_confirm", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + + assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call.parse_arguments() == {} + assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == {} + + +def test_agui_tool_approval_from_confirm_changes_falls_back_to_sibling_call(): + """Confirm_changes approvals map to the only sibling tool call when metadata is missing.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_tool", + "type": "function", + "function": {"name": "get_datetime", "arguments": {}}, + }, + { + "id": "call_confirm", + "type": "function", + "function": {"name": "confirm_changes", "arguments": {}}, + }, + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps( + { + "accepted": True, + "steps": [{"description": "Approve get_datetime", "status": "enabled"}], + } + ), + "toolCallId": "call_confirm", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + + assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call.parse_arguments() == {} + assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == {} + + +def test_agui_tool_approval_from_generate_task_steps_maps_to_function_call(): + """Approval tool payloads map to the referenced function call when function_call_id is present.""" + messages_input = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_tool", + "type": "function", + "function": {"name": "get_datetime", "arguments": {}}, + }, + { + "id": "call_steps", + "type": "function", + "function": { + "name": "generate_task_steps", + "arguments": { + "function_name": "get_datetime", + "function_call_id": "call_tool", + "function_arguments": {}, + "steps": [{"description": "Execute get_datetime", "status": "enabled"}], + }, + }, + }, + ], + "id": "msg_1", + }, + { + "role": "tool", + "content": json.dumps( + { + "accepted": True, + "steps": [{"description": "Execute get_datetime", "status": "enabled"}], + } + ), + "toolCallId": "call_steps", + "id": "msg_2", + }, + ] + + messages = agui_messages_to_agent_framework(messages_input) + + from agent_framework import FunctionApprovalResponseContent + + approval_msg = messages[1] + approval_content = next( + content for content in approval_msg.contents if isinstance(content, FunctionApprovalResponseContent) + ) + + assert approval_content.function_call.call_id == "call_tool" + assert approval_content.function_call.name == "get_datetime" + assert approval_content.function_call.parse_arguments() == {} + + def test_agui_multiple_messages_to_agent_framework(): """Test converting multiple AG-UI messages.""" messages_input = [ From 2bf562e4eeb18a1a64171bc33976912d019fa231 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 9 Jan 2026 09:18:46 +0900 Subject: [PATCH 7/8] More code cleanup --- .../_confirmation_strategies.py | 202 +++++---- .../ag-ui/agent_framework_ag_ui/_events.py | 23 +- .../_message_adapters.py | 57 +-- .../_orchestration/_helpers.py | 391 ++++++++++++++++++ .../_orchestration/_predictive_state.py | 230 +++++++++++ .../agent_framework_ag_ui/_orchestrators.py | 315 ++------------ .../ag-ui/agent_framework_ag_ui/_utils.py | 97 ++++- 7 files changed, 890 insertions(+), 425 deletions(-) create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py b/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py index 8bba842705..35e648c100 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py @@ -11,21 +11,61 @@ class ConfirmationStrategy(ABC): - """Strategy for generating confirmation messages during human-in-the-loop flows.""" + """Strategy for generating confirmation messages during human-in-the-loop flows. + Subclasses must define the message properties. The methods use those properties + by default, but can be overridden for complete customization. + """ + + @property + @abstractmethod + def approval_header(self) -> str: + """Header for approval accepted message. Must be overridden.""" + ... + + @property + @abstractmethod + def approval_footer(self) -> str: + """Footer for approval accepted message. Must be overridden.""" + ... + + @property + @abstractmethod + def rejection_message(self) -> str: + """Message when user rejects. Must be overridden.""" + ... + + @property @abstractmethod + def state_confirmed_message(self) -> str: + """Message when state is confirmed. Must be overridden.""" + ... + + @property + @abstractmethod + def state_rejected_message(self) -> str: + """Message when state is rejected. Must be overridden.""" + ... + def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: """Generate message when user approves function execution. + Default implementation uses header/footer properties. + Override for complete customization. + Args: steps: List of approved steps with 'description', 'status', etc. Returns: Message to display to user """ - ... + enabled_steps = [s for s in steps if s.get("status") == "enabled"] + message_parts = [self.approval_header.format(count=len(enabled_steps))] + for i, step in enumerate(enabled_steps, 1): + message_parts.append(f"{i}. {step['description']}\n") + message_parts.append(self.approval_footer) + return "".join(message_parts) - @abstractmethod def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: """Generate message when user rejects function execution. @@ -35,141 +75,143 @@ def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: Returns: Message to display to user """ - ... + return self.rejection_message - @abstractmethod def on_state_confirmed(self) -> str: """Generate message when user confirms predictive state changes. Returns: Message to display to user """ - ... + return self.state_confirmed_message - @abstractmethod def on_state_rejected(self) -> str: """Generate message when user rejects predictive state changes. Returns: Message to display to user """ - ... + return self.state_rejected_message class DefaultConfirmationStrategy(ConfirmationStrategy): - """Generic confirmation messages suitable for most agents. - - This preserves the original behavior from v1. - """ + """Generic confirmation messages suitable for most agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate generic approval message with step list.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"] - - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - - message_parts.append("\nAll steps completed successfully!") + @property + def approval_header(self) -> str: + return "Executing {count} approved steps:\n\n" - return "".join(message_parts) + @property + def approval_footer(self) -> str: + return "\nAll steps completed successfully!" - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate generic rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! What would you like me to change about the plan?" - def on_state_confirmed(self) -> str: - """Generate generic state confirmation message.""" + @property + def state_confirmed_message(self) -> str: return "Changes confirmed and applied successfully!" - def on_state_rejected(self) -> str: - """Generate generic state rejection message.""" + @property + def state_rejected_message(self) -> str: return "No problem! What would you like me to change?" class TaskPlannerConfirmationStrategy(ConfirmationStrategy): """Domain-specific confirmation messages for task planning agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate task-specific approval message.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = ["Executing your requested tasks:\n\n"] - - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") + @property + def approval_header(self) -> str: + return "Executing your requested tasks:\n\n" - message_parts.append("\nAll tasks completed successfully!") + @property + def approval_footer(self) -> str: + return "\nAll tasks completed successfully!" - return "".join(message_parts) - - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate task-specific rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! Let me revise the plan. What would you like me to change?" - def on_state_confirmed(self) -> str: - """Task planners typically don't use state confirmation.""" + @property + def state_confirmed_message(self) -> str: return "Tasks confirmed and ready to execute!" - def on_state_rejected(self) -> str: - """Task planners typically don't use state confirmation.""" + @property + def state_rejected_message(self) -> str: return "No problem! How should I adjust the task list?" class RecipeConfirmationStrategy(ConfirmationStrategy): """Domain-specific confirmation messages for recipe agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate recipe-specific approval message.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = ["Updating your recipe:\n\n"] + @property + def approval_header(self) -> str: + return "Updating your recipe:\n\n" - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - - message_parts.append("\nRecipe updated successfully!") - - return "".join(message_parts) + @property + def approval_footer(self) -> str: + return "\nRecipe updated successfully!" - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate recipe-specific rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! What ingredients or steps should I change?" - def on_state_confirmed(self) -> str: - """Generate recipe-specific state confirmation message.""" + @property + def state_confirmed_message(self) -> str: return "Recipe changes applied successfully!" - def on_state_rejected(self) -> str: - """Generate recipe-specific state rejection message.""" + @property + def state_rejected_message(self) -> str: return "No problem! What would you like me to adjust in the recipe?" class DocumentWriterConfirmationStrategy(ConfirmationStrategy): """Domain-specific confirmation messages for document writing agents.""" - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate document-specific approval message.""" - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - - message_parts = ["Applying your edits:\n\n"] - - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - - message_parts.append("\nDocument updated successfully!") + @property + def approval_header(self) -> str: + return "Applying your edits:\n\n" - return "".join(message_parts) + @property + def approval_footer(self) -> str: + return "\nDocument updated successfully!" - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate document-specific rejection message.""" + @property + def rejection_message(self) -> str: return "No problem! Which changes should I keep or modify?" - def on_state_confirmed(self) -> str: - """Generate document-specific state confirmation message.""" + @property + def state_confirmed_message(self) -> str: return "Document edits applied!" - def on_state_rejected(self) -> str: - """Generate document-specific state rejection message.""" + @property + def state_rejected_message(self) -> str: return "No problem! What should I change about the document?" + + +def apply_confirmation_strategy( + strategy: ConfirmationStrategy | None, + accepted: bool, + steps: list[dict[str, Any]], +) -> str: + """Apply a confirmation strategy to generate a message. + + This helper consolidates the pattern used in multiple orchestrators. + + Args: + strategy: Strategy to use, or None for default + accepted: Whether the user approved + steps: List of steps (may be empty for state confirmations) + + Returns: + Generated message string + """ + if strategy is None: + strategy = DefaultConfirmationStrategy() + + if not steps: + # State confirmation (no steps) + return strategy.on_state_confirmed() if accepted else strategy.on_state_rejected() + # Step-based approval + return strategy.on_approval_accepted(steps) if accepted else strategy.on_approval_rejected(steps) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index a31d4cf024..633fc501db 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -32,7 +32,7 @@ prepare_function_call_results, ) -from ._utils import generate_event_id +from ._utils import extract_state_from_tool_args, generate_event_id, safe_json_parse logger = logging.getLogger(__name__) @@ -209,10 +209,8 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: self.current_tool_call_name, ) - parsed_args = None - try: - parsed_args = json.loads(self.streaming_tool_args) - except json.JSONDecodeError: + parsed_args = safe_json_parse(self.streaming_tool_args) + if parsed_args is None: for state_key, config in self.predict_state_config.items(): if config["tool"] != self.current_tool_call_name: continue @@ -256,11 +254,8 @@ def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: continue tool_arg_name = config["tool_argument"] - if tool_arg_name == "*": - state_value = parsed_args - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - else: + state_value = extract_state_from_tool_args(parsed_args, tool_arg_name) + if state_value is None: continue if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != state_value: @@ -493,12 +488,8 @@ def _handle_function_approval_request_content(self, content: FunctionApprovalReq tool_arg_name, ) - state_value: Any - if tool_arg_name == "*": - state_value = parsed_args - elif tool_arg_name in parsed_args: - state_value = parsed_args[tool_arg_name] - else: + state_value = extract_state_from_tool_args(parsed_args, tool_arg_name) + if state_value is None: logger.warning(f" Tool argument '{tool_arg_name}' not found in parsed args") continue diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 0ff88a8e36..1ff858e9f5 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -16,35 +16,17 @@ prepare_function_call_results, ) -# Role mapping constants -_AGUI_TO_FRAMEWORK_ROLE = { - "user": Role.USER, - "assistant": Role.ASSISTANT, - "system": Role.SYSTEM, -} - -_FRAMEWORK_TO_AGUI_ROLE = { - Role.USER: "user", - Role.ASSISTANT: "assistant", - Role.SYSTEM: "system", -} - -_ALLOWED_AGUI_ROLES = {"user", "assistant", "system", "tool"} +from ._utils import ( + AGUI_TO_FRAMEWORK_ROLE, + FRAMEWORK_TO_AGUI_ROLE, + get_role_value, + normalize_agui_role, + safe_json_parse, +) logger = logging.getLogger(__name__) -def _normalize_agui_role(raw_role: Any) -> str: - if not isinstance(raw_role, str): - return "user" - role = raw_role.lower() - if role == "developer": - return "system" - if role in _ALLOWED_AGUI_ROLES: - return role - return "user" - - def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: """Normalize tool ordering and inject synthetic results for AG-UI edge cases.""" sanitized: list[ChatMessage] = [] @@ -52,7 +34,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: pending_confirm_changes_id: str | None = None for msg in messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + role_value = get_role_value(msg) if role_value == "assistant": tool_ids = { @@ -191,7 +173,7 @@ def _deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: unique_messages: list[ChatMessage] = [] for idx, msg in enumerate(messages): - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) + role_value = get_role_value(msg) if role_value == "tool" and msg.contents and isinstance(msg.contents[0], FunctionResultContent): call_id = str(msg.contents[0].call_id) @@ -305,16 +287,7 @@ def _find_matching_func_call(call_id: str) -> FunctionCallContent | None: return None def _parse_arguments(arguments: Any) -> dict[str, Any] | None: - if isinstance(arguments, dict): - return arguments - if isinstance(arguments, str): - try: - parsed = json.loads(arguments) - except json.JSONDecodeError: - return None - if isinstance(parsed, dict): - return parsed - return None + return safe_json_parse(arguments) def _resolve_approval_call_id(tool_call_id: str, parsed_payload: dict[str, Any] | None) -> str | None: if parsed_payload: @@ -375,7 +348,7 @@ def _filter_modified_args( for msg in messages: # Handle standard tool result messages early (role="tool") to preserve provider invariants # This path maps AG‑UI tool messages to FunctionResultContent with the correct tool_call_id - role_str = _normalize_agui_role(msg.get("role", "user")) + role_str = normalize_agui_role(msg.get("role", "user")) if role_str == "tool": # Prefer explicit tool_call_id fields; fall back to backend fields only if necessary tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId") @@ -599,7 +572,7 @@ def _filter_modified_args( # No special handling required for assistant/plain messages here - role = _AGUI_TO_FRAMEWORK_ROLE.get(role_str, Role.USER) + role = AGUI_TO_FRAMEWORK_ROLE.get(role_str, Role.USER) # Check if this message contains function approvals if "function_approvals" in msg and msg["function_approvals"]: @@ -655,7 +628,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str if isinstance(msg, dict): # Always work on a copy to avoid mutating input normalized_msg = msg.copy() - normalized_msg["role"] = _normalize_agui_role(normalized_msg.get("role")) + normalized_msg["role"] = normalize_agui_role(normalized_msg.get("role")) # Ensure ID exists if "id" not in normalized_msg: normalized_msg["id"] = generate_event_id() @@ -672,7 +645,7 @@ def agent_framework_messages_to_agui(messages: list[ChatMessage] | list[dict[str continue # Convert ChatMessage to AG-UI format - role = _FRAMEWORK_TO_AGUI_ROLE.get(msg.role, "user") + role = FRAMEWORK_TO_AGUI_ROLE.get(msg.role, "user") content_text = "" tool_calls: list[dict[str, Any]] = [] @@ -798,7 +771,7 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic function_payload_dict["arguments"] = json.dumps(arguments) # Normalize tool_call_id to toolCallId for tool messages - normalized_msg["role"] = _normalize_agui_role(normalized_msg.get("role")) + normalized_msg["role"] = normalize_agui_role(normalized_msg.get("role")) if normalized_msg.get("role") == "tool": if "tool_call_id" in normalized_msg: normalized_msg["toolCallId"] = normalized_msg["tool_call_id"] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py new file mode 100644 index 0000000000..ebf6ef6f57 --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -0,0 +1,391 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Helper functions for orchestration logic.""" + +import json +import logging +from typing import TYPE_CHECKING, Any + +from ag_ui.core import StateSnapshotEvent +from agent_framework import ( + ChatMessage, + FunctionApprovalResponseContent, + FunctionCallContent, + FunctionResultContent, + TextContent, +) + +from .._utils import get_role_value, safe_json_parse + +if TYPE_CHECKING: + from .._events import AgentFrameworkEventBridge + from ._state_manager import StateManager + +logger = logging.getLogger(__name__) + + +def pending_tool_call_ids(messages: list[ChatMessage]) -> set[str]: + """Get IDs of tool calls without corresponding results. + + Args: + messages: List of messages to scan + + Returns: + Set of pending tool call IDs + """ + pending_ids: set[str] = set() + resolved_ids: set[str] = set() + for msg in messages: + for content in msg.contents: + if isinstance(content, FunctionCallContent) and content.call_id: + pending_ids.add(str(content.call_id)) + elif isinstance(content, FunctionResultContent) and content.call_id: + resolved_ids.add(str(content.call_id)) + return pending_ids - resolved_ids + + +def is_state_context_message(message: ChatMessage) -> bool: + """Check if a message is a state context system message. + + Args: + message: Message to check + + Returns: + True if this is a state context message + """ + if get_role_value(message) != "system": + return False + for content in message.contents: + if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): + return True + return False + + +def ensure_tool_call_entry( + tool_call_id: str, + tool_calls_by_id: dict[str, dict[str, Any]], + pending_tool_calls: list[dict[str, Any]], +) -> dict[str, Any]: + """Get or create a tool call entry in the tracking dicts. + + Args: + tool_call_id: The tool call ID + tool_calls_by_id: Dict mapping IDs to tool call entries + pending_tool_calls: List of pending tool calls + + Returns: + The tool call entry dict + """ + entry = tool_calls_by_id.get(tool_call_id) + if entry is None: + entry = { + "id": tool_call_id, + "type": "function", + "function": { + "name": "", + "arguments": "", + }, + } + tool_calls_by_id[tool_call_id] = entry + pending_tool_calls.append(entry) + return entry + + +def tool_name_for_call_id( + tool_calls_by_id: dict[str, dict[str, Any]], + tool_call_id: str, +) -> str | None: + """Get the tool name for a given call ID. + + Args: + tool_calls_by_id: Dict mapping IDs to tool call entries + tool_call_id: The tool call ID to look up + + Returns: + Tool name or None if not found + """ + entry = tool_calls_by_id.get(tool_call_id) + if not entry: + return None + function = entry.get("function") + if not isinstance(function, dict): + return None + name = function.get("name") + return str(name) if name else None + + +def tool_calls_match_state( + provider_messages: list[ChatMessage], + state_manager: "StateManager", +) -> bool: + """Check if tool calls in messages match current state. + + Args: + provider_messages: Messages to check + state_manager: State manager with config and current state + + Returns: + True if tool calls match state configuration + """ + if not state_manager.predict_state_config or not state_manager.current_state: + return False + + for state_key, config in state_manager.predict_state_config.items(): + tool_name = config["tool"] + tool_arg_name = config["tool_argument"] + tool_args: dict[str, Any] | None = None + + for msg in reversed(provider_messages): + if get_role_value(msg) != "assistant": + continue + for content in msg.contents: + if isinstance(content, FunctionCallContent) and content.name == tool_name: + tool_args = safe_json_parse(content.arguments) + break + if tool_args is not None: + break + + if not tool_args: + return False + + if tool_arg_name == "*": + state_value = tool_args + elif tool_arg_name in tool_args: + state_value = tool_args[tool_arg_name] + else: + return False + + if state_manager.current_state.get(state_key) != state_value: + return False + + return True + + +def schema_has_steps(schema: Any) -> bool: + """Check if a schema has a steps array property. + + Args: + schema: JSON schema to check + + Returns: + True if schema has steps array + """ + if not isinstance(schema, dict): + return False + properties = schema.get("properties") + if not isinstance(properties, dict): + return False + steps_schema = properties.get("steps") + if not isinstance(steps_schema, dict): + return False + return steps_schema.get("type") == "array" + + +def select_approval_tool_name(client_tools: list[Any] | None) -> str | None: + """Select appropriate approval tool from client tools. + + Args: + client_tools: List of client tool definitions + + Returns: + Name of approval tool, or None if not found + """ + if not client_tools: + return None + for tool in client_tools: + tool_name = getattr(tool, "name", None) + if not tool_name: + continue + params_fn = getattr(tool, "parameters", None) + if not callable(params_fn): + continue + schema = params_fn() + if schema_has_steps(schema): + return str(tool_name) + return None + + +def select_messages_to_run( + provider_messages: list[ChatMessage], + state_manager: "StateManager", +) -> list[ChatMessage]: + """Select and prepare messages for agent execution. + + Injects state context message when appropriate. + + Args: + provider_messages: Original messages from client + state_manager: State manager instance + + Returns: + Messages ready for agent execution + """ + if not provider_messages: + return [] + + is_new_user_turn = get_role_value(provider_messages[-1]) == "user" + conversation_has_tool_calls = tool_calls_match_state(provider_messages, state_manager) + state_context_msg = state_manager.state_context_message( + is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls + ) + if not state_context_msg: + return list(provider_messages) + + messages_to_run = [msg for msg in provider_messages if not is_state_context_message(msg)] + if pending_tool_call_ids(messages_to_run): + return messages_to_run + + insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) + if insert_index < 0: + insert_index = 0 + messages_to_run.insert(insert_index, state_context_msg) + return messages_to_run + + +def build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: + """Build metadata dict with truncated string values. + + Args: + thread_metadata: Raw metadata dict + + Returns: + Metadata with string values truncated to 512 chars + """ + if not thread_metadata: + return {} + safe_metadata: dict[str, Any] = {} + for key, value in thread_metadata.items(): + value_str = value if isinstance(value, str) else json.dumps(value) + if len(value_str) > 512: + value_str = value_str[:512] + safe_metadata[key] = value_str + return safe_metadata + + +def collect_approved_state_snapshots( + provider_messages: list[ChatMessage], + predict_state_config: dict[str, dict[str, str]] | None, + current_state: dict[str, Any], + event_bridge: "AgentFrameworkEventBridge", +) -> list[StateSnapshotEvent]: + """Collect state snapshots from approved function calls. + + Args: + provider_messages: Messages containing approvals + predict_state_config: Predictive state configuration + current_state: Current state dict (will be mutated) + event_bridge: Event bridge for creating events + + Returns: + List of state snapshot events + """ + if not predict_state_config: + return [] + + events: list[StateSnapshotEvent] = [] + for msg in provider_messages: + if get_role_value(msg) != "user": + continue + for content in msg.contents: + if type(content) is FunctionApprovalResponseContent: + if not content.function_call or not content.approved: + continue + parsed_args = content.function_call.parse_arguments() + state_args = None + if content.additional_properties: + state_args = content.additional_properties.get("ag_ui_state_args") + if not isinstance(state_args, dict): + state_args = parsed_args + if not state_args: + continue + for state_key, config in predict_state_config.items(): + if config["tool"] != content.function_call.name: + continue + tool_arg_name = config["tool_argument"] + if tool_arg_name == "*": + state_value = state_args + elif isinstance(state_args, dict) and tool_arg_name in state_args: + state_value = state_args[tool_arg_name] + else: + continue + current_state[state_key] = state_value + event_bridge.current_state[state_key] = state_value + logger.info( + f"Emitting StateSnapshotEvent for approved state key '{state_key}' " + f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" + ) + events.append(StateSnapshotEvent(snapshot=current_state)) + break + return events + + +def latest_approval_response(messages: list[ChatMessage]) -> FunctionApprovalResponseContent | None: + """Get the latest approval response from messages. + + Args: + messages: Messages to search + + Returns: + Latest approval response or None + """ + if not messages: + return None + last_message = messages[-1] + for content in last_message.contents: + if type(content) is FunctionApprovalResponseContent: + return content + return None + + +def approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: + """Extract steps from an approval response. + + Args: + approval: Approval response content + + Returns: + List of steps, or empty list if none + """ + state_args: Any | None = None + if approval.additional_properties: + state_args = approval.additional_properties.get("ag_ui_state_args") + if isinstance(state_args, dict): + steps = state_args.get("steps") + if isinstance(steps, list): + return steps + + if approval.function_call: + parsed_args = approval.function_call.parse_arguments() + if isinstance(parsed_args, dict): + steps = parsed_args.get("steps") + if isinstance(steps, list): + return steps + + return [] + + +def is_step_based_approval( + approval: FunctionApprovalResponseContent, + predict_state_config: dict[str, dict[str, str]] | None, +) -> bool: + """Check if an approval is step-based. + + Args: + approval: Approval response to check + predict_state_config: Predictive state configuration + + Returns: + True if this is a step-based approval + """ + steps = approval_steps(approval) + if steps: + return True + if not approval.function_call: + return False + if not predict_state_config: + return False + tool_name = approval.function_call.name + for config in predict_state_config.values(): + if config.get("tool") == tool_name and config.get("tool_argument") == "steps": + return True + return False diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py new file mode 100644 index 0000000000..8662036bbf --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_predictive_state.py @@ -0,0 +1,230 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Predictive state handling utilities.""" + +import json +import logging +import re +from typing import Any + +from ag_ui.core import StateDeltaEvent + +from .._utils import safe_json_parse + +logger = logging.getLogger(__name__) + + +class PredictiveStateHandler: + """Handles predictive state updates from streaming tool calls.""" + + def __init__( + self, + predict_state_config: dict[str, dict[str, str]] | None = None, + current_state: dict[str, Any] | None = None, + ) -> None: + """Initialize the handler. + + Args: + predict_state_config: Configuration mapping state keys to tool/argument pairs + current_state: Reference to current state dict + """ + self.predict_state_config = predict_state_config or {} + self.current_state = current_state or {} + self.streaming_tool_args: str = "" + self.last_emitted_state: dict[str, Any] = {} + self.state_delta_count: int = 0 + self.pending_state_updates: dict[str, Any] = {} + + def reset_streaming(self) -> None: + """Reset streaming state for a new tool call.""" + self.streaming_tool_args = "" + self.state_delta_count = 0 + + def extract_state_value( + self, + tool_name: str, + args: dict[str, Any] | str | None, + ) -> tuple[str, Any] | None: + """Extract state value from tool arguments based on config. + + Args: + tool_name: Name of the tool being called + args: Tool arguments (dict or JSON string) + + Returns: + Tuple of (state_key, state_value) or None if no match + """ + if not self.predict_state_config: + return None + + parsed_args = safe_json_parse(args) if isinstance(args, str) else args + if not parsed_args: + return None + + for state_key, config in self.predict_state_config.items(): + if config["tool"] != tool_name: + continue + tool_arg_name = config["tool_argument"] + if tool_arg_name == "*": + return (state_key, parsed_args) + if tool_arg_name in parsed_args: + return (state_key, parsed_args[tool_arg_name]) + + return None + + def is_predictive_tool(self, tool_name: str | None) -> bool: + """Check if a tool is configured for predictive state. + + Args: + tool_name: Name of the tool to check + + Returns: + True if tool is in predictive state config + """ + if not tool_name or not self.predict_state_config: + return False + for config in self.predict_state_config.values(): + if config["tool"] == tool_name: + return True + return False + + def emit_streaming_deltas( + self, + tool_name: str | None, + argument_chunk: str, + ) -> list[StateDeltaEvent]: + """Process streaming argument chunk and emit state deltas. + + Args: + tool_name: Name of the current tool + argument_chunk: New chunk of JSON arguments + + Returns: + List of state delta events to emit + """ + events: list[StateDeltaEvent] = [] + if not tool_name or not self.predict_state_config: + return events + + self.streaming_tool_args += argument_chunk + logger.debug( + "Predictive state: accumulated %s chars for tool '%s'", + len(self.streaming_tool_args), + tool_name, + ) + + # Try to parse complete JSON first + parsed_args = None + try: + parsed_args = json.loads(self.streaming_tool_args) + except json.JSONDecodeError: + # Fall back to regex matching for partial JSON + events.extend(self._emit_partial_deltas(tool_name)) + + if parsed_args: + events.extend(self._emit_complete_deltas(tool_name, parsed_args)) + + return events + + def _emit_partial_deltas(self, tool_name: str) -> list[StateDeltaEvent]: + """Emit deltas from partial JSON using regex matching. + + Args: + tool_name: Name of the current tool + + Returns: + List of state delta events + """ + events: list[StateDeltaEvent] = [] + + for state_key, config in self.predict_state_config.items(): + if config["tool"] != tool_name: + continue + tool_arg_name = config["tool_argument"] + pattern = rf'"{re.escape(tool_arg_name)}":\s*"([^"]*)' + match = re.search(pattern, self.streaming_tool_args) + + if match: + partial_value = match.group(1).replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\") + + if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != partial_value: + event = self._create_delta_event(state_key, partial_value) + events.append(event) + self.last_emitted_state[state_key] = partial_value + self.pending_state_updates[state_key] = partial_value + + return events + + def _emit_complete_deltas( + self, + tool_name: str, + parsed_args: dict[str, Any], + ) -> list[StateDeltaEvent]: + """Emit deltas from complete parsed JSON. + + Args: + tool_name: Name of the current tool + parsed_args: Fully parsed arguments dict + + Returns: + List of state delta events + """ + events: list[StateDeltaEvent] = [] + + for state_key, config in self.predict_state_config.items(): + if config["tool"] != tool_name: + continue + tool_arg_name = config["tool_argument"] + + if tool_arg_name == "*": + state_value = parsed_args + elif tool_arg_name in parsed_args: + state_value = parsed_args[tool_arg_name] + else: + continue + + if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != state_value: + event = self._create_delta_event(state_key, state_value) + events.append(event) + self.last_emitted_state[state_key] = state_value + self.pending_state_updates[state_key] = state_value + + return events + + def _create_delta_event(self, state_key: str, value: Any) -> StateDeltaEvent: + """Create a state delta event with logging. + + Args: + state_key: The state key being updated + value: The new value + + Returns: + StateDeltaEvent instance + """ + self.state_delta_count += 1 + if self.state_delta_count % 10 == 1: + logger.info( + "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value_length=%s", + self.state_delta_count, + state_key, + state_key, + len(str(value)), + ) + elif self.state_delta_count % 100 == 0: + logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") + + return StateDeltaEvent( + delta=[ + { + "op": "replace", + "path": f"/{state_key}", + "value": value, + } + ], + ) + + def apply_pending_updates(self) -> None: + """Apply pending updates to current state and clear them.""" + for key, value in self.pending_state_updates.items(): + self.current_state[key] = value + self.pending_state_updates.clear() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6ade68cf51..3067e3e4a7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -13,7 +13,6 @@ BaseEvent, MessagesSnapshotEvent, RunErrorEvent, - StateSnapshotEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, @@ -26,7 +25,6 @@ AgentProtocol, AgentThread, ChatAgent, - FunctionApprovalResponseContent, FunctionCallContent, FunctionResultContent, TextContent, @@ -39,282 +37,32 @@ _try_execute_function_calls, # type: ignore ) -from ._utils import convert_agui_tools_to_agent_framework, generate_event_id +from ._orchestration._helpers import ( + approval_steps, + build_safe_metadata, + collect_approved_state_snapshots, + ensure_tool_call_entry, + is_step_based_approval, + latest_approval_response, + select_approval_tool_name, + select_messages_to_run, + tool_name_for_call_id, +) +from ._orchestration._tooling import ( + collect_server_tools, + merge_tools, + register_additional_client_tools, +) +from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value if TYPE_CHECKING: from ._agent import AgentConfig from ._confirmation_strategies import ConfirmationStrategy - from ._events import AgentFrameworkEventBridge - from ._orchestration._state_manager import StateManager logger = logging.getLogger(__name__) -def _role_value(message: Any) -> str: - role = getattr(message, "role", None) - if role is None: - return "" - if hasattr(role, "value"): - return str(role.value) - return str(role) - - -def _pending_tool_call_ids(messages: list[Any]) -> set[str]: - pending_ids: set[str] = set() - resolved_ids: set[str] = set() - for msg in messages: - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.call_id: - pending_ids.add(str(content.call_id)) - elif isinstance(content, FunctionResultContent) and content.call_id: - resolved_ids.add(str(content.call_id)) - return pending_ids - resolved_ids - - -def _is_state_context_message(message: Any) -> bool: - if _role_value(message) != "system": - return False - for content in message.contents or []: - if isinstance(content, TextContent) and content.text.startswith("Current state of the application:"): - return True - return False - - -def _ensure_tool_call_entry( - tool_call_id: str, - tool_calls_by_id: dict[str, dict[str, Any]], - pending_tool_calls: list[dict[str, Any]], -) -> dict[str, Any]: - entry = tool_calls_by_id.get(tool_call_id) - if entry is None: - entry = { - "id": tool_call_id, - "type": "function", - "function": { - "name": "", - "arguments": "", - }, - } - tool_calls_by_id[tool_call_id] = entry - pending_tool_calls.append(entry) - return entry - - -def _tool_name_for_call_id(tool_calls_by_id: dict[str, dict[str, Any]], tool_call_id: str) -> str | None: - entry = tool_calls_by_id.get(tool_call_id) - if not entry: - return None - function = entry.get("function") - if not isinstance(function, dict): - return None - name = function.get("name") - return str(name) if name else None - - -def _tool_calls_match_state(provider_messages: list[Any], state_manager: "StateManager") -> bool: - if not state_manager.predict_state_config or not state_manager.current_state: - return False - - def _parse_args(arguments: Any) -> dict[str, Any] | None: - if isinstance(arguments, dict): - return arguments - if isinstance(arguments, str): - try: - parsed = json.loads(arguments) - except json.JSONDecodeError: - return None - if isinstance(parsed, dict): - return parsed - return None - - for state_key, config in state_manager.predict_state_config.items(): - tool_name = config["tool"] - tool_arg_name = config["tool_argument"] - tool_args: dict[str, Any] | None = None - - for msg in reversed(provider_messages): - if _role_value(msg) != "assistant": - continue - for content in msg.contents or []: - if isinstance(content, FunctionCallContent) and content.name == tool_name: - tool_args = _parse_args(content.arguments) - break - if tool_args is not None: - break - - if not tool_args: - return False - - if tool_arg_name == "*": - state_value = tool_args - elif tool_arg_name in tool_args: - state_value = tool_args[tool_arg_name] - else: - return False - - if state_manager.current_state.get(state_key) != state_value: - return False - - return True - - -def _schema_has_steps(schema: Any) -> bool: - if not isinstance(schema, dict): - return False - properties = schema.get("properties") - if not isinstance(properties, dict): - return False - steps_schema = properties.get("steps") - if not isinstance(steps_schema, dict): - return False - return steps_schema.get("type") == "array" - - -def _select_approval_tool_name(client_tools: list[Any] | None) -> str | None: - if not client_tools: - return None - for tool in client_tools: - tool_name = getattr(tool, "name", None) - if not tool_name: - continue - params_fn = getattr(tool, "parameters", None) - if not callable(params_fn): - continue - schema = params_fn() - if _schema_has_steps(schema): - return str(tool_name) - return None - - -def _select_messages_to_run(provider_messages: list[Any], state_manager: "StateManager") -> list[Any]: - if not provider_messages: - return [] - - is_new_user_turn = _role_value(provider_messages[-1]) == "user" - conversation_has_tool_calls = _tool_calls_match_state(provider_messages, state_manager) - state_context_msg = state_manager.state_context_message( - is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls - ) - if not state_context_msg: - return list(provider_messages) - - messages_to_run = [msg for msg in provider_messages if not _is_state_context_message(msg)] - if _pending_tool_call_ids(messages_to_run): - return messages_to_run - - insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) - if insert_index < 0: - insert_index = 0 - messages_to_run.insert(insert_index, state_context_msg) - return messages_to_run - - -def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: - if not thread_metadata: - return {} - safe_metadata: dict[str, Any] = {} - for key, value in thread_metadata.items(): - value_str = value if isinstance(value, str) else json.dumps(value) - if len(value_str) > 512: - value_str = value_str[:512] - safe_metadata[key] = value_str - return safe_metadata - - -def _collect_approved_state_snapshots( - provider_messages: list[Any], - predict_state_config: dict[str, dict[str, str]], - current_state: dict[str, Any], - event_bridge: "AgentFrameworkEventBridge", -) -> list[StateSnapshotEvent]: - if not predict_state_config: - return [] - - events: list[StateSnapshotEvent] = [] - for msg in provider_messages: - if _role_value(msg) != "user": - continue - for content in msg.contents or []: - if type(content) is FunctionApprovalResponseContent: - if not content.function_call or not content.approved: - continue - parsed_args = content.function_call.parse_arguments() - state_args = None - if content.additional_properties: - state_args = content.additional_properties.get("ag_ui_state_args") - if not isinstance(state_args, dict): - state_args = parsed_args - if not state_args: - continue - for state_key, config in predict_state_config.items(): - if config["tool"] != content.function_call.name: - continue - tool_arg_name = config["tool_argument"] - if tool_arg_name == "*": - state_value = state_args - elif isinstance(state_args, dict) and tool_arg_name in state_args: - state_value = state_args[tool_arg_name] - else: - continue - current_state[state_key] = state_value - event_bridge.current_state[state_key] = state_value - logger.info( - f"Emitting StateSnapshotEvent for approved state key '{state_key}' " - f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" - ) - events.append(StateSnapshotEvent(snapshot=current_state)) - break - return events - - -def _latest_approval_response(messages: list[Any]) -> FunctionApprovalResponseContent | None: - if not messages: - return None - last_message = messages[-1] - for content in last_message.contents or []: - if type(content) is FunctionApprovalResponseContent: - return content - return None - - -def _approval_steps(approval: FunctionApprovalResponseContent) -> list[Any]: - state_args: Any | None = None - if approval.additional_properties: - state_args = approval.additional_properties.get("ag_ui_state_args") - if isinstance(state_args, dict): - steps = state_args.get("steps") - if isinstance(steps, list): - return steps - - if approval.function_call: - parsed_args = approval.function_call.parse_arguments() - if isinstance(parsed_args, dict): - steps = parsed_args.get("steps") - if isinstance(steps, list): - return steps - - return [] - - -def _is_step_based_approval( - approval: FunctionApprovalResponseContent, - predict_state_config: dict[str, dict[str, str]], -) -> bool: - steps = _approval_steps(approval) - if steps: - return True - if not approval.function_call: - return False - if not predict_state_config: - return False - tool_name = approval.function_call.name - for config in predict_state_config.values(): - if config.get("tool") == tool_name and config.get("tool_argument") == "steps": - return True - return False - - class ExecutionContext: """Shared context for orchestrators.""" @@ -566,11 +314,6 @@ async def run( """ from ._events import AgentFrameworkEventBridge from ._orchestration._state_manager import StateManager - from ._orchestration._tooling import ( - collect_server_tools, - merge_tools, - register_additional_client_tools, - ) logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}") @@ -580,7 +323,7 @@ async def run( skip_text_content = response_format is not None client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) - approval_tool_name = _select_approval_tool_name(client_tools) + approval_tool_name = select_approval_tool_name(client_tools) state_manager = StateManager( state_schema=context.config.state_schema, @@ -626,7 +369,7 @@ async def run( logger.info(f"Received {len(provider_messages)} provider messages from client") for i, msg in enumerate(provider_messages): - role = _role_value(msg) + role = get_role_value(msg) msg_id = getattr(msg, "message_id", None) logger.info(f" Message {i}: role={role}, id={msg_id}") if hasattr(msg, "contents") and msg.contents: @@ -661,7 +404,7 @@ async def run( # Check for FunctionApprovalResponseContent and emit updated state snapshot # This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps) - for snapshot_evt in _collect_approved_state_snapshots( + for snapshot_evt in collect_approved_state_snapshots( provider_messages, context.config.predict_state_config, current_state, @@ -669,7 +412,7 @@ async def run( ): yield snapshot_evt - messages_to_run = _select_messages_to_run(provider_messages, state_manager) + messages_to_run = select_messages_to_run(provider_messages, state_manager) logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") if client_tools: @@ -686,7 +429,7 @@ async def run( all_updates: list[Any] | None = [] if collect_updates else None update_count = 0 # Prepare metadata for chat client (Azure requires string values) - safe_metadata = _build_safe_metadata(getattr(thread, "metadata", None)) + safe_metadata = build_safe_metadata(getattr(thread, "metadata", None)) run_kwargs: dict[str, Any] = { "thread": thread, @@ -788,17 +531,17 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap # Use tools_param if available (includes client tools), otherwise fall back to server_tools # This ensures both server tools AND client tools can be executed after approval tools_for_approval = tools_param if tools_param is not None else server_tools - latest_approval = _latest_approval_response(messages_to_run) + latest_approval = latest_approval_response(messages_to_run) await _resolve_approval_responses(messages_to_run, tools_for_approval) - if latest_approval and _is_step_based_approval(latest_approval, context.config.predict_state_config): + if latest_approval and is_step_based_approval(latest_approval, context.config.predict_state_config): from ._confirmation_strategies import DefaultConfirmationStrategy strategy = context.confirmation_strategy if strategy is None: strategy = DefaultConfirmationStrategy() - steps = _approval_steps(latest_approval) + steps = approval_steps(latest_approval) if steps: if latest_approval.approved: confirmation_message = strategy.on_approval_accepted(steps) @@ -844,10 +587,10 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap elif isinstance(event, TextMessageContentEvent): accumulated_text_content += event.delta elif isinstance(event, ToolCallStartEvent): - tool_call_entry = _ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) + tool_call_entry = ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) tool_call_entry["function"]["name"] = event.tool_call_name elif isinstance(event, ToolCallArgsEvent): - tool_call_entry = _ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) + tool_call_entry = ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) tool_call_entry["function"]["arguments"] += event.delta elif isinstance(event, ToolCallEndEvent): tool_calls_ended.add(event.tool_call_id) @@ -863,14 +606,14 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event if isinstance(event, ToolCallResultEvent): - tool_name = _tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) + tool_name = tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) if _should_emit_tool_snapshot(tool_name): messages_snapshot_emitted = True messages_snapshot = _build_messages_snapshot() logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") yield messages_snapshot elif isinstance(event, ToolCallEndEvent): - tool_name = _tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) + tool_name = tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) if tool_name == "confirm_changes": messages_snapshot_emitted = True messages_snapshot = _build_messages_snapshot() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index 8b271988dc..c0da986308 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -3,13 +3,29 @@ """Utility functions for AG-UI integration.""" import copy +import json import uuid from collections.abc import Callable, MutableMapping, Sequence from dataclasses import asdict, is_dataclass from datetime import date, datetime from typing import Any -from agent_framework import AIFunction, ToolProtocol +from agent_framework import AIFunction, Role, ToolProtocol + +# Role mapping constants +AGUI_TO_FRAMEWORK_ROLE: dict[str, Role] = { + "user": Role.USER, + "assistant": Role.ASSISTANT, + "system": Role.SYSTEM, +} + +FRAMEWORK_TO_AGUI_ROLE: dict[Role, str] = { + Role.USER: "user", + Role.ASSISTANT: "assistant", + Role.SYSTEM: "system", +} + +ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool"} def generate_event_id() -> str: @@ -17,6 +33,85 @@ def generate_event_id() -> str: return str(uuid.uuid4()) +def safe_json_parse(value: Any) -> dict[str, Any] | None: + """Safely parse a value as JSON dict. + + Args: + value: String or dict to parse + + Returns: + Parsed dict or None if parsing fails + """ + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, dict): + return parsed + except json.JSONDecodeError: + pass + return None + + +def get_role_value(message: Any) -> str: + """Extract role string from a message object. + + Handles both enum roles (with .value) and string roles. + + Args: + message: Message object with role attribute + + Returns: + Role as lowercase string, or empty string if not found + """ + role = getattr(message, "role", None) + if role is None: + return "" + if hasattr(role, "value"): + return str(role.value) + return str(role) + + +def normalize_agui_role(raw_role: Any) -> str: + """Normalize an AG-UI role to a standard role string. + + Args: + raw_role: Raw role value from AG-UI message + + Returns: + Normalized role string (user, assistant, system, or tool) + """ + if not isinstance(raw_role, str): + return "user" + role = raw_role.lower() + if role == "developer": + return "system" + if role in ALLOWED_AGUI_ROLES: + return role + return "user" + + +def extract_state_from_tool_args( + args: dict[str, Any] | None, + tool_arg_name: str, +) -> Any: + """Extract state value from tool arguments based on config. + + Args: + args: Parsed tool arguments dict + tool_arg_name: Name of the argument to extract, or "*" for entire args + + Returns: + Extracted state value, or None if not found + """ + if not args: + return None + if tool_arg_name == "*": + return args + return args.get(tool_arg_name) + + def merge_state(current: dict[str, Any], update: dict[str, Any]) -> dict[str, Any]: """Merge state updates. From 47040e3973b5243a4a278d128d484b54c3a0f034 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Fri, 9 Jan 2026 09:28:29 +0900 Subject: [PATCH 8/8] Add version detection in __init__.py to ruff ignore list --- python/pyproject.toml | 3 +- python/uv.lock | 74 +++++++++++++++++++++---------------------- 2 files changed, 39 insertions(+), 38 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index ed98cc8020..cbdf3f0d75 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -145,7 +145,8 @@ ignore = [ "D418", # allow overload to have a docstring "TD003", # allow missing link to todo issue "FIX002", # allow todo - "B027" # allow empty non-abstract method in ABC + "B027", # allow empty non-abstract method in ABC + "RUF067", # allow version detection in __init__.py ] [tool.ruff.lint.per-file-ignores] diff --git a/python/uv.lock b/python/uv.lock index 0ad64ed9ea..15d1fd855e 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -2237,7 +2237,7 @@ wheels = [ [[package]] name = "google-api-core" -version = "2.28.1" +version = "2.29.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -2246,9 +2246,9 @@ dependencies = [ { name = "protobuf", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/61/da/83d7043169ac2c8c7469f0e375610d78ae2160134bf1b80634c482fa079c/google_api_core-2.28.1.tar.gz", hash = "sha256:2b405df02d68e68ce0fbc138559e6036559e685159d148ae5861013dc201baf8", size = 176759, upload-time = "2025-10-28T21:34:51.529Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0d/10/05572d33273292bac49c2d1785925f7bc3ff2fe50e3044cf1062c1dde32e/google_api_core-2.29.0.tar.gz", hash = "sha256:84181be0f8e6b04006df75ddfe728f24489f0af57c96a529ff7cf45bc28797f7", size = 177828, upload-time = "2026-01-08T22:21:39.269Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/d4/90197b416cb61cefd316964fd9e7bd8324bcbafabf40eef14a9f20b81974/google_api_core-2.28.1-py3-none-any.whl", hash = "sha256:4021b0f8ceb77a6fb4de6fde4502cecab45062e66ff4f2895169e0b35bc9466c", size = 173706, upload-time = "2025-10-28T21:34:50.151Z" }, + { url = "https://files.pythonhosted.org/packages/77/b6/85c4d21067220b9a78cfb81f516f9725ea6befc1544ec9bd2c1acd97c324/google_api_core-2.29.0-py3-none-any.whl", hash = "sha256:d30bc60980daa36e314b5d5a3e5958b0200cb44ca8fa1be2b614e932b75a3ea9", size = 173906, upload-time = "2026-01-08T22:21:36.093Z" }, ] [[package]] @@ -4277,11 +4277,11 @@ wheels = [ [[package]] name = "pathspec" -version = "1.0.1" +version = "1.0.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/28/2e/83722ece0f6ee24387d6cb830dd562ddbcd6ce0b9d76072c6849670c31b4/pathspec-1.0.1.tar.gz", hash = "sha256:e2769b508d0dd47b09af6ee2c75b2744a2cb1f474ae4b1494fd6a1b7a841613c", size = 129791, upload-time = "2026-01-06T13:02:55.15Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/b9/6eb731b52f132181a9144bbe77ff82117f6b2d2fbfba49aaab2c014c4760/pathspec-1.0.2.tar.gz", hash = "sha256:fa32b1eb775ed9ba8d599b22c5f906dc098113989da2c00bf8b210078ca7fb92", size = 130502, upload-time = "2026-01-08T04:33:27.613Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/fe/2257c71721aeab6a6e8aa1f00d01f2a20f58547d249a6c8fef5791f559fc/pathspec-1.0.1-py3-none-any.whl", hash = "sha256:8870061f22c58e6d83463cfce9a7dd6eca0512c772c1001fb09ac64091816721", size = 54584, upload-time = "2026-01-06T13:02:53.601Z" }, + { url = "https://files.pythonhosted.org/packages/78/6b/14fc9049d78435fd29e82846c777bd7ed9c470013dc8d0260fff3ff1c11e/pathspec-1.0.2-py3-none-any.whl", hash = "sha256:62f8558917908d237d399b9b338ef455a814801a4688bc41074b25feefd93472", size = 54844, upload-time = "2026-01-08T04:33:26.4Z" }, ] [[package]] @@ -4485,7 +4485,7 @@ wheels = [ [[package]] name = "posthog" -version = "7.5.0" +version = "7.5.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, @@ -4495,9 +4495,9 @@ dependencies = [ { name = "six", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/7d/7b81b79ab79de1d47230267a389df127d15761202123c0f705d00621ca61/posthog-7.5.0.tar.gz", hash = "sha256:ae57605508ff16bd5a89f392efb26c88e8f3019db8f35611fd94273bf51048e3", size = 144880, upload-time = "2026-01-07T13:11:52.07Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/3b/866af11cb12e9d35feffcd480d4ebf31f87b2164926b9c670cbdafabc814/posthog-7.5.1.tar.gz", hash = "sha256:d8a8165b3d47465023ea2f919982a34890e2dda76402ec47d6c68424b2534a55", size = 145244, upload-time = "2026-01-08T21:18:39.266Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cd/c7/42c0cf72d37256fed5552517ddcbe549b6b1408f38b78bc6d980a1d06bc2/posthog-7.5.0-py3-none-any.whl", hash = "sha256:e1cba868a804fe1a13d5c0aaf5bab70aa89fd067d73a4046fa9d3699e225c9d0", size = 167271, upload-time = "2026-01-07T13:11:48.948Z" }, + { url = "https://files.pythonhosted.org/packages/1f/03/ba011712ce9d07fe87dcfb72474c388d960e6d0c4f2262d2ae11fd27f0c5/posthog-7.5.1-py3-none-any.whl", hash = "sha256:fd3431ce32c9bbfb1e3775e3633c32ee589c052b0054fafe5ed9e4b17c1969d3", size = 167555, upload-time = "2026-01-08T21:18:37.437Z" }, ] [[package]] @@ -5017,15 +5017,15 @@ wheels = [ [[package]] name = "pyright" -version = "1.1.407" +version = "1.1.408" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodeenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/1b/0aa08ee42948b61745ac5b5b5ccaec4669e8884b53d31c8ec20b2fcd6b6f/pyright-1.1.407.tar.gz", hash = "sha256:099674dba5c10489832d4a4b2d302636152a9a42d317986c38474c76fe562262", size = 4122872, upload-time = "2025-10-24T23:17:15.145Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/93/b69052907d032b00c40cb656d21438ec00b3a471733de137a3f65a49a0a0/pyright-1.1.407-py3-none-any.whl", hash = "sha256:6dd419f54fcc13f03b52285796d65e639786373f433e243f8b94cf93a7444d21", size = 5997008, upload-time = "2025-10-24T23:17:13.159Z" }, + { url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" }, ] [[package]] @@ -5626,28 +5626,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/08/52232a877978dd8f9cf2aeddce3e611b40a63287dfca29b6b8da791f5e8d/ruff-0.14.10.tar.gz", hash = "sha256:9a2e830f075d1a42cd28420d7809ace390832a490ed0966fe373ba288e77aaf4", size = 5859763, upload-time = "2025-12-18T19:28:57.98Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/60/01/933704d69f3f05ee16ef11406b78881733c186fe14b6a46b05cfcaf6d3b2/ruff-0.14.10-py3-none-linux_armv6l.whl", hash = "sha256:7a3ce585f2ade3e1f29ec1b92df13e3da262178df8c8bdf876f48fa0e8316c49", size = 13527080, upload-time = "2025-12-18T19:29:25.642Z" }, - { url = "https://files.pythonhosted.org/packages/df/58/a0349197a7dfa603ffb7f5b0470391efa79ddc327c1e29c4851e85b09cc5/ruff-0.14.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:674f9be9372907f7257c51f1d4fc902cb7cf014b9980152b802794317941f08f", size = 13797320, upload-time = "2025-12-18T19:29:02.571Z" }, - { url = "https://files.pythonhosted.org/packages/7b/82/36be59f00a6082e38c23536df4e71cdbc6af8d7c707eade97fcad5c98235/ruff-0.14.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d85713d522348837ef9df8efca33ccb8bd6fcfc86a2cde3ccb4bc9d28a18003d", size = 12918434, upload-time = "2025-12-18T19:28:51.202Z" }, - { url = "https://files.pythonhosted.org/packages/a6/00/45c62a7f7e34da92a25804f813ebe05c88aa9e0c25e5cb5a7d23dd7450e3/ruff-0.14.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6987ebe0501ae4f4308d7d24e2d0fe3d7a98430f5adfd0f1fead050a740a3a77", size = 13371961, upload-time = "2025-12-18T19:29:04.991Z" }, - { url = "https://files.pythonhosted.org/packages/40/31/a5906d60f0405f7e57045a70f2d57084a93ca7425f22e1d66904769d1628/ruff-0.14.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:16a01dfb7b9e4eee556fbfd5392806b1b8550c9b4a9f6acd3dbe6812b193c70a", size = 13275629, upload-time = "2025-12-18T19:29:21.381Z" }, - { url = "https://files.pythonhosted.org/packages/3e/60/61c0087df21894cf9d928dc04bcd4fb10e8b2e8dca7b1a276ba2155b2002/ruff-0.14.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7165d31a925b7a294465fa81be8c12a0e9b60fb02bf177e79067c867e71f8b1f", size = 14029234, upload-time = "2025-12-18T19:29:00.132Z" }, - { url = "https://files.pythonhosted.org/packages/44/84/77d911bee3b92348b6e5dab5a0c898d87084ea03ac5dc708f46d88407def/ruff-0.14.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:c561695675b972effb0c0a45db233f2c816ff3da8dcfbe7dfc7eed625f218935", size = 15449890, upload-time = "2025-12-18T19:28:53.573Z" }, - { url = "https://files.pythonhosted.org/packages/e9/36/480206eaefa24a7ec321582dda580443a8f0671fdbf6b1c80e9c3e93a16a/ruff-0.14.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4bb98fcbbc61725968893682fd4df8966a34611239c9fd07a1f6a07e7103d08e", size = 15123172, upload-time = "2025-12-18T19:29:23.453Z" }, - { url = "https://files.pythonhosted.org/packages/5c/38/68e414156015ba80cef5473d57919d27dfb62ec804b96180bafdeaf0e090/ruff-0.14.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f24b47993a9d8cb858429e97bdf8544c78029f09b520af615c1d261bf827001d", size = 14460260, upload-time = "2025-12-18T19:29:27.808Z" }, - { url = "https://files.pythonhosted.org/packages/b3/19/9e050c0dca8aba824d67cc0db69fb459c28d8cd3f6855b1405b3f29cc91d/ruff-0.14.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59aabd2e2c4fd614d2862e7939c34a532c04f1084476d6833dddef4afab87e9f", size = 14229978, upload-time = "2025-12-18T19:29:11.32Z" }, - { url = "https://files.pythonhosted.org/packages/51/eb/e8dd1dd6e05b9e695aa9dd420f4577debdd0f87a5ff2fedda33c09e9be8c/ruff-0.14.10-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:213db2b2e44be8625002dbea33bb9c60c66ea2c07c084a00d55732689d697a7f", size = 14338036, upload-time = "2025-12-18T19:29:09.184Z" }, - { url = "https://files.pythonhosted.org/packages/6a/12/f3e3a505db7c19303b70af370d137795fcfec136d670d5de5391e295c134/ruff-0.14.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b914c40ab64865a17a9a5b67911d14df72346a634527240039eb3bd650e5979d", size = 13264051, upload-time = "2025-12-18T19:29:13.431Z" }, - { url = "https://files.pythonhosted.org/packages/08/64/8c3a47eaccfef8ac20e0484e68e0772013eb85802f8a9f7603ca751eb166/ruff-0.14.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1484983559f026788e3a5c07c81ef7d1e97c1c78ed03041a18f75df104c45405", size = 13283998, upload-time = "2025-12-18T19:29:06.994Z" }, - { url = "https://files.pythonhosted.org/packages/12/84/534a5506f4074e5cc0529e5cd96cfc01bb480e460c7edf5af70d2bcae55e/ruff-0.14.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c70427132db492d25f982fffc8d6c7535cc2fd2c83fc8888f05caaa248521e60", size = 13601891, upload-time = "2025-12-18T19:28:55.811Z" }, - { url = "https://files.pythonhosted.org/packages/0d/1e/14c916087d8598917dbad9b2921d340f7884824ad6e9c55de948a93b106d/ruff-0.14.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5bcf45b681e9f1ee6445d317ce1fa9d6cba9a6049542d1c3d5b5958986be8830", size = 14336660, upload-time = "2025-12-18T19:29:16.531Z" }, - { url = "https://files.pythonhosted.org/packages/f2/1c/d7b67ab43f30013b47c12b42d1acd354c195351a3f7a1d67f59e54227ede/ruff-0.14.10-py3-none-win32.whl", hash = "sha256:104c49fc7ab73f3f3a758039adea978869a918f31b73280db175b43a2d9b51d6", size = 13196187, upload-time = "2025-12-18T19:29:19.006Z" }, - { url = "https://files.pythonhosted.org/packages/fb/9c/896c862e13886fae2af961bef3e6312db9ebc6adc2b156fe95e615dee8c1/ruff-0.14.10-py3-none-win_amd64.whl", hash = "sha256:466297bd73638c6bdf06485683e812db1c00c7ac96d4ddd0294a338c62fdc154", size = 14661283, upload-time = "2025-12-18T19:29:30.16Z" }, - { url = "https://files.pythonhosted.org/packages/74/31/b0e29d572670dca3674eeee78e418f20bdf97fa8aa9ea71380885e175ca0/ruff-0.14.10-py3-none-win_arm64.whl", hash = "sha256:e51d046cf6dda98a4633b8a8a771451107413b0f07183b2bef03f075599e44e6", size = 13729839, upload-time = "2025-12-18T19:28:48.636Z" }, +version = "0.14.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/77/9a7fe084d268f8855d493e5031ea03fa0af8cc05887f638bf1c4e3363eb8/ruff-0.14.11.tar.gz", hash = "sha256:f6dc463bfa5c07a59b1ff2c3b9767373e541346ea105503b4c0369c520a66958", size = 5993417, upload-time = "2026-01-08T19:11:58.322Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f0/a6/a4c40a5aaa7e331f245d2dc1ac8ece306681f52b636b40ef87c88b9f7afd/ruff-0.14.11-py3-none-linux_armv6l.whl", hash = "sha256:f6ff2d95cbd335841a7217bdfd9c1d2e44eac2c584197ab1385579d55ff8830e", size = 12951208, upload-time = "2026-01-08T19:12:09.218Z" }, + { url = "https://files.pythonhosted.org/packages/5c/5c/360a35cb7204b328b685d3129c08aca24765ff92b5a7efedbdd6c150d555/ruff-0.14.11-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f6eb5c1c8033680f4172ea9c8d3706c156223010b8b97b05e82c59bdc774ee6", size = 13330075, upload-time = "2026-01-08T19:12:02.549Z" }, + { url = "https://files.pythonhosted.org/packages/1b/9e/0cc2f1be7a7d33cae541824cf3f95b4ff40d03557b575912b5b70273c9ec/ruff-0.14.11-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f2fc34cc896f90080fca01259f96c566f74069a04b25b6205d55379d12a6855e", size = 12257809, upload-time = "2026-01-08T19:12:00.366Z" }, + { url = "https://files.pythonhosted.org/packages/a7/e5/5faab97c15bb75228d9f74637e775d26ac703cc2b4898564c01ab3637c02/ruff-0.14.11-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53386375001773ae812b43205d6064dae49ff0968774e6befe16a994fc233caa", size = 12678447, upload-time = "2026-01-08T19:12:13.899Z" }, + { url = "https://files.pythonhosted.org/packages/1b/33/e9767f60a2bef779fb5855cab0af76c488e0ce90f7bb7b8a45c8a2ba4178/ruff-0.14.11-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a697737dce1ca97a0a55b5ff0434ee7205943d4874d638fe3ae66166ff46edbe", size = 12758560, upload-time = "2026-01-08T19:11:42.55Z" }, + { url = "https://files.pythonhosted.org/packages/eb/84/4c6cf627a21462bb5102f7be2a320b084228ff26e105510cd2255ea868e5/ruff-0.14.11-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6845ca1da8ab81ab1dce755a32ad13f1db72e7fba27c486d5d90d65e04d17b8f", size = 13599296, upload-time = "2026-01-08T19:11:30.371Z" }, + { url = "https://files.pythonhosted.org/packages/88/e1/92b5ed7ea66d849f6157e695dc23d5d6d982bd6aa8d077895652c38a7cae/ruff-0.14.11-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:e36ce2fd31b54065ec6f76cb08d60159e1b32bdf08507862e32f47e6dde8bcbf", size = 15048981, upload-time = "2026-01-08T19:12:04.742Z" }, + { url = "https://files.pythonhosted.org/packages/61/df/c1bd30992615ac17c2fb64b8a7376ca22c04a70555b5d05b8f717163cf9f/ruff-0.14.11-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:590bcc0e2097ecf74e62a5c10a6b71f008ad82eb97b0a0079e85defe19fe74d9", size = 14633183, upload-time = "2026-01-08T19:11:40.069Z" }, + { url = "https://files.pythonhosted.org/packages/04/e9/fe552902f25013dd28a5428a42347d9ad20c4b534834a325a28305747d64/ruff-0.14.11-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:53fe71125fc158210d57fe4da26e622c9c294022988d08d9347ec1cf782adafe", size = 14050453, upload-time = "2026-01-08T19:11:37.555Z" }, + { url = "https://files.pythonhosted.org/packages/ae/93/f36d89fa021543187f98991609ce6e47e24f35f008dfe1af01379d248a41/ruff-0.14.11-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a35c9da08562f1598ded8470fcfef2afb5cf881996e6c0a502ceb61f4bc9c8a3", size = 13757889, upload-time = "2026-01-08T19:12:07.094Z" }, + { url = "https://files.pythonhosted.org/packages/b7/9f/c7fb6ecf554f28709a6a1f2a7f74750d400979e8cd47ed29feeaa1bd4db8/ruff-0.14.11-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:0f3727189a52179393ecf92ec7057c2210203e6af2676f08d92140d3e1ee72c1", size = 13955832, upload-time = "2026-01-08T19:11:55.064Z" }, + { url = "https://files.pythonhosted.org/packages/db/a0/153315310f250f76900a98278cf878c64dfb6d044e184491dd3289796734/ruff-0.14.11-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:eb09f849bd37147a789b85995ff734a6c4a095bed5fd1608c4f56afc3634cde2", size = 12586522, upload-time = "2026-01-08T19:11:35.356Z" }, + { url = "https://files.pythonhosted.org/packages/2f/2b/a73a2b6e6d2df1d74bf2b78098be1572191e54bec0e59e29382d13c3adc5/ruff-0.14.11-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:c61782543c1231bf71041461c1f28c64b961d457d0f238ac388e2ab173d7ecb7", size = 12724637, upload-time = "2026-01-08T19:11:47.796Z" }, + { url = "https://files.pythonhosted.org/packages/f0/41/09100590320394401cd3c48fc718a8ba71c7ddb1ffd07e0ad6576b3a3df2/ruff-0.14.11-py3-none-musllinux_1_2_i686.whl", hash = "sha256:82ff352ea68fb6766140381748e1f67f83c39860b6446966cff48a315c3e2491", size = 13145837, upload-time = "2026-01-08T19:11:32.87Z" }, + { url = "https://files.pythonhosted.org/packages/3b/d8/e035db859d1d3edf909381eb8ff3e89a672d6572e9454093538fe6f164b0/ruff-0.14.11-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:728e56879df4ca5b62a9dde2dd0eb0edda2a55160c0ea28c4025f18c03f86984", size = 13850469, upload-time = "2026-01-08T19:12:11.694Z" }, + { url = "https://files.pythonhosted.org/packages/4e/02/bb3ff8b6e6d02ce9e3740f4c17dfbbfb55f34c789c139e9cd91985f356c7/ruff-0.14.11-py3-none-win32.whl", hash = "sha256:337c5dd11f16ee52ae217757d9b82a26400be7efac883e9e852646f1557ed841", size = 12851094, upload-time = "2026-01-08T19:11:45.163Z" }, + { url = "https://files.pythonhosted.org/packages/58/f1/90ddc533918d3a2ad628bc3044cdfc094949e6d4b929220c3f0eb8a1c998/ruff-0.14.11-py3-none-win_amd64.whl", hash = "sha256:f981cea63d08456b2c070e64b79cb62f951aa1305282974d4d5216e6e0178ae6", size = 14001379, upload-time = "2026-01-08T19:11:52.591Z" }, + { url = "https://files.pythonhosted.org/packages/c4/1c/1dbe51782c0e1e9cfce1d1004752672d2d4629ea46945d19d731ad772b3b/ruff-0.14.11-py3-none-win_arm64.whl", hash = "sha256:649fb6c9edd7f751db276ef42df1f3df41c38d67d199570ae2a7bd6cbc3590f0", size = 12938644, upload-time = "2026-01-08T19:11:50.027Z" }, ] [[package]] @@ -6764,14 +6764,14 @@ wheels = [ [[package]] name = "werkzeug" -version = "3.1.4" +version = "3.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markupsafe", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/45/ea/b0f8eeb287f8df9066e56e831c7824ac6bab645dd6c7a8f4b2d767944f9b/werkzeug-3.1.4.tar.gz", hash = "sha256:cd3cd98b1b92dc3b7b3995038826c68097dcb16f9baa63abe35f20eafeb9fe5e", size = 864687, upload-time = "2025-11-29T02:15:22.841Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/70/1469ef1d3542ae7c2c7b72bd5e3a4e6ee69d7978fa8a3af05a38eca5becf/werkzeug-3.1.5.tar.gz", hash = "sha256:6a548b0e88955dd07ccb25539d7d0cc97417ee9e179677d22c7041c8f078ce67", size = 864754, upload-time = "2026-01-08T17:49:23.247Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/f9/9e082990c2585c744734f85bec79b5dae5df9c974ffee58fe421652c8e91/werkzeug-3.1.4-py3-none-any.whl", hash = "sha256:2ad50fb9ed09cc3af22c54698351027ace879a0b60a3b5edf5730b2f7d876905", size = 224960, upload-time = "2025-11-29T02:15:21.13Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e4/8d97cca767bcc1be76d16fb76951608305561c6e056811587f36cb1316a8/werkzeug-3.1.5-py3-none-any.whl", hash = "sha256:5111e36e91086ece91f93268bb39b4a35c1e6f1feac762c9c822ded0a4e322dc", size = 225025, upload-time = "2026-01-08T17:49:21.859Z" }, ] [[package]]