Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 114 additions & 69 deletions python/packages/ag-ui/agent_framework_ag_ui/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import logging
import uuid
from collections.abc import Awaitable
from collections.abc import AsyncIterable, Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -172,12 +172,12 @@ class FlowState:
tool_call_id: str | None = None # Current tool call being streamed
tool_call_name: str | None = None # Name of current tool call
waiting_for_approval: bool = False # Stop after approval request
current_state: dict[str, Any] = field(default_factory=dict) # Shared state
current_state: dict[str, Any] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
accumulated_text: str = "" # For MessagesSnapshotEvent
pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent
tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict)
tool_results: list[dict[str, Any]] = field(default_factory=list)
tool_calls_ended: set[str] = field(default_factory=set) # Track which tool calls have been ended
pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType]
tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType]
tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType]

def get_tool_name(self, call_id: str | None) -> str | None:
"""Get tool name by call ID."""
Expand All @@ -191,6 +191,40 @@ def get_pending_without_end(self) -> list[dict[str, Any]]:
return [tc for tc in self.pending_tool_calls if tc.get("id") not in self.tool_calls_ended]


async def _normalize_response_stream(response_stream: Any) -> AsyncIterable[Any]:
"""Normalize agent streaming return types to an async iterable.

Supports:
- ResponseStream (standard agent stream type)
- AsyncIterable[AgentResponseUpdate] (workflow-style stream)
- Awaitable that resolves to either of the above
"""
if isinstance(response_stream, Awaitable):
resolved_stream = await cast(Awaitable[Any], response_stream)
if isinstance(resolved_stream, ResponseStream):
# AG-UI consumes update iteration only; ResponseStream finalizers are not used here.
return cast(AsyncIterable[Any], resolved_stream)
if isinstance(resolved_stream, AsyncIterable):
return cast(AsyncIterable[Any], resolved_stream)
resolved_type = f"{type(resolved_stream).__module__}.{type(resolved_stream).__name__}"
raise AgentExecutionException(
"Agent did not return a streaming AsyncIterable response. "
f"Awaitable resolved to unsupported type: {resolved_type}."
)

if isinstance(response_stream, ResponseStream):
# AG-UI consumes update iteration only; ResponseStream finalizers are not used here.
return cast(AsyncIterable[Any], response_stream)

if isinstance(response_stream, AsyncIterable):
return cast(AsyncIterable[Any], response_stream)

stream_type = f"{type(response_stream).__module__}.{type(response_stream).__name__}"
raise AgentExecutionException(
f"Agent did not return a streaming AsyncIterable response. Received unsupported type: {stream_type}."
)


def _create_state_context_message(
current_state: dict[str, Any],
state_schema: dict[str, Any],
Expand Down Expand Up @@ -460,7 +494,7 @@ def _emit_approval_request(
parent_message_id=flow.message_id,
)
)
args = {
args: dict[str, Any] = {
"function_name": func_name,
"function_call_id": func_call_id,
"function_arguments": make_json_safe(func_call.parse_arguments()) or {},
Expand Down Expand Up @@ -515,14 +549,17 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool:
if not messages:
return False
last = messages[-1]
if not last.additional_properties.get("is_tool_result", False):
additional_properties = cast(dict[str, Any], getattr(last, "additional_properties", {}) or {})
if not additional_properties.get("is_tool_result", False):
return False

# Parse the content to check if it has the confirm_changes structure
for content in last.contents:
if getattr(content, "type", None) == "text" and content.text:
try:
result = json.loads(content.text)
if not isinstance(result, dict):
continue
# confirm_changes results have 'accepted' and 'steps' keys
if "accepted" in result and "steps" in result:
return True
Expand All @@ -548,13 +585,19 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]:
message = "Acknowledged."
else:
try:
result = json.loads(approval_text)
accepted = result.get("accepted", False)
steps = result.get("steps", [])
parsed_result = json.loads(approval_text)
result: dict[str, Any] = cast(dict[str, Any], parsed_result) if isinstance(parsed_result, dict) else {}
accepted = bool(result.get("accepted", False))
steps_raw = result.get("steps", [])
steps: list[dict[str, Any]] = []
if isinstance(steps_raw, list):
for step_raw in cast(list[Any], steps_raw):
if isinstance(step_raw, dict):
steps.append(cast(dict[str, Any], step_raw))

if accepted:
# Generate acceptance message with step descriptions
enabled_steps = [s for s in steps if s.get("status") == "enabled"]
enabled_steps: list[dict[str, Any]] = [step for step in steps if step.get("status") == "enabled"]
if enabled_steps:
message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"]
for i, step in enumerate(enabled_steps, 1):
Expand Down Expand Up @@ -678,8 +721,9 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None:
result.append(msg)
continue

function_results = [c for c in (msg.contents or []) if getattr(c, "type", None) == "function_result"]
other_contents = [c for c in (msg.contents or []) if getattr(c, "type", None) != "function_result"]
msg_contents = cast(list[Content], getattr(msg, "contents", None) or [])
function_results: list[Content] = [content for content in msg_contents if content.type == "function_result"]
other_contents: list[Content] = [content for content in msg_contents if content.type != "function_result"]

if not function_results:
result.append(msg)
Expand All @@ -695,7 +739,7 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None:

# Then user message with remaining content (if any)
if other_contents:
result.append(Message(role=msg.role, contents=other_contents))
result.append(Message(role="user", contents=other_contents))

messages[:] = result

Expand Down Expand Up @@ -765,21 +809,24 @@ async def run_agent_stream(
if input_data.get("state"):
flow.current_state = dict(input_data["state"])

state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {})
predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {})

