diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index dfa64e9bdb..d9a197df9e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -44,7 +44,32 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: confirm_changes_call = content break - sanitized.append(msg) + # Filter out confirm_changes from assistant messages before sending to LLM. + # confirm_changes is a synthetic tool for the approval UI flow - the LLM shouldn't + # see it because it may contain stale function_arguments that confuse the model + # (e.g., showing 5 steps when only 2 were approved). + # When we filter out confirm_changes, we also remove it from tool_ids and don't + # set pending_confirm_changes_id, so no synthetic result is injected for it. + # This is required because OpenAI validates that every tool result has a matching + # tool call in the previous assistant message. + if confirm_changes_call: + filtered_contents = [ + c for c in (msg.contents or []) if not (c.type == "function_call" and c.name == "confirm_changes") + ] + if filtered_contents: + # Create a new message without confirm_changes to avoid mutating the input + filtered_msg = ChatMessage(role=msg.role, contents=filtered_contents) + sanitized.append(filtered_msg) + # If no contents left after filtering, don't append anything + + # Remove confirm_changes from tool_ids since we filtered it from the message + if confirm_changes_call.call_id: + tool_ids.discard(str(confirm_changes_call.call_id)) + # Don't set pending_confirm_changes_id - we don't want a synthetic result + confirm_changes_call = None + else: + sanitized.append(msg) + pending_tool_call_ids = tool_ids if tool_ids else None pending_confirm_changes_id = ( str(confirm_changes_call.call_id) if confirm_changes_call and confirm_changes_call.call_id else None @@ -66,7 +91,7 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: if approval_call_ids and pending_tool_call_ids: pending_tool_call_ids -= approval_call_ids logger.info( - f"FunctionApprovalResponseContent found for call_ids={sorted(approval_call_ids)} - " + f"function_approval_response content found for call_ids={sorted(approval_call_ids)} - " "framework will handle execution" ) @@ -93,6 +118,8 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: user_text = content.text # type: ignore[assignment] break + if not user_text: + continue try: parsed = json.loads(user_text) # type: ignore[arg-type] if "accepted" in parsed: @@ -149,6 +176,10 @@ def _sanitize_tool_history(messages: list[ChatMessage]) -> list[ChatMessage]: call_id = str(content.call_id) if call_id in pending_tool_call_ids: keep = True + # Remove the call_id from pending since we now have its result. + # This prevents duplicate synthetic "skipped" results from being + # injected when a user message arrives later. + pending_tool_call_ids.discard(call_id) if call_id == pending_confirm_changes_id: pending_confirm_changes_id = None break @@ -337,7 +368,7 @@ def _filter_modified_args( result: list[ChatMessage] = [] for msg in messages: # Handle standard tool result messages early (role="tool") to preserve provider invariants - # This path maps AG‑UI tool messages to FunctionResultContent with the correct tool_call_id + # This path maps AG‑UI tool messages to function_result content with the correct tool_call_id role_str = normalize_agui_role(msg.get("role", "user")) if role_str == "tool": # Prefer explicit tool_call_id fields; fall back to backend fields only if necessary @@ -370,7 +401,7 @@ def _filter_modified_args( if is_approval: # Look for the matching function call in previous messages to create - # a proper FunctionApprovalResponseContent. This enables the agent framework + # proper function_approval_response content. This enables the agent framework # to execute the approved tool (fix for GitHub issue #3034). accepted = parsed.get("accepted", False) if parsed is not None else False approval_payload_text = result_content if isinstance(result_content, str) else json.dumps(parsed) @@ -447,11 +478,17 @@ def _filter_modified_args( merged_args["steps"] = merged_steps state_args = merged_args - # Keep the original tool call and AG-UI snapshot in sync with approved args. - updated_args = ( - json.dumps(merged_args) if isinstance(matching_func_call.arguments, str) else merged_args + # Update the ChatMessage tool call with only enabled steps (for LLM context). + # The LLM should only see the steps that were actually approved/executed. + updated_args_for_llm = ( + json.dumps(filtered_args) + if isinstance(matching_func_call.arguments, str) + else filtered_args ) - matching_func_call.arguments = updated_args + matching_func_call.arguments = updated_args_for_llm + + # Update raw messages with all steps + status (for MESSAGES_SNAPSHOT display). + # This allows the UI to show which steps were enabled/disabled. _update_tool_call_arguments(messages, str(approval_call_id), merged_args) # Create a new FunctionCallContent with the modified arguments func_call_for_approval = Content.from_function_call( @@ -464,7 +501,7 @@ def _filter_modified_args( # No modified arguments - use the original function call func_call_for_approval = matching_func_call - # Create FunctionApprovalResponseContent for the agent framework + # Create function_approval_response content for the agent framework approval_response = Content.from_function_approval_response( approved=accepted, id=str(approval_call_id), @@ -488,7 +525,7 @@ def _filter_modified_args( result.append(chat_msg) continue - # Cast result_content to acceptable type for FunctionResultContent + # Cast result_content to acceptable type for function_result content func_result: str | dict[str, Any] | list[Any] if isinstance(result_content, str): func_result = result_content @@ -565,7 +602,7 @@ def _filter_modified_args( # Check if this message contains function approvals if "function_approvals" in msg and msg["function_approvals"]: - # Convert function approvals to FunctionApprovalResponseContent + # Convert function approvals to function_approval_response content approval_contents: list[Any] = [] for approval in msg["function_approvals"]: # Create FunctionCallContent with the modified arguments diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index 7cd9e0c686..c6faf8fb9e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -45,6 +45,7 @@ convert_agui_tools_to_agent_framework, generate_event_id, get_conversation_id_from_update, + get_role_value, make_json_safe, ) @@ -344,7 +345,7 @@ def _emit_tool_result( flow: FlowState, predictive_handler: PredictiveStateHandler | None = None, ) -> list[BaseEvent]: - """Emit ToolCallResult events for FunctionResultContent.""" + """Emit ToolCallResult events for function_result content.""" events: list[BaseEvent] = [] # Cannot emit tool result without a call_id to associate it with @@ -385,6 +386,13 @@ def _emit_tool_result( # After tool result, any subsequent text should start a new message flow.tool_call_id = None flow.tool_call_name = None + + # Close any open text message before resetting message_id (issue #3568) + # This handles the case where a TextMessageStartEvent was emitted for tool-only + # messages (Feature #4) but needs to be closed before starting a new message + if flow.message_id: + logger.debug("Closing text message (issue #3568 fix): message_id=%s", flow.message_id) + events.append(TextMessageEndEvent(message_id=flow.message_id)) flow.message_id = None # Reset so next text content starts a new message return events @@ -454,9 +462,21 @@ def _emit_approval_request( "function_arguments": make_json_safe(func_call.parse_arguments()) or {}, "steps": [{"description": f"Execute {func_name}", "status": "enabled"}], } - events.append(ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(args))) + args_json = json.dumps(args) + events.append(ToolCallArgsEvent(tool_call_id=confirm_id, delta=args_json)) events.append(ToolCallEndEvent(tool_call_id=confirm_id)) + # Track confirm_changes in pending_tool_calls for MessagesSnapshotEvent + # The frontend needs to see this in the snapshot to render the confirmation dialog + confirm_entry = { + "id": confirm_id, + "type": "function", + "function": {"name": "confirm_changes", "arguments": args_json}, + } + flow.pending_tool_calls.append(confirm_entry) + flow.tool_calls_by_id[confirm_id] = confirm_entry + flow.tool_calls_ended.add(confirm_id) # Mark as ended since we emit End event + flow.waiting_for_approval = True return events @@ -496,7 +516,7 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool: # Parse the content to check if it has the confirm_changes structure for content in last.contents: - if getattr(content, "type", None) == "text": + if getattr(content, "type", None) == "text" and content.text: try: result = json.loads(content.text) # confirm_changes results have 'accepted' and 'steps' keys @@ -516,31 +536,34 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]: # Parse the approval content approval_text = "" for content in last.contents: - if getattr(content, "type", None) == "text": + if getattr(content, "type", None) == "text" and content.text: approval_text = content.text break - try: - result = json.loads(approval_text) - accepted = result.get("accepted", False) - steps = result.get("steps", []) - - if accepted: - # Generate acceptance message with step descriptions - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - if enabled_steps: - message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"] - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step.get('description', 'Step')}\n") - message_parts.append("\nAll steps completed successfully!") - message = "".join(message_parts) - else: - message = "Changes confirmed and applied successfully!" - else: - # Rejection message - message = "No problem! What would you like me to change about the plan?" - except json.JSONDecodeError: + if not approval_text: message = "Acknowledged." + else: + try: + result = json.loads(approval_text) + accepted = result.get("accepted", False) + steps = result.get("steps", []) + + if accepted: + # Generate acceptance message with step descriptions + enabled_steps = [s for s in steps if s.get("status") == "enabled"] + if enabled_steps: + message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"] + for i, step in enumerate(enabled_steps, 1): + message_parts.append(f"{i}. {step.get('description', 'Step')}\n") + message_parts.append("\nAll steps completed successfully!") + message = "".join(message_parts) + else: + message = "Changes confirmed and applied successfully!" + else: + # Rejection message + message = "No problem! What would you like me to change about the plan?" + except json.JSONDecodeError: + message = "Acknowledged." message_id = generate_event_id() events.append(TextMessageStartEvent(message_id=message_id, role="assistant")) @@ -558,8 +581,8 @@ async def _resolve_approval_responses( ) -> None: """Execute approved function calls and replace approval content with results. - This modifies the messages list in place, replacing FunctionApprovalResponseContent - with FunctionResultContent containing the actual tool execution result. + This modifies the messages list in place, replacing function_approval_response + content with function_result content containing the actual tool execution result. Args: messages: List of messages (will be modified in place) @@ -622,6 +645,53 @@ async def _resolve_approval_responses( _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore + # Post-process: Convert user messages with function_result content to proper tool messages. + # After _replace_approval_contents_with_results, approved tool calls have their results + # placed in user messages. OpenAI requires tool results to be in role="tool" messages. + # This transformation ensures the message history is valid for the LLM provider. + _convert_approval_results_to_tool_messages(messages) + + +def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None: + """Convert function_result content in user messages to proper tool messages. + + After approval processing, tool results end up in user messages. OpenAI and other + providers require tool results to be in role="tool" messages. This function + extracts function_result content from user messages and creates proper tool messages. + + This modifies the messages list in place. + + Args: + messages: List of ChatMessage objects to process + """ + result: list[Any] = [] + + for msg in messages: + if get_role_value(msg) != "user": + result.append(msg) + continue + + function_results = [c for c in (msg.contents or []) if getattr(c, "type", None) == "function_result"] + other_contents = [c for c in (msg.contents or []) if getattr(c, "type", None) != "function_result"] + + if not function_results: + result.append(msg) + continue + + logger.info( + f"Converting {len(function_results)} function_result content(s) from user message to tool message(s)" + ) + + # Tool messages first (right after the preceding assistant message per OpenAI requirements) + for func_result in function_results: + result.append(ChatMessage(role="tool", contents=[func_result])) + + # Then user message with remaining content (if any) + if other_contents: + result.append(ChatMessage(role=msg.role, contents=other_contents)) + + messages[:] = result + def _build_messages_snapshot( flow: FlowState, @@ -630,25 +700,29 @@ def _build_messages_snapshot( """Build MessagesSnapshotEvent from current flow state.""" all_messages = list(snapshot_messages) - # Add assistant message with tool calls + # Add assistant message with tool calls only (no content) if flow.pending_tool_calls: tool_call_message = { "id": flow.message_id or generate_event_id(), "role": "assistant", "tool_calls": flow.pending_tool_calls.copy(), } - if flow.accumulated_text: - tool_call_message["content"] = flow.accumulated_text all_messages.append(tool_call_message) # Add tool results all_messages.extend(flow.tool_results) - # Add text-only assistant message if no tool calls - if flow.accumulated_text and not flow.pending_tool_calls: + # Add text-only assistant message if there is accumulated text + # This is a separate message from the tool calls message to maintain + # the expected AG-UI protocol format (see issue #3619) + if flow.accumulated_text: + # Use a new ID for the content message if we had tool calls (separate message) + content_message_id = ( + generate_event_id() if flow.pending_tool_calls else (flow.message_id or generate_event_id()) + ) all_messages.append( { - "id": flow.message_id or generate_event_id(), + "id": content_message_id, "role": "assistant", "content": flow.accumulated_text, } @@ -827,6 +901,8 @@ async def run_agent_stream( # Emit events for each content item for content in update.contents: + content_type = getattr(content, "type", None) + logger.debug(f"Processing content type={content_type}, message_id={flow.message_id}") for event in _emit_content( content, flow, @@ -922,6 +998,20 @@ async def run_agent_stream( tool_call_id, ) + # Parse function arguments - skip confirm_changes if we can't parse + # (we can't ask user to confirm something we can't properly display) + try: + function_arguments = json.loads(tool_call.get("function", {}).get("arguments", "{}")) + except json.JSONDecodeError: + logger.warning( + "Failed to decode JSON arguments for confirm_changes tool '%s' " + "(tool_call_id=%s). Skipping confirmation flow - cannot display " + "malformed arguments to user for approval.", + tool_name, + tool_call_id, + ) + continue # Skip to next tool call without emitting confirm_changes + # Emit confirm_changes tool call confirm_id = generate_event_id() yield ToolCallStartEvent( @@ -932,15 +1022,28 @@ async def run_agent_stream( confirm_args = { "function_name": tool_name, "function_call_id": tool_call_id, - "function_arguments": json.loads(tool_call.get("function", {}).get("arguments", "{}")), + "function_arguments": function_arguments, "steps": [{"description": f"Execute {tool_name}", "status": "enabled"}], } - yield ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(confirm_args)) + confirm_args_json = json.dumps(confirm_args) + yield ToolCallArgsEvent(tool_call_id=confirm_id, delta=confirm_args_json) yield ToolCallEndEvent(tool_call_id=confirm_id) + + # Track confirm_changes in pending_tool_calls for MessagesSnapshotEvent + # The frontend needs to see this in the snapshot to render the confirmation dialog + confirm_entry = { + "id": confirm_id, + "type": "function", + "function": {"name": "confirm_changes", "arguments": confirm_args_json}, + } + flow.pending_tool_calls.append(confirm_entry) + flow.tool_calls_by_id[confirm_id] = confirm_entry + flow.tool_calls_ended.add(confirm_id) # Mark as ended since we emit End event flow.waiting_for_approval = True # Close any open message if flow.message_id: + logger.debug(f"End of run: closing text message message_id={flow.message_id}") yield TextMessageEndEvent(message_id=flow.message_id) # Emit MessagesSnapshotEvent if we have tool calls or results diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index 85fe778e09..b2461d5bab 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -98,7 +98,14 @@ def test_agui_tool_result_to_agent_framework(): def test_agui_tool_approval_updates_tool_call_arguments(): - """Tool approval updates matching tool call arguments for snapshots and agent context.""" + """Tool approval updates matching tool call arguments for snapshots and agent context. + + The LLM context (ChatMessage) should contain only enabled steps, so the LLM + generates responses based on what was actually approved/executed. + + The raw messages (for MESSAGES_SNAPSHOT) should contain all steps with status, + so the UI can show which steps were enabled/disabled. + """ messages_input = [ { "role": "assistant", @@ -142,13 +149,14 @@ def test_agui_tool_approval_updates_tool_call_arguments(): assert len(messages) == 2 assistant_msg = messages[0] func_call = next(content for content in assistant_msg.contents if content.type == "function_call") + # LLM context should only have enabled steps (what was actually approved) assert func_call.arguments == { "steps": [ {"description": "Boil water", "status": "enabled"}, - {"description": "Brew coffee", "status": "disabled"}, {"description": "Serve coffee", "status": "enabled"}, ] } + # Raw messages (for MESSAGES_SNAPSHOT) should have all steps with status assert messages_input[0]["tool_calls"][0]["function"]["arguments"] == { "steps": [ {"description": "Boil water", "status": "enabled"}, diff --git a/python/packages/ag-ui/tests/test_message_hygiene.py b/python/packages/ag-ui/tests/test_message_hygiene.py index 03c8a1b9b3..42e098e4f6 100644 --- a/python/packages/ag-ui/tests/test_message_hygiene.py +++ b/python/packages/ag-ui/tests/test_message_hygiene.py @@ -5,7 +5,13 @@ from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history -def test_sanitize_tool_history_injects_confirm_changes_result() -> None: +def test_sanitize_tool_history_filters_out_confirm_changes_only_message() -> None: + """Test that assistant messages with ONLY confirm_changes are filtered out entirely. + + When an assistant message contains only a confirm_changes tool call (no other tools), + the entire message should be filtered out because confirm_changes is a synthetic + tool for the approval UI flow that shouldn't be sent to the LLM. + """ messages = [ ChatMessage( role="assistant", @@ -25,10 +31,17 @@ def test_sanitize_tool_history_injects_confirm_changes_result() -> None: sanitized = _sanitize_tool_history(messages) - tool_messages = [msg for msg in sanitized if (msg.role if hasattr(msg.role, "value") else str(msg.role)) == "tool"] - assert len(tool_messages) == 1 - assert str(tool_messages[0].contents[0].call_id) == "call_confirm_123" - assert tool_messages[0].contents[0].result == "Confirmed" + # Assistant message with only confirm_changes should be filtered out + assistant_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "assistant" + ] + assert len(assistant_messages) == 0 + + # No synthetic tool result should be injected since confirm_changes was filtered out + tool_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" + ] + assert len(tool_messages) == 0 def test_deduplicate_messages_prefers_non_empty_tool_results() -> None: @@ -46,3 +59,212 @@ def test_deduplicate_messages_prefers_non_empty_tool_results() -> None: deduped = _deduplicate_messages(messages) assert len(deduped) == 1 assert deduped[0].contents[0].result == "result data" + + +def test_convert_approval_results_to_tool_messages() -> None: + """Test that function_result content in user messages gets converted to tool messages. + + This is a regression test for the MCP tool double-call bug where approved tool + results ended up in user messages instead of tool messages, causing OpenAI to + reject the request with 'tool_call_ids did not have response messages'. + """ + from agent_framework_ag_ui._run import _convert_approval_results_to_tool_messages + + # Simulate what happens after _resolve_approval_responses: + # A user message contains function_result content (the executed tool result) + messages = [ + ChatMessage( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_123", name="my_mcp_tool", arguments="{}"), + ], + ), + ChatMessage( + role="user", + contents=[ + Content.from_function_result(call_id="call_123", result="tool execution result"), + ], + ), + ] + + _convert_approval_results_to_tool_messages(messages) + + # After conversion, the function result should be in a tool message, not user message + assert len(messages) == 2 + + # First message unchanged + assert messages[0].role == "assistant" + + # Second message should now be role="tool" + assert messages[1].role == "tool" + assert messages[1].contents[0].type == "function_result" + assert messages[1].contents[0].call_id == "call_123" + + +def test_convert_approval_results_preserves_other_user_content() -> None: + """Test that user messages with mixed content are handled correctly. + + If a user message has both function_result content and other content (like text), + the function_result content should be extracted to a tool message while the + remaining content stays in the user message. + """ + from agent_framework_ag_ui._run import _convert_approval_results_to_tool_messages + + messages = [ + ChatMessage( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_123", name="my_tool", arguments="{}"), + ], + ), + ChatMessage( + role="user", + contents=[ + Content.from_text(text="User also said something"), + Content.from_function_result(call_id="call_123", result="tool result"), + ], + ), + ] + + _convert_approval_results_to_tool_messages(messages) + + # Should have 3 messages now: assistant, tool (with result), user (with text) + # OpenAI requires tool messages immediately after the assistant message with the tool call + assert len(messages) == 3 + + # First message unchanged + assert messages[0].role == "assistant" + + # Second message should be tool with result (must come right after assistant per OpenAI requirements) + assert messages[1].role == "tool" + assert messages[1].contents[0].type == "function_result" + + # Third message should be user with just text + assert messages[2].role == "user" + assert len(messages[2].contents) == 1 + assert messages[2].contents[0].type == "text" + + +def test_sanitize_tool_history_filters_confirm_changes_keeps_other_tools() -> None: + """Test that confirm_changes is filtered but other tools are preserved. + + When an assistant message contains both a real tool call and confirm_changes, + confirm_changes should be filtered out while the real tool call is kept. + No synthetic result is injected for confirm_changes since it's filtered. + """ + messages = [ + # User asks something + ChatMessage( + role="user", + contents=[Content.from_text(text="What time is it?")], + ), + # Assistant calls MCP tool + confirm_changes + ChatMessage( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_1", name="get_datetime", arguments="{}"), + Content.from_function_call(call_id="call_c1", name="confirm_changes", arguments="{}"), + ], + ), + # Tool result for the actual MCP tool + ChatMessage( + role="tool", + contents=[Content.from_function_result(call_id="call_1", result="2024-01-01 12:00:00")], + ), + # User asks something else + ChatMessage( + role="user", + contents=[Content.from_text(text="What's the date?")], + ), + ] + + sanitized = _sanitize_tool_history(messages) + + # Find the assistant message + assistant_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "assistant" + ] + assert len(assistant_messages) == 1 + + # Assistant message should only have get_datetime, not confirm_changes + function_call_names = [c.name for c in assistant_messages[0].contents if c.type == "function_call"] + assert "get_datetime" in function_call_names + assert "confirm_changes" not in function_call_names + + # Only one tool message (for call_1), no synthetic for confirm_changes + tool_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" + ] + assert len(tool_messages) == 1 + assert str(tool_messages[0].contents[0].call_id) == "call_1" + + +def test_sanitize_tool_history_filters_confirm_changes_from_assistant_messages() -> None: + """Test that confirm_changes is removed from assistant messages sent to LLM. + + This is a regression test for the human-in-the-loop bug where the LLM would see + confirm_changes with function_arguments containing the original steps (e.g., 5 steps) + even when the user only approved a subset (e.g., 2 steps), causing the LLM to + respond with "Here's your 5-step plan" instead of "Here's your 2-step plan". + """ + messages = [ + ChatMessage( + role="user", + contents=[Content.from_text(text="Build a robot")], + ), + # Assistant message with both generate_task_steps and confirm_changes + ChatMessage( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="generate_task_steps", + arguments='{"steps": [{"description": "Step 1"}, {"description": "Step 2"}]}', + ), + Content.from_function_call( + call_id="call_c1", + name="confirm_changes", + arguments='{"function_arguments": {"steps": [{"description": "Step 1"}, {"description": "Step 2"}]}}', + ), + ], + ), + # Approval response + ChatMessage( + role="user", + contents=[ + Content.from_function_approval_response( + approved=True, + id="call_1", + function_call=Content.from_function_call( + call_id="call_1", + name="generate_task_steps", + arguments='{"steps": [{"description": "Step 1"}]}', # Only 1 step approved + ), + ), + ], + ), + ] + + sanitized = _sanitize_tool_history(messages) + + # Find the assistant message in sanitized output + assistant_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "assistant" + ] + + assert len(assistant_messages) == 1 + + # The assistant message should NOT contain confirm_changes + assistant_contents = assistant_messages[0].contents or [] + function_call_names = [c.name for c in assistant_contents if c.type == "function_call"] + assert "generate_task_steps" in function_call_names + assert "confirm_changes" not in function_call_names + + # No synthetic tool result for confirm_changes (it was filtered from the message) + tool_messages = [ + msg for msg in sanitized if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" + ] + # No tool results expected since there are no completed tool calls + # (the approval response is handled separately by the framework) + tool_call_ids = {str(msg.contents[0].call_id) for msg in tool_messages} + assert "call_c1" not in tool_call_ids # No synthetic result for confirm_changes diff --git a/python/packages/ag-ui/tests/test_run.py b/python/packages/ag-ui/tests/test_run.py index 7fb7055ae0..a5bc700675 100644 --- a/python/packages/ag-ui/tests/test_run.py +++ b/python/packages/ag-ui/tests/test_run.py @@ -2,12 +2,18 @@ """Tests for _run.py helper functions and FlowState.""" +from ag_ui.core import ( + TextMessageEndEvent, + TextMessageStartEvent, +) from agent_framework import ChatMessage, Content from agent_framework_ag_ui._run import ( FlowState, _build_safe_metadata, _create_state_context_message, + _emit_content, + _emit_tool_result, _has_only_tool_calls, _inject_state_context, _should_suppress_intermediate_snapshot, @@ -351,6 +357,50 @@ def test_emit_tool_call_generates_id(): assert flow.tool_call_id is not None # ID should be generated +def test_emit_tool_result_closes_open_message(): + """Test _emit_tool_result emits TextMessageEndEvent for open text message. + + This is a regression test for where TEXT_MESSAGE_END was not + emitted when using MCP tools because the message_id was reset without + closing the message first. + """ + flow = FlowState() + # Simulate an open text message (e.g., from Feature #4 tool-only detection) + flow.message_id = "open-msg-123" + flow.tool_call_id = "call_456" + + content = Content.from_function_result(call_id="call_456", result="tool result") + + events = _emit_tool_result(content, flow, predictive_handler=None) + + # Should have: ToolCallEndEvent, ToolCallResultEvent, TextMessageEndEvent + assert len(events) == 3 + + # Verify TextMessageEndEvent is emitted for the open message + text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] + assert len(text_end_events) == 1 + assert text_end_events[0].message_id == "open-msg-123" + + # Verify message_id is reset after + assert flow.message_id is None + + +def test_emit_tool_result_no_open_message(): + """Test _emit_tool_result works when there's no open text message.""" + flow = FlowState() + # No open message + flow.message_id = None + flow.tool_call_id = "call_456" + + content = Content.from_function_result(call_id="call_456", result="tool result") + + events = _emit_tool_result(content, flow, predictive_handler=None) + + # Should have: ToolCallEndEvent, ToolCallResultEvent (no TextMessageEndEvent) + text_end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] + assert len(text_end_events) == 0 + + def test_extract_approved_state_updates_no_handler(): """Test _extract_approved_state_updates returns empty with no handler.""" from agent_framework_ag_ui._run import _extract_approved_state_updates @@ -369,3 +419,268 @@ def test_extract_approved_state_updates_no_approval(): messages = [ChatMessage("user", [Content.from_text("Hello")])] result = _extract_approved_state_updates(messages, handler) assert result == {} + + +class TestBuildMessagesSnapshot: + """Tests for _build_messages_snapshot function.""" + + def test_tool_calls_and_text_are_separate_messages(self): + """Test that tool calls and text content are emitted as separate messages. + + This is a regression test for issue #3619 where tool calls and content + were incorrectly merged into a single assistant message. + """ + from agent_framework_ag_ui._run import FlowState, _build_messages_snapshot + + flow = FlowState() + flow.message_id = "msg-123" + flow.pending_tool_calls = [ + {"id": "call_1", "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}}, + ] + flow.accumulated_text = "Here is the weather information." + flow.tool_results = [{"id": "result-1", "role": "tool", "content": '{"temp": 72}', "toolCallId": "call_1"}] + + result = _build_messages_snapshot(flow, []) + + # Should have 3 messages: tool call msg, tool result, text content msg + assert len(result.messages) == 3 + + # First message: assistant with tool calls only (no content) + assistant_tool_msg = result.messages[0] + assert assistant_tool_msg.role == "assistant" + assert assistant_tool_msg.tool_calls is not None + assert len(assistant_tool_msg.tool_calls) == 1 + assert assistant_tool_msg.content is None + + # Second message: tool result + tool_result_msg = result.messages[1] + assert tool_result_msg.role == "tool" + + # Third message: assistant with content only (no tool calls) + assistant_text_msg = result.messages[2] + assert assistant_text_msg.role == "assistant" + assert assistant_text_msg.content == "Here is the weather information." + assert assistant_text_msg.tool_calls is None + + # The text message should have a different ID than the tool call message + assert assistant_text_msg.id != assistant_tool_msg.id + + def test_only_tool_calls_no_text(self): + """Test snapshot with only tool calls and no accumulated text.""" + from agent_framework_ag_ui._run import FlowState, _build_messages_snapshot + + flow = FlowState() + flow.message_id = "msg-123" + flow.pending_tool_calls = [ + {"id": "call_1", "function": {"name": "get_weather", "arguments": "{}"}}, + ] + flow.accumulated_text = "" + flow.tool_results = [] + + result = _build_messages_snapshot(flow, []) + + # Should have 1 message: tool call msg only + assert len(result.messages) == 1 + assert result.messages[0].role == "assistant" + assert result.messages[0].tool_calls is not None + assert result.messages[0].content is None + + def test_only_text_no_tool_calls(self): + """Test snapshot with only text and no tool calls.""" + from agent_framework_ag_ui._run import FlowState, _build_messages_snapshot + + flow = FlowState() + flow.message_id = "msg-123" + flow.pending_tool_calls = [] + flow.accumulated_text = "Hello world" + flow.tool_results = [] + + result = _build_messages_snapshot(flow, []) + + # Should have 1 message: text content msg only + assert len(result.messages) == 1 + assert result.messages[0].role == "assistant" + assert result.messages[0].content == "Hello world" + assert result.messages[0].tool_calls is None + # Should use the existing message_id + assert result.messages[0].id == "msg-123" + + def test_preserves_snapshot_messages(self): + """Test that existing snapshot messages are preserved.""" + from agent_framework_ag_ui._run import FlowState, _build_messages_snapshot + + flow = FlowState() + flow.pending_tool_calls = [] + flow.accumulated_text = "" + + existing_messages = [ + {"id": "user-1", "role": "user", "content": "Hello"}, + {"id": "assist-1", "role": "assistant", "content": "Hi there"}, + ] + + result = _build_messages_snapshot(flow, existing_messages) + + assert len(result.messages) == 2 + assert result.messages[0].id == "user-1" + assert result.messages[1].id == "assist-1" + + +def test_malformed_json_in_confirm_args_skips_confirmation(): + """Test that malformed JSON in tool arguments skips confirm_changes flow. + + This is a regression test to ensure that when tool arguments contain malformed + JSON, the code skips the confirmation flow entirely rather than crashing or + showing incomplete data to the user. + """ + import json + + # Simulate the parsing logic - malformed JSON should trigger skip + malformed_arguments = "{ invalid json }" + tool_call = {"function": {"name": "write_doc", "arguments": malformed_arguments}} + + # This is what the code should do - detect parsing failure and skip + should_skip_confirmation = False + try: + json.loads(tool_call.get("function", {}).get("arguments", "{}")) + except json.JSONDecodeError: + should_skip_confirmation = True + + # Should skip confirmation when JSON is malformed + assert should_skip_confirmation is True + + # Valid JSON should proceed with confirmation + valid_arguments = '{"content": "hello"}' + tool_call_valid = {"function": {"name": "write_doc", "arguments": valid_arguments}} + should_skip_confirmation = False + try: + function_arguments = json.loads(tool_call_valid.get("function", {}).get("arguments", "{}")) + except json.JSONDecodeError: + should_skip_confirmation = True + + assert should_skip_confirmation is False + assert function_arguments == {"content": "hello"} + + +class TestTextMessageEventBalancing: + """Tests for proper TEXT_MESSAGE_START/END event balancing. + + These tests verify that the streaming flow produces balanced pairs of + TextMessageStartEvent and TextMessageEndEvent, especially when tool + execution is involved. + """ + + def test_tool_only_flow_produces_balanced_events(self): + """Test that a tool-only response produces balanced TEXT_MESSAGE events. + + This simulates the scenario where the LLM immediately calls a tool + without any initial text, then returns text after the tool result. + """ + flow = FlowState() + all_events: list = [] + + # Step 1: LLM outputs function_call only (no text) + func_call_content = Content.from_function_call( + call_id="call_weather", + name="get_weather", + arguments='{"city": "Seattle"}', + ) + + # Feature #4 check: this should trigger TextMessageStartEvent + contents = [func_call_content] + if not flow.message_id and _has_only_tool_calls(contents): + flow.message_id = "tool-msg-1" + all_events.append(TextMessageStartEvent(message_id=flow.message_id, role="assistant")) + + # Emit tool call events + all_events.extend(_emit_content(func_call_content, flow)) + + # Step 2: Tool executes and returns result + func_result_content = Content.from_function_result( + call_id="call_weather", + result='{"temp": 55, "conditions": "rainy"}', + ) + + # This should close the text message + all_events.extend(_emit_tool_result(func_result_content, flow)) + + # Verify message_id was reset + assert flow.message_id is None, "message_id should be reset after tool result" + + # Step 3: LLM outputs text response + text_content = Content.from_text("The weather in Seattle is 55°F and rainy.") + + # Since message_id is None, _emit_text should create a new one + for event in _emit_content(text_content, flow): + all_events.append(event) + + # Step 4: End of stream - emit final TextMessageEndEvent + if flow.message_id: + all_events.append(TextMessageEndEvent(message_id=flow.message_id)) + + # Verify event counts + start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)] + end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)] + + # Should have 2 TextMessageStartEvent and 2 TextMessageEndEvent + assert len(start_events) == 2, f"Expected 2 start events, got {len(start_events)}" + assert len(end_events) == 2, f"Expected 2 end events, got {len(end_events)}" + + # Verify order: first message should start and end before second starts + # Find indices + start_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageStartEvent)] + end_indices = [i for i, e in enumerate(all_events) if isinstance(e, TextMessageEndEvent)] + + # First end should come before second start + assert end_indices[0] < start_indices[1], ( + f"First TextMessageEndEvent (index {end_indices[0]}) should come " + f"before second TextMessageStartEvent (index {start_indices[1]})" + ) + + def test_text_then_tool_flow(self): + """Test flow where LLM outputs text first, then calls a tool. + + This simulates: "Let me check the weather..." -> tool call -> tool result -> "The weather is..." + """ + flow = FlowState() + all_events: list = [] + + # Step 1: LLM outputs text first + text1 = Content.from_text("Let me check the weather for you.") + all_events.extend(_emit_content(text1, flow)) + + # Verify message_id is set + assert flow.message_id is not None, "message_id should be set after text" + first_msg_id = flow.message_id + + # Step 2: LLM outputs function_call + func_call = Content.from_function_call( + call_id="call_1", + name="get_weather", + arguments="{}", + ) + all_events.extend(_emit_content(func_call, flow)) + + # Step 3: Tool result comes back + func_result = Content.from_function_result(call_id="call_1", result="sunny") + all_events.extend(_emit_tool_result(func_result, flow)) + + # Verify message_id was reset and first message was closed + assert flow.message_id is None + end_events_so_far = [e for e in all_events if isinstance(e, TextMessageEndEvent)] + assert len(end_events_so_far) == 1 + assert end_events_so_far[0].message_id == first_msg_id + + # Step 4: LLM outputs follow-up text + text2 = Content.from_text("The weather is sunny!") + all_events.extend(_emit_content(text2, flow)) + + # Step 5: End of stream + if flow.message_id: + all_events.append(TextMessageEndEvent(message_id=flow.message_id)) + + # Verify balance + start_events = [e for e in all_events if isinstance(e, TextMessageStartEvent)] + end_events = [e for e in all_events if isinstance(e, TextMessageEndEvent)] + + assert len(start_events) == 2 + assert len(end_events) == 2