Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,28 @@
from ._endpoint import add_agent_framework_fastapi_endpoint
from ._event_converters import AGUIEventConverter
from ._http_service import AGUIHttpService
from ._types import AGUIRequest

try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0"

# Default OpenAPI tags for AG-UI endpoints
DEFAULT_TAGS = ["AG-UI"]

__all__ = [
"AgentFrameworkAgent",
"add_agent_framework_fastapi_endpoint",
"AGUIChatClient",
"AGUIEventConverter",
"AGUIHttpService",
"AGUIRequest",
"ConfirmationStrategy",
"DefaultConfirmationStrategy",
"TaskPlannerConfirmationStrategy",
"RecipeConfirmationStrategy",
"DocumentWriterConfirmationStrategy",
"DEFAULT_TAGS",
"__version__",
]
11 changes: 7 additions & 4 deletions python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from ag_ui.encoder import EventEncoder
from agent_framework import AgentProtocol
from fastapi import FastAPI, Request
from fastapi import FastAPI
from fastapi.responses import StreamingResponse

from ._agent import AgentFrameworkAgent
from ._types import AGUIRequest

logger = logging.getLogger(__name__)

Expand All @@ -24,6 +25,7 @@ def add_agent_framework_fastapi_endpoint(
predict_state_config: dict[str, dict[str, str]] | None = None,
allow_origins: list[str] | None = None,
default_state: dict[str, Any] | None = None,
tags: list[str] | None = None,
) -> None:
"""Add an AG-UI endpoint to a FastAPI app.

Expand All @@ -36,6 +38,7 @@ def add_agent_framework_fastapi_endpoint(
Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}}
allow_origins: CORS origins (not yet implemented)
default_state: Optional initial state to seed when the client does not provide state keys
tags: OpenAPI tags for endpoint categorization (defaults to ["AG-UI"])
"""
if isinstance(agent, AgentProtocol):
wrapped_agent = AgentFrameworkAgent(
Expand All @@ -46,15 +49,15 @@ def add_agent_framework_fastapi_endpoint(
else:
wrapped_agent = agent

@app.post(path)
async def agent_endpoint(request: Request): # type: ignore[misc]
@app.post(path, tags=tags or ["AG-UI"]) # type: ignore[arg-type]
async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc]
"""Handle AG-UI agent requests.

Note: Function is accessed via FastAPI's decorator registration,
despite appearing unused to static analysis.
"""
try:
input_data = await request.json()
input_data = request_body.model_dump(exclude_none=True)
if default_state:
state = input_data.setdefault("state", {})
for key, value in default_state.items():
Expand Down
23 changes: 23 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from typing import Any, TypedDict

from pydantic import BaseModel, Field


class PredictStateConfig(TypedDict):
"""Configuration for predictive state updates."""
Expand All @@ -25,3 +27,24 @@ class AgentState(TypedDict):
"""Base state for AG-UI agents."""

messages: list[Any] | None


class AGUIRequest(BaseModel):
"""Request model for AG-UI endpoints."""

messages: list[dict[str, Any]] = Field(
...,
description="AG-UI format messages array",
)
run_id: str | None = Field(
None,
description="Optional run identifier for tracking",
)
thread_id: str | None = Field(
None,
description="Optional thread identifier for conversation context",
)
state: dict[str, Any] | None = Field(
None,
description="Optional shared state for agentic generative UI",
)
125 changes: 120 additions & 5 deletions python/packages/ag-ui/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,8 @@ async def test_endpoint_error_handling():
# Send invalid JSON to trigger parsing error before streaming
response = client.post("/failing", data=b"invalid json", headers={"content-type": "application/json"}) # type: ignore

# The exception handler catches it and returns JSON error
assert response.status_code == 200
content = json.loads(response.content)
assert "error" in content
assert content["error"] == "An internal error has occurred."
# Pydantic validation now returns 422 for invalid request body
assert response.status_code == 422


async def test_endpoint_multiple_paths():
Expand Down Expand Up @@ -266,3 +263,121 @@ async def test_endpoint_complex_input():
)

