Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions veadk/toolkits/apps/reverse_mcp/server_with_reverse_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import asyncio
import json
import uuid
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional

from fastapi import FastAPI, HTTPException, Request, Response, WebSocket
from fastapi.responses import StreamingResponse
from google.adk.agents.run_config import StreamingMode
from google.adk.agents.run_config import RunConfig, StreamingMode
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.adk_web_server import RunAgentRequest
from google.adk.runners import Runner as GoogleRunner, RunConfig
from google.adk.runners import Runner as GoogleRunner
from google.adk.sessions import InMemorySessionService, Session
from google.adk.tools.mcp_tool.mcp_session_manager import (
StreamableHTTPConnectionParams,
Expand All @@ -42,6 +42,12 @@
REVERSE_MCP_HEADER_KEY = "X-Reverse-MCP-ID"


class ExtraRoute(BaseModel):
path: str
endpoint: Callable
methods: list[str]


class WebsocketSessionManager:
def __init__(self):
# ws id -> ws instance
Expand Down Expand Up @@ -93,13 +99,21 @@ def __init__(
agent: "Agent",
host: str = "0.0.0.0",
port: int = 8000,
extra_routes: list[ExtraRoute] | None = None,
):
self.agent = agent

self.host = host
self.port = port

self.app = FastAPI()
self.extra_routes = extra_routes

self.app = FastAPI(
openapi_url=None,
docs_url=None,
redoc_url=None,
swagger_ui_oauth2_redirect_url=None,
)

self.artifact_service = InMemoryArtifactService()

Expand Down Expand Up @@ -215,7 +229,8 @@ def _get_session_service(websocket_id: str) -> InMemorySessionService:
"""Get session service for the websocket client."""
if websocket_id not in self.ws_session_service_mgr:
raise HTTPException(
status_code=404, detail=f"WebSocket client {websocket_id} not found"
status_code=404,
detail=f"WebSocket client {websocket_id} not found",
)
return self.ws_session_service_mgr[websocket_id]

Expand Down Expand Up @@ -276,7 +291,9 @@ async def create_session_with_id(
return session

@self.app.post("/run_sse")
async def run_agent_sse(req: RunAgentRequestWithWsId) -> StreamingResponse:
async def run_agent_sse(
req: RunAgentRequestWithWsId,
) -> StreamingResponse:
"""Run agent with SSE streaming."""
session_service = _get_session_service(req.websocket_id)

Expand Down Expand Up @@ -337,7 +354,10 @@ async def event_generator():
content_event.actions.artifact_delta = {}
artifact_event = event.model_copy(deep=True)
artifact_event.content = None
events_to_stream = [content_event, artifact_event]
events_to_stream = [
content_event,
artifact_event,
]

for event_to_stream in events_to_stream:
sse_event = event_to_stream.model_dump_json(
Expand All @@ -354,6 +374,14 @@ async def event_generator():
media_type="text/event-stream",
)

if self.extra_routes:
for route in self.extra_routes:
self.app.add_api_route(
path=route.path,
endpoint=route.endpoint,
methods=route.methods,
)

# build the fake MPC server,
# and intercept all requests to the client websocket client.
# NOTE: This catch-all route must be defined LAST
Expand Down