|
| 1 | +package io.modelcontextprotocol; |
| 2 | + |
| 3 | +import java.time.Duration; |
| 4 | +import java.util.Map; |
| 5 | +import java.util.stream.Stream; |
| 6 | + |
| 7 | +import org.assertj.core.api.Assertions; |
| 8 | +import org.junit.jupiter.api.*; |
| 9 | +import org.junit.jupiter.params.ParameterizedTest; |
| 10 | +import org.junit.jupiter.params.provider.Arguments; |
| 11 | +import org.junit.jupiter.params.provider.MethodSource; |
| 12 | + |
| 13 | +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; |
| 14 | +import org.springframework.web.reactive.function.client.WebClient; |
| 15 | +import org.springframework.web.reactive.function.server.RouterFunctions; |
| 16 | +import org.springframework.web.reactive.function.server.ServerRequest; |
| 17 | + |
| 18 | +import io.modelcontextprotocol.client.McpClient; |
| 19 | +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; |
| 20 | +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; |
| 21 | +import io.modelcontextprotocol.common.McpTransportContext; |
| 22 | +import io.modelcontextprotocol.server.McpServer; |
| 23 | +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; |
| 24 | +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; |
| 25 | +import io.modelcontextprotocol.server.McpTransportContextExtractor; |
| 26 | +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; |
| 27 | +import reactor.core.publisher.Hooks; |
| 28 | +import reactor.netty.DisposableServer; |
| 29 | +import reactor.netty.http.server.HttpServer; |
| 30 | + |
| 31 | +@Timeout(15) |
| 32 | +public class WebFluxSseCloseGracefullyIntegrationTests extends AbstractMcpClientServerIntegrationTests { |
| 33 | + |
| 34 | + private int port; |
| 35 | + |
| 36 | + private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse"; |
| 37 | + |
| 38 | + private static final String DEFAULT_MESSAGE_ENDPOINT = "/mcp/message"; |
| 39 | + |
| 40 | + private DisposableServer httpServer; |
| 41 | + |
| 42 | + private WebFluxSseServerTransportProvider mcpServerTransportProvider; |
| 43 | + |
| 44 | + static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext |
| 45 | + .create(Map.of("important", "value")); |
| 46 | + |
| 47 | + static Stream<Arguments> clientsForTesting() { |
| 48 | + return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux")); |
| 49 | + } |
| 50 | + |
| 51 | + @Override |
| 52 | + protected void prepareClients(int port, String mcpEndpoint) { |
| 53 | + clientBuilders |
| 54 | + .put("httpclient", |
| 55 | + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port) |
| 56 | + .sseEndpoint(CUSTOM_SSE_ENDPOINT) |
| 57 | + .build()).requestTimeout(Duration.ofSeconds(10))); |
| 58 | + |
| 59 | + clientBuilders.put("webflux", McpClient |
| 60 | + .sync(WebFluxSseClientTransport.builder(org.springframework.web.reactive.function.client.WebClient.builder() |
| 61 | + .baseUrl("http://localhost:" + port)).sseEndpoint(CUSTOM_SSE_ENDPOINT).build()) |
| 62 | + .requestTimeout(Duration.ofSeconds(10))); |
| 63 | + } |
| 64 | + |
| 65 | + @Override |
| 66 | + protected AsyncSpecification<?> prepareAsyncServerBuilder() { |
| 67 | + return McpServer.async(mcpServerTransportProvider); |
| 68 | + } |
| 69 | + |
| 70 | + @Override |
| 71 | + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { |
| 72 | + return McpServer.sync(mcpServerTransportProvider); |
| 73 | + } |
| 74 | + |
| 75 | + @BeforeEach |
| 76 | + void before() { |
| 77 | + // Build the transport provider with BOTH endpoints (message required) |
| 78 | + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder() |
| 79 | + .messageEndpoint(DEFAULT_MESSAGE_ENDPOINT) |
| 80 | + .sseEndpoint(CUSTOM_SSE_ENDPOINT) |
| 81 | + .contextExtractor(TEST_CONTEXT_EXTRACTOR) |
| 82 | + .build(); |
| 83 | + |
| 84 | + // Wire session factory |
| 85 | + prepareSyncServerBuilder().build(); |
| 86 | + |
| 87 | + // Bind on ephemeral port and discover the actual port |
| 88 | + var httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction()); |
| 89 | + var adapter = new ReactorHttpHandlerAdapter(httpHandler); |
| 90 | + this.httpServer = HttpServer.create().port(0).handle(adapter).bindNow(); |
| 91 | + this.port = httpServer.port(); |
| 92 | + |
| 93 | + // Build clients using the discovered port |
| 94 | + prepareClients(this.port, null); |
| 95 | + |
| 96 | + // keep your onErrorDropped suppression if you need it for noisy Reactor paths |
| 97 | + Hooks.onErrorDropped(e -> { |
| 98 | + }); |
| 99 | + } |
| 100 | + |
| 101 | + @AfterEach |
| 102 | + void after() { |
| 103 | + if (httpServer != null) |
| 104 | + httpServer.disposeNow(); |
| 105 | + Hooks.resetOnErrorDropped(); |
| 106 | + } |
| 107 | + |
| 108 | + @ParameterizedTest(name = "closeGracefully after outage: {0}") |
| 109 | + @MethodSource("clientsForTesting") |
| 110 | + @DisplayName("closeGracefully() signals failure after server outage (WebFlux/SSE, sync client)") |
| 111 | + void closeGracefully_disposes_after_server_unavailable(String clientKey) { |
| 112 | + var reactiveClient = io.modelcontextprotocol.client.McpClient |
| 113 | + .async(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + this.port)) |
| 114 | + .sseEndpoint(CUSTOM_SSE_ENDPOINT) |
| 115 | + .build()) |
| 116 | + .requestTimeout(Duration.ofSeconds(10)) |
| 117 | + .build(); |
| 118 | + |
| 119 | + reactiveClient.initialize().block(Duration.ofSeconds(5)); |
| 120 | + |
| 121 | + httpServer.disposeNow(); |
| 122 | + |
| 123 | + Assertions.assertThatCode(() -> reactiveClient.closeGracefully().block(Duration.ofSeconds(5))) |
| 124 | + .doesNotThrowAnyException(); |
| 125 | + |
| 126 | + Assertions.assertThatThrownBy(() -> reactiveClient.initialize().block(Duration.ofSeconds(3))) |
| 127 | + .isInstanceOf(Exception.class); |
| 128 | + |
| 129 | + } |
| 130 | + |
| 131 | +} |
0 commit comments