Skip to content

Commit 646bb21

Browse files
committed
Validate message endpoint in SSE WebFlux client transport
Signed-off-by: Daniel Garnier-Moiroux <git@garnier.wf>
1 parent ac2f0b3 commit 646bb21

2 files changed

Lines changed: 112 additions & 11 deletions

File tree

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
package io.modelcontextprotocol.client.transport;
66

77
import java.io.IOException;
8+
import java.net.URI;
89
import java.util.List;
10+
import java.util.concurrent.atomic.AtomicReference;
911
import java.util.function.BiConsumer;
1012
import java.util.function.Function;
1113

@@ -128,7 +130,17 @@ public class WebFluxSseClientTransport implements McpClientTransport {
128130
* The SSE endpoint URI provided by the server. Used for sending outbound messages via
129131
* HTTP POST requests.
130132
*/
131-
private String sseEndpoint;
133+
private final String sseEndpoint;
134+
135+
/**
136+
* Used to capture the full SSE URI from the web client when connecting.
137+
*/
138+
private final AtomicReference<URI> sseUri = new AtomicReference<>();
139+
140+
/**
141+
* Validator for the message endpoint.
142+
*/
143+
private final SseMessageEndpointValidator messageEndpointValidator;
132144

133145
/**
134146
* Constructs a new SseClientTransport with the specified WebClient builder and
@@ -152,13 +164,30 @@ public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapp
152164
* @throws IllegalArgumentException if either parameter is null
153165
*/
154166
public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint) {
167+
this(webClientBuilder, jsonMapper, sseEndpoint, new DefaultSseMessageEndpointValidator());
168+
}
169+
170+
/**
171+
* Constructs a new SseClientTransport with the specified WebClient builder and
172+
* ObjectMapper. Initializes both inbound and outbound message processing pipelines.
173+
* @param webClientBuilder the WebClient.Builder to use for creating the WebClient
174+
* instance
175+
* @param jsonMapper the ObjectMapper to use for JSON processing
176+
* @param sseEndpoint the SSE endpoint URI to use for establishing the connection
177+
* @param messageEndpointValidator validator for the message endpoint
178+
* @throws IllegalArgumentException if either parameter is null
179+
*/
180+
public WebFluxSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper, String sseEndpoint,
181+
SseMessageEndpointValidator messageEndpointValidator) {
155182
Assert.notNull(jsonMapper, "jsonMapper must not be null");
156183
Assert.notNull(webClientBuilder, "WebClient.Builder must not be null");
184+
Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null");
157185
Assert.hasText(sseEndpoint, "SSE endpoint must not be null or empty");
158186

159187
this.jsonMapper = jsonMapper;
160188
this.webClient = webClientBuilder.build();
161189
this.sseEndpoint = sseEndpoint;
190+
this.messageEndpointValidator = messageEndpointValidator;
162191
}
163192

