Skip to content

Commit c7e71b3

Browse files
committed
test: replace delay based concurrency test with thread swarm
1 parent 2715ab8 commit c7e71b3

File tree

1 file changed

+106
-26
lines changed

1 file changed

+106
-26
lines changed

mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,26 @@
44

55
package io.modelcontextprotocol.server.transport;
66

7+
import java.io.BufferedReader;
78
import java.io.ByteArrayInputStream;
89
import java.io.ByteArrayOutputStream;
910
import java.io.InputStream;
11+
import java.io.InputStreamReader;
12+
import java.io.PipedInputStream;
13+
import java.io.PipedOutputStream;
1014
import java.io.PrintStream;
1115
import java.nio.charset.StandardCharsets;
1216
import java.util.Map;
17+
import java.util.Set;
18+
import java.util.concurrent.CompletableFuture;
19+
import java.util.concurrent.ConcurrentHashMap;
1320
import java.util.concurrent.CountDownLatch;
21+
import java.util.concurrent.ExecutorService;
22+
import java.util.concurrent.Executors;
1423
import java.util.concurrent.TimeUnit;
24+
import java.util.concurrent.atomic.AtomicInteger;
1525
import java.util.concurrent.atomic.AtomicReference;
26+
import java.util.function.Consumer;
1627

1728
import io.modelcontextprotocol.json.McpJsonDefaults;
1829
import io.modelcontextprotocol.spec.McpError;
@@ -23,7 +34,10 @@
2334
import org.junit.jupiter.api.BeforeEach;
2435
import org.junit.jupiter.api.Disabled;
2536
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.api.Timeout;
2638
import reactor.core.publisher.Mono;
39+
import reactor.core.scheduler.Scheduler;
40+
import reactor.core.scheduler.Schedulers;
2741
import reactor.test.StepVerifier;
2842

2943
import static org.assertj.core.api.Assertions.assertThat;
@@ -221,21 +235,73 @@ void shouldHandleSessionClose() throws Exception {
221235
}
222236

223237
@Test
238+
@Timeout(15)
224239
void shouldHandleConcurrentMessages() throws Exception {
225-
java.io.PipedOutputStream pipedOut = new java.io.PipedOutputStream();
226-
java.io.PipedInputStream pipedIn = new java.io.PipedInputStream(pipedOut);
240+
int messageCount = 100;
241+
int workerCount = 32;
242+
PipedOutputStream pipedRequestOut = new PipedOutputStream();
243+
PipedInputStream pipedRequestIn = new PipedInputStream(pipedRequestOut);
244+
PipedInputStream pipedResponseIn = new PipedInputStream();
245+
PipedOutputStream pipedResponseOut = new PipedOutputStream(pipedResponseIn);
246+
247+
CountDownLatch inFlightReached = new CountDownLatch(workerCount);
248+
CountDownLatch startSignal = new CountDownLatch(1);
249+
CompletableFuture<Void> responsesDone = new CompletableFuture<>();
250+
Set<Integer> responseIds = ConcurrentHashMap.newKeySet();
251+
AtomicInteger inFlight = new AtomicInteger();
252+
AtomicInteger maxInFlight = new AtomicInteger();
253+
ExecutorService swarmExecutor = Executors.newFixedThreadPool(workerCount);
254+
Scheduler swarmScheduler = Schedulers.fromExecutorService(swarmExecutor);
255+
Consumer<Throwable> fail = t -> {
256+
responsesDone.completeExceptionally(t);
257+
startSignal.countDown();
258+
};
259+
260+
Thread responseReader = new Thread(() -> {
261+
try (BufferedReader reader = new BufferedReader(
262+
new InputStreamReader(pipedResponseIn, StandardCharsets.UTF_8))) {
263+
for (int received = 0; received < messageCount; received++) {
264+
String line = reader.readLine();
265+
if (line == null) {
266+
throw new AssertionError("Stream closed before all responses were received");
267+
}
268+
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(McpJsonDefaults.getMapper(),
269+
line);
270+
if (!(message instanceof McpSchema.JSONRPCResponse response)) {
271+
throw new AssertionError("Expected JSONRPCResponse");
272+
}
273+
if (!(response.id() instanceof Number idNumber)) {
274+
throw new AssertionError("Expected numeric response id");
275+
}
276+
responseIds.add(idNumber.intValue());
277+
}
278+
responsesDone.complete(null);
279+
}
280+
catch (Throwable t) {
281+
fail.accept(t);
282+
}
283+
});
284+
responseReader.start();
227285

228-
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
229-
transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), pipedIn, outputStream);
286+
transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), pipedRequestIn,
287+
pipedResponseOut);
230288