# Apply schema defaults for missing state keys
if config.state_schema:
for key, schema in config.state_schema.items():
if state_schema:
for key, schema in state_schema.items():
if key in flow.current_state:
continue
if isinstance(schema, dict) and schema.get("type") == "array":
if isinstance(schema, dict) and cast(dict[str, Any], schema).get("type") == "array":
flow.current_state[key] = []
else:
flow.current_state[key] = {}

# Initialize predictive state handler if configured
predictive_handler: PredictiveStateHandler | None = None
if config.predict_state_config:
if predict_state_config:
predictive_handler = PredictiveStateHandler(
predict_state_config=config.predict_state_config,
predict_state_config=predict_state_config,
current_state=flow.current_state,
)

Expand All @@ -789,11 +836,11 @@ async def run_agent_stream(

# Check for structured output mode (skip text content)
skip_text = False
response_format = None
from agent_framework import Agent

if isinstance(agent, Agent):
response_format = agent.default_options.get("response_format")
response_format: type[Any] | None = None
default_options = getattr(agent, "default_options", None)
if isinstance(default_options, dict):
typed_default_options = cast(dict[str, Any], default_options)
response_format = cast(type[Any] | None, typed_default_options.get("response_format"))
skip_text = response_format is not None

# Handle empty messages (emit RunStarted immediately since no agent response)
Expand Down Expand Up @@ -831,8 +878,9 @@ async def run_agent_stream(
run_kwargs["tools"] = tools
# Filter out AG-UI internal metadata keys before passing to chat client
# These are used internally for orchestration and should not be sent to the LLM provider
client_metadata = {
k: v for k, v in (getattr(session, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS
session_metadata = cast(dict[str, Any], getattr(session, "metadata", None) or {})
client_metadata: dict[str, Any] = {
k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS
}
safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {}
if safe_metadata:
Expand Down Expand Up @@ -863,19 +911,14 @@ async def run_agent_stream(

# Inject state context message so the model knows current application state
# This is critical for shared state scenarios where the UI state needs to be visible
if config.state_schema and flow.current_state:
messages = _inject_state_context(messages, flow.current_state, config.state_schema)
if state_schema and flow.current_state:
messages = _inject_state_context(messages, flow.current_state, state_schema)

# Stream from agent - emit RunStarted after first update to get service IDs
run_started_emitted = False
all_updates: list[Any] = [] # Collect for structured output processing
response_stream = agent.run(messages, stream=True, **run_kwargs)
if isinstance(response_stream, ResponseStream):
stream = response_stream
else:
stream = await cast(Awaitable[ResponseStream[Any, Any]], response_stream)
if not isinstance(stream, ResponseStream):
raise AgentExecutionException("Chat client did not return a ResponseStream.")
stream = await _normalize_response_stream(response_stream)
async for update in stream:
# Collect updates for structured output processing
if response_format is not None:
Expand All @@ -891,18 +934,18 @@ async def run_agent_stream(
# NOW emit RunStarted with proper IDs
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
# Emit PredictState custom event if configured
if config.predict_state_config:
if predict_state_config:
predict_state_value = [
{
"state_key": state_key,
"tool": cfg["tool"],
"tool_argument": cfg["tool_argument"],
}
for state_key, cfg in config.predict_state_config.items()
for state_key, cfg in predict_state_config.items()
]
yield CustomEvent(name="PredictState", value=predict_state_value)
# Emit initial state snapshot only if we have both state_schema and state
if config.state_schema and flow.current_state:
if state_schema and flow.current_state:
yield StateSnapshotEvent(snapshot=flow.current_state)
run_started_emitted = True

Expand Down Expand Up @@ -933,49 +976,51 @@ async def run_agent_stream(
# If no updates at all, still emit RunStarted
if not run_started_emitted:
yield RunStartedEvent(run_id=run_id, thread_id=thread_id)
if config.predict_state_config:
if predict_state_config:
predict_state_value = [
{
"state_key": state_key,
"tool": cfg["tool"],
"tool_argument": cfg["tool_argument"],
}
for state_key, cfg in config.predict_state_config.items()
for state_key, cfg in predict_state_config.items()
]
yield CustomEvent(name="PredictState", value=predict_state_value)
if config.state_schema and flow.current_state:
if state_schema and flow.current_state:
yield StateSnapshotEvent(snapshot=flow.current_state)

# Process structured output if response_format is set
if response_format is not None and all_updates:
from agent_framework import AgentResponse
from pydantic import BaseModel

logger.info(f"Processing structured output, update count: {len(all_updates)}")
final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format)

if final_response.value and isinstance(final_response.value, BaseModel):
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
logger.info(f"Received structured output keys: {list(response_dict.keys())}")

# Extract state updates - if no state_schema, all non-message fields are state
state_keys = (
set(config.state_schema.keys()) if config.state_schema else set(response_dict.keys()) - {"message"}
)
state_updates = {k: v for k, v in response_dict.items() if k in state_keys}

if state_updates:
flow.current_state.update(state_updates)
yield StateSnapshotEvent(snapshot=flow.current_state)
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")

# Emit message field as text if present
if "message" in response_dict and response_dict["message"]:
message_id = generate_event_id()
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"])
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")
if not (isinstance(response_format, type) and issubclass(response_format, BaseModel)):
logger.warning("Skipping structured output parsing: response_format is not a Pydantic model type.")
else:
logger.info(f"Processing structured output, update count: {len(all_updates)}")
final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format)

if final_response.value and isinstance(final_response.value, BaseModel):
response_dict = final_response.value.model_dump(mode="json", exclude_none=True)
logger.info(f"Received structured output keys: {list(response_dict.keys())}")

# Extract state updates - if no state_schema, all non-message fields are state
state_keys = set(state_schema.keys()) if state_schema else set(response_dict.keys()) - {"message"}
state_updates = {k: v for k, v in response_dict.items() if k in state_keys}

if state_updates:
flow.current_state.update(state_updates)
yield StateSnapshotEvent(snapshot=flow.current_state)
logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}")

# Emit message field as text if present
message_text = response_dict.get("message")
if isinstance(message_text, str) and message_text:
message_id = generate_event_id()
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=message_text)
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(message_text)}")

# Feature #1: Emit ToolCallEndEvent for declaration-only tools (tools without results)
pending_without_end = flow.get_pending_without_end()
Expand All @@ -989,8 +1034,8 @@ async def run_agent_stream(
yield ToolCallEndEvent(tool_call_id=tool_call_id)

# For predictive tools with require_confirmation, emit confirm_changes
if config.require_confirmation and config.predict_state_config and tool_name:
is_predictive_tool = any(cfg["tool"] == tool_name for cfg in config.predict_state_config.values())
if config.require_confirmation and predict_state_config and tool_name:
is_predictive_tool = any(cfg["tool"] == tool_name for cfg in predict_state_config.values())
if is_predictive_tool:
logger.info(f"Emitting confirm_changes for predictive tool '{tool_name}'")
# Extract state value from tool arguments for StateSnapshot
Expand Down Expand Up @@ -1071,7 +1116,7 @@ async def run_agent_stream(
last_call_id = last_result.get("toolCallId")
last_tool_name = flow.get_tool_name(last_call_id)
if not _should_suppress_intermediate_snapshot(
last_tool_name, config.predict_state_config, config.require_confirmation
last_tool_name, predict_state_config, config.require_confirmation
):
yield _build_messages_snapshot(flow, snapshot_messages)

Expand Down
23 changes: 23 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
from agent_framework import Agent, ChatResponseUpdate, Content
from agent_framework.orchestrations import SequentialBuilder
from fastapi import FastAPI, Header, HTTPException
from fastapi.params import Depends
from fastapi.testclient import TestClient
Expand Down Expand Up @@ -165,6 +166,28 @@ async def test_endpoint_event_streaming(build_chat_client):
assert found_run_finished


async def test_endpoint_with_workflow_as_agent_stream_output(build_chat_client):
"""Test endpoint handles workflow-as-agent stream outputs."""
app = FastAPI()
brainstorm_agent = Agent(name="brainstorm", instructions="Brainstorm ideas", client=build_chat_client("Idea"))
reviewer_agent = Agent(name="reviewer", instructions="Review ideas", client=build_chat_client("Review"))
agent = SequentialBuilder(participants=[brainstorm_agent, reviewer_agent]).build().as_agent()

add_agent_framework_fastapi_endpoint(app, agent, path="/workflow-like")

client = TestClient(app)
response = client.post("/workflow-like", json={"messages": [{"role": "user", "content": "Hello"}]})

assert response.status_code == 200
content = response.content.decode("utf-8")
lines = [line for line in content.split("\n") if line.startswith("data: ")]
event_types = [json.loads(line[6:]).get("type") for line in lines]

assert "RUN_STARTED" in event_types
assert "TEXT_MESSAGE_CONTENT" in event_types
assert "RUN_FINISHED" in event_types


async def test_endpoint_error_handling(build_chat_client):
"""Test endpoint error handling during request parsing."""
app = FastAPI()
Expand Down
Loading