@@ -2,6 +2,7 @@ import { createServer, type IncomingMessage, type Server } from "http";
22import { AddressInfo } from "net" ;
33import { JSONRPCMessage } from "../types.js" ;
44import { SSEClientTransport } from "./sse.js" ;
5+ import { auth , OAuthClientProvider } from "./auth.js" ;
56
67describe ( "SSEClientTransport" , ( ) => {
78 let server : Server ;
@@ -284,4 +285,180 @@ describe("SSEClientTransport", () => {
284285 expect ( calledHeaders . get ( "content-type" ) ) . toBe ( "application/json" ) ;
285286 } ) ;
286287 } ) ;
288+
289+ describe ( "auth handling" , ( ) => {
290+ let mockAuthProvider : jest . Mocked < OAuthClientProvider > ;
291+
292+ beforeEach ( ( ) => {
293+ mockAuthProvider = {
294+ get redirectUrl ( ) { return "http://localhost/callback" ; } ,
295+ get clientMetadata ( ) { return { redirect_uris : [ "http://localhost/callback" ] } ; } ,
296+ clientInformation : jest . fn ( ( ) => ( { client_id : "test-client-id" } ) ) ,
297+ tokens : jest . fn ( ) ,
298+ saveTokens : jest . fn ( ) ,
299+ redirectToAuthorization : jest . fn ( ) ,
300+ saveCodeVerifier : jest . fn ( ) ,
301+ codeVerifier : jest . fn ( ) ,
302+ } ;
303+ } ) ;
304+
305+ it ( "attaches auth header from provider on SSE connection" , async ( ) => {
306+ mockAuthProvider . tokens . mockResolvedValue ( {
307+ access_token : "test-token" ,
308+ token_type : "Bearer"
309+ } ) ;
310+
311+ transport = new SSEClientTransport ( baseUrl , {
312+ authProvider : mockAuthProvider ,
313+ } ) ;
314+
315+ await transport . start ( ) ;
316+
317+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
318+ expect ( mockAuthProvider . tokens ) . toHaveBeenCalled ( ) ;
319+ } ) ;
320+
321+ it ( "attaches auth header from provider on POST requests" , async ( ) => {
322+ mockAuthProvider . tokens . mockResolvedValue ( {
323+ access_token : "test-token" ,
324+ token_type : "Bearer"
325+ } ) ;
326+
327+ transport = new SSEClientTransport ( baseUrl , {
328+ authProvider : mockAuthProvider ,
329+ } ) ;
330+
331+ await transport . start ( ) ;
332+
333+ const message : JSONRPCMessage = {
334+ jsonrpc : "2.0" ,
335+ id : "1" ,
336+ method : "test" ,
337+ params : { } ,
338+ } ;
339+
340+ await transport . send ( message ) ;
341+
342+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
343+ expect ( mockAuthProvider . tokens ) . toHaveBeenCalled ( ) ;
344+ } ) ;
345+
346+ it ( "attempts auth flow on 401 during SSE connection" , async ( ) => {
347+ // Create server that returns 401s
348+ server . close ( ) ;
349+ await new Promise ( resolve => server . on ( "close" , resolve ) ) ;
350+
351+ server = createServer ( ( req , res ) => {
352+ lastServerRequest = req ;
353+ if ( req . url !== "/" ) {
354+ res . writeHead ( 404 ) . end ( ) ;
355+ } else {
356+ res . writeHead ( 401 ) . end ( ) ;
357+ }
358+ } ) ;
359+
360+ await new Promise < void > ( resolve => {
361+ server . listen ( 0 , "127.0.0.1" , ( ) => {
362+ const addr = server . address ( ) as AddressInfo ;
363+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
364+ resolve ( ) ;
365+ } ) ;
366+ } ) ;
367+
368+ transport = new SSEClientTransport ( baseUrl , {
369+ authProvider : mockAuthProvider ,
370+ } ) ;
371+
372+ await expect ( ( ) => transport . start ( ) ) . rejects . toThrow ( "Unauthorized" ) ;
373+ expect ( mockAuthProvider . redirectToAuthorization . mock . calls ) . toHaveLength ( 1 ) ;
374+ } ) ;
375+
376+ it ( "attempts auth flow on 401 during POST request" , async ( ) => {
377+ // Create server that accepts SSE but returns 401 on POST
378+ server . close ( ) ;
379+ await new Promise ( resolve => server . on ( "close" , resolve ) ) ;
380+
381+ server = createServer ( ( req , res ) => {
382+ lastServerRequest = req ;
383+
384+ switch ( req . method ) {
385+ case "GET" :
386+ if ( req . url !== "/" ) {
387+ res . writeHead ( 404 ) . end ( ) ;
388+ return ;
389+ }
390+
391+ res . writeHead ( 200 , {
392+ "Content-Type" : "text/event-stream" ,
393+ "Cache-Control" : "no-cache" ,
394+ Connection : "keep-alive" ,
395+ } ) ;
396+ res . write ( "event: endpoint\n" ) ;
397+ res . write ( `data: ${ baseUrl . href } \n\n` ) ;
398+ break ;
399+
400+ case "POST" :
401+ res . writeHead ( 401 ) ;
402+ res . end ( ) ;
403+ break ;
404+ }
405+ } ) ;
406+
407+ await new Promise < void > ( resolve => {
408+ server . listen ( 0 , "127.0.0.1" , ( ) => {
409+ const addr = server . address ( ) as AddressInfo ;
410+ baseUrl = new URL ( `http://127.0.0.1:${ addr . port } ` ) ;
411+ resolve ( ) ;
412+ } ) ;
413+ } ) ;
414+
415+ transport = new SSEClientTransport ( baseUrl , {
416+ authProvider : mockAuthProvider ,
417+ } ) ;
418+
419+ await transport . start ( ) ;
420+
421+ const message : JSONRPCMessage = {
422+ jsonrpc : "2.0" ,
423+ id : "1" ,
424+ method : "test" ,
425+ params : { } ,
426+ } ;
427+
428+ await expect ( ( ) => transport . send ( message ) ) . rejects . toThrow ( "Unauthorized" ) ;
429+ expect ( mockAuthProvider . redirectToAuthorization . mock . calls ) . toHaveLength ( 1 ) ;
430+ } ) ;
431+
432+ it ( "respects custom headers when using auth provider" , async ( ) => {
433+ mockAuthProvider . tokens . mockResolvedValue ( {
434+ access_token : "test-token" ,
435+ token_type : "Bearer"
436+ } ) ;
437+
438+ const customHeaders = {
439+ "X-Custom-Header" : "custom-value" ,
440+ } ;
441+
442+ transport = new SSEClientTransport ( baseUrl , {
443+ authProvider : mockAuthProvider ,
444+ requestInit : {
445+ headers : customHeaders ,
446+ } ,
447+ } ) ;
448+
449+ await transport . start ( ) ;
450+
451+ const message : JSONRPCMessage = {
452+ jsonrpc : "2.0" ,
453+ id : "1" ,
454+ method : "test" ,
455+ params : { } ,
456+ } ;
457+
458+ await transport . send ( message ) ;
459+
460+ expect ( lastServerRequest . headers . authorization ) . toBe ( "Bearer test-token" ) ;
461+ expect ( lastServerRequest . headers [ "x-custom-header" ] ) . toBe ( "custom-value" ) ;
462+ } ) ;
463+ } ) ;
287464} ) ;
0 commit comments