From 6988e5b829211d6455975dd9893569efbe0841b3 Mon Sep 17 00:00:00 2001 From: Volodymyr Panivko Date: Mon, 29 Dec 2025 16:04:50 +0100 Subject: [PATCH] Add Middleware handlers to StreamableHttpTransport --- composer.json | 2 + docs/transports.md | 40 +++++ .../Transport/StreamableHttpTransport.php | 94 ++++++++-- .../Transport/StreamableHttpTransportTest.php | 164 ++++++++++++++++++ 4 files changed, 283 insertions(+), 17 deletions(-) create mode 100644 tests/Unit/Server/Transport/StreamableHttpTransportTest.php diff --git a/composer.json b/composer.json index 83a08f39..b7d2483f 100644 --- a/composer.json +++ b/composer.json @@ -28,6 +28,8 @@ "psr/event-dispatcher": "^1.0", "psr/http-factory": "^1.1", "psr/http-message": "^1.1 || ^2.0", + "psr/http-server-handler": "^1.0", + "psr/http-server-middleware": "^1.0", "psr/log": "^1.0 || ^2.0 || ^3.0", "symfony/finder": "^5.4 || ^6.4 || ^7.3 || ^8.0", "symfony/uid": "^5.4 || ^6.4 || ^7.3 || ^8.0" diff --git a/docs/transports.md b/docs/transports.md index 290fd49c..a68875d9 100644 --- a/docs/transports.md +++ b/docs/transports.md @@ -179,6 +179,46 @@ Default CORS headers: - `Access-Control-Allow-Methods: GET, POST, DELETE, OPTIONS` - `Access-Control-Allow-Headers: Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept` +### PSR-15 Middleware + +`StreamableHttpTransport` can run a PSR-15 middleware chain before it processes the request. Middleware can log, +enforce auth, or short-circuit with a response for any HTTP method. + +```php +use Mcp\Server\Transport\StreamableHttpTransport; +use Psr\Http\Message\ResponseFactoryInterface; +use Psr\Http\Message\ServerRequestInterface; +use Psr\Http\Server\MiddlewareInterface; +use Psr\Http\Server\RequestHandlerInterface; + +final class AuthMiddleware implements MiddlewareInterface +{ + public function __construct(private ResponseFactoryInterface $responses) + { + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler) + { + if (!$request->hasHeader('Authorization')) { + return $this->responses->createResponse(401); + } + + return $handler->handle($request); + } +} + +$transport = new StreamableHttpTransport( + $request, + $responseFactory, + $streamFactory, + [], + $logger, + [new AuthMiddleware($responseFactory)], +); +``` + +If middleware returns a response, the transport will still ensure CORS headers are present unless you set them yourself. + ### Architecture The HTTP transport doesn't run its own web server. Instead, it processes PSR-7 requests and returns PSR-7 responses that diff --git a/src/Server/Transport/StreamableHttpTransport.php b/src/Server/Transport/StreamableHttpTransport.php index 2b1e6869..73c491d8 100644 --- a/src/Server/Transport/StreamableHttpTransport.php +++ b/src/Server/Transport/StreamableHttpTransport.php @@ -17,6 +17,8 @@ use Psr\Http\Message\ResponseInterface; use Psr\Http\Message\ServerRequestInterface; use Psr\Http\Message\StreamFactoryInterface; +use Psr\Http\Server\MiddlewareInterface; +use Psr\Http\Server\RequestHandlerInterface; use Psr\Log\LoggerInterface; use Symfony\Component\Uid\Uuid; @@ -36,19 +38,22 @@ class StreamableHttpTransport extends BaseTransport /** @var array */ private array $corsHeaders; + /** @var list */ + private array $middlewares = []; + /** * @param array $corsHeaders + * @param iterable $middlewares */ public function __construct( - private readonly ServerRequestInterface $request, + private ServerRequestInterface $request, ?ResponseFactoryInterface $responseFactory = null, ?StreamFactoryInterface $streamFactory = null, array $corsHeaders = [], ?LoggerInterface $logger = null, + iterable $middlewares = [], ) { parent::__construct($logger); - $sessionIdString = $this->request->getHeaderLine('Mcp-Session-Id'); - $this->sessionId = $sessionIdString ? Uuid::fromString($sessionIdString) : null; $this->responseFactory = $responseFactory ?? Psr17FactoryDiscovery::findResponseFactory(); $this->streamFactory = $streamFactory ?? Psr17FactoryDiscovery::findStreamFactory(); @@ -59,6 +64,13 @@ public function __construct( 'Access-Control-Allow-Headers' => 'Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept', 'Access-Control-Expose-Headers' => 'Mcp-Session-Id', ], $corsHeaders); + + foreach ($middlewares as $middleware) { + if (!$middleware instanceof MiddlewareInterface) { + throw new \InvalidArgumentException('Streamable HTTP middleware must implement Psr\\Http\\Server\\MiddlewareInterface.'); + } + $this->middlewares[] = $middleware; + } } public function send(string $data, array $context): void @@ -69,17 +81,15 @@ public function send(string $data, array $context): void public function listen(): ResponseInterface { - return match ($this->request->getMethod()) { - 'OPTIONS' => $this->handleOptionsRequest(), - 'POST' => $this->handlePostRequest(), - 'DELETE' => $this->handleDeleteRequest(), - default => $this->createErrorResponse(Error::forInvalidRequest('Method Not Allowed'), 405), - }; + $handler = $this->createRequestHandler(); + $response = $handler->handle($this->request); + + return $this->withCorsHeaders($response); } protected function handleOptionsRequest(): ResponseInterface { - return $this->withCorsHeaders($this->responseFactory->createResponse(204)); + return $this->responseFactory->createResponse(204); } protected function handlePostRequest(): ResponseInterface @@ -92,7 +102,7 @@ protected function handlePostRequest(): ResponseInterface ->withHeader('Content-Type', 'application/json') ->withBody($this->streamFactory->createStream($this->immediateResponse)); - return $this->withCorsHeaders($response); + return $response; } if (null !== $this->sessionFiber) { @@ -112,7 +122,7 @@ protected function handleDeleteRequest(): ResponseInterface $this->handleSessionEnd($this->sessionId); - return $this->withCorsHeaders($this->responseFactory->createResponse(200)); + return $this->responseFactory->createResponse(200); } protected function createJsonResponse(): ResponseInterface @@ -120,7 +130,7 @@ protected function createJsonResponse(): ResponseInterface $outgoingMessages = $this->getOutgoingMessages($this->sessionId); if (empty($outgoingMessages)) { - return $this->withCorsHeaders($this->responseFactory->createResponse(202)); + return $this->responseFactory->createResponse(202); } $messages = array_column($outgoingMessages, 'message'); @@ -134,7 +144,7 @@ protected function createJsonResponse(): ResponseInterface $response = $response->withHeader('Mcp-Session-Id', $this->sessionId->toRfc4122()); } - return $this->withCorsHeaders($response); + return $response; } protected function createStreamedResponse(): ResponseInterface @@ -201,7 +211,7 @@ protected function createStreamedResponse(): ResponseInterface $response = $response->withHeader('Mcp-Session-Id', $this->sessionId->toRfc4122()); } - return $this->withCorsHeaders($response); + return $response; } protected function handleFiberTermination(): void @@ -242,15 +252,65 @@ protected function createErrorResponse(Error $jsonRpcError, int $statusCode): Re ->withHeader('Content-Type', 'application/json') ->withBody($this->streamFactory->createStream($payload)); - return $this->withCorsHeaders($response); + return $response; } protected function withCorsHeaders(ResponseInterface $response): ResponseInterface { foreach ($this->corsHeaders as $name => $value) { - $response = $response->withHeader($name, $value); + if (!$response->hasHeader($name)) { + $response = $response->withHeader($name, $value); + } } return $response; } + + private function handleRequest(ServerRequestInterface $request): ResponseInterface + { + $this->request = $request; + $sessionIdString = $request->getHeaderLine('Mcp-Session-Id'); + $this->sessionId = $sessionIdString ? Uuid::fromString($sessionIdString) : null; + + return match ($request->getMethod()) { + 'OPTIONS' => $this->handleOptionsRequest(), + 'POST' => $this->handlePostRequest(), + 'DELETE' => $this->handleDeleteRequest(), + default => $this->createErrorResponse(Error::forInvalidRequest('Method Not Allowed'), 405), + }; + } + + private function createRequestHandler(): RequestHandlerInterface + { + /** + * @see self::handleRequest + */ + $handler = new class(\Closure::fromCallable([$this, 'handleRequest'])) implements RequestHandlerInterface { + public function __construct(private \Closure $handler) + { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + return ($this->handler)($request); + } + }; + + foreach (array_reverse($this->middlewares) as $middleware) { + $handler = new class($middleware, $handler) implements RequestHandlerInterface { + public function __construct( + private MiddlewareInterface $middleware, + private RequestHandlerInterface $handler, + ) { + } + + public function handle(ServerRequestInterface $request): ResponseInterface + { + return $this->middleware->process($request, $this->handler); + } + }; + } + + return $handler; + } } diff --git a/tests/Unit/Server/Transport/StreamableHttpTransportTest.php b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php new file mode 100644 index 00000000..51af6272 --- /dev/null +++ b/tests/Unit/Server/Transport/StreamableHttpTransportTest.php @@ -0,0 +1,164 @@ + ['GET', false, 401]; + yield 'POST (middleware returns 401)' => ['POST', false, 401]; + yield 'DELETE (middleware returns 401)' => ['DELETE', false, 401]; + yield 'OPTIONS (middleware delegates -> transport handles preflight)' => ['OPTIONS', true, 204]; + yield 'GET (middleware delegates -> transport handles preflight)' => ['GET', true, 405]; + yield 'POST (middleware delegates -> transport handles preflight)' => ['POST', true, 202]; + yield 'DELETE (middleware delegates -> transport handles preflight)' => ['DELETE', true, 400]; + } + + #[DataProvider('corsHeaderProvider')] + #[TestDox('CORS headers are always applied')] + public function testCorsHeader(string $method, bool $middlewareDelegatesToTransport, int $expectedStatusCode): void + { + $factory = new Psr17Factory(); + $request = $factory->createServerRequest($method, 'https://example.com'); + + $middleware = new class($factory, $expectedStatusCode, $middlewareDelegatesToTransport) implements MiddlewareInterface { + public function __construct( + private ResponseFactoryInterface $responseFactory, + private int $expectedStatusCode, + private bool $middlewareDelegatesToTransport, + ) { + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + if ($this->middlewareDelegatesToTransport) { + return $handler->handle($request); + } + + return $this->responseFactory->createResponse($this->expectedStatusCode); + } + }; + + $transport = new StreamableHttpTransport( + $request, + $factory, + $factory, + [], + null, + [$middleware], + ); + + $response = $transport->listen(); + + $this->assertSame($expectedStatusCode, $response->getStatusCode(), $response->getBody()->getContents()); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Origin')); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Methods')); + $this->assertTrue($response->hasHeader('Access-Control-Allow-Headers')); + $this->assertTrue($response->hasHeader('Access-Control-Expose-Headers')); + + $this->assertSame('*', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); + $this->assertSame( + 'Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept', + $response->getHeaderLine('Access-Control-Allow-Headers') + ); + $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); + } + + #[TestDox('transport replaces existing CORS headers on the response')] + public function testCorsHeadersAreReplacedWhenAlreadyPresent(): void + { + $factory = new Psr17Factory(); + $request = $factory->createServerRequest('GET', 'https://example.com'); + + $middleware = new class($factory) implements MiddlewareInterface { + public function __construct(private ResponseFactoryInterface $responses) + { + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + return $this->responses->createResponse(200) + ->withHeader('Access-Control-Allow-Origin', 'https://another.com'); + } + }; + + $transport = new StreamableHttpTransport( + $request, + $factory, + $factory, + [], + null, + [$middleware], + ); + + $response = $transport->listen(); + + $this->assertSame(200, $response->getStatusCode()); + + $this->assertSame('https://another.com', $response->getHeaderLine('Access-Control-Allow-Origin')); + $this->assertSame('GET, POST, DELETE, OPTIONS', $response->getHeaderLine('Access-Control-Allow-Methods')); + $this->assertSame( + 'Content-Type, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID, Authorization, Accept', + $response->getHeaderLine('Access-Control-Allow-Headers') + ); + $this->assertSame('Mcp-Session-Id', $response->getHeaderLine('Access-Control-Expose-Headers')); + } + + #[TestDox('middleware runs before transport handles the request')] + public function testMiddlewareRunsBeforeTransportHandlesRequest(): void + { + $factory = new Psr17Factory(); + $request = $factory->createServerRequest('OPTIONS', 'https://example.com'); + + $state = new \stdClass(); + $state->called = false; + $middleware = new class($state) implements MiddlewareInterface { + public function __construct(private \stdClass $state) + { + } + + public function process(ServerRequestInterface $request, RequestHandlerInterface $handler): ResponseInterface + { + $this->state->called = true; + + return $handler->handle($request); + } + }; + + $transport = new StreamableHttpTransport( + $request, + $factory, + $factory, + [], + null, + [$middleware], + ); + + $response = $transport->listen(); + + $this->assertTrue($state->called); + $this->assertSame(204, $response->getStatusCode()); + } +}