231289
McpServerSession.Factory realSessionFactory = transport -> {
232290
McpServerSession session = mock(McpServerSession.class);
233291
when(session.handle(any())).thenAnswer(invocation -> {
234-
McpSchema.JSONRPCMessage incomingMessage = invocation.getArgument(0);
235-
// Simulate async tool call processing with a delay
236-
return Mono.delay(java.time.Duration.ofMillis(50))
237-
.then(transport.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION,
238-
((McpSchema.JSONRPCRequest) incomingMessage).id(), Map.of("result", "ok"), null)));
292+
McpSchema.JSONRPCRequest incomingMessage = invocation.getArgument(0);
293+
return Mono.fromCallable(() -> {
294+
inFlightReached.countDown();
295+
startSignal.await();
296+
return incomingMessage.id();
297+
}).subscribeOn(swarmScheduler).flatMap(id -> {
298+
int currentInFlight = inFlight.incrementAndGet();
299+
maxInFlight.accumulateAndGet(currentInFlight, Math::max);
300+
return transport
301+
.sendMessage(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, id,
302+
Map.of("result", "ok"), null))
303+
.doFinally(signalType -> inFlight.decrementAndGet());
304+
}).doOnError(fail);
239305
});
240306
when(session.closeGracefully()).thenReturn(Mono.empty());
241307
return session;
@@ -244,23 +310,37 @@ void shouldHandleConcurrentMessages() throws Exception {
244310
// Set session factory
245311
transportProvider.setSessionFactory(realSessionFactory);
246312

247-
String jsonMessage1 = "{\"jsonrpc\":\"2.0\",\"method\":\"test1\",\"params\":{},\"id\":1}\n";
248-
String jsonMessage2 = "{\"jsonrpc\":\"2.0\",\"method\":\"test2\",\"params\":{},\"id\":2}\n";
249-
pipedOut.write(jsonMessage1.getBytes(StandardCharsets.UTF_8));
250-
pipedOut.write(jsonMessage2.getBytes(StandardCharsets.UTF_8));
251-
pipedOut.flush();
252-
253-
// Verify both concurrent responses complete without error
254-
StepVerifier
255-
.create(Mono.delay(java.time.Duration.ofSeconds(2))
256-
.then(Mono.fromCallable(() -> outputStream.toString(StandardCharsets.UTF_8))))
257-
.assertNext(output -> {
258-
assertThat(output).contains("\"id\":1");
259-
assertThat(output).contains("\"id\":2");
260-
})
261-
.verifyComplete();
262-
263-
pipedOut.close();
313+
try {
314+
for (int i = 1; i <= messageCount; i++) {
315+
String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test" + i + "\",\"params\":{},\"id\":" + i
316+
+ "}\n";
317+
pipedRequestOut.write(jsonMessage.getBytes(StandardCharsets.UTF_8));
318+
}
319+
pipedRequestOut.flush();
320+
inFlightReached.await();
321+
startSignal.countDown();
322+
responsesDone.get();
323+
324+
// Verify that all messages were recevied
325+
assertThat(responseIds).hasSize(messageCount);
326+
// Verify that concurrency happened
327+
assertThat(maxInFlight.get()).isGreaterThan(1);
328+
// Verify that every responseId exists
329+
for (int i = 1; i <= messageCount; i++) {
330+
assertThat(responseIds).contains(i);
331+
}
332+
}
333+
finally {
334+
startSignal.countDown();
335+
swarmScheduler.dispose();
336+
swarmExecutor.shutdownNow();
337+
pipedRequestOut.close();
338+
pipedResponseOut.close();
339+
pipedResponseIn.close();
340+
responseReader.join(TimeUnit.SECONDS.toMillis(2));
341+
342+
assertThat(responseReader.isAlive()).isFalse();
343+
}
264344
}
265345

266346
}

0 commit comments

Comments
 (0)