164193
@Override
@@ -195,6 +224,14 @@ public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> h
195224
this.inboundSubscription = events.concatMap(event -> Mono.just(event).<JSONRPCMessage>handle((e, s) -> {
196225
if (ENDPOINT_EVENT_TYPE.equals(event.event())) {
197226
String messageEndpointUri = event.data();
227+
try {
228+
this.messageEndpointValidator.validate(this.sseUri.get(), messageEndpointUri);
229+
}
230+
catch (InvalidSseMessageEndpointException ex) {
231+
messageEndpointSink.tryEmitError(ex);
232+
s.error(ex);
233+
return;
234+
}
198235
if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
199236
s.complete();
200237
}
@@ -276,16 +313,17 @@ public Mono<Void> sendMessage(JSONRPCMessage message) {
276313
* Includes automatic retry logic for handling transient connection failures.
277314
*/
278315
// visible for tests
279-
protected Flux<ServerSentEvent<String>> eventStream() {// @formatter:off
280-
return this.webClient
281-
.get()
316+
protected Flux<ServerSentEvent<String>> eventStream() {
317+
return this.webClient.get()
282318
.uri(this.sseEndpoint)
283319
.accept(MediaType.TEXT_EVENT_STREAM)
284320
.header(HttpHeaders.PROTOCOL_VERSION, MCP_PROTOCOL_VERSION)
285-
.retrieve()
286-
.bodyToFlux(SSE_TYPE)
321+
.exchangeToFlux(exchange -> {
322+
this.sseUri.set(exchange.request().getURI());
323+
return exchange.bodyToFlux(SSE_TYPE);
324+
})
287325
.retryWhen(Retry.from(retrySignal -> retrySignal.handle(inboundRetryHandler)));
288-
} // @formatter:on
326+
}
289327

290328
/**
291329
* Retry handler for the inbound SSE stream. Implements the retry logic for handling
@@ -368,6 +406,8 @@ public static class Builder {
368406

369407
private McpJsonMapper jsonMapper;
370408

409+
private SseMessageEndpointValidator messageEndpointValidator = new DefaultSseMessageEndpointValidator();
410+
371411
/**
372412
* Creates a new builder with the specified WebClient.Builder.
373413
* @param webClientBuilder the WebClient.Builder to use
@@ -399,13 +439,26 @@ public Builder jsonMapper(McpJsonMapper jsonMapper) {
399439
return this;
400440
}
401441

442+
/**
443+
* Sets the validator that ensure the message endpoint returned over the SSE
444+
* connection is valid.
445+
* @param messageEndpointValidator the validator
446+
* @return this builder
447+
*/
448+
public Builder messageEndpointValidator(SseMessageEndpointValidator messageEndpointValidator) {
449+
Assert.notNull(messageEndpointValidator, "messageEndpointValidator must not be null");
450+
this.messageEndpointValidator = messageEndpointValidator;
451+
return this;
452+
}
453+
402454
/**
403455
* Builds a new {@link WebFluxSseClientTransport} instance.
404456
* @return a new transport instance
405457
*/
406458
public WebFluxSseClientTransport build() {
407459
return new WebFluxSseClientTransport(webClientBuilder,
408-
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, sseEndpoint);
460+
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, sseEndpoint,
461+
messageEndpointValidator);
409462
}
410463

411464
}

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package io.modelcontextprotocol.client.transport;
66

7+
import java.net.URI;
78
import java.time.Duration;
89
import java.util.Map;
910
import java.util.concurrent.CopyOnWriteArrayList;
@@ -21,6 +22,7 @@
2122
import org.junit.jupiter.api.BeforeEach;
2223
import org.junit.jupiter.api.Test;
2324
import org.junit.jupiter.api.Timeout;
25+
import org.mockito.ArgumentCaptor;
2426
import org.testcontainers.containers.GenericContainer;
2527
import org.testcontainers.containers.wait.strategy.Wait;
2628
import reactor.core.publisher.Flux;
@@ -35,6 +37,9 @@
3537
import static org.assertj.core.api.Assertions.assertThat;
3638
import static org.assertj.core.api.Assertions.assertThatCode;
3739
import static org.assertj.core.api.Assertions.assertThatThrownBy;
40+
import static org.mockito.ArgumentMatchers.matches;
41+
import static org.mockito.Mockito.mock;
42+
import static org.mockito.Mockito.verify;
3843

3944
/**
4045
* Tests for the {@link WebFluxSseClientTransport} class.
@@ -57,15 +62,18 @@ class WebFluxSseClientTransportTests {
5762

5863
private WebClient.Builder webClientBuilder;
5964

65+
private SseMessageEndpointValidator sseMessageEndpointValidator = mock(SseMessageEndpointValidator.class);
66+
6067
// Test class to access protected methods
6168
static class TestSseClientTransport extends WebFluxSseClientTransport {
6269

6370
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
6471

6572
private Sinks.Many<ServerSentEvent<String>> events = Sinks.many().unicast().onBackpressureBuffer();
6673

67-
public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper) {
68-
super(webClientBuilder, jsonMapper);
74+
public TestSseClientTransport(WebClient.Builder webClientBuilder, McpJsonMapper jsonMapper,
75+
SseMessageEndpointValidator sseMessageEndpointValidator) {
76+
super(webClientBuilder, jsonMapper, "/sse", sseMessageEndpointValidator);
6977
}
7078

7179
@Override
@@ -113,7 +121,7 @@ static void cleanup() {
113121
@BeforeEach
114122
void setUp() {
115123
webClientBuilder = WebClient.builder().baseUrl(host);
116-
transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER);
124+
transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER, sseMessageEndpointValidator);
117125
transport.connect(Function.identity()).block();
118126
}
119127

@@ -368,4 +376,44 @@ void testMessageOrderPreservation() {
368376
assertThat(transport.getInboundMessageCount()).isEqualTo(3);
369377
}
370378

379+
@Test
380+
void testMessageEndpointValidation() throws InvalidSseMessageEndpointException {
381+
var uriCaptor = ArgumentCaptor.forClass(URI.class);
382+
verify(sseMessageEndpointValidator).validate(uriCaptor.capture(), matches("/message\\?sessionId=[a-z0-9-]+"));
383+
assertThat(uriCaptor.getValue().toString()).matches("http://localhost:\\d+/sse");
384+
}
385+
386+
@Test
387+
void testMessageEndpointValidationRejects() {
388+
TestSseClientTransport transport = new TestSseClientTransport(webClientBuilder, JSON_MAPPER,
389+
(sseUri, messageEndpoint) -> {
390+
throw new InvalidSseMessageEndpointException("boom", messageEndpoint);
391+
});
392+
393+
try {
394+
// fails to connect
395+
StepVerifier.create(transport.connect(Function.identity()))
396+
.verifyErrorMatches(WebFluxSseClientTransportTests::isInvalidEndpointError);
397+
398+
// Since connection failed, there is no message endpoint, and no message can
399+
// be sent
400+
JSONRPCRequest testMessage = new JSONRPCRequest(McpSchema.JSONRPC_VERSION, "test-method", "test-id",
401+
Map.of("key", "value"));
402+
403+
StepVerifier.create(transport.sendMessage(testMessage))
404+
.verifyErrorMatches(WebFluxSseClientTransportTests::isInvalidEndpointError);
405+
}
406+
finally {
407+
transport.closeGracefully();
408+
}
409+
}
410+
411+
private static boolean isInvalidEndpointError(Throwable e) {
412+
if (e instanceof InvalidSseMessageEndpointException ismee) {
413+
return ismee.getMessageEndpoint().matches("/message\\?sessionId=[a-z0-9-]+")
414+
&& ismee.getMessage().equals("boom");
415+
}
416+
return false;
417+
}
418+
371419
}

0 commit comments

Comments
 (0)