@@ -78,22 +78,52 @@ export type RequestOptions = {
7878 * If not specified, there is no maximum total timeout.
7979 */
8080 maxTotalTimeout ?: number ;
81+
82+ /**
83+ * May be used to indicate to the transport which incoming request to associate this outgoing request with.
84+ */
85+ relatedRequestId ?: RequestId ;
8186} ;
8287
8388/**
84- * Extra data given to request handlers .
89+ * Options that can be given per notification .
8590 */
86- export type RequestHandlerExtra = {
91+ export type NotificationOptions = {
8792 /**
88- * An abort signal used to communicate if the request was cancelled from the sender's side .
93+ * May be used to indicate to the transport which incoming request to associate this outgoing notification with .
8994 */
90- signal : AbortSignal ;
95+ relatedRequestId ?: RequestId ;
96+ }
9197
92- /**
93- * The session ID from the transport, if available.
94- */
95- sessionId ?: string ;
96- } ;
98+ /**
99+ * Extra data given to request handlers.
100+ */
101+ export type RequestHandlerExtra < SendRequestT extends Request ,
102+ SendNotificationT extends Notification > = {
103+ /**
104+ * An abort signal used to communicate if the request was cancelled from the sender's side.
105+ */
106+ signal : AbortSignal ;
107+
108+ /**
109+ * The session ID from the transport, if available.
110+ */
111+ sessionId ?: string ;
112+
113+ /**
114+ * Sends a notification that relates to the current request being handled.
115+ *
116+ * This is used by certain transports to correctly associate related messages.
117+ */
118+ sendNotification : ( notification : SendNotificationT ) => Promise < void > ;
119+
120+ /**
121+ * Sends a request that relates to the current request being handled.
122+ *
123+ * This is used by certain transports to correctly associate related messages.
124+ */
125+ sendRequest : < U extends ZodType < object > > ( request : SendRequestT , resultSchema : U , options ?: RequestOptions ) => Promise < z . infer < U > > ;
126+ } ;
97127
98128/**
99129 * Information about a request's timeout state
@@ -122,7 +152,7 @@ export abstract class Protocol<
122152 string ,
123153 (
124154 request : JSONRPCRequest ,
125- extra : RequestHandlerExtra ,
155+ extra : RequestHandlerExtra < SendRequestT , SendNotificationT > ,
126156 ) => Promise < SendResultT >
127157 > = new Map ( ) ;
128158 private _requestHandlerAbortControllers : Map < RequestId , AbortController > =
@@ -316,9 +346,14 @@ export abstract class Protocol<
316346 this . _requestHandlerAbortControllers . set ( request . id , abortController ) ;
317347
318348 // Create extra object with both abort signal and sessionId from transport
319- const extra : RequestHandlerExtra = {
349+ const extra : RequestHandlerExtra < SendRequestT , SendNotificationT > = {
320350 signal : abortController . signal ,
321351 sessionId : this . _transport ?. sessionId ,
352+ sendNotification :
353+ ( notification ) =>
354+ this . notification ( notification , { relatedRequestId : request . id } ) ,
355+ sendRequest : ( r , resultSchema , options ?) =>
356+ this . request ( r , resultSchema , { ...options , relatedRequestId : request . id } )
322357 } ;
323358
324359 // Starting with Promise.resolve() puts any synchronous errors into the monad as well.
@@ -364,7 +399,7 @@ export abstract class Protocol<
364399 private _onprogress ( notification : ProgressNotification ) : void {
365400 const { progressToken, ...params } = notification . params ;
366401 const messageId = Number ( progressToken ) ;
367-
402+
368403 const handler = this . _progressHandlers . get ( messageId ) ;
369404 if ( ! handler ) {
370405 this . _onerror ( new Error ( `Received a progress notification for an unknown token: ${ JSON . stringify ( notification ) } ` ) ) ;
@@ -373,7 +408,7 @@ export abstract class Protocol<
373408
374409 const responseHandler = this . _responseHandlers . get ( messageId ) ;
375410 const timeoutInfo = this . _timeoutInfo . get ( messageId ) ;
376-
411+
377412 if ( timeoutInfo && responseHandler && timeoutInfo . resetTimeoutOnProgress ) {
378413 try {
379414 this . _resetTimeout ( messageId ) ;
@@ -460,6 +495,8 @@ export abstract class Protocol<
460495 resultSchema : T ,
461496 options ?: RequestOptions ,
462497 ) : Promise < z . infer < T > > {
498+ const { relatedRequestId } = options ?? { } ;
499+
463500 return new Promise ( ( resolve , reject ) => {
464501 if ( ! this . _transport ) {
465502 reject ( new Error ( "Not connected" ) ) ;
@@ -500,7 +537,7 @@ export abstract class Protocol<
500537 requestId : messageId ,
501538 reason : String ( reason ) ,
502539 } ,
503- } )
540+ } , { relatedRequestId } )
504541 . catch ( ( error ) =>
505542 this . _onerror ( new Error ( `Failed to send cancellation: ${ error } ` ) ) ,
506543 ) ;
@@ -538,7 +575,7 @@ export abstract class Protocol<
538575
539576 this . _setupTimeout ( messageId , timeout , options ?. maxTotalTimeout , timeoutHandler , options ?. resetTimeoutOnProgress ?? false ) ;
540577
541- this . _transport . send ( jsonrpcRequest ) . catch ( ( error ) => {
578+ this . _transport . send ( jsonrpcRequest , { relatedRequestId } ) . catch ( ( error ) => {
542579 this . _cleanupTimeout ( messageId ) ;
543580 reject ( error ) ;
544581 } ) ;
@@ -548,7 +585,7 @@ export abstract class Protocol<
548585 /**
549586 * Emits a notification, which is a one-way message that does not expect a response.
550587 */
551- async notification ( notification : SendNotificationT ) : Promise < void > {
588+ async notification ( notification : SendNotificationT , options ?: NotificationOptions ) : Promise < void > {
552589 if ( ! this . _transport ) {
553590 throw new Error ( "Not connected" ) ;
554591 }
@@ -560,7 +597,7 @@ export abstract class Protocol<
560597 jsonrpc : "2.0" ,
561598 } ;
562599
563- await this . _transport . send ( jsonrpcNotification ) ;
600+ await this . _transport . send ( jsonrpcNotification , options ) ;
564601 }
565602
566603 /**
@@ -576,14 +613,15 @@ export abstract class Protocol<
576613 requestSchema : T ,
577614 handler : (
578615 request : z . infer < T > ,
579- extra : RequestHandlerExtra ,
616+ extra : RequestHandlerExtra < SendRequestT , SendNotificationT > ,
580617 ) => SendResultT | Promise < SendResultT > ,
581618 ) : void {
582619 const method = requestSchema . shape . method . value ;
583620 this . assertRequestHandlerCapability ( method ) ;
584- this . _requestHandlers . set ( method , ( request , extra ) =>
585- Promise . resolve ( handler ( requestSchema . parse ( request ) , extra ) ) ,
586- ) ;
621+
622+ this . _requestHandlers . set ( method , ( request , extra ) => {
623+ return Promise . resolve ( handler ( requestSchema . parse ( request ) , extra ) ) ;
624+ } ) ;
587625 }
588626
589627 /**
0 commit comments