44
55package io .modelcontextprotocol .server .transport ;
66
7+ import java .io .BufferedReader ;
78import java .io .ByteArrayInputStream ;
89import java .io .ByteArrayOutputStream ;
910import java .io .InputStream ;
11+ import java .io .InputStreamReader ;
12+ import java .io .PipedInputStream ;
13+ import java .io .PipedOutputStream ;
1014import java .io .PrintStream ;
1115import java .nio .charset .StandardCharsets ;
1216import java .util .Map ;
17+ import java .util .Set ;
18+ import java .util .concurrent .CompletableFuture ;
19+ import java .util .concurrent .ConcurrentHashMap ;
1320import java .util .concurrent .CountDownLatch ;
21+ import java .util .concurrent .ExecutorService ;
22+ import java .util .concurrent .Executors ;
1423import java .util .concurrent .TimeUnit ;
24+ import java .util .concurrent .atomic .AtomicInteger ;
1525import java .util .concurrent .atomic .AtomicReference ;
26+ import java .util .function .Consumer ;
1627
1728import io .modelcontextprotocol .json .McpJsonDefaults ;
1829import io .modelcontextprotocol .spec .McpError ;
2334import org .junit .jupiter .api .BeforeEach ;
2435import org .junit .jupiter .api .Disabled ;
2536import org .junit .jupiter .api .Test ;
37+ import org .junit .jupiter .api .Timeout ;
2638import reactor .core .publisher .Mono ;
39+ import reactor .core .scheduler .Scheduler ;
40+ import reactor .core .scheduler .Schedulers ;
2741import reactor .test .StepVerifier ;
2842
2943import static org .assertj .core .api .Assertions .assertThat ;
@@ -221,21 +235,73 @@ void shouldHandleSessionClose() throws Exception {
221235 }
222236
223237 @ Test
238+ @ Timeout (15 )
224239 void shouldHandleConcurrentMessages () throws Exception {
225- java .io .PipedOutputStream pipedOut = new java .io .PipedOutputStream ();
226- java .io .PipedInputStream pipedIn = new java .io .PipedInputStream (pipedOut );
240+ int messageCount = 100 ;
241+ int workerCount = 32 ;
242+ PipedOutputStream pipedRequestOut = new PipedOutputStream ();
243+ PipedInputStream pipedRequestIn = new PipedInputStream (pipedRequestOut );
244+ PipedInputStream pipedResponseIn = new PipedInputStream ();
245+ PipedOutputStream pipedResponseOut = new PipedOutputStream (pipedResponseIn );
246+
247+ CountDownLatch inFlightReached = new CountDownLatch (workerCount );
248+ CountDownLatch startSignal = new CountDownLatch (1 );
249+ CompletableFuture <Void > responsesDone = new CompletableFuture <>();
250+ Set <Integer > responseIds = ConcurrentHashMap .newKeySet ();
251+ AtomicInteger inFlight = new AtomicInteger ();
252+ AtomicInteger maxInFlight = new AtomicInteger ();
253+ ExecutorService swarmExecutor = Executors .newFixedThreadPool (workerCount );
254+ Scheduler swarmScheduler = Schedulers .fromExecutorService (swarmExecutor );
255+ Consumer <Throwable > fail = t -> {
256+ responsesDone .completeExceptionally (t );
257+ startSignal .countDown ();
258+ };
259+
260+ Thread responseReader = new Thread (() -> {
261+ try (BufferedReader reader = new BufferedReader (
262+ new InputStreamReader (pipedResponseIn , StandardCharsets .UTF_8 ))) {
263+ for (int received = 0 ; received < messageCount ; received ++) {
264+ String line = reader .readLine ();
265+ if (line == null ) {
266+ throw new AssertionError ("Stream closed before all responses were received" );
267+ }
268+ McpSchema .JSONRPCMessage message = McpSchema .deserializeJsonRpcMessage (McpJsonDefaults .getMapper (),
269+ line );
270+ if (!(message instanceof McpSchema .JSONRPCResponse response )) {
271+ throw new AssertionError ("Expected JSONRPCResponse" );
272+ }
273+ if (!(response .id () instanceof Number idNumber )) {
274+ throw new AssertionError ("Expected numeric response id" );
275+ }
276+ responseIds .add (idNumber .intValue ());
277+ }
278+ responsesDone .complete (null );
279+ }
280+ catch (Throwable t ) {
281+ fail .accept (t );
282+ }
283+ });
284+ responseReader .start ();
227285
228- ByteArrayOutputStream outputStream = new ByteArrayOutputStream ();
229- transportProvider = new StdioServerTransportProvider ( McpJsonDefaults . getMapper (), pipedIn , outputStream );
286+ transportProvider = new StdioServerTransportProvider ( McpJsonDefaults . getMapper (), pipedRequestIn ,
287+ pipedResponseOut );
230288
231289 McpServerSession .Factory realSessionFactory = transport -> {
232290 McpServerSession session = mock (McpServerSession .class );
233291 when (session .handle (any ())).thenAnswer (invocation -> {
234- McpSchema .JSONRPCMessage incomingMessage = invocation .getArgument (0 );
235- // Simulate async tool call processing with a delay
236- return Mono .delay (java .time .Duration .ofMillis (50 ))
237- .then (transport .sendMessage (new McpSchema .JSONRPCResponse (McpSchema .JSONRPC_VERSION ,
238- ((McpSchema .JSONRPCRequest ) incomingMessage ).id (), Map .of ("result" , "ok" ), null )));
292+ McpSchema .JSONRPCRequest incomingMessage = invocation .getArgument (0 );
293+ return Mono .fromCallable (() -> {
294+ inFlightReached .countDown ();
295+ startSignal .await ();
296+ return incomingMessage .id ();
297+ }).subscribeOn (swarmScheduler ).flatMap (id -> {
298+ int currentInFlight = inFlight .incrementAndGet ();
299+ maxInFlight .accumulateAndGet (currentInFlight , Math ::max );
300+ return transport
301+ .sendMessage (new McpSchema .JSONRPCResponse (McpSchema .JSONRPC_VERSION , id ,
302+ Map .of ("result" , "ok" ), null ))
303+ .doFinally (signalType -> inFlight .decrementAndGet ());
304+ }).doOnError (fail );
239305 });
240306 when (session .closeGracefully ()).thenReturn (Mono .empty ());
241307 return session ;
@@ -244,23 +310,37 @@ void shouldHandleConcurrentMessages() throws Exception {
244310 // Set session factory
245311 transportProvider .setSessionFactory (realSessionFactory );
246312
247- String jsonMessage1 = "{\" jsonrpc\" :\" 2.0\" ,\" method\" :\" test1\" ,\" params\" :{},\" id\" :1}\n " ;
248- String jsonMessage2 = "{\" jsonrpc\" :\" 2.0\" ,\" method\" :\" test2\" ,\" params\" :{},\" id\" :2}\n " ;
249- pipedOut .write (jsonMessage1 .getBytes (StandardCharsets .UTF_8 ));
250- pipedOut .write (jsonMessage2 .getBytes (StandardCharsets .UTF_8 ));
251- pipedOut .flush ();
252-
253- // Verify both concurrent responses complete without error
254- StepVerifier
255- .create (Mono .delay (java .time .Duration .ofSeconds (2 ))
256- .then (Mono .fromCallable (() -> outputStream .toString (StandardCharsets .UTF_8 ))))
257- .assertNext (output -> {
258- assertThat (output ).contains ("\" id\" :1" );
259- assertThat (output ).contains ("\" id\" :2" );
260- })
261- .verifyComplete ();
262-
263- pipedOut .close ();
313+ try {
314+ for (int i = 1 ; i <= messageCount ; i ++) {
315+ String jsonMessage = "{\" jsonrpc\" :\" 2.0\" ,\" method\" :\" test" + i + "\" ,\" params\" :{},\" id\" :" + i
316+ + "}\n " ;
317+ pipedRequestOut .write (jsonMessage .getBytes (StandardCharsets .UTF_8 ));
318+ }
319+ pipedRequestOut .flush ();
320+ inFlightReached .await ();
321+ startSignal .countDown ();
322+ responsesDone .get ();
323+
324+ // Verify that all messages were recevied
325+ assertThat (responseIds ).hasSize (messageCount );
326+ // Verify that concurrency happened
327+ assertThat (maxInFlight .get ()).isGreaterThan (1 );
328+ // Verify that every responseId exists
329+ for (int i = 1 ; i <= messageCount ; i ++) {
330+ assertThat (responseIds ).contains (i );
331+ }
332+ }
333+ finally {
334+ startSignal .countDown ();
335+ swarmScheduler .dispose ();
336+ swarmExecutor .shutdownNow ();
337+ pipedRequestOut .close ();
338+ pipedResponseOut .close ();
339+ pipedResponseIn .close ();
340+ responseReader .join (TimeUnit .SECONDS .toMillis (2 ));
341+
342+ assertThat (responseReader .isAlive ()).isFalse ();
343+ }
264344 }
265345
266346}
0 commit comments