diff --git a/.gitignore b/.gitignore index 2980152..df5f375 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .venv/ node_modules/ src/vendor/ -.vscode/ \ No newline at end of file +.vscode/ +.wrangler/ \ No newline at end of file diff --git a/README.md b/README.md index 0c2ea4e..8ceb50c 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,13 @@ -# Vendoring Packages: FastAPI + Jinja2 Example +# Python Workers: FastAPI-MCP Example -*Note: You must have Python Packages enabled on your account for built-in packages to work. Request Access to our Closed Beta using [This Form](https://forms.gle/FcjjhV3YtPyjRPaL8)* - -This is an example of a Python Worker that uses a built-in package (FastAPI) with a vendored package (Jinja2). +This is an example of a Python Worker that uses the FastAPI-MCP package. ## Adding Packages -Built-in packages can be selected from [this list](https://developers.cloudflare.com/workers/languages/python/packages/#supported-packages) and added to your `requirements.txt` file. These can be used with no other explicit install step. - Vendored packages are added to your source files and need to be installed in a special manner. The Python Workers team plans to make this process automatic in the future, but for now, manual steps need to be taken. ### Vendoring Packages -[//]: # (NOTE: when updating the instructions below, be sure to also update the vendoring.yml CI workflow) - First, install Python3.12 and pip for Python 3.12. *Currently, other versions of Python will not work - use 3.12!* @@ -30,34 +24,11 @@ Within our virtual environment, install the pyodide CLI: .venv/bin/pyodide venv .venv-pyodide ``` -Next, add packages to your vendor.txt file. Here we'll add jinja2 -``` -jinja2 -``` - -Lastly, add these packages to your source files at `src/vendor`. For any additional packages, re-run this command. +Lastly, download the vendored packages. For any additional packages, re-run this command. ```console .venv-pyodide/bin/pip install -t src/vendor -r vendor.txt ``` -### Using Vendored packages - -In your wrangler.toml, make the vendor directory available: - -```toml -[[rules]] -globs = ["vendor/**"] -type = "Data" -fallthrough = true -``` - -Now, you can import and use the packages: - -```python -import jinja2 -# ... etc ... -``` - ### Developing and Deploying To develop your Worker, run `npx wrangler@latest dev`. diff --git a/src/asgi.py b/src/asgi.py index 804b98e..f3220c7 100644 --- a/src/asgi.py +++ b/src/asgi.py @@ -1,159 +1,6 @@ -from asyncio import Future, Event, Queue, ensure_future, sleep, create_task +from asyncio import Event, Future, Queue, create_task, ensure_future, sleep from contextlib import contextmanager from inspect import isawaitable -import typing - -if typing.TYPE_CHECKING: - from typing import ( - Any, - Callable, - Literal, - Optional, - Protocol, - TypedDict, - Union, - NotRequired, - ) - from collections.abc import Awaitable, Iterable, MutableMapping - - class HTTPRequestEvent(TypedDict): - type: Literal["http.request"] - body: bytes - more_body: bool - - class HTTPResponseDebugEvent(TypedDict): - type: Literal["http.response.debug"] - info: dict[str, object] - - class HTTPResponseStartEvent(TypedDict): - type: Literal["http.response.start"] - status: int - headers: NotRequired[Iterable[tuple[bytes, bytes]]] - trailers: NotRequired[bool] - - class HTTPResponseBodyEvent(TypedDict): - type: Literal["http.response.body"] - body: bytes - more_body: NotRequired[bool] - - class HTTPResponseTrailersEvent(TypedDict): - type: Literal["http.response.trailers"] - headers: Iterable[tuple[bytes, bytes]] - more_trailers: bool - - class HTTPServerPushEvent(TypedDict): - type: Literal["http.response.push"] - path: str - headers: Iterable[tuple[bytes, bytes]] - - class HTTPDisconnectEvent(TypedDict): - type: Literal["http.disconnect"] - - class WebSocketConnectEvent(TypedDict): - type: Literal["websocket.connect"] - - class WebSocketAcceptEvent(TypedDict): - type: Literal["websocket.accept"] - subprotocol: NotRequired[str | None] - headers: NotRequired[Iterable[tuple[bytes, bytes]]] - - class _WebSocketReceiveEventBytes(TypedDict): - type: Literal["websocket.receive"] - bytes: bytes - text: NotRequired[None] - - class _WebSocketReceiveEventText(TypedDict): - type: Literal["websocket.receive"] - bytes: NotRequired[None] - text: str - - WebSocketReceiveEvent = Union[ - _WebSocketReceiveEventBytes, _WebSocketReceiveEventText - ] - - class _WebSocketSendEventBytes(TypedDict): - type: Literal["websocket.send"] - bytes: bytes - text: NotRequired[None] - - class _WebSocketSendEventText(TypedDict): - type: Literal["websocket.send"] - bytes: NotRequired[None] - text: str - - WebSocketSendEvent = Union[_WebSocketSendEventBytes, _WebSocketSendEventText] - - class WebSocketResponseStartEvent(TypedDict): - type: Literal["websocket.http.response.start"] - status: int - headers: Iterable[tuple[bytes, bytes]] - - class WebSocketResponseBodyEvent(TypedDict): - type: Literal["websocket.http.response.body"] - body: bytes - more_body: NotRequired[bool] - - class WebSocketDisconnectEvent(TypedDict): - type: Literal["websocket.disconnect"] - code: int - reason: NotRequired[str | None] - - class WebSocketCloseEvent(TypedDict): - type: Literal["websocket.close"] - code: NotRequired[int] - reason: NotRequired[str | None] - - class LifespanStartupEvent(TypedDict): - type: Literal["lifespan.startup"] - - class LifespanShutdownEvent(TypedDict): - type: Literal["lifespan.shutdown"] - - class LifespanStartupCompleteEvent(TypedDict): - type: Literal["lifespan.startup.complete"] - - class LifespanStartupFailedEvent(TypedDict): - type: Literal["lifespan.startup.failed"] - message: str - - class LifespanShutdownCompleteEvent(TypedDict): - type: Literal["lifespan.shutdown.complete"] - - class LifespanShutdownFailedEvent(TypedDict): - type: Literal["lifespan.shutdown.failed"] - message: str - - WebSocketEvent = Union[ - WebSocketReceiveEvent, WebSocketDisconnectEvent, WebSocketConnectEvent - ] - - ASGIReceiveEvent = Union[ - HTTPRequestEvent, - HTTPDisconnectEvent, - WebSocketConnectEvent, - WebSocketReceiveEvent, - WebSocketDisconnectEvent, - LifespanStartupEvent, - LifespanShutdownEvent, - ] - - ASGISendEvent = Union[ - HTTPResponseStartEvent, - HTTPResponseBodyEvent, - HTTPResponseTrailersEvent, - HTTPServerPushEvent, - HTTPDisconnectEvent, - WebSocketAcceptEvent, - WebSocketSendEvent, - WebSocketResponseStartEvent, - WebSocketResponseBodyEvent, - WebSocketCloseEvent, - LifespanStartupCompleteEvent, - LifespanStartupFailedEvent, - LifespanShutdownCompleteEvent, - LifespanShutdownFailedEvent, - ] - ASGI = {"spec_version": "2.0", "version": "3.0"} @@ -251,8 +98,9 @@ async def send(got): return shutdown -async def process_request(app, req, env): +async def process_request(app, req, env, ctx): from js import Object, Response, TransformStream + from pyodide.ffi import create_proxy status = None @@ -274,14 +122,12 @@ async def process_request(app, req, env): await receive_queue.put({"body": b"", "more_body": False, "type": "http.request"}) async def receive(): - print("Receiving") message = None if not receive_queue.empty(): message = await receive_queue.get() else: await finished_response.wait() message = {"type": "http.disconnect"} - print(f"Received {message}") return message # Create a transform stream for handling streaming responses @@ -290,12 +136,11 @@ async def receive(): writable = transform_stream.writable writer = writable.getWriter() - async def send(got: "ASGISendEvent"): + async def send(got): nonlocal status nonlocal headers nonlocal is_sse - print(got) if got["type"] == "http.response.start": status = got["status"] # Like above, we need to convert byte-pairs into string explicitly. @@ -305,20 +150,18 @@ async def send(got: "ASGISendEvent"): if k.lower() == "content-type" and v.lower().startswith( "text/event-stream" ): - print("SSE RESPONSE") is_sse = True - - # For SSE, create and return the response immediately after http.response.start - resp = Response.new( - readable, headers=Object.fromEntries(headers), status=status - ) - result.set_result(resp) break + if is_sse: + # For SSE, create and return the response immediately after http.response.start + resp = Response.new( + readable, headers=Object.fromEntries(headers), status=status + ) + result.set_result(resp) elif got["type"] == "http.response.body": body = got["body"] more_body = got.get("more_body", False) - print(f"{body=}, {more_body=}") # Convert body to JS buffer px = create_proxy(body) @@ -337,6 +180,7 @@ async def send(got: "ASGISendEvent"): buf.data, headers=Object.fromEntries(headers), status=status ) result.set_result(resp) + await writer.close() finished_response.set() # Run the application in the background to handle SSE @@ -346,17 +190,13 @@ async def run_app(): # If we get here and no response has been set yet, the app didn't generate a response if not result.done(): - await writer.close() # Close the writer - finished_response.set() - result.set_exception( - RuntimeError("The application did not generate a response") - ) + raise RuntimeError("The application did not generate a response") # noqa: TRY301 except Exception as e: # Handle any errors in the application if not result.done(): + result.set_exception(e) await writer.close() # Close the writer finished_response.set() - result.set_exception(e) # Create task to run the application in the background app_task = create_task(run_app()) @@ -367,7 +207,13 @@ async def run_app(): # For non-SSE responses, we need to wait for the application to complete if not is_sse: await app_task - print(f"Returning response! {is_sse}") + else: # noqa: PLR5501 + if ctx is not None: + ctx.waitUntil(create_proxy(app_task)) + else: + raise RuntimeError( + "Server-Side-Events require ctx to be passed to asgi.fetch" + ) return response @@ -423,9 +269,9 @@ async def ws_receive(): return Response.new(None, status=101, webSocket=client) -async def fetch(app, req, env): +async def fetch(app, req, env, ctx=None): shutdown = await start_application(app) - result = await process_request(app, req, env) + result = await process_request(app, req, env, ctx) await shutdown() return result diff --git a/src/exceptions.py b/src/exceptions.py index 90c71a6..5421d9d 100644 --- a/src/exceptions.py +++ b/src/exceptions.py @@ -1,8 +1,10 @@ +from logger import logger from starlette.exceptions import HTTPException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response async def http_exception(request: Request, exc: Exception) -> Response: assert isinstance(exc, HTTPException) + logger.exception(exc) if exc.status_code in {204, 304}: return Response(status_code=exc.status_code, headers=exc.headers) return PlainTextResponse(exc.detail, status_code=exc.status_code, headers=exc.headers) diff --git a/src/httpx_patch.py b/src/httpx_patch.py new file mode 100644 index 0000000..877b461 --- /dev/null +++ b/src/httpx_patch.py @@ -0,0 +1,11 @@ +from httpx._transports.jsfetch import AsyncJavascriptFetchTransport + +orig_handle_async_request = AsyncJavascriptFetchTransport.handle_async_request + +async def handle_async_request(self, request): + response = await orig_handle_async_request(self, request) + # fix content-encoding headers because the javascript fetch handles that + response.headers.update({"content-encoding": "identity"}) + return response + +AsyncJavascriptFetchTransport.handle_async_request = handle_async_request \ No newline at end of file diff --git a/src/logger.py b/src/logger.py new file mode 100644 index 0000000..b4bf7c0 --- /dev/null +++ b/src/logger.py @@ -0,0 +1,39 @@ +import structlog +import logging +import sys + +# Create two handlers - one for stdout and one for stderr +stdout_handler = logging.StreamHandler(sys.stdout) +stderr_handler = logging.StreamHandler(sys.stderr) + +# Configure stdout handler to only handle INFO and DEBUG +stdout_handler.setLevel(logging.DEBUG) +stdout_handler.addFilter(lambda record: record.levelno <= logging.INFO) + +# Configure stderr handler to only handle WARNING and above +stderr_handler.setLevel(logging.WARNING) + +# Get the root logger and add both handlers +root_logger = logging.getLogger() +root_logger.setLevel(logging.INFO) # Allow all logs to pass through +root_logger.addHandler(stdout_handler) +root_logger.addHandler(stderr_handler) + +structlog.configure( + processors=[ + structlog.stdlib.filter_by_level, + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + structlog.processors.EventRenamer("message"), + structlog.processors.StackInfoRenderer(), + structlog.processors.ExceptionRenderer(), + structlog.processors.JSONRenderer(), + ], + context_class=dict, + logger_factory=structlog.stdlib.LoggerFactory(), + cache_logger_on_first_use=True +) + +# Get a logger +logger: structlog.stdlib.BoundLogger = structlog.get_logger() diff --git a/src/uvicorn.py b/src/uvicorn.py new file mode 100644 index 0000000..ecaf15a --- /dev/null +++ b/src/uvicorn.py @@ -0,0 +1,5 @@ +# This file must exist as a hack to satisfy mcp. +# mcp has an optional dependency on uvicorn but still imports it at the top scope, see: +# https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/server/fastmcp/server.py#L18 +# Because we never call `run_sse_async` this is not required. However, Python workers used asgi.py +# rather than uvicorn which is why this hack is needed. With this, the import succeeds. \ No newline at end of file diff --git a/src/worker.py b/src/worker.py index 2ff853c..88a9cd3 100644 --- a/src/worker.py +++ b/src/worker.py @@ -1,81 +1,70 @@ -# from pydantic import BaseModel, create_model -# from fastapi import FastAPI, Request -# from fastapi_mcp import FastApiMCP +from workers import DurableObject +from logger import logger +import sys +import httpx_patch # noqa: F401 +sys.path.insert(0, "/session/metadata/vendor") +sys.path.insert(0, "/session/metadata") + + +def setup_server(env): + from fastapi import FastAPI, Request + from pydantic import BaseModel + from fastapi_mcp import FastApiMCP + from exceptions import HTTPException, http_exception + + app = FastAPI() + app.add_exception_handler(HTTPException, http_exception) + + mcp = FastApiMCP(app) + + # Mount the MCP server directly to your FastAPI app + mcp.mount() + # Auto-generated operation_id (something like "read_user_users__user_id__get") + @app.get("/") + async def root(): + return {"message": "Hello, World!"} + + @app.get("/env") + async def root(): + return {"message": "Here is an example of getting an environment variable: " + env.MESSAGE} + + class Item(BaseModel): + name: str + description: str | None = None + price: float + tax: float | None = None + + @app.post("/items/") + async def create_item(item: Item): + return item + + @app.put("/items/{item_id}") + async def create_item(item_id: int, item: Item, q: str | None = None): + result = {"item_id": item_id, **item.dict()} + if q: + result.update({"q": q}) + return result + + @app.get("/items/{item_id}") + async def read_item(item_id: int): + return {"item_id": item_id} + + mcp.setup_server() + return mcp, app + + +class FastMCPServer(DurableObject): + def __init__(self, ctx, env): + self.ctx = ctx + self.env = env + self.mcp, self.app = setup_server(self.env) + + async def call(self, request): + import asgi + return await asgi.fetch(self.app, request, self.env, self.ctx) -############## NORMAL ############## - -# from js import Response - -# async def on_fetch(request, env): -# return Response.new("Hello") - -#################################### - -############## FASTMCP ############## -from exceptions import HTTPException, http_exception -from mcp.server.fastmcp import FastMCP - -mcp = FastMCP("Demo") - -@mcp.tool() -def add(a: int, b: int) -> int: - """Add two numbers""" - return a + b - -@mcp.resource("greeting://{name}") -def get_greeting(name: str) -> str: - """Get a personalized greeting""" - return f"Hello, {name}!" - -@mcp.tool() -def calculate_bmi(weight_kg: float, height_m: float) -> float: - """Calculate BMI given weight in kg and height in meters""" - return weight_kg / (height_m**2) - -@mcp.prompt() -def echo_prompt(message: str) -> str: - """Create an echo prompt""" - return f"Please process this message: {message}" - -# mcp depends on uvicorn and imports it at the top scope that we have to patch that to move the -# import into the function that uses it. -# TODO(now): Change uvicorn to optional in mcp -app = mcp.sse_app() -# Starlette default http exception handler is sync which starlette tries to run in threadpool -# in https://github.com/encode/starlette/blob/master/starlette/_exception_handler.py#L61. -# Since we don't support threads we need to override it with the same function but async. -# TODO(now): change starlette's http_exception to be async, it is strictly slower to spawn a new -# thread -app.add_exception_handler(HTTPException, http_exception) async def on_fetch(request, env): - import asgi - return await asgi.fetch(app, request, env) - -##################################### - -############## FASTAPI ############## -# from fastapi import FastAPI, Request - -# app = FastAPI() - - -# @app.get("/") -# async def root(): -# message = "This is an example of FastAPI" -# return {"message": message} - - -# @app.get("/env") -# async def env(req: Request): -# env = req.scope["env"] -# return { -# "message": "Here is an example of getting an environment variable: " -# + env.MESSAGE -# } - -# async def on_fetch(request, env): -# import asgi -# return await asgi.fetch(app, request, env) - -##################################### + id = env.ns.idFromName("A") + obj = env.ns.get(id) + return await obj.call(request) diff --git a/vendor.txt b/vendor.txt index fd4697a..13cdb40 100644 --- a/vendor.txt +++ b/vendor.txt @@ -1 +1,4 @@ -mcp \ No newline at end of file +fastapi-mcp +fastapi +pydantic +structlog diff --git a/wrangler.jsonc b/wrangler.jsonc index 32fd0f3..63fd1de 100644 --- a/wrangler.jsonc +++ b/wrangler.jsonc @@ -2,11 +2,11 @@ "name": "fastapi-worker", "main": "src/worker.py", "compatibility_flags": [ - "python_workers" + "python_workers", ], "compatibility_date": "2025-04-10", "vars": { - "API_HOST": "example.com" + "MESSAGE": "hello world" }, "rules": [ { @@ -16,5 +16,24 @@ "type": "Data", "fallthrough": true } - ] + ], + "durable_objects": { + "bindings": [ + { + "name": "ns", + "class_name": "FastMCPServer" + } + ] + }, + "migrations": [ + { + "tag": "v1", + "new_sqlite_classes": [ + "FastMCPServer" + ] + } + ], + "observability": { + "enabled": true + } } \ No newline at end of file