44import multiprocessing
55import re
66import socket
7+ from multiprocessing .connection import Connection
78
89import anyio
910import httpx
2425from mcp .shared ._stream_protocols import WriteStream
2526from mcp .shared .message import SessionMessage
2627from mcp .types import JSONRPCRequest , JSONRPCResponse , Tool
27- from tests .test_helpers import wait_for_server
2828
2929logger = logging .getLogger (__name__ )
3030SERVER_NAME = "test_sse_security_server"
@@ -39,18 +39,6 @@ def reset_sse_starlette_exit_event() -> None:
3939 app_status .should_exit_event = None
4040
4141
42- @pytest .fixture
43- def server_port () -> int :
44- with socket .socket () as s :
45- s .bind (("127.0.0.1" , 0 ))
46- return s .getsockname ()[1 ]
47-
48-
49- @pytest .fixture
50- def server_url (server_port : int ) -> str : # pragma: no cover
51- return f"http://127.0.0.1:{ server_port } "
52-
53-
5442class SecurityTestServer (Server ): # pragma: no cover
5543 def __init__ (self ):
5644 super ().__init__ (SERVER_NAME )
@@ -59,7 +47,9 @@ async def on_list_tools(self) -> list[Tool]:
5947 return []
6048
6149
62- def run_server_with_settings (port : int , security_settings : TransportSecuritySettings | None = None ): # pragma: no cover
50+ def run_server_with_settings (
51+ port_writer : Connection , security_settings : TransportSecuritySettings | None = None
52+ ): # pragma: no cover
6353 """Run the SSE server with specified security settings."""
6454 app = SecurityTestServer ()
6555 sse_transport = SseServerTransport ("/messages/" , security_settings )
@@ -80,47 +70,65 @@ async def handle_sse(request: Request):
8070 ]
8171
8272 starlette_app = Starlette (routes = routes )
83- uvicorn .run (starlette_app , host = "127.0.0.1" , port = port , log_level = "error" )
73+ sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
74+ sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
75+ sock .bind (("127.0.0.1" , 0 ))
76+ sock .listen ()
77+ port = sock .getsockname ()[1 ]
78+ port_writer .send (port )
79+ port_writer .close ()
8480
81+ server = uvicorn .Server (config = uvicorn .Config (app = starlette_app , log_level = "error" ))
82+ server .run (sockets = [sock ])
8583
86- def start_server_process (port : int , security_settings : TransportSecuritySettings | None = None ):
84+
85+ def start_server_process (
86+ security_settings : TransportSecuritySettings | None = None ,
87+ ) -> tuple [multiprocessing .Process , int ]:
8788 """Start server in a separate process."""
88- process = multiprocessing .Process (target = run_server_with_settings , args = (port , security_settings ))
89+ reader , writer = multiprocessing .Pipe (duplex = False )
90+ process = multiprocessing .Process (
91+ target = run_server_with_settings ,
92+ kwargs = {"port_writer" : writer , "security_settings" : security_settings },
93+ )
8994 process .start ()
90- # Wait for server to be ready to accept connections
91- wait_for_server (port )
92- return process
95+ writer .close ()
96+ try :
97+ port = reader .recv ()
98+ finally :
99+ reader .close ()
100+ return process , port
93101
94102
95103@pytest .mark .anyio
96- async def test_sse_security_default_settings (server_port : int ):
104+ async def test_sse_security_default_settings ():
97105 """Test SSE with default security settings (protection disabled)."""
98- process = start_server_process (server_port )
106+ process , port = start_server_process ()
99107
100108 try :
101109 headers = {"Host" : "evil.com" , "Origin" : "http://evil.com" }
102110
103111 async with httpx .AsyncClient (timeout = 5.0 ) as client :
104- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
112+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
105113 assert response .status_code == 200
106114 finally :
107115 process .terminate ()
108116 process .join ()
109117
110118
111119@pytest .mark .anyio
112- async def test_sse_security_invalid_host_header (server_port : int ):
120+ async def test_sse_security_invalid_host_header ():
113121 """Test SSE with invalid Host header."""
114122 # Enable security by providing settings with an empty allowed_hosts list
115123 security_settings = TransportSecuritySettings (enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ])
116- process = start_server_process (server_port , security_settings )
124+ process , port = start_server_process (security_settings )
117125
118126 try :
119127 # Test with invalid host header
120128 headers = {"Host" : "evil.com" }
121129
122130 async with httpx .AsyncClient () as client :
123- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
131+ response = await client .get (f"http://127.0.0.1:{ port } /sse" , headers = headers )
124132 assert response .status_code == 421
125133 assert response .text == "Invalid Host header"
126134
@@ -130,20 +138,20 @@ async def test_sse_security_invalid_host_header(server_port: int):
130138
131139
132140@pytest .mark .anyio
133- async def test_sse_security_invalid_origin_header (server_port : int ):
141+ async def test_sse_security_invalid_origin_header ():
134142 """Test SSE with invalid Origin header."""
135143 # Configure security to allow the host but restrict origins
136144 security_settings = TransportSecuritySettings (
137145 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://localhost:*" ]
138146 )
139- process = start_server_process (server_port , security_settings )
147+ process , port = start_server_process (security_settings )
140148
141149 try :
142150 # Test with invalid origin header
143151 headers = {"Origin" : "http://evil.com" }
144152
145153 async with httpx .AsyncClient () as client :
146- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
154+ response = await client .get (f"http://127.0.0.1:{ port } /sse" , headers = headers )
147155 assert response .status_code == 403
148156 assert response .text == "Invalid Origin header"
149157
@@ -153,20 +161,20 @@ async def test_sse_security_invalid_origin_header(server_port: int):
153161
154162
155163@pytest .mark .anyio
156- async def test_sse_security_post_invalid_content_type (server_port : int ):
164+ async def test_sse_security_post_invalid_content_type ():
157165 """Test POST endpoint with invalid Content-Type header."""
158166 # Configure security to allow the host
159167 security_settings = TransportSecuritySettings (
160168 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
161169 )
162- process = start_server_process (server_port , security_settings )
170+ process , port = start_server_process (security_settings )
163171
164172 try :
165173 async with httpx .AsyncClient (timeout = 5.0 ) as client :
166174 # Test POST with invalid content type
167175 fake_session_id = "12345678123456781234567812345678"
168176 response = await client .post (
169- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " ,
177+ f"http://127.0.0.1:{ port } /messages/?session_id={ fake_session_id } " ,
170178 headers = {"Content-Type" : "text/plain" },
171179 content = "test" ,
172180 )
@@ -175,7 +183,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
175183
176184 # Test POST with missing content type
177185 response = await client .post (
178- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " , content = "test"
186+ f"http://127.0.0.1:{ port } /messages/?session_id={ fake_session_id } " , content = "test"
179187 )
180188 assert response .status_code == 400
181189 assert response .text == "Invalid Content-Type header"
@@ -186,18 +194,18 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
186194
187195
188196@pytest .mark .anyio
189- async def test_sse_security_disabled (server_port : int ):
197+ async def test_sse_security_disabled ():
190198 """Test SSE with security disabled."""
191199 settings = TransportSecuritySettings (enable_dns_rebinding_protection = False )
192- process = start_server_process (server_port , settings )
200+ process , port = start_server_process (settings )
193201
194202 try :
195203 # Test with invalid host header - should still work
196204 headers = {"Host" : "evil.com" }
197205
198206 async with httpx .AsyncClient (timeout = 5.0 ) as client :
199207 # For SSE endpoints, we need to use stream to avoid timeout
200- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
208+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
201209 # Should connect successfully even with invalid host
202210 assert response .status_code == 200
203211
@@ -207,30 +215,30 @@ async def test_sse_security_disabled(server_port: int):
207215
208216
209217@pytest .mark .anyio
210- async def test_sse_security_custom_allowed_hosts (server_port : int ):
218+ async def test_sse_security_custom_allowed_hosts ():
211219 """Test SSE with custom allowed hosts."""
212220 settings = TransportSecuritySettings (
213221 enable_dns_rebinding_protection = True ,
214222 allowed_hosts = ["localhost" , "127.0.0.1" , "custom.host" ],
215223 allowed_origins = ["http://localhost" , "http://127.0.0.1" , "http://custom.host" ],
216224 )
217- process = start_server_process (server_port , settings )
225+ process , port = start_server_process (settings )
218226
219227 try :
220228 # Test with custom allowed host
221229 headers = {"Host" : "custom.host" }
222230
223231 async with httpx .AsyncClient (timeout = 5.0 ) as client :
224232 # For SSE endpoints, we need to use stream to avoid timeout
225- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
233+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
226234 # Should connect successfully with custom host
227235 assert response .status_code == 200
228236
229237 # Test with non-allowed host
230238 headers = {"Host" : "evil.com" }
231239
232240 async with httpx .AsyncClient () as client :
233- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
241+ response = await client .get (f"http://127.0.0.1:{ port } /sse" , headers = headers )
234242 assert response .status_code == 421
235243 assert response .text == "Invalid Host header"
236244
@@ -240,14 +248,14 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
240248
241249
242250@pytest .mark .anyio
243- async def test_sse_security_wildcard_ports (server_port : int ):
251+ async def test_sse_security_wildcard_ports ():
244252 """Test SSE with wildcard port patterns."""
245253 settings = TransportSecuritySettings (
246254 enable_dns_rebinding_protection = True ,
247255 allowed_hosts = ["localhost:*" , "127.0.0.1:*" ],
248256 allowed_origins = ["http://localhost:*" , "http://127.0.0.1:*" ],
249257 )
250- process = start_server_process (server_port , settings )
258+ process , port = start_server_process (settings )
251259
252260 try :
253261 # Test with various port numbers
@@ -256,15 +264,15 @@ async def test_sse_security_wildcard_ports(server_port: int):
256264
257265 async with httpx .AsyncClient (timeout = 5.0 ) as client :
258266 # For SSE endpoints, we need to use stream to avoid timeout
259- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
267+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
260268 # Should connect successfully with any port
261269 assert response .status_code == 200
262270
263271 headers = {"Origin" : f"http://localhost:{ test_port } " }
264272
265273 async with httpx .AsyncClient (timeout = 5.0 ) as client :
266274 # For SSE endpoints, we need to use stream to avoid timeout
267- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
275+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
268276 # Should connect successfully with any port
269277 assert response .status_code == 200
270278
@@ -274,13 +282,13 @@ async def test_sse_security_wildcard_ports(server_port: int):
274282
275283
276284@pytest .mark .anyio
277- async def test_sse_security_post_valid_content_type (server_port : int ):
285+ async def test_sse_security_post_valid_content_type ():
278286 """Test POST endpoint with valid Content-Type headers."""
279287 # Configure security to allow the host
280288 security_settings = TransportSecuritySettings (
281289 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
282290 )
283- process = start_server_process (server_port , security_settings )
291+ process , port = start_server_process (security_settings )
284292
285293 try :
286294 async with httpx .AsyncClient () as client :
@@ -296,7 +304,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
296304 # Use a valid UUID format (even though session won't exist)
297305 fake_session_id = "12345678123456781234567812345678"
298306 response = await client .post (
299- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " ,
307+ f"http://127.0.0.1:{ port } /messages/?session_id={ fake_session_id } " ,
300308 headers = {"Content-Type" : content_type },
301309 json = {"test" : "data" },
302310 )
0 commit comments