diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index 8ab9fa0956..46f2dd4829 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -7,7 +7,7 @@ import json import logging import uuid -from collections.abc import Awaitable +from collections.abc import AsyncIterable, Awaitable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, cast @@ -172,12 +172,12 @@ class FlowState: tool_call_id: str | None = None # Current tool call being streamed tool_call_name: str | None = None # Name of current tool call waiting_for_approval: bool = False # Stop after approval request - current_state: dict[str, Any] = field(default_factory=dict) # Shared state + current_state: dict[str, Any] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType] accumulated_text: str = "" # For MessagesSnapshotEvent - pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent - tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) - tool_results: list[dict[str, Any]] = field(default_factory=list) - tool_calls_ended: set[str] = field(default_factory=set) # Track which tool calls have been ended + pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType] + tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) # pyright: ignore[reportUnknownVariableType] + tool_results: list[dict[str, Any]] = field(default_factory=list) # pyright: ignore[reportUnknownVariableType] + tool_calls_ended: set[str] = field(default_factory=set) # pyright: ignore[reportUnknownVariableType] def get_tool_name(self, call_id: str | None) -> str | None: """Get tool name by call ID.""" @@ -191,6 +191,40 @@ def get_pending_without_end(self) -> list[dict[str, Any]]: return [tc for tc in self.pending_tool_calls if tc.get("id") not in self.tool_calls_ended] +async def _normalize_response_stream(response_stream: Any) -> AsyncIterable[Any]: + """Normalize agent streaming return types to an async iterable. + + Supports: + - ResponseStream (standard agent stream type) + - AsyncIterable[AgentResponseUpdate] (workflow-style stream) + - Awaitable that resolves to either of the above + """ + if isinstance(response_stream, Awaitable): + resolved_stream = await cast(Awaitable[Any], response_stream) + if isinstance(resolved_stream, ResponseStream): + # AG-UI consumes update iteration only; ResponseStream finalizers are not used here. + return cast(AsyncIterable[Any], resolved_stream) + if isinstance(resolved_stream, AsyncIterable): + return cast(AsyncIterable[Any], resolved_stream) + resolved_type = f"{type(resolved_stream).__module__}.{type(resolved_stream).__name__}" + raise AgentExecutionException( + "Agent did not return a streaming AsyncIterable response. " + f"Awaitable resolved to unsupported type: {resolved_type}." + ) + + if isinstance(response_stream, ResponseStream): + # AG-UI consumes update iteration only; ResponseStream finalizers are not used here. + return cast(AsyncIterable[Any], response_stream) + + if isinstance(response_stream, AsyncIterable): + return cast(AsyncIterable[Any], response_stream) + + stream_type = f"{type(response_stream).__module__}.{type(response_stream).__name__}" + raise AgentExecutionException( + f"Agent did not return a streaming AsyncIterable response. Received unsupported type: {stream_type}." + ) + + def _create_state_context_message( current_state: dict[str, Any], state_schema: dict[str, Any], @@ -460,7 +494,7 @@ def _emit_approval_request( parent_message_id=flow.message_id, ) ) - args = { + args: dict[str, Any] = { "function_name": func_name, "function_call_id": func_call_id, "function_arguments": make_json_safe(func_call.parse_arguments()) or {}, @@ -515,7 +549,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool: if not messages: return False last = messages[-1] - if not last.additional_properties.get("is_tool_result", False): + additional_properties = cast(dict[str, Any], getattr(last, "additional_properties", {}) or {}) + if not additional_properties.get("is_tool_result", False): return False # Parse the content to check if it has the confirm_changes structure @@ -523,6 +558,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool: if getattr(content, "type", None) == "text" and content.text: try: result = json.loads(content.text) + if not isinstance(result, dict): + continue # confirm_changes results have 'accepted' and 'steps' keys if "accepted" in result and "steps" in result: return True @@ -548,13 +585,19 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]: message = "Acknowledged." else: try: - result = json.loads(approval_text) - accepted = result.get("accepted", False) - steps = result.get("steps", []) + parsed_result = json.loads(approval_text) + result: dict[str, Any] = cast(dict[str, Any], parsed_result) if isinstance(parsed_result, dict) else {} + accepted = bool(result.get("accepted", False)) + steps_raw = result.get("steps", []) + steps: list[dict[str, Any]] = [] + if isinstance(steps_raw, list): + for step_raw in cast(list[Any], steps_raw): + if isinstance(step_raw, dict): + steps.append(cast(dict[str, Any], step_raw)) if accepted: # Generate acceptance message with step descriptions - enabled_steps = [s for s in steps if s.get("status") == "enabled"] + enabled_steps: list[dict[str, Any]] = [step for step in steps if step.get("status") == "enabled"] if enabled_steps: message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"] for i, step in enumerate(enabled_steps, 1): @@ -678,8 +721,9 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None: result.append(msg) continue - function_results = [c for c in (msg.contents or []) if getattr(c, "type", None) == "function_result"] - other_contents = [c for c in (msg.contents or []) if getattr(c, "type", None) != "function_result"] + msg_contents = cast(list[Content], getattr(msg, "contents", None) or []) + function_results: list[Content] = [content for content in msg_contents if content.type == "function_result"] + other_contents: list[Content] = [content for content in msg_contents if content.type != "function_result"] if not function_results: result.append(msg) @@ -695,7 +739,7 @@ def _convert_approval_results_to_tool_messages(messages: list[Any]) -> None: # Then user message with remaining content (if any) if other_contents: - result.append(Message(role=msg.role, contents=other_contents)) + result.append(Message(role="user", contents=other_contents)) messages[:] = result @@ -765,21 +809,24 @@ async def run_agent_stream( if input_data.get("state"): flow.current_state = dict(input_data["state"]) + state_schema = cast(dict[str, Any], getattr(config, "state_schema", {}) or {}) + predict_state_config = cast(dict[str, dict[str, str]], getattr(config, "predict_state_config", {}) or {}) + # Apply schema defaults for missing state keys - if config.state_schema: - for key, schema in config.state_schema.items(): + if state_schema: + for key, schema in state_schema.items(): if key in flow.current_state: continue - if isinstance(schema, dict) and schema.get("type") == "array": + if isinstance(schema, dict) and cast(dict[str, Any], schema).get("type") == "array": flow.current_state[key] = [] else: flow.current_state[key] = {} # Initialize predictive state handler if configured predictive_handler: PredictiveStateHandler | None = None - if config.predict_state_config: + if predict_state_config: predictive_handler = PredictiveStateHandler( - predict_state_config=config.predict_state_config, + predict_state_config=predict_state_config, current_state=flow.current_state, ) @@ -789,11 +836,11 @@ async def run_agent_stream( # Check for structured output mode (skip text content) skip_text = False - response_format = None - from agent_framework import Agent - - if isinstance(agent, Agent): - response_format = agent.default_options.get("response_format") + response_format: type[Any] | None = None + default_options = getattr(agent, "default_options", None) + if isinstance(default_options, dict): + typed_default_options = cast(dict[str, Any], default_options) + response_format = cast(type[Any] | None, typed_default_options.get("response_format")) skip_text = response_format is not None # Handle empty messages (emit RunStarted immediately since no agent response) @@ -831,8 +878,9 @@ async def run_agent_stream( run_kwargs["tools"] = tools # Filter out AG-UI internal metadata keys before passing to chat client # These are used internally for orchestration and should not be sent to the LLM provider - client_metadata = { - k: v for k, v in (getattr(session, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS + session_metadata = cast(dict[str, Any], getattr(session, "metadata", None) or {}) + client_metadata: dict[str, Any] = { + k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS } safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {} if safe_metadata: @@ -863,19 +911,14 @@ async def run_agent_stream( # Inject state context message so the model knows current application state # This is critical for shared state scenarios where the UI state needs to be visible - if config.state_schema and flow.current_state: - messages = _inject_state_context(messages, flow.current_state, config.state_schema) + if state_schema and flow.current_state: + messages = _inject_state_context(messages, flow.current_state, state_schema) # Stream from agent - emit RunStarted after first update to get service IDs run_started_emitted = False all_updates: list[Any] = [] # Collect for structured output processing response_stream = agent.run(messages, stream=True, **run_kwargs) - if isinstance(response_stream, ResponseStream): - stream = response_stream - else: - stream = await cast(Awaitable[ResponseStream[Any, Any]], response_stream) - if not isinstance(stream, ResponseStream): - raise AgentExecutionException("Chat client did not return a ResponseStream.") + stream = await _normalize_response_stream(response_stream) async for update in stream: # Collect updates for structured output processing if response_format is not None: @@ -891,18 +934,18 @@ async def run_agent_stream( # NOW emit RunStarted with proper IDs yield RunStartedEvent(run_id=run_id, thread_id=thread_id) # Emit PredictState custom event if configured - if config.predict_state_config: + if predict_state_config: predict_state_value = [ { "state_key": state_key, "tool": cfg["tool"], "tool_argument": cfg["tool_argument"], } - for state_key, cfg in config.predict_state_config.items() + for state_key, cfg in predict_state_config.items() ] yield CustomEvent(name="PredictState", value=predict_state_value) # Emit initial state snapshot only if we have both state_schema and state - if config.state_schema and flow.current_state: + if state_schema and flow.current_state: yield StateSnapshotEvent(snapshot=flow.current_state) run_started_emitted = True @@ -933,17 +976,17 @@ async def run_agent_stream( # If no updates at all, still emit RunStarted if not run_started_emitted: yield RunStartedEvent(run_id=run_id, thread_id=thread_id) - if config.predict_state_config: + if predict_state_config: predict_state_value = [ { "state_key": state_key, "tool": cfg["tool"], "tool_argument": cfg["tool_argument"], } - for state_key, cfg in config.predict_state_config.items() + for state_key, cfg in predict_state_config.items() ] yield CustomEvent(name="PredictState", value=predict_state_value) - if config.state_schema and flow.current_state: + if state_schema and flow.current_state: yield StateSnapshotEvent(snapshot=flow.current_state) # Process structured output if response_format is set @@ -951,31 +994,33 @@ async def run_agent_stream( from agent_framework import AgentResponse from pydantic import BaseModel - logger.info(f"Processing structured output, update count: {len(all_updates)}") - final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format) - - if final_response.value and isinstance(final_response.value, BaseModel): - response_dict = final_response.value.model_dump(mode="json", exclude_none=True) - logger.info(f"Received structured output keys: {list(response_dict.keys())}") - - # Extract state updates - if no state_schema, all non-message fields are state - state_keys = ( - set(config.state_schema.keys()) if config.state_schema else set(response_dict.keys()) - {"message"} - ) - state_updates = {k: v for k, v in response_dict.items() if k in state_keys} - - if state_updates: - flow.current_state.update(state_updates) - yield StateSnapshotEvent(snapshot=flow.current_state) - logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") - - # Emit message field as text if present - if "message" in response_dict and response_dict["message"]: - message_id = generate_event_id() - yield TextMessageStartEvent(message_id=message_id, role="assistant") - yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"]) - yield TextMessageEndEvent(message_id=message_id) - logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + if not (isinstance(response_format, type) and issubclass(response_format, BaseModel)): + logger.warning("Skipping structured output parsing: response_format is not a Pydantic model type.") + else: + logger.info(f"Processing structured output, update count: {len(all_updates)}") + final_response = AgentResponse.from_updates(all_updates, output_format_type=response_format) + + if final_response.value and isinstance(final_response.value, BaseModel): + response_dict = final_response.value.model_dump(mode="json", exclude_none=True) + logger.info(f"Received structured output keys: {list(response_dict.keys())}") + + # Extract state updates - if no state_schema, all non-message fields are state + state_keys = set(state_schema.keys()) if state_schema else set(response_dict.keys()) - {"message"} + state_updates = {k: v for k, v in response_dict.items() if k in state_keys} + + if state_updates: + flow.current_state.update(state_updates) + yield StateSnapshotEvent(snapshot=flow.current_state) + logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") + + # Emit message field as text if present + message_text = response_dict.get("message") + if isinstance(message_text, str) and message_text: + message_id = generate_event_id() + yield TextMessageStartEvent(message_id=message_id, role="assistant") + yield TextMessageContentEvent(message_id=message_id, delta=message_text) + yield TextMessageEndEvent(message_id=message_id) + logger.info(f"Emitted conversational message with length={len(message_text)}") # Feature #1: Emit ToolCallEndEvent for declaration-only tools (tools without results) pending_without_end = flow.get_pending_without_end() @@ -989,8 +1034,8 @@ async def run_agent_stream( yield ToolCallEndEvent(tool_call_id=tool_call_id) # For predictive tools with require_confirmation, emit confirm_changes - if config.require_confirmation and config.predict_state_config and tool_name: - is_predictive_tool = any(cfg["tool"] == tool_name for cfg in config.predict_state_config.values()) + if config.require_confirmation and predict_state_config and tool_name: + is_predictive_tool = any(cfg["tool"] == tool_name for cfg in predict_state_config.values()) if is_predictive_tool: logger.info(f"Emitting confirm_changes for predictive tool '{tool_name}'") # Extract state value from tool arguments for StateSnapshot @@ -1071,7 +1116,7 @@ async def run_agent_stream( last_call_id = last_result.get("toolCallId") last_tool_name = flow.get_tool_name(last_call_id) if not _should_suppress_intermediate_snapshot( - last_tool_name, config.predict_state_config, config.require_confirmation + last_tool_name, predict_state_config, config.require_confirmation ): yield _build_messages_snapshot(flow, snapshot_messages) diff --git a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py index 6d80fff588..dfdab0f07c 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_endpoint.py +++ b/python/packages/ag-ui/tests/ag_ui/test_endpoint.py @@ -6,6 +6,7 @@ import pytest from agent_framework import Agent, ChatResponseUpdate, Content +from agent_framework.orchestrations import SequentialBuilder from fastapi import FastAPI, Header, HTTPException from fastapi.params import Depends from fastapi.testclient import TestClient @@ -165,6 +166,28 @@ async def test_endpoint_event_streaming(build_chat_client): assert found_run_finished +async def test_endpoint_with_workflow_as_agent_stream_output(build_chat_client): + """Test endpoint handles workflow-as-agent stream outputs.""" + app = FastAPI() + brainstorm_agent = Agent(name="brainstorm", instructions="Brainstorm ideas", client=build_chat_client("Idea")) + reviewer_agent = Agent(name="reviewer", instructions="Review ideas", client=build_chat_client("Review")) + agent = SequentialBuilder(participants=[brainstorm_agent, reviewer_agent]).build().as_agent() + + add_agent_framework_fastapi_endpoint(app, agent, path="/workflow-like") + + client = TestClient(app) + response = client.post("/workflow-like", json={"messages": [{"role": "user", "content": "Hello"}]}) + + assert response.status_code == 200 + content = response.content.decode("utf-8") + lines = [line for line in content.split("\n") if line.startswith("data: ")] + event_types = [json.loads(line[6:]).get("type") for line in lines] + + assert "RUN_STARTED" in event_types + assert "TEXT_MESSAGE_CONTENT" in event_types + assert "RUN_FINISHED" in event_types + + async def test_endpoint_error_handling(build_chat_client): """Test endpoint error handling during request parsing.""" app = FastAPI() diff --git a/python/packages/ag-ui/tests/ag_ui/test_run.py b/python/packages/ag-ui/tests/ag_ui/test_run.py index 9d21bd2d0a..8cee6e4338 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_run.py @@ -2,11 +2,13 @@ """Tests for _run.py helper functions and FlowState.""" +import pytest from ag_ui.core import ( TextMessageEndEvent, TextMessageStartEvent, ) -from agent_framework import Content, Message +from agent_framework import AgentResponseUpdate, Content, Message, ResponseStream +from agent_framework.exceptions import AgentExecutionException from agent_framework_ag_ui._run import ( FlowState, @@ -16,6 +18,7 @@ _emit_tool_result, _has_only_tool_calls, _inject_state_context, + _normalize_response_stream, _should_suppress_intermediate_snapshot, ) @@ -179,6 +182,54 @@ def test_get_pending_without_end(self): assert result[0]["id"] == "call_2" +class TestNormalizeResponseStream: + """Tests for _normalize_response_stream helper.""" + + async def test_accepts_response_stream(self): + """Accept standard ResponseStream values.""" + + async def _stream(): + yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant") + + stream = await _normalize_response_stream(ResponseStream(_stream())) + updates = [update async for update in stream] + + assert len(updates) == 1 + assert updates[0].contents[0].text == "hello" + + async def test_accepts_async_iterable(self): + """Accept workflow-style async generator streams.""" + + async def _stream(): + yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant") + + stream = await _normalize_response_stream(_stream()) + updates = [update async for update in stream] + + assert len(updates) == 1 + assert updates[0].contents[0].text == "hello" + + async def test_accepts_awaitable_resolving_to_async_iterable(self): + """Accept awaitables that resolve to async iterable streams.""" + + async def _stream(): + yield AgentResponseUpdate(contents=[Content.from_text("hello")], role="assistant") + + async def _resolve(): + return _stream() + + stream = await _normalize_response_stream(_resolve()) + updates = [update async for update in stream] + + assert len(updates) == 1 + assert updates[0].contents[0].text == "hello" + + async def test_rejects_non_stream_values(self): + """Reject unsupported stream return values.""" + with pytest.raises(AgentExecutionException): + await _normalize_response_stream("not-a-stream") + + class TestCreateStateContextMessage: """Tests for _create_state_context_message function."""