Skip to content

Commit d764062

Browse files
committed
tests: resolve TOCTOU port race conditions for streamable-HTTP/SSE tests
1 parent 616476f commit d764062

6 files changed

Lines changed: 270 additions & 358 deletions

File tree

tests/client/test_http_unicode.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
(server→client and client→server) using the streamable HTTP transport.
55
"""
66

7-
import multiprocessing
87
import socket
98
from collections.abc import AsyncGenerator, Generator
109
from contextlib import asynccontextmanager
10+
from multiprocessing.connection import Connection
1111

1212
import pytest
1313
from starlette.applications import Starlette
@@ -19,7 +19,7 @@
1919
from mcp.server import Server, ServerRequestContext
2020
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2121
from mcp.types import TextContent, Tool
22-
from tests.test_helpers import wait_for_server
22+
from tests.test_helpers import running_server
2323

2424
# Test constants with various Unicode characters
2525
UNICODE_TEST_STRINGS = {
@@ -41,7 +41,7 @@
4141
}
4242

4343

44-
def run_unicode_server(port: int) -> None: # pragma: no cover
44+
def run_unicode_server(port_writer: Connection) -> None: # pragma: no cover
4545
"""Run the Unicode test server in a separate process."""
4646
import uvicorn
4747

@@ -137,43 +137,28 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
137137
lifespan=lifespan,
138138
)
139139

140+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
141+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
142+
sock.bind(("127.0.0.1", 0))
143+
sock.listen()
144+
port = sock.getsockname()[1]
145+
port_writer.send(port)
146+
port_writer.close()
147+
140148
# Run the server
141149
config = uvicorn.Config(
142150
app=app,
143-
host="127.0.0.1",
144-
port=port,
145151
log_level="error",
146152
)
147153
uvicorn_server = uvicorn.Server(config)
148-
uvicorn_server.run()
149-
150-
151-
@pytest.fixture
152-
def unicode_server_port() -> int:
153-
"""Find an available port for the Unicode test server."""
154-
with socket.socket() as s:
155-
s.bind(("127.0.0.1", 0))
156-
return s.getsockname()[1]
154+
uvicorn_server.run(sockets=[sock])
157155

158156

159157
@pytest.fixture
160-
def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]:
158+
def running_unicode_server() -> Generator[str, None, None]:
161159
"""Start a Unicode test server in a separate process."""
162-
proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True)
163-
proc.start()
164-
165-
# Wait for server to be ready
166-
wait_for_server(unicode_server_port)
167-
168-
try:
169-
yield f"http://127.0.0.1:{unicode_server_port}"
170-
finally:
171-
# Clean up - try graceful termination first
172-
proc.terminate()
173-
proc.join(timeout=2)
174-
if proc.is_alive(): # pragma: no cover
175-
proc.kill()
176-
proc.join(timeout=1)
160+
with running_server(run_unicode_server) as url:
161+
yield url
177162

178163

179164
@pytest.mark.anyio

tests/server/test_sse_security.py

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import multiprocessing
55
import re
66
import socket
7+
from multiprocessing.connection import Connection
78

89
import anyio
910
import httpx
@@ -24,7 +25,6 @@
2425
from mcp.shared._stream_protocols import WriteStream
2526
from mcp.shared.message import SessionMessage
2627
from mcp.types import JSONRPCRequest, JSONRPCResponse, Tool
27-
from tests.test_helpers import wait_for_server
2828

2929
logger = logging.getLogger(__name__)
3030
SERVER_NAME = "test_sse_security_server"
@@ -39,18 +39,6 @@ def reset_sse_starlette_exit_event() -> None:
3939
app_status.should_exit_event = None
4040

4141

42-
@pytest.fixture
43-
def server_port() -> int:
44-
with socket.socket() as s:
45-
s.bind(("127.0.0.1", 0))
46-
return s.getsockname()[1]
47-
48-
49-
@pytest.fixture
50-
def server_url(server_port: int) -> str: # pragma: no cover
51-
return f"http://127.0.0.1:{server_port}"
52-
53-
5442
class SecurityTestServer(Server): # pragma: no cover
5543
def __init__(self):
5644
super().__init__(SERVER_NAME)
@@ -59,7 +47,9 @@ async def on_list_tools(self) -> list[Tool]:
5947
return []
6048

6149

62-
def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover
50+
def run_server_with_settings(
51+
port_writer: Connection, security_settings: TransportSecuritySettings | None = None
52+
): # pragma: no cover
6353
"""Run the SSE server with specified security settings."""
6454
app = SecurityTestServer()
6555
sse_transport = SseServerTransport("/messages/", security_settings)
@@ -80,47 +70,65 @@ async def handle_sse(request: Request):
8070
]
8171

8272
starlette_app = Starlette(routes=routes)
83-
uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error")
73+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
74+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
75+
sock.bind(("127.0.0.1", 0))
76+
sock.listen()
77+
port = sock.getsockname()[1]
78+
port_writer.send(port)
79+
port_writer.close()
8480

81+
server = uvicorn.Server(config=uvicorn.Config(app=starlette_app, log_level="error"))
82+
server.run(sockets=[sock])
8583

86-
def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):
84+
85+
def start_server_process(
86+
security_settings: TransportSecuritySettings | None = None,
87+
) -> tuple[multiprocessing.Process, int]:
8788
"""Start server in a separate process."""
88-
process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings))
89+
reader, writer = multiprocessing.Pipe(duplex=False)
90+
process = multiprocessing.Process(
91+
target=run_server_with_settings,
92+
kwargs={"port_writer": writer, "security_settings": security_settings},
93+
)
8994
process.start()
90-
# Wait for server to be ready to accept connections
91-
wait_for_server(port)
92-
return process
95+
writer.close()
96+
try:
97+
port = reader.recv()
98+
finally:
99+
reader.close()
100+
return process, port
93101

94102

95103
@pytest.mark.anyio
96-
async def test_sse_security_default_settings(server_port: int):
104+
async def test_sse_security_default_settings():
97105
"""Test SSE with default security settings (protection disabled)."""
98-
process = start_server_process(server_port)
106+
process, port = start_server_process()
99107

100108
try:
101109
headers = {"Host": "evil.com", "Origin": "http://evil.com"}
102110

103111
async with httpx.AsyncClient(timeout=5.0) as client:
104-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
112+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
105113
assert response.status_code == 200
106114
finally:
107115
process.terminate()
108116
process.join()
109117

110118

111119
@pytest.mark.anyio
112-
async def test_sse_security_invalid_host_header(server_port: int):
120+
async def test_sse_security_invalid_host_header():
113121
"""Test SSE with invalid Host header."""
114122
# Enable security by providing settings with an empty allowed_hosts list
115123
security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
116-
process = start_server_process(server_port, security_settings)
124+
process, port = start_server_process(security_settings)
117125

118126
try:
119127
# Test with invalid host header
120128
headers = {"Host": "evil.com"}
121129

122130
async with httpx.AsyncClient() as client:
123-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
131+
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
124132
assert response.status_code == 421
125133
assert response.text == "Invalid Host header"
126134

@@ -130,20 +138,20 @@ async def test_sse_security_invalid_host_header(server_port: int):
130138

131139

132140
@pytest.mark.anyio
133-
async def test_sse_security_invalid_origin_header(server_port: int):
141+
async def test_sse_security_invalid_origin_header():
134142
"""Test SSE with invalid Origin header."""
135143
# Configure security to allow the host but restrict origins
136144
security_settings = TransportSecuritySettings(
137145
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
138146
)
139-
process = start_server_process(server_port, security_settings)
147+
process, port = start_server_process(security_settings)
140148

141149
try:
142150
# Test with invalid origin header
143151
headers = {"Origin": "http://evil.com"}
144152

145153
async with httpx.AsyncClient() as client:
146-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
154+
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
147155
assert response.status_code == 403
148156
assert response.text == "Invalid Origin header"
149157

@@ -153,20 +161,20 @@ async def test_sse_security_invalid_origin_header(server_port: int):
153161

154162

155163
@pytest.mark.anyio
156-
async def test_sse_security_post_invalid_content_type(server_port: int):
164+
async def test_sse_security_post_invalid_content_type():
157165
"""Test POST endpoint with invalid Content-Type header."""
158166
# Configure security to allow the host
159167
security_settings = TransportSecuritySettings(
160168
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
161169
)
162-
process = start_server_process(server_port, security_settings)
170+
process, port = start_server_process(security_settings)
163171

164172
try:
165173
async with httpx.AsyncClient(timeout=5.0) as client:
166174
# Test POST with invalid content type
167175
fake_session_id = "12345678123456781234567812345678"
168176
response = await client.post(
169-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
177+
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}",
170178
headers={"Content-Type": "text/plain"},
171179
content="test",
172180
)
@@ -175,7 +183,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
175183

176184
# Test POST with missing content type
177185
response = await client.post(
178-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test"
186+
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}", content="test"
179187
)
180188
assert response.status_code == 400
181189
assert response.text == "Invalid Content-Type header"
@@ -186,18 +194,18 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
186194

187195

188196
@pytest.mark.anyio
189-
async def test_sse_security_disabled(server_port: int):
197+
async def test_sse_security_disabled():
190198
"""Test SSE with security disabled."""
191199
settings = TransportSecuritySettings(enable_dns_rebinding_protection=False)
192-
process = start_server_process(server_port, settings)
200+
process, port = start_server_process(settings)
193201

194202
try:
195203
# Test with invalid host header - should still work
196204
headers = {"Host": "evil.com"}
197205

198206
async with httpx.AsyncClient(timeout=5.0) as client:
199207
# For SSE endpoints, we need to use stream to avoid timeout
200-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
208+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
201209
# Should connect successfully even with invalid host
202210
assert response.status_code == 200
203211

@@ -207,30 +215,30 @@ async def test_sse_security_disabled(server_port: int):
207215

208216

209217
@pytest.mark.anyio
210-
async def test_sse_security_custom_allowed_hosts(server_port: int):
218+
async def test_sse_security_custom_allowed_hosts():
211219
"""Test SSE with custom allowed hosts."""
212220
settings = TransportSecuritySettings(
213221
enable_dns_rebinding_protection=True,
214222
allowed_hosts=["localhost", "127.0.0.1", "custom.host"],
215223
allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"],
216224
)
217-
process = start_server_process(server_port, settings)
225+
process, port = start_server_process(settings)
218226

219227
try:
220228
# Test with custom allowed host
221229
headers = {"Host": "custom.host"}
222230

223231
async with httpx.AsyncClient(timeout=5.0) as client:
224232
# For SSE endpoints, we need to use stream to avoid timeout
225-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
233+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
226234
# Should connect successfully with custom host
227235
assert response.status_code == 200
228236

229237
# Test with non-allowed host
230238
headers = {"Host": "evil.com"}
231239

232240
async with httpx.AsyncClient() as client:
233-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
241+
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
234242
assert response.status_code == 421
235243
assert response.text == "Invalid Host header"
236244

@@ -240,14 +248,14 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
240248

241249

242250
@pytest.mark.anyio
243-
async def test_sse_security_wildcard_ports(server_port: int):
251+
async def test_sse_security_wildcard_ports():
244252
"""Test SSE with wildcard port patterns."""
245253
settings = TransportSecuritySettings(
246254
enable_dns_rebinding_protection=True,
247255
allowed_hosts=["localhost:*", "127.0.0.1:*"],
248256
allowed_origins=["http://localhost:*", "http://127.0.0.1:*"],
249257
)
250-
process = start_server_process(server_port, settings)
258+
process, port = start_server_process(settings)
251259

252260
try:
253261
# Test with various port numbers
@@ -256,15 +264,15 @@ async def test_sse_security_wildcard_ports(server_port: int):
256264

257265
async with httpx.AsyncClient(timeout=5.0) as client:
258266
# For SSE endpoints, we need to use stream to avoid timeout
259-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
267+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
260268
# Should connect successfully with any port
261269
assert response.status_code == 200
262270

263271
headers = {"Origin": f"http://localhost:{test_port}"}
264272

265273
async with httpx.AsyncClient(timeout=5.0) as client:
266274
# For SSE endpoints, we need to use stream to avoid timeout
267-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
275+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
268276
# Should connect successfully with any port
269277
assert response.status_code == 200
270278

@@ -274,13 +282,13 @@ async def test_sse_security_wildcard_ports(server_port: int):
274282

275283

276284
@pytest.mark.anyio
277-
async def test_sse_security_post_valid_content_type(server_port: int):
285+
async def test_sse_security_post_valid_content_type():
278286
"""Test POST endpoint with valid Content-Type headers."""
279287
# Configure security to allow the host
280288
security_settings = TransportSecuritySettings(
281289
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
282290
)
283-
process = start_server_process(server_port, security_settings)
291+
process, port = start_server_process(security_settings)
284292

285293
try:
286294
async with httpx.AsyncClient() as client:
@@ -296,7 +304,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
296304
# Use a valid UUID format (even though session won't exist)
297305
fake_session_id = "12345678123456781234567812345678"
298306
response = await client.post(
299-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
307+
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}",
300308
headers={"Content-Type": content_type},
301309
json={"test": "data"},
302310
)

0 commit comments

Comments
 (0)