assert response.status_code == 200


async def test_endpoint_openapi_schema():
"""Test that endpoint generates proper OpenAPI schema with request model."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())

add_agent_framework_fastapi_endpoint(app, agent, path="/schema-test")

client = TestClient(app)
response = client.get("/openapi.json")

assert response.status_code == 200
openapi_spec = response.json()

# Verify the endpoint exists in the schema
assert "/schema-test" in openapi_spec["paths"]
endpoint_spec = openapi_spec["paths"]["/schema-test"]["post"]

# Verify request body schema is defined
assert "requestBody" in endpoint_spec
request_body = endpoint_spec["requestBody"]
assert "content" in request_body
assert "application/json" in request_body["content"]

# Verify schema references AGUIRequest model
schema_ref = request_body["content"]["application/json"]["schema"]
assert "$ref" in schema_ref
assert "AGUIRequest" in schema_ref["$ref"]

# Verify AGUIRequest model is in components
assert "components" in openapi_spec
assert "schemas" in openapi_spec["components"]
assert "AGUIRequest" in openapi_spec["components"]["schemas"]

# Verify AGUIRequest has required fields
agui_request_schema = openapi_spec["components"]["schemas"]["AGUIRequest"]
assert "properties" in agui_request_schema
assert "messages" in agui_request_schema["properties"]
assert "run_id" in agui_request_schema["properties"]
assert "thread_id" in agui_request_schema["properties"]
assert "state" in agui_request_schema["properties"]
assert "required" in agui_request_schema
assert "messages" in agui_request_schema["required"]


async def test_endpoint_default_tags():
"""Test that endpoint uses default 'AG-UI' tag."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())

add_agent_framework_fastapi_endpoint(app, agent, path="/default-tags")

client = TestClient(app)
response = client.get("/openapi.json")

assert response.status_code == 200
openapi_spec = response.json()

endpoint_spec = openapi_spec["paths"]["/default-tags"]["post"]
assert "tags" in endpoint_spec
assert endpoint_spec["tags"] == ["AG-UI"]


async def test_endpoint_custom_tags():
"""Test that endpoint accepts custom tags."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())

add_agent_framework_fastapi_endpoint(app, agent, path="/custom-tags", tags=["Custom", "Agent"])

client = TestClient(app)
response = client.get("/openapi.json")

assert response.status_code == 200
openapi_spec = response.json()

endpoint_spec = openapi_spec["paths"]["/custom-tags"]["post"]
assert "tags" in endpoint_spec
assert endpoint_spec["tags"] == ["Custom", "Agent"]


async def test_endpoint_missing_required_field():
"""Test that endpoint validates required fields with Pydantic."""
app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())

add_agent_framework_fastapi_endpoint(app, agent, path="/validation")

client = TestClient(app)

# Missing required 'messages' field should trigger validation error
response = client.post("/validation", json={"run_id": "test-123"})

assert response.status_code == 422
error_detail = response.json()
assert "detail" in error_detail


async def test_endpoint_internal_error_handling():
"""Test endpoint error handling when an exception occurs before streaming starts."""
from unittest.mock import patch

app = FastAPI()
agent = ChatAgent(name="test", instructions="Test agent", chat_client=build_chat_client())

# Use default_state to trigger the code path that can raise an exception
add_agent_framework_fastapi_endpoint(app, agent, path="/error-test", default_state={"key": "value"})

client = TestClient(app)

# Mock copy.deepcopy to raise an exception during default_state processing
with patch("agent_framework_ag_ui._endpoint.copy.deepcopy") as mock_deepcopy:
mock_deepcopy.side_effect = Exception("Simulated internal error")
response = client.post("/error-test", json={"messages": [{"role": "user", "content": "Hello"}]})

assert response.status_code == 200
assert response.json() == {"error": "An internal error has occurred."}
Loading