Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ jobs:
UV_CACHE_DIR: /tmp/.uv-cache
# Unit tests
- name: Run all tests
run: uv run poe test -A
run: uv run poe test -A --junitxml=pytest.xml
working-directory: ./python

# Surface failing tests
- name: Surface failing tests
if: always()
uses: pmeier/pytest-results-action@v0.7.2
with:
path: ./python/**.xml
path: ./python/pytest.xml
summary: true
display-options: fEX
fail-on-empty: false
Expand Down
5 changes: 5 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from ._run_common import (
FlowState,
_build_run_finished_event, # type: ignore
_close_reasoning_block, # type: ignore
_emit_content, # type: ignore
_extract_resume_payload, # type: ignore
_has_only_tool_calls, # type: ignore
Expand Down Expand Up @@ -1058,6 +1059,10 @@ async def run_agent_stream(
}
)

# Close any open reasoning block
for event in _close_reasoning_block(flow):
yield event

# Close any open message
if flow.message_id:
logger.debug(f"End of run: closing text message message_id={flow.message_id}")
Expand Down
106 changes: 82 additions & 24 deletions python/packages/ag-ui/agent_framework_ag_ui/_run_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class FlowState:
interrupts: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
reasoning_messages: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
accumulated_reasoning: dict[str, str] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
reasoning_message_id: str | None = None

def get_tool_name(self, call_id: str | None) -> str | None:
"""Get tool name by call ID."""
Expand Down Expand Up @@ -462,12 +463,39 @@ def _emit_mcp_tool_result(
return _emit_tool_result_common(content.call_id, raw_output, flow, predictive_handler)


def _close_reasoning_block(flow: FlowState) -> list[BaseEvent]:
"""Close an open reasoning block, emitting end events.

Should be called when the reasoning block is complete -- e.g. when
non-reasoning content arrives or at end of a run.
"""
if not flow.reasoning_message_id:
return []
message_id = flow.reasoning_message_id
flow.reasoning_message_id = None
return [
ReasoningMessageEndEvent(message_id=message_id),
ReasoningEndEvent(message_id=message_id),
]


def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> list[BaseEvent]:
"""Emit AG-UI reasoning events for text_reasoning content.

Uses the protocol-defined reasoning event types so that AG-UI consumers
such as CopilotKit can render reasoning natively.

When *flow* is provided the function follows the streaming pattern: it
emits ``ReasoningStartEvent`` / ``ReasoningMessageStartEvent`` only on
the first delta for a given ``message_id`` and just
``ReasoningMessageContentEvent`` for subsequent deltas. The matching
``ReasoningMessageEndEvent`` / ``ReasoningEndEvent`` are deferred until
``_close_reasoning_block`` is called (e.g. when non-reasoning content
arrives or at end-of-run).

Without *flow* (backward-compat) the full Start→Content→End sequence is
emitted for every call.

Only ``content.text`` is used for the visible reasoning message. If
``content.protected_data`` is present it is emitted as a
``ReasoningEncryptedValueEvent`` so that consumers can persist encrypted
Expand All @@ -483,26 +511,49 @@ def _emit_text_reasoning(content: Content, flow: FlowState | None = None) -> lis

message_id = content.id or generate_event_id()

events: list[BaseEvent] = [
ReasoningStartEvent(message_id=message_id),
ReasoningMessageStartEvent(message_id=message_id, role="assistant"),
]
events: list[BaseEvent] = []

if text:
events.append(ReasoningMessageContentEvent(message_id=message_id, delta=text))
if flow is not None:
# Streaming mode: track open reasoning block in flow state.
if flow.reasoning_message_id != message_id:
# Close any previously open reasoning block (different message_id).
events.extend(_close_reasoning_block(flow))
# Open new reasoning block.
events.append(ReasoningStartEvent(message_id=message_id))
events.append(ReasoningMessageStartEvent(message_id=message_id, role="assistant"))
flow.reasoning_message_id = message_id

events.append(ReasoningMessageEndEvent(message_id=message_id))
if text:
events.append(ReasoningMessageContentEvent(message_id=message_id, delta=text))

if content.protected_data is not None:
events.append(
ReasoningEncryptedValueEvent(
subtype="message",
entity_id=message_id,
encrypted_value=content.protected_data,
)
)
else:
# No flow -- backward-compatible full sequence per call.
events.append(ReasoningStartEvent(message_id=message_id))
events.append(ReasoningMessageStartEvent(message_id=message_id, role="assistant"))

if content.protected_data is not None:
events.append(
ReasoningEncryptedValueEvent(
subtype="message",
entity_id=message_id,
encrypted_value=content.protected_data,
if text:
events.append(ReasoningMessageContentEvent(message_id=message_id, delta=text))

events.append(ReasoningMessageEndEvent(message_id=message_id))

if content.protected_data is not None:
events.append(
ReasoningEncryptedValueEvent(
subtype="message",
entity_id=message_id,
encrypted_value=content.protected_data,
)
)
)

events.append(ReasoningEndEvent(message_id=message_id))
events.append(ReasoningEndEvent(message_id=message_id))

# Persist reasoning into flow state for MESSAGES_SNAPSHOT.
# Accumulate reasoning text per message_id, similar to flow.accumulated_text,
Expand Down Expand Up @@ -546,23 +597,30 @@ def _emit_content(
) -> list[BaseEvent]:
"""Emit appropriate events for any content type."""
content_type = getattr(content, "type", None)

# Close open reasoning block when switching to non-reasoning content.
if content_type != "text_reasoning":
events = _close_reasoning_block(flow)
else:
events = []

if content_type == "text":
return _emit_text(content, flow, skip_text)
return events + _emit_text(content, flow, skip_text)
if content_type == "function_call":
return _emit_tool_call(content, flow, predictive_handler)
return events + _emit_tool_call(content, flow, predictive_handler)
if content_type == "function_result":
return _emit_tool_result(content, flow, predictive_handler)
return events + _emit_tool_result(content, flow, predictive_handler)
if content_type == "function_approval_request":
return _emit_approval_request(content, flow, predictive_handler, require_confirmation)
return events + _emit_approval_request(content, flow, predictive_handler, require_confirmation)
if content_type == "usage":
return _emit_usage(content)
return events + _emit_usage(content)
if content_type == "oauth_consent_request":
return _emit_oauth_consent(content)
return events + _emit_oauth_consent(content)
if content_type == "mcp_server_tool_call":
return _emit_mcp_tool_call(content, flow)
return events + _emit_mcp_tool_call(content, flow)
if content_type == "mcp_server_tool_result":
return _emit_mcp_tool_result(content, flow, predictive_handler)
return events + _emit_mcp_tool_result(content, flow, predictive_handler)
if content_type == "text_reasoning":
return _emit_text_reasoning(content, flow)
logger.debug("Skipping unsupported content type in AG-UI emitter: %s", content_type)
return []
return events
4 changes: 4 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ._run_common import (
FlowState,
_build_run_finished_event,
_close_reasoning_block,
_emit_content,
_extract_resume_payload,
_normalize_resume_interrupts,
Expand Down Expand Up @@ -729,6 +730,9 @@ def _drain_open_message() -> list[TextMessageEndEvent]:
run_error_emitted = True
terminal_emitted = True

for reasoning_evt in _close_reasoning_block(flow):
yield reasoning_evt

for end_event in _drain_open_message():
yield end_event

Expand Down
141 changes: 140 additions & 1 deletion python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ReasoningMessageEndEvent,
ReasoningMessageStartEvent,
ReasoningStartEvent,
TextMessageContentEvent,
TextMessageEndEvent,
TextMessageStartEvent,
ToolCallArgsEvent,
Expand All @@ -29,6 +30,7 @@
from agent_framework_ag_ui._run_common import (
FlowState,
_build_run_finished_event,
_close_reasoning_block,
_emit_approval_request,
_emit_content,
_emit_mcp_tool_call,
Expand Down Expand Up @@ -1344,8 +1346,11 @@ def test_routes_text_reasoning(self):

events = _emit_content(content, flow)

assert len(events) == 5
# Streaming pattern: Start + MessageStart + Content (no End events yet)
assert len(events) == 3
assert isinstance(events[0], ReasoningStartEvent)
assert isinstance(events[1], ReasoningMessageStartEvent)
assert isinstance(events[2], ReasoningMessageContentEvent)


class TestReasoningInSnapshot:
Expand Down Expand Up @@ -1501,3 +1506,137 @@ def test_reasoning_encrypted_value_updated_on_later_delta(self):
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "part1 part2"
assert flow.reasoning_messages[0]["encryptedValue"] == "encrypted-payload"

def test_reasoning_done_after_deltas_does_not_duplicate(self):
"""A done-style content arriving after deltas does not duplicate accumulated text.

The upstream client should skip done events when deltas preceded them,
but if one leaks through, the accumulator must not double-append.
This test verifies that only the delta-produced text is stored.
"""
flow = FlowState()
msg_id = "reason_dedup"

delta1 = Content.from_text_reasoning(id=msg_id, text="Hello ")
delta2 = Content.from_text_reasoning(id=msg_id, text="world")

_emit_text_reasoning(delta1, flow)
_emit_text_reasoning(delta2, flow)

# Accumulated text should equal the concatenation of deltas only
assert len(flow.reasoning_messages) == 1
assert flow.reasoning_messages[0]["content"] == "Hello world"
assert flow.reasoning_messages[0]["id"] == msg_id

def test_reasoning_deltas_emit_one_content_event_each(self):
"""Each reasoning delta emits exactly one ReasoningMessageContentEvent
within a single Start/End sequence (streaming pattern)."""
flow = FlowState()
msg_id = "reason_evt"

delta1 = Content.from_text_reasoning(id=msg_id, text="Think ")
delta2 = Content.from_text_reasoning(id=msg_id, text="hard")

events1 = _emit_text_reasoning(delta1, flow)
events2 = _emit_text_reasoning(delta2, flow)
close_events = _close_reasoning_block(flow)

all_events = events1 + events2 + close_events
content_events = [e for e in all_events if isinstance(e, ReasoningMessageContentEvent)]

assert len(content_events) == 2
assert content_events[0].delta == "Think "
assert content_events[1].delta == "hard"

# Streaming pattern: one Start/End sequence wrapping both content events
start_events = [e for e in all_events if isinstance(e, ReasoningStartEvent)]
end_events = [e for e in all_events if isinstance(e, ReasoningEndEvent)]
msg_start_events = [e for e in all_events if isinstance(e, ReasoningMessageStartEvent)]
msg_end_events = [e for e in all_events if isinstance(e, ReasoningMessageEndEvent)]
assert len(start_events) == 1
assert len(end_events) == 1
assert len(msg_start_events) == 1
assert len(msg_end_events) == 1

def test_reasoning_streaming_event_order(self):
"""Streaming reasoning emits Start once, then Content per delta, then End on close."""
flow = FlowState()
msg_id = "reason_order"

d1 = Content.from_text_reasoning(id=msg_id, text="A ")
d2 = Content.from_text_reasoning(id=msg_id, text="B ")
d3 = Content.from_text_reasoning(id=msg_id, text="C")

events = []
events.extend(_emit_text_reasoning(d1, flow))
events.extend(_emit_text_reasoning(d2, flow))
events.extend(_emit_text_reasoning(d3, flow))
events.extend(_close_reasoning_block(flow))

assert isinstance(events[0], ReasoningStartEvent)
assert isinstance(events[1], ReasoningMessageStartEvent)
assert isinstance(events[2], ReasoningMessageContentEvent)
assert events[2].delta == "A "
assert isinstance(events[3], ReasoningMessageContentEvent)
assert events[3].delta == "B "
assert isinstance(events[4], ReasoningMessageContentEvent)
assert events[4].delta == "C"
assert isinstance(events[5], ReasoningMessageEndEvent)
assert isinstance(events[6], ReasoningEndEvent)
assert len(events) == 7

def test_close_reasoning_block_noop_when_not_open(self):
"""_close_reasoning_block returns empty list when no reasoning block is open."""
flow = FlowState()
assert _close_reasoning_block(flow) == []

def test_close_reasoning_block_resets_state(self):
"""_close_reasoning_block clears reasoning_message_id."""
flow = FlowState()
_emit_text_reasoning(Content.from_text_reasoning(id="r1", text="x"), flow)
assert flow.reasoning_message_id == "r1"

_close_reasoning_block(flow)
assert flow.reasoning_message_id is None

def test_emit_content_closes_reasoning_on_text(self):
"""Switching from reasoning to text content auto-closes reasoning block."""
flow = FlowState()
reasoning = Content.from_text_reasoning(id="r1", text="thinking")
text = Content.from_text("answer")

r_events = _emit_content(reasoning, flow)
t_events = _emit_content(text, flow)

# reasoning events: Start + MsgStart + Content
assert isinstance(r_events[0], ReasoningStartEvent)
# text events should start with reasoning End events
assert isinstance(t_events[0], ReasoningMessageEndEvent)
assert isinstance(t_events[1], ReasoningEndEvent)
# then text start

assert isinstance(t_events[2], TextMessageStartEvent)
assert isinstance(t_events[3], TextMessageContentEvent)

def test_reasoning_distinct_ids_close_previous_block(self):
"""Emitting reasoning with a new message_id auto-closes the previous block."""
flow = FlowState()
c1 = Content.from_text_reasoning(id="block1", text="first")
c2 = Content.from_text_reasoning(id="block2", text="second")

events1 = _emit_text_reasoning(c1, flow)
events2 = _emit_text_reasoning(c2, flow)
close = _close_reasoning_block(flow)

# events1: Start(block1) + MsgStart(block1) + Content(block1)
assert events1[0].message_id == "block1"
# events2: MsgEnd(block1) + End(block1) + Start(block2) + MsgStart(block2) + Content(block2)
assert isinstance(events2[0], ReasoningMessageEndEvent)
assert events2[0].message_id == "block1"
assert isinstance(events2[1], ReasoningEndEvent)
assert events2[1].message_id == "block1"
assert isinstance(events2[2], ReasoningStartEvent)
assert events2[2].message_id == "block2"
# close: MsgEnd(block2) + End(block2)
assert isinstance(close[0], ReasoningMessageEndEvent)
assert close[0].message_id == "block2"
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,7 @@ async def get_latest(self, *, workflow_name: str) -> WorkflowCheckpoint | None:
"""
await self._ensure_container_proxy()

query = (
"SELECT * FROM c WHERE c.workflow_name = @workflow_name "
"ORDER BY c.timestamp DESC OFFSET 0 LIMIT 1"
)
query = "SELECT * FROM c WHERE c.workflow_name = @workflow_name ORDER BY c.timestamp DESC OFFSET 0 LIMIT 1"
parameters: list[dict[str, object]] = [
{"name": "@workflow_name", "value": workflow_name},
]
Expand Down Expand Up @@ -351,10 +348,7 @@ async def list_checkpoint_ids(self, *, workflow_name: str) -> list[CheckpointID]
"""
await self._ensure_container_proxy()

query = (
"SELECT c.checkpoint_id FROM c WHERE c.workflow_name = @workflow_name "
"ORDER BY c.timestamp ASC"
)
query = "SELECT c.checkpoint_id FROM c WHERE c.workflow_name = @workflow_name ORDER BY c.timestamp ASC"
parameters: list[dict[str, object]] = [
{"name": "@workflow_name", "value": workflow_name},
]
Expand Down
Loading
Loading