|
| 1 | +"""AG-UI LangGraph single agent with Gateway MCP tools and Memory. |
| 2 | +
|
| 3 | +Uses ag-ui-langgraph to produce native AG-UI SSE events. |
| 4 | +AgentCore proxies these unchanged when deployed with --protocol AGUI. |
| 5 | +""" |
| 6 | + |
| 7 | +import logging |
| 8 | +import os |
| 9 | + |
| 10 | +from fastapi import FastAPI, Request |
| 11 | +from fastapi.responses import StreamingResponse |
| 12 | + |
| 13 | +from ag_ui.core.types import RunAgentInput |
| 14 | +from ag_ui.encoder import EventEncoder |
| 15 | +from ag_ui_langgraph import LangGraphAgent |
| 16 | + |
| 17 | +from langgraph.prebuilt import create_react_agent |
| 18 | +from langchain_aws import ChatBedrock |
| 19 | +from langchain_mcp_adapters.client import MultiServerMCPClient |
| 20 | +from langgraph_checkpoint_aws import AgentCoreMemorySaver |
| 21 | + |
| 22 | +from bedrock_agentcore.identity.auth import requires_access_token |
| 23 | +from utils.auth import extract_user_id_from_request, setup_agentcore_context |
| 24 | +from utils.ssm import get_ssm_parameter |
| 25 | +from patched_langgraph_agent import PatchedLangGraphAgent |
| 26 | + |
| 27 | +logger = logging.getLogger(__name__) |
| 28 | + |
| 29 | + |
| 30 | +@requires_access_token( |
| 31 | + provider_name=os.environ["GATEWAY_CREDENTIAL_PROVIDER_NAME"], |
| 32 | + auth_flow="M2M", |
| 33 | + scopes=[] |
| 34 | +) |
| 35 | +async def _fetch_gateway_token(access_token: str) -> str: |
| 36 | + return access_token |
| 37 | + |
| 38 | + |
| 39 | +async def create_gateway_mcp_client() -> MultiServerMCPClient: |
| 40 | + stack_name = os.environ.get("STACK_NAME") |
| 41 | + if not stack_name: |
| 42 | + raise ValueError("STACK_NAME environment variable is required") |
| 43 | + if not stack_name.replace("-", "").replace("_", "").isalnum(): |
| 44 | + raise ValueError("Invalid STACK_NAME format") |
| 45 | + |
| 46 | + gateway_url = get_ssm_parameter(f"/{stack_name}/gateway_url") |
| 47 | + logger.info("[AGUI-LG] Gateway URL: %s", gateway_url) |
| 48 | + |
| 49 | + fresh_token = await _fetch_gateway_token() |
| 50 | + logger.info("[AGUI-LG] Gateway token fetched (%d chars)", len(fresh_token)) |
| 51 | + return MultiServerMCPClient({ |
| 52 | + "gateway": { |
| 53 | + "transport": "streamable_http", |
| 54 | + "url": gateway_url, |
| 55 | + "headers": {"Authorization": f"Bearer {fresh_token}"} |
| 56 | + } |
| 57 | + }) |
| 58 | + |
| 59 | + |
| 60 | +async def create_agent(user_id: str) -> LangGraphAgent: |
| 61 | + system_prompt = """You are a helpful assistant with access to tools via the Gateway. |
| 62 | + When asked about your tools, list them and explain what they do.""" |
| 63 | + |
| 64 | + bedrock_model = ChatBedrock( |
| 65 | + model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0", |
| 66 | + temperature=0.1, |
| 67 | + streaming=True |
| 68 | + ) |
| 69 | + |
| 70 | + memory_id = os.environ.get("MEMORY_ID") |
| 71 | + if not memory_id: |
| 72 | + raise ValueError("MEMORY_ID environment variable is required") |
| 73 | + |
| 74 | + checkpointer = AgentCoreMemorySaver( |
| 75 | + memory_id=memory_id, |
| 76 | + region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1") |
| 77 | + ) |
| 78 | + |
| 79 | + mcp_client = await create_gateway_mcp_client() |
| 80 | + tools = await mcp_client.get_tools() |
| 81 | + logger.info("[AGUI-LG] Loaded %d tools from Gateway", len(tools)) |
| 82 | + |
| 83 | + # Code Interpreter |
| 84 | + region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") |
| 85 | + try: |
| 86 | + from langgraph_code_interpreter import LangGraphCodeInterpreterTools |
| 87 | + code_tools = LangGraphCodeInterpreterTools(region) |
| 88 | + tools.append(code_tools.execute_python_securely) |
| 89 | + logger.info("[AGUI-LG] Code Interpreter loaded") |
| 90 | + except Exception as e: |
| 91 | + logger.warning("[AGUI-LG] Code Interpreter not available: %s", e) |
| 92 | + |
| 93 | + graph = create_react_agent( |
| 94 | + model=bedrock_model, |
| 95 | + tools=tools, |
| 96 | + checkpointer=checkpointer, |
| 97 | + prompt=system_prompt |
| 98 | + ) |
| 99 | + |
| 100 | + return PatchedLangGraphAgent( |
| 101 | + name="agui_langgraph_agent", |
| 102 | + graph=graph, |
| 103 | + description="AG-UI LangGraph agent with Gateway MCP tools and Memory", |
| 104 | + config={"configurable": {"actor_id": user_id}}, |
| 105 | + ) |
| 106 | + |
| 107 | + |
| 108 | +# --- Create app and register endpoint --- |
| 109 | + |
| 110 | +app = FastAPI(title="AG-UI LangGraph Agent") |
| 111 | + |
| 112 | + |
| 113 | +@app.post("/invocations") |
| 114 | +async def invocations(input_data: RunAgentInput, request: Request): |
| 115 | + try: |
| 116 | + setup_agentcore_context(request) |
| 117 | + user_id = extract_user_id_from_request(request) |
| 118 | + agent = await create_agent(user_id) |
| 119 | + |
| 120 | + encoder = EventEncoder(accept=request.headers.get("accept")) |
| 121 | + |
| 122 | + async def event_generator(): |
| 123 | + async for event in agent.run(input_data): |
| 124 | + yield encoder.encode(event) |
| 125 | + |
| 126 | + return StreamingResponse(event_generator(), media_type=encoder.get_content_type()) |
| 127 | + except Exception: |
| 128 | + logger.exception("[AGUI-LG] /invocations failed") |
| 129 | + raise |
| 130 | + |
| 131 | + |
| 132 | +@app.get("/ping") |
| 133 | +def ping(): |
| 134 | + return {"status": "Healthy"} |
| 135 | + |
| 136 | + |
| 137 | +if __name__ == "__main__": |
| 138 | + import uvicorn |
| 139 | + port = int(os.getenv("PORT", "8080")) |
| 140 | + uvicorn.run("langgraph_agent:app", host="0.0.0.0", port=port, reload=True) |
0 commit comments