diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 4b00330283..e9ce610b10 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -684,6 +684,10 @@ def _build_messages_snapshot( } ) + # Add reasoning messages so frontends that reconcile state from + # MESSAGES_SNAPSHOT retain reasoning content after streaming ends. + all_messages.extend(flow.reasoning_messages) + return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type] @@ -1061,7 +1065,9 @@ async def run_agent_stream( # Emit MessagesSnapshotEvent if we have tool calls or results # Feature #5: Suppress intermediate snapshots for predictive tools without confirmation - should_emit_snapshot = flow.pending_tool_calls or flow.tool_results or flow.accumulated_text + should_emit_snapshot = ( + flow.pending_tool_calls or flow.tool_results or flow.accumulated_text or flow.reasoning_messages + ) if should_emit_snapshot: # Check if we should suppress for predictive tool last_tool_name = None diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 2e5294a6b6..5e4fced97c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -604,6 +604,10 @@ def _filter_modified_args( # Handle standard tool result messages early (role="tool") to preserve provider invariants # This path maps AG‑UI tool messages to function_result content with the correct tool_call_id role_str = normalize_agui_role(msg.get("role", "user")) + if role_str == "reasoning": + # Reasoning messages are UI-only state carried in MESSAGES_SNAPSHOT. + # They should not be forwarded to the LLM provider. + continue if role_str == "tool": # Prefer explicit tool_call_id fields; fall back to backend fields only if necessary tool_call_id = msg.get("tool_call_id") or msg.get("toolCallId") @@ -1020,6 +1024,11 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic elif "toolCallId" not in normalized_msg: normalized_msg["toolCallId"] = "" + # Normalize encrypted_value to encryptedValue for reasoning messages + if normalized_msg.get("role") == "reasoning" and "encrypted_value" in normalized_msg: + normalized_msg["encryptedValue"] = normalized_msg["encrypted_value"] + del normalized_msg["encrypted_value"] + result.append(normalized_msg) return result diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py index 0a9f4cea9c..155f559a94 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run_common.py @@ -126,6 +126,8 @@ class FlowState: tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType] tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType] interrupts: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType] + reasoning_messages: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType] + accumulated_reasoning: dict[str, str] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType] def get_tool_name(self, call_id: str | None) -> str | None: """Get tool name by call ID.""" @@ -460,7 +462,7 @@ def _emit_mcp_tool_result( return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler) -def _emit_text_reasoning(content: Content) -> list[BaseEvent]: +def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> list[BaseEvent]: """Emit AG-UI reasoning events for text_reasoning content. Uses the protocol-defined reasoning event types so that AG-UI consumers @@ -470,6 +472,10 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]: ``content.protected_data`` is present it is emitted as a ``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted reasoning for state continuity without conflating it with display text. + + When *flow* is provided the reasoning message is persisted into + ``flow.reasoning_messages`` so that ``_build_messages_snapshot`` can + include it in the final ``MESSAGES_SNAPSHOT``. """ text = content.text or "" if not text and content.protected_data is None: @@ -498,6 +504,36 @@ def _emit_text_reasoning(content: Content) -> list[BaseEvent]: events.append(ReasoningEndEvent(message_id=message_id)) + # Persist reasoning into flow state for MESSAGES_SNAPSHOT. + # Accumulate reasoning text per message_id, similar to flow.accumulated_text, + # so that incremental deltas build the full reasoning string. + if flow is not None: + if text: + previous_text = flow.accumulated_reasoning.get(message_id, "") + flow.accumulated_reasoning[message_id] = previous_text + text + full_text = flow.accumulated_reasoning.get(message_id, text or "") + + # Update existing reasoning entry for this message_id if present; otherwise append a new one. + existing_entry: dict[str, Any] | None = None + for entry in flow.reasoning_messages: + if isinstance(entry, dict) and entry.get("id") == message_id: + existing_entry = entry + break + + if existing_entry is None: + reasoning_entry: dict[str, Any] = { + "id": message_id, + "role": "reasoning", + "content": full_text, + } + if content.protected_data is not None: + reasoning_entry["encryptedValue"] = content.protected_data + flow.reasoning_messages.append(reasoning_entry) + else: + existing_entry["content"] = full_text + if content.protected_data is not None: + existing_entry["encryptedValue"] = content.protected_data + return events @@ -527,6 +563,6 @@ def _emit_content( if content_type == "mcp_server_tool_result": return _emit_mcp_tool_result(content, flow, predictive_handler) if content_type == "text_reasoning": - return _emit_text_reasoning(content) + return _emit_text_reasoning(content, flow) logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type) return [] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index bfda3948ec..c68301f7d2 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -27,7 +27,7 @@ "system": "system", } -ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool"} +ALLOWED_AGUI_ROLES: set[str] = {"user", "assistant", "system", "tool", "reasoning"} def generate_event_id() -> str: @@ -82,7 +82,7 @@ def normalize_agui_role(raw_role: Any) -> str: raw_role: Raw role value from AG-UI message Returns: - Normalized role string (user, assistant, system, or tool) + Normalized role string (user, assistant, system, tool, or reasoning) """ if not isinstance(raw_role, str): return "user" diff --git a/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py b/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py index cc4f1230df..9508b53085 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py +++ b/python/packages/ag-ui/tests/ag_ui/test_message_adapters.py @@ -1669,3 +1669,94 @@ def test_agui_fresh_approval_is_still_processed(): assert len(approval_contents) == 1, "Fresh approval should produce function_approval_response" assert approval_contents[0].approved is True assert approval_contents[0].function_call.name == "get_datetime" + + +class TestReasoningRoundTrip: + """Tests for reasoning message handling in inbound/outbound adapters.""" + + def test_reasoning_skipped_on_inbound(self): + """Reasoning messages from prior snapshot are not forwarded to the LLM.""" + messages_input = [ + {"id": "u1", "role": "user", "content": "Hello"}, + {"id": "r1", "role": "reasoning", "content": "Thinking..."}, + {"id": "a1", "role": "assistant", "content": "Hi there"}, + ] + + result = agui_messages_to_agent_framework(messages_input) + + roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result] + assert "reasoning" not in roles + assert len(result) == 2 + + def test_reasoning_preserved_in_snapshot_format(self): + """Reasoning messages retain their role through snapshot normalization.""" + messages_input = [ + {"id": "u1", "role": "user", "content": "Hello"}, + {"id": "r1", "role": "reasoning", "content": "Thinking about this..."}, + {"id": "a1", "role": "assistant", "content": "Answer"}, + ] + + result = agui_messages_to_snapshot_format(messages_input) + + reasoning_msgs = [m for m in result if m.get("role") == "reasoning"] + assert len(reasoning_msgs) == 1 + assert reasoning_msgs[0]["content"] == "Thinking about this..." + + def test_reasoning_with_encrypted_value_in_snapshot_format(self): + """Reasoning with encryptedValue passes through snapshot normalization.""" + messages_input = [ + { + "id": "r1", + "role": "reasoning", + "content": "visible", + "encryptedValue": "secret-data", + }, + ] + + result = agui_messages_to_snapshot_format(messages_input) + + assert len(result) == 1 + assert result[0]["role"] == "reasoning" + assert result[0]["encryptedValue"] == "secret-data" + + def test_reasoning_encrypted_value_snake_case_normalized(self): + """Snake-case encrypted_value is normalized to encryptedValue in snapshot format.""" + messages_input = [ + { + "id": "r1", + "role": "reasoning", + "content": "visible", + "encrypted_value": "snake-case-data", + }, + ] + + result = agui_messages_to_snapshot_format(messages_input) + + assert len(result) == 1 + assert result[0]["encryptedValue"] == "snake-case-data" + assert "encrypted_value" not in result[0] + + def test_multi_turn_with_reasoning_in_prior_snapshot(self): + """Second turn with reasoning from prior snapshot does not corrupt messages.""" + messages_input = [ + {"id": "u1", "role": "user", "content": "First question"}, + {"id": "r1", "role": "reasoning", "content": "Prior reasoning"}, + {"id": "a1", "role": "assistant", "content": "First answer"}, + {"id": "u2", "role": "user", "content": "Follow-up question"}, + ] + + result = agui_messages_to_agent_framework(messages_input) + + roles = [m.role if hasattr(m.role, "value") else str(m.role) for m in result] + # Reasoning is filtered out, other messages preserved in order + assert roles == ["user", "assistant", "user"] + # Content not corrupted + texts = [] + for m in result: + for c in m.contents or []: + if hasattr(c, "text") and c.text: + texts.append(c.text) + assert "First question" in texts + assert "First answer" in texts + assert "Follow-up question" in texts + assert "Prior reasoning" not in texts diff --git a/python/packages/ag-ui/tests/ag_ui/test_run.py b/python/packages/ag-ui/tests/ag_ui/test_run.py index ae8c5e85b0..0e5c329ce9 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run.py @@ -1346,3 +1346,158 @@ def test_routes_text_reasoning(self): assert len(events) == 5 assert isinstance(events[0], ReasoningStartEvent) + + +class TestReasoningInSnapshot: + """Tests for reasoning message inclusion in MESSAGES_SNAPSHOT.""" + + def test_reasoning_persisted_to_flow_state(self): + """_emit_text_reasoning with flow persists reasoning into flow.reasoning_messages.""" + flow = FlowState() + content = Content.from_text_reasoning( + id="reason_persist", + text="Let me think step by step.", + ) + + _emit_text_reasoning(content, flow) + + assert len(flow.reasoning_messages) == 1 + assert flow.reasoning_messages[0]["id"] == "reason_persist" + assert flow.reasoning_messages[0]["role"] == "reasoning" + assert flow.reasoning_messages[0]["content"] == "Let me think step by step." + assert "encryptedValue" not in flow.reasoning_messages[0] + + def test_reasoning_with_encrypted_value_persisted(self): + """Reasoning with protected_data preserves encryptedValue in flow state.""" + flow = FlowState() + content = Content.from_text_reasoning( + id="reason_enc", + text="visible reasoning", + protected_data="encrypted-data-123", + ) + + _emit_text_reasoning(content, flow) + + assert len(flow.reasoning_messages) == 1 + assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-data-123" + + def test_snapshot_includes_reasoning(self): + """_build_messages_snapshot includes reasoning messages from flow state.""" + from agent_framework_ag_ui._agent_run import _build_messages_snapshot + + flow = FlowState() + flow.accumulated_text = "Here is my answer." + flow.reasoning_messages = [ + {"id": "r1", "role": "reasoning", "content": "Thinking..."}, + ] + + snapshot = _build_messages_snapshot(flow, []) + + roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages] + assert "reasoning" in roles + + def test_snapshot_preserves_reasoning_encrypted_value(self): + """Snapshot reasoning with encryptedValue is preserved end-to-end.""" + from agent_framework_ag_ui._agent_run import _build_messages_snapshot + + flow = FlowState() + content = Content.from_text_reasoning( + id="reason_e2e", + text="visible", + protected_data="secret-data", + ) + _emit_text_reasoning(content, flow) + + text_content = Content.from_text("Final answer.") + _emit_text(text_content, flow) + + snapshot = _build_messages_snapshot(flow, []) + + reasoning_msgs = [ + m + for m in snapshot.messages + if (m.get("role") if isinstance(m, dict) else getattr(m, "role", None)) == "reasoning" + ] + assert len(reasoning_msgs) == 1 + msg = reasoning_msgs[0] + if isinstance(msg, dict): + assert msg["content"] == "visible" + assert msg["encryptedValue"] == "secret-data" + + def test_emit_content_routes_reasoning_with_flow(self): + """_emit_content passes flow to _emit_text_reasoning for persistence.""" + flow = FlowState() + content = Content.from_text_reasoning(text="routed reasoning") + + _emit_content(content, flow) + + assert len(flow.reasoning_messages) == 1 + assert flow.reasoning_messages[0]["content"] == "routed reasoning" + + def test_reasoning_without_flow_does_not_error(self): + """Calling _emit_text_reasoning without flow still works (backward compat).""" + content = Content.from_text_reasoning(text="no flow") + + events = _emit_text_reasoning(content) + + assert len(events) == 5 + assert isinstance(events[0], ReasoningStartEvent) + + def test_snapshot_reasoning_ordering(self): + """Reasoning messages appear after assistant text in snapshot.""" + from agent_framework_ag_ui._agent_run import _build_messages_snapshot + + flow = FlowState() + reasoning_content = Content.from_text_reasoning(id="r1", text="Thinking...") + _emit_text_reasoning(reasoning_content, flow) + + text_content = Content.from_text("Answer") + _emit_text(text_content, flow) + + snapshot = _build_messages_snapshot(flow, [{"id": "u1", "role": "user", "content": "Hi"}]) + + # user -> assistant text -> reasoning + assert len(snapshot.messages) == 3 + roles = [m.get("role") if isinstance(m, dict) else getattr(m, "role", None) for m in snapshot.messages] + assert roles == ["user", "assistant", "reasoning"] + + def test_reasoning_accumulates_incremental_deltas(self): + """Multiple reasoning deltas with the same id accumulate into one entry.""" + flow = FlowState() + content1 = Content.from_text_reasoning(id="reason_inc", text="First ") + content2 = Content.from_text_reasoning(id="reason_inc", text="second ") + content3 = Content.from_text_reasoning(id="reason_inc", text="third.") + + _emit_text_reasoning(content1, flow) + _emit_text_reasoning(content2, flow) + _emit_text_reasoning(content3, flow) + + assert len(flow.reasoning_messages) == 1 + assert flow.reasoning_messages[0]["id"] == "reason_inc" + assert flow.reasoning_messages[0]["content"] == "First second third." + + def test_reasoning_accumulates_distinct_message_ids(self): + """Reasoning entries with different ids are stored separately.""" + flow = FlowState() + content_a = Content.from_text_reasoning(id="a", text="alpha") + content_b = Content.from_text_reasoning(id="b", text="beta") + + _emit_text_reasoning(content_a, flow) + _emit_text_reasoning(content_b, flow) + + assert len(flow.reasoning_messages) == 2 + assert flow.reasoning_messages[0]["content"] == "alpha" + assert flow.reasoning_messages[1]["content"] == "beta" + + def test_reasoning_encrypted_value_updated_on_later_delta(self): + """encryptedValue is set even when it arrives with a later delta.""" + flow = FlowState() + content1 = Content.from_text_reasoning(id="enc_late", text="part1 ") + content2 = Content.from_text_reasoning(id="enc_late", text="part2", protected_data="encrypted-payload") + + _emit_text_reasoning(content1, flow) + _emit_text_reasoning(content2, flow) + + assert len(flow.reasoning_messages) == 1 + assert flow.reasoning_messages[0]["content"] == "part1 part2" + assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload" diff --git a/python/packages/ag-ui/tests/ag_ui/test_utils.py b/python/packages/ag-ui/tests/ag_ui/test_utils.py index 0f453132f7..f353d2f0a7 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_utils.py +++ b/python/packages/ag-ui/tests/ag_ui/test_utils.py @@ -450,6 +450,7 @@ def test_normalize_agui_role_valid(): assert normalize_agui_role("assistant") == "assistant" assert normalize_agui_role("system") == "system" assert normalize_agui_role("tool") == "tool" + assert normalize_agui_role("reasoning") == "reasoning" def test_normalize_agui_role_invalid():