1+ import { AnyZodObject , ZodLiteral , ZodObject , z } from "zod" ;
12import {
23 ErrorCode ,
34 JSONRPCError ,
@@ -25,9 +26,6 @@ export type ProgressCallback = (progress: Progress) => void;
2526 * features like request/response linking, notifications, and progress.
2627 */
2728export class Protocol <
28- ReceiveRequestT extends Request ,
29- ReceiveNotificationT extends Notification ,
30- ReceiveResultT extends Result ,
3129 SendRequestT extends Request ,
3230 SendNotificationT extends Notification ,
3331 SendResultT extends Result ,
@@ -36,15 +34,15 @@ export class Protocol<
3634 private _requestMessageId = 0 ;
3735 private _requestHandlers : Map <
3836 string ,
39- ( request : ReceiveRequestT ) => Promise < SendResultT >
37+ ( request : JSONRPCRequest ) => Promise < SendResultT >
4038 > = new Map ( ) ;
4139 private _notificationHandlers : Map <
4240 string ,
43- ( notification : ReceiveNotificationT ) => Promise < void >
41+ ( notification : JSONRPCNotification ) => Promise < void >
4442 > = new Map ( ) ;
4543 private _responseHandlers : Map <
4644 number ,
47- ( response : ReceiveResultT | Error ) => void
45+ ( response : JSONRPCResponse | Error ) => void
4846 > = new Map ( ) ;
4947 private _progressHandlers : Map < number , ProgressCallback > = new Map ( ) ;
5048
@@ -65,25 +63,20 @@ export class Protocol<
6563 /**
6664 * A handler to invoke for any request types that do not have their own handler installed.
6765 */
68- fallbackRequestHandler ?: ( request : ReceiveRequestT ) => Promise < SendResultT > ;
66+ fallbackRequestHandler ?: ( request : Request ) => Promise < SendResultT > ;
6967
7068 /**
7169 * A handler to invoke for any notification types that do not have their own handler installed.
7270 */
73- fallbackNotificationHandler ?: (
74- notification : ReceiveNotificationT ,
75- ) => Promise < void > ;
71+ fallbackNotificationHandler ?: ( notification : Notification ) => Promise < void > ;
7672
7773 constructor ( ) {
78- this . setNotificationHandler (
79- ProgressNotificationSchema . shape . method . value ,
80- ( notification ) => {
81- this . _onprogress ( notification as unknown as ProgressNotification ) ;
82- } ,
83- ) ;
74+ this . setNotificationHandler ( ProgressNotificationSchema , ( notification ) => {
75+ this . _onprogress ( notification as unknown as ProgressNotification ) ;
76+ } ) ;
8477
8578 this . setRequestHandler (
86- PingRequestSchema . shape . method . value ,
79+ PingRequestSchema ,
8780 // Automatic pong by default.
8881 ( _request ) => ( { } ) as SendResultT ,
8982 ) ;
@@ -106,11 +99,11 @@ export class Protocol<
10699
107100 this . _transport . onmessage = ( message ) => {
108101 if ( ! ( "method" in message ) ) {
109- this . _onresponse ( message as JSONRPCResponse | JSONRPCError ) ;
102+ this . _onresponse ( message ) ;
110103 } else if ( "id" in message ) {
111- this . _onrequest ( message as JSONRPCRequest ) ;
104+ this . _onrequest ( message ) ;
112105 } else {
113- this . _onnotification ( message as JSONRPCNotification ) ;
106+ this . _onnotification ( message ) ;
114107 }
115108 } ;
116109 }
@@ -142,7 +135,7 @@ export class Protocol<
142135 return ;
143136 }
144137
145- handler ( notification as unknown as ReceiveNotificationT ) . catch ( ( error ) =>
138+ handler ( notification ) . catch ( ( error ) =>
146139 this . _onerror (
147140 new Error ( `Uncaught error in notification handler: ${ error } ` ) ,
148141 ) ,
@@ -171,7 +164,7 @@ export class Protocol<
171164 return ;
172165 }
173166
174- handler ( request as unknown as ReceiveRequestT )
167+ handler ( request )
175168 . then (
176169 ( result ) => {
177170 this . _transport ?. send ( {
@@ -228,7 +221,7 @@ export class Protocol<
228221 this . _responseHandlers . delete ( Number ( messageId ) ) ;
229222 this . _progressHandlers . delete ( Number ( messageId ) ) ;
230223 if ( "result" in response ) {
231- handler ( response . result as ReceiveResultT ) ;
224+ handler ( response ) ;
232225 } else {
233226 const error = new McpError (
234227 response . error . code ,
@@ -255,11 +248,11 @@ export class Protocol<
255248 *
256249 * Do not use this method to emit notifications! Use notification() instead.
257250 */
258- // TODO: This could infer a better response type based on the method
259- request (
251+ request < T extends AnyZodObject > (
260252 request : SendRequestT ,
253+ resultSchema : T ,
261254 onprogress ?: ProgressCallback ,
262- ) : Promise < ReceiveResultT > {
255+ ) : Promise < z . infer < T > > {
263256 return new Promise ( ( resolve , reject ) => {
264257 if ( ! this . _transport ) {
265258 reject ( new Error ( "Not connected" ) ) ;
@@ -283,9 +276,14 @@ export class Protocol<
283276
284277 this . _responseHandlers . set ( messageId , ( response ) => {
285278 if ( response instanceof Error ) {
286- reject ( response ) ;
287- } else {
288- resolve ( response ) ;
279+ return reject ( response ) ;
280+ }
281+
282+ try {
283+ const result = resultSchema . parse ( response . result ) ;
284+ resolve ( result ) ;
285+ } catch ( error ) {
286+ reject ( error ) ;
289287 }
290288 } ) ;
291289
@@ -314,13 +312,16 @@ export class Protocol<
314312 *
315313 * Note that this will replace any previous request handler for the same method.
316314 */
317- // TODO: This could infer a better request type based on the method.
318- setRequestHandler (
319- method : string ,
320- handler : ( request : ReceiveRequestT ) => SendResultT | Promise < SendResultT > ,
315+ setRequestHandler <
316+ T extends ZodObject < {
317+ method : ZodLiteral < string > ;
318+ } > ,
319+ > (
320+ requestSchema : T ,
321+ handler : ( request : z . infer < T > ) => SendResultT | Promise < SendResultT > ,
321322 ) : void {
322- this . _requestHandlers . set ( method , ( request ) =>
323- Promise . resolve ( handler ( request ) ) ,
323+ this . _requestHandlers . set ( requestSchema . shape . method . value , ( request ) =>
324+ Promise . resolve ( handler ( requestSchema . parse ( request ) ) ) ,
324325 ) ;
325326 }
326327
@@ -336,13 +337,18 @@ export class Protocol<
336337 *
337338 * Note that this will replace any previous notification handler for the same method.
338339 */
339- // TODO: This could infer a better notification type based on the method.
340- setNotificationHandler < T extends ReceiveNotificationT > (
341- method : string ,
342- handler : ( notification : T ) => void | Promise < void > ,
340+ setNotificationHandler <
341+ T extends ZodObject < {
342+ method : ZodLiteral < string > ;
343+ } > ,
344+ > (
345+ notificationSchema : T ,
346+ handler : ( notification : z . infer < T > ) => void | Promise < void > ,
343347 ) : void {
344- this . _notificationHandlers . set ( method , ( notification ) =>
345- Promise . resolve ( handler ( notification as T ) ) ,
348+ this . _notificationHandlers . set (
349+ notificationSchema . shape . method . value ,
350+ ( notification ) =>
351+ Promise . resolve ( handler ( notificationSchema . parse ( notification ) ) ) ,
346352 ) ;
347353 }
348354
0 commit comments