Skip to content

Commit 2ab0e25

Browse files
Phani PemmarajuPhani Pemmaraju
authored andcommitted
fix(session): always dispose in closeGracefully
1 parent a0afdcd commit 2ab0e25

File tree

3 files changed

+272
-2
lines changed

3 files changed

+272
-2
lines changed

mcp-core/src/main/java/io/modelcontextprotocol/spec/DefaultMcpTransportSession.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.slf4j.LoggerFactory;
1010
import reactor.core.Disposable;
1111
import reactor.core.Disposables;
12+
import reactor.core.Exceptions;
1213
import reactor.core.publisher.Mono;
1314

1415
import java.util.Optional;
@@ -77,8 +78,36 @@ public void close() {
7778

7879
@Override
7980
public Mono<Void> closeGracefully() {
80-
return Mono.from(this.onClose.apply(this.sessionId.get()))
81-
.then(Mono.fromRunnable(this.openConnections::dispose));
81+
return Mono.defer(() -> {
82+
final String sessionId = this.sessionId.get();
83+
84+
final AtomicReference<Throwable> primary = new AtomicReference<>(null);
85+
86+
// Subscribe to onClose publisher and capture any error
87+
return Mono.from(this.onClose.apply(sessionId)).onErrorResume(err -> {
88+
primary.set(err);
89+
return Mono.empty();
90+
})
91+
// Always dispose openConnections
92+
.then(Mono.defer(() -> {
93+
try {
94+
this.openConnections.dispose();
95+
}
96+
catch (Throwable disposeEx) {
97+
if (primary.get() != null) {
98+
primary.get().addSuppressed(disposeEx);
99+
}
100+
else {
101+
primary.set(disposeEx);
102+
}
103+
}
104+
105+
// Re-emit the original error (with suppressed dispose error),
106+
// complete
107+
Throwable throwable = primary.get();
108+
return (throwable == null) ? Mono.empty() : Mono.error(Exceptions.propagate(throwable));
109+
}));
110+
});
82111
}
83112

84113
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package io.modelcontextprotocol.spec;
2+
3+
import org.assertj.core.api.Assertions;
4+
import org.junit.jupiter.api.Test;
5+
import org.mockito.Mockito;
6+
import org.reactivestreams.Publisher;
7+
import org.springframework.util.ReflectionUtils;
8+
import reactor.core.Disposable;
9+
import reactor.core.publisher.Mono;
10+
11+
import java.lang.reflect.Field;
12+
import java.util.concurrent.atomic.AtomicBoolean;
13+
import java.util.concurrent.atomic.AtomicReference;
14+
import java.util.function.Function;
15+
16+
import static org.assertj.core.api.Assertions.assertThat;
17+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
18+
19+
/**
20+
* Tests for {@link DefaultMcpTransportSession}.
21+
*
22+
* @author Phani Pemmaraju
23+
*/
24+
class DefaultMcpTransportSessionTests {
25+
26+
/** Minimal Disposable to flag that dispose() was called. */
27+
static final class FlagDisposable implements Disposable {
28+
29+
final AtomicBoolean disposed = new AtomicBoolean(false);
30+
31+
@Override
32+
public void dispose() {
33+
disposed.set(true);
34+
}
35+
36+
@Override
37+
public boolean isDisposed() {
38+
return disposed.get();
39+
}
40+
41+
}
42+
43+
@Test
44+
void closeGracefully_disposes_when_onClose_throws() {
45+
@SuppressWarnings("unchecked")
46+
Function<String, Publisher<Void>> onClose = Mockito.mock(Function.class);
47+
Mockito.when(onClose.apply(Mockito.any())).thenReturn(Mono.error(new RuntimeException("runtime-exception")));
48+
49+
// construct session with required ctor
50+
var session = new DefaultMcpTransportSession(onClose);
51+
52+
// seed session id
53+
setField(session, "sessionId", new AtomicReference<>("sessionId-123"));
54+
55+
// get the existing final composite and add a child flag-disposable
56+
Disposable.Composite composite = (Disposable.Composite) getField(session, "openConnections");
57+
FlagDisposable flag = new FlagDisposable();
58+
composite.add(flag);
59+
60+
// act + assert: original onClose error is propagated
61+
assertThatThrownBy(() -> session.closeGracefully().block()).isInstanceOf(RuntimeException.class)
62+
.hasMessageContaining("runtime-exception");
63+
64+
// and the child disposable was disposed => proves composite.dispose() executed
65+
assertThat(flag.isDisposed()).isTrue();
66+
}
67+
68+
@Test
69+
void closeGracefully_propagates_onClose_error_and_disposes_children() {
70+
// onClose fails again
71+
@SuppressWarnings("unchecked")
72+
Function<String, Publisher<Void>> onClose = Mockito.mock(Function.class);
73+
Mockito.when(onClose.apply(Mockito.any())).thenReturn(Mono.error(new RuntimeException("runtime-exception")));
74+
75+
var session = new DefaultMcpTransportSession(onClose);
76+
setField(session, "sessionId", new AtomicReference<>("sessionId-xyz"));
77+
78+
Disposable.Composite composite = (Disposable.Composite) getField(session, "openConnections");
79+
FlagDisposable a = new FlagDisposable();
80+
FlagDisposable b = new FlagDisposable();
81+
composite.add(a);
82+
composite.add(b);
83+
84+
Throwable thrown = Assertions.catchThrowable(() -> session.closeGracefully().block());
85+
86+
// primary error is from onClose
87+
assertThat(thrown).isInstanceOf(RuntimeException.class).hasMessageContaining("runtime-exception");
88+
89+
// both children disposed
90+
assertThat(a.isDisposed()).isTrue();
91+
assertThat(b.isDisposed()).isTrue();
92+
}
93+
94+
private static void setField(Object target, String fieldName, Object value) {
95+
Field f = ReflectionUtils.findField(target.getClass(), fieldName);
96+
if (f == null)
97+
throw new IllegalArgumentException("No such field: " + fieldName);
98+
ReflectionUtils.makeAccessible(f);
99+
ReflectionUtils.setField(f, target, value);
100+
}
101+
102+
private static Object getField(Object target, String fieldName) {
103+
Field f = ReflectionUtils.findField(target.getClass(), fieldName);
104+
if (f == null)
105+
throw new IllegalArgumentException("No such field: " + fieldName);
106+
ReflectionUtils.makeAccessible(f);
107+
return ReflectionUtils.getField(f, target);
108+
}
109+
110+
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package io.modelcontextprotocol;
2+
3+
import java.time.Duration;
4+
import java.util.Map;
5+
import java.util.stream.Stream;
6+
7+
import org.assertj.core.api.Assertions;
8+
import org.junit.jupiter.api.*;
9+
import org.junit.jupiter.params.ParameterizedTest;
10+
import org.junit.jupiter.params.provider.Arguments;
11+
import org.junit.jupiter.params.provider.MethodSource;
12+
13+
import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter;
14+
import org.springframework.web.reactive.function.client.WebClient;
15+
import org.springframework.web.reactive.function.server.RouterFunctions;
16+
import org.springframework.web.reactive.function.server.ServerRequest;
17+
18+
import io.modelcontextprotocol.client.McpClient;
19+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
20+
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
21+
import io.modelcontextprotocol.common.McpTransportContext;
22+
import io.modelcontextprotocol.server.McpServer;
23+
import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
24+
import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
25+
import io.modelcontextprotocol.server.McpTransportContextExtractor;
26+
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
27+
import reactor.core.publisher.Hooks;
28+
import reactor.netty.DisposableServer;
29+
import reactor.netty.http.server.HttpServer;
30+
31+
@Timeout(15)
32+
public class WebFluxSseCloseGracefullyIntegrationTests extends AbstractMcpClientServerIntegrationTests {
33+
34+
private int port;
35+
36+
private static final String CUSTOM_SSE_ENDPOINT = "/somePath/sse";
37+
38+
private static final String DEFAULT_MESSAGE_ENDPOINT = "/mcp/message";
39+
40+
private DisposableServer httpServer;
41+
42+
private WebFluxSseServerTransportProvider mcpServerTransportProvider;
43+
44+
static McpTransportContextExtractor<ServerRequest> TEST_CONTEXT_EXTRACTOR = (r) -> McpTransportContext
45+
.create(Map.of("important", "value"));
46+
47+
static Stream<Arguments> clientsForTesting() {
48+
return Stream.of(Arguments.of("httpclient"), Arguments.of("webflux"));
49+
}
50+
51+
@Override
52+
protected void prepareClients(int port, String mcpEndpoint) {
53+
clientBuilders
54+
.put("httpclient",
55+
McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port)
56+
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
57+
.build()).requestTimeout(Duration.ofSeconds(10)));
58+
59+
clientBuilders.put("webflux", McpClient
60+
.sync(WebFluxSseClientTransport.builder(org.springframework.web.reactive.function.client.WebClient.builder()
61+
.baseUrl("http://localhost:" + port)).sseEndpoint(CUSTOM_SSE_ENDPOINT).build())
62+
.requestTimeout(Duration.ofSeconds(10)));
63+
}
64+
65+
@Override
66+
protected AsyncSpecification<?> prepareAsyncServerBuilder() {
67+
return McpServer.async(mcpServerTransportProvider);
68+
}
69+
70+
@Override
71+
protected SingleSessionSyncSpecification prepareSyncServerBuilder() {
72+
return McpServer.sync(mcpServerTransportProvider);
73+
}
74+
75+
@BeforeEach
76+
void before() {
77+
// Build the transport provider with BOTH endpoints (message required)
78+
this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider.Builder()
79+
.messageEndpoint(DEFAULT_MESSAGE_ENDPOINT)
80+
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
81+
.contextExtractor(TEST_CONTEXT_EXTRACTOR)
82+
.build();
83+
84+
// Wire session factory
85+
prepareSyncServerBuilder().build();
86+
87+
// Bind on ephemeral port and discover the actual port
88+
var httpHandler = RouterFunctions.toHttpHandler(mcpServerTransportProvider.getRouterFunction());
89+
var adapter = new ReactorHttpHandlerAdapter(httpHandler);
90+
this.httpServer = HttpServer.create().port(0).handle(adapter).bindNow();
91+
this.port = httpServer.port();
92+
93+
// Build clients using the discovered port
94+
prepareClients(this.port, null);
95+
96+
// keep your onErrorDropped suppression if you need it for noisy Reactor paths
97+
Hooks.onErrorDropped(e -> {
98+
});
99+
}
100+
101+
@AfterEach
102+
void after() {
103+
if (httpServer != null)
104+
httpServer.disposeNow();
105+
Hooks.resetOnErrorDropped();
106+
}
107+
108+
@ParameterizedTest(name = "closeGracefully after outage: {0}")
109+
@MethodSource("clientsForTesting")
110+
@DisplayName("closeGracefully() signals failure after server outage (WebFlux/SSE, sync client)")
111+
void closeGracefully_disposes_after_server_unavailable(String clientKey) {
112+
var reactiveClient = io.modelcontextprotocol.client.McpClient
113+
.async(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + this.port))
114+
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
115+
.build())
116+
.requestTimeout(Duration.ofSeconds(10))
117+
.build();
118+
119+
reactiveClient.initialize().block(Duration.ofSeconds(5));
120+
121+
httpServer.disposeNow();
122+
123+
Assertions.assertThatCode(() -> reactiveClient.closeGracefully().block(Duration.ofSeconds(5)))
124+
.doesNotThrowAnyException();
125+
126+
Assertions.assertThatThrownBy(() -> reactiveClient.initialize().block(Duration.ofSeconds(3)))
127+
.isInstanceOf(Exception.class);
128+
129+
}
130+
131+
}

0 commit comments

Comments
 (0)