diff --git a/python/CHANGELOG.md b/python/CHANGELOG.md index 0edfd01c42..5d41d5ddf6 100644 --- a/python/CHANGELOG.md +++ b/python/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.0.0b251117] - 2025-11-17 + +### Fixed + +- **agent-framework-ag-ui**: Fix ag-ui state handling issues ([#2289](https://github.com/microsoft/agent-framework/pull/2289)) + ## [1.0.0b251114] - 2025-11-14 ### Added diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py index e10f3e92c8..8aec59d52c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_events.py @@ -86,6 +86,7 @@ def __init__( self.pending_tool_calls: list[dict[str, Any]] = [] # Track tool calls for assistant message self.tool_results: list[dict[str, Any]] = [] # Track tool results self.tool_calls_ended: set[str] = set() # Track which tool calls have had ToolCallEndEvent emitted + self.accumulated_text_content: str = "" # Track accumulated text for final MessagesSnapshotEvent async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[BaseEvent]: """ @@ -99,18 +100,29 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba """ events: list[BaseEvent] = [] - for content in update.contents: + logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items") + for idx, content in enumerate(update.contents): + logger.info(f" Content {idx}: type={type(content).__name__}") if isinstance(content, TextContent): + logger.info( + f" TextContent found: text_length={len(content.text)}, text_preview='{content.text[:100]}'" + ) + logger.info( + f" Flags: skip_text_content={self.skip_text_content}, should_stop_after_confirm={self.should_stop_after_confirm}" + ) + # Skip text content if using structured outputs (it's just the JSON) if self.skip_text_content: + logger.info(" SKIPPING TextContent: skip_text_content is True") continue # Skip text content if we're about to emit confirm_changes # The summary should only appear after user confirms if self.should_stop_after_confirm: - logger.debug("Skipping text content - waiting for confirm_changes response") + logger.info(" SKIPPING TextContent: waiting for confirm_changes response") # Save the summary text to show after confirmation self.suppressed_summary += content.text + logger.info(f" Suppressed summary now has {len(self.suppressed_summary)} chars") continue if not self.current_message_id: @@ -119,14 +131,16 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba message_id=self.current_message_id, role="assistant", ) - logger.debug(f"Emitting TextMessageStartEvent with message_id={self.current_message_id}") + logger.info(f" EMITTING TextMessageStartEvent with message_id={self.current_message_id}") events.append(start_event) event = TextMessageContentEvent( message_id=self.current_message_id, delta=content.text, ) - logger.debug(f"Emitting TextMessageContentEvent with delta: {content.text}") + # Accumulate text content for final MessagesSnapshotEvent + self.accumulated_text_content += content.text + logger.info(f" EMITTING TextMessageContentEvent with delta: '{content.text}'") events.append(event) elif isinstance(content, FunctionCallContent): @@ -427,7 +441,24 @@ async def from_agent_run_update(self, update: AgentRunResponseUpdate) -> list[Ba # Emit MessagesSnapshotEvent with the complete conversation including tool calls and results # This is required for CopilotKit's useCopilotAction to detect tool result - if self.pending_tool_calls and self.tool_results: + # HOWEVER: Skip this for predictive tools when require_confirmation=False, because + # the agent will generate a follow-up text message and we'll emit a complete snapshot at the end. + # Emitting here would create an incomplete snapshot that gets replaced, causing UI flicker. + should_emit_snapshot = self.pending_tool_calls and self.tool_results + + # Check if this is a predictive tool that will have a follow-up message + is_predictive_without_confirmation = False + if should_emit_snapshot and self.current_tool_call_name and self.predict_state_config: + for state_key, config in self.predict_state_config.items(): + if config["tool"] == self.current_tool_call_name and not self.require_confirmation: + is_predictive_without_confirmation = True + logger.info( + f"Skipping intermediate MessagesSnapshotEvent for predictive tool '{self.current_tool_call_name}' " + "- will emit complete snapshot after follow-up message" + ) + break + + if should_emit_snapshot and not is_predictive_without_confirmation: # Import message adapter from ._message_adapters import agent_framework_messages_to_agui diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index a6c1001386..11d2977f90 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -283,8 +283,62 @@ def extract_text_from_contents(contents: list[Any]) -> str: return "".join(text_parts) +def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Normalize AG-UI messages for MessagesSnapshotEvent. + + Converts AG-UI input format (with 'input_text' type) to snapshot format (with 'text' type). + + Args: + messages: List of AG-UI messages in input format + + Returns: + List of normalized messages suitable for MessagesSnapshotEvent + """ + from ._utils import generate_event_id + + result: list[dict[str, Any]] = [] + for msg in messages: + normalized_msg = msg.copy() + + # Ensure ID exists + if "id" not in normalized_msg: + normalized_msg["id"] = generate_event_id() + + # Normalize content field + content = normalized_msg.get("content") + if isinstance(content, list): + # Convert content array format to simple string + text_parts = [] + for item in content: + if isinstance(item, dict): + # Convert 'input_text' to 'text' type + if item.get("type") == "input_text": + text_parts.append(item.get("text", "")) + elif item.get("type") == "text": + text_parts.append(item.get("text", "")) + else: + # Other types - just extract text field if present + text_parts.append(item.get("text", "")) + normalized_msg["content"] = "".join(text_parts) + elif content is None: + normalized_msg["content"] = "" + + # Normalize tool_call_id to toolCallId for tool messages + if normalized_msg.get("role") == "tool": + if "tool_call_id" in normalized_msg: + normalized_msg["toolCallId"] = normalized_msg["tool_call_id"] + del normalized_msg["tool_call_id"] + elif "toolCallId" not in normalized_msg: + normalized_msg["toolCallId"] = "" + + result.append(normalized_msg) + + return result + + __all__ = [ "agui_messages_to_agent_framework", "agent_framework_messages_to_agui", + "agui_messages_to_snapshot_format", "extract_text_from_contents", ] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 1f6a43d8c1..6da46d819f 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -11,6 +11,7 @@ from ag_ui.core import ( BaseEvent, + MessagesSnapshotEvent, RunErrorEvent, TextMessageContentEvent, TextMessageEndEvent, @@ -588,32 +589,37 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: # We should NOT add to thread.on_new_messages() as that would cause duplication. # Instead, we pass messages directly to the agent via messages_to_run. - # Inject current state as system message context if we have state + # Inject current state as system message context if we have state and this is a new user turn messages_to_run: list[Any] = [] + # Check if the last message is from the user (new turn) vs assistant/tool (mid-execution) + is_new_user_turn = False + if provider_messages: + last_msg = provider_messages[-1] + is_new_user_turn = last_msg.role.value == "user" + + # Check if conversation has tool calls (indicates mid-execution) conversation_has_tool_calls = False - logger.debug(f"Checking {len(provider_messages)} provider messages for tool calls") - for i, msg in enumerate(provider_messages): - logger.debug( - f" Message {i}: role={msg.role.value}, contents={len(msg.contents) if hasattr(msg, 'contents') and msg.contents else 0}" - ) for msg in provider_messages: if msg.role.value == "assistant" and hasattr(msg, "contents") and msg.contents: if any(isinstance(content, FunctionCallContent) for content in msg.contents): conversation_has_tool_calls = True break - if current_state and context.config.state_schema and not conversation_has_tool_calls: + + # Only inject state context on new user turns AND when conversation doesn't have tool calls + # (tool calls indicate we're mid-execution, so state context was already injected) + if current_state and context.config.state_schema and is_new_user_turn and not conversation_has_tool_calls: state_json = json.dumps(current_state, indent=2) state_context_msg = ChatMessage( role="system", contents=[ TextContent( text=f"""Current state of the application: -{state_json} + {state_json} -When modifying state, you MUST include ALL existing data plus your changes. -For example, if adding a new ingredient, include all existing ingredients PLUS the new one. -Never replace existing data - always append or merge.""" + When modifying state, you MUST include ALL existing data plus your changes. + For example, if adding one new item to a list, include ALL existing items PLUS the one new item. + Never replace existing data - always preserve and append or merge.""" ) ], ) @@ -714,12 +720,19 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: # Collect all updates to get the final structured output all_updates: list[Any] = [] + update_count = 0 async for update in context.agent.run_stream(messages_to_run, thread=thread, tools=tools_param): + update_count += 1 + logger.info(f"[STREAM] Received update #{update_count} from agent") all_updates.append(update) events = await event_bridge.from_agent_run_update(update) + logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") for event in events: + logger.info(f"[STREAM] Yielding event: {type(event).__name__}") yield event + logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") + # After agent completes, check if we should stop (waiting for user to confirm changes) if event_bridge.should_stop_after_confirm: logger.info("Stopping run after confirm_changes - waiting for user response") @@ -793,9 +806,56 @@ def deduplicate_messages(messages: list[ChatMessage]) -> list[ChatMessage]: yield TextMessageEndEvent(message_id=message_id) logger.info(f"Emitted conversational message: {response_dict['message'][:100]}...") + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") if event_bridge.current_message_id: + logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") yield event_bridge.create_message_end_event(event_bridge.current_message_id) + # Emit MessagesSnapshotEvent to persist the final assistant text message + from ._message_adapters import agui_messages_to_snapshot_format + + # Build the final assistant message with accumulated text content + assistant_text_message = { + "id": event_bridge.current_message_id, + "role": "assistant", + "content": event_bridge.accumulated_text_content, + } + + # Convert input messages to snapshot format (normalize content structure) + # event_bridge.input_messages are already in AG-UI format, just need normalization + converted_input_messages = agui_messages_to_snapshot_format(event_bridge.input_messages) + + # Build complete messages array + # Include: input messages + any pending tool calls/results + final text message + all_messages = converted_input_messages.copy() + + # Add assistant message with tool calls if any + if event_bridge.pending_tool_calls: + tool_call_message = { + "id": generate_event_id(), + "role": "assistant", + "tool_calls": event_bridge.pending_tool_calls.copy(), + } + all_messages.append(tool_call_message) + + # Add tool results if any + all_messages.extend(event_bridge.tool_results.copy()) + + # Add final text message + all_messages.append(assistant_text_message) + + messages_snapshot = MessagesSnapshotEvent( + messages=all_messages, # type: ignore[arg-type] + ) + logger.info( + f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(all_messages)} messages " + f"(text content length: {len(event_bridge.accumulated_text_content)})" + ) + yield messages_snapshot + else: + logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") + + logger.info("[FINALIZE] Emitting RUN_FINISHED event") yield event_bridge.create_run_finished_event() logger.info(f"Completed agent run for thread_id={context.thread_id}, run_id={context.run_id}") diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py index dfdd058bc7..e39fb8c75e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py @@ -130,4 +130,5 @@ def recipe_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgent: "recipe": {"tool": "update_recipe", "tool_argument": "recipe"}, }, confirmation_strategy=RecipeConfirmationStrategy(), + require_confirmation=False, ) diff --git a/python/packages/ag-ui/pyproject.toml b/python/packages/ag-ui/pyproject.toml index 9b31801054..26208aef04 100644 --- a/python/packages/ag-ui/pyproject.toml +++ b/python/packages/ag-ui/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "agent-framework-ag-ui" -version = "1.0.0b251114" +version = "1.0.0b251117" description = "AG-UI protocol integration for Agent Framework" readme = "README.md" license-files = ["LICENSE"] diff --git a/python/packages/ag-ui/tests/test_client.py b/python/packages/ag-ui/tests/test_ag_ui_client.py similarity index 100% rename from python/packages/ag-ui/tests/test_client.py rename to python/packages/ag-ui/tests/test_ag_ui_client.py