33import logging
44import multiprocessing
55import socket
6+ from multiprocessing .connection import Connection
67
78import httpx
89import pytest
1617from mcp .server .sse import SseServerTransport
1718from mcp .server .transport_security import TransportSecuritySettings
1819from mcp .types import Tool
19- from tests .test_helpers import wait_for_server
2020
2121logger = logging .getLogger (__name__ )
2222SERVER_NAME = "test_sse_security_server"
2323
2424
25- @pytest .fixture
26- def server_port () -> int :
27- with socket .socket () as s :
28- s .bind (("127.0.0.1" , 0 ))
29- return s .getsockname ()[1 ]
30-
31-
32- @pytest .fixture
33- def server_url (server_port : int ) -> str : # pragma: no cover
34- return f"http://127.0.0.1:{ server_port } "
35-
36-
3725class SecurityTestServer (Server ): # pragma: no cover
3826 def __init__ (self ):
3927 super ().__init__ (SERVER_NAME )
@@ -42,7 +30,9 @@ async def on_list_tools(self) -> list[Tool]:
4230 return []
4331
4432
45- def run_server_with_settings (port : int , security_settings : TransportSecuritySettings | None = None ): # pragma: no cover
33+ def run_server_with_settings (
34+ port_writer : Connection , security_settings : TransportSecuritySettings | None = None
35+ ): # pragma: no cover
4636 """Run the SSE server with specified security settings."""
4737 app = SecurityTestServer ()
4838 sse_transport = SseServerTransport ("/messages/" , security_settings )
@@ -63,47 +53,65 @@ async def handle_sse(request: Request):
6353 ]
6454
6555 starlette_app = Starlette (routes = routes )
66- uvicorn .run (starlette_app , host = "127.0.0.1" , port = port , log_level = "error" )
56+ sock = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
57+ sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
58+ sock .bind (("127.0.0.1" , 0 ))
59+ sock .listen ()
60+ port = sock .getsockname ()[1 ]
61+ port_writer .send (port )
62+ port_writer .close ()
6763
64+ server = uvicorn .Server (config = uvicorn .Config (app = starlette_app , log_level = "error" ))
65+ server .run (sockets = [sock ])
6866
69- def start_server_process (port : int , security_settings : TransportSecuritySettings | None = None ):
67+
68+ def start_server_process (
69+ security_settings : TransportSecuritySettings | None = None ,
70+ ) -> tuple [multiprocessing .Process , int ]:
7071 """Start server in a separate process."""
71- process = multiprocessing .Process (target = run_server_with_settings , args = (port , security_settings ))
72+ reader , writer = multiprocessing .Pipe (duplex = False )
73+ process = multiprocessing .Process (
74+ target = run_server_with_settings ,
75+ kwargs = {"port_writer" : writer , "security_settings" : security_settings },
76+ )
7277 process .start ()
73- # Wait for server to be ready to accept connections
74- wait_for_server (port )
75- return process
78+ writer .close ()
79+ try :
80+ port = reader .recv ()
81+ finally :
82+ reader .close ()
83+ return process , port
7684
7785
7886@pytest .mark .anyio
79- async def test_sse_security_default_settings (server_port : int ):
87+ async def test_sse_security_default_settings ():
8088 """Test SSE with default security settings (protection disabled)."""
81- process = start_server_process (server_port )
89+ process , port = start_server_process ()
8290
8391 try :
8492 headers = {"Host" : "evil.com" , "Origin" : "http://evil.com" }
8593
8694 async with httpx .AsyncClient (timeout = 5.0 ) as client :
87- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
95+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
8896 assert response .status_code == 200
8997 finally :
9098 process .terminate ()
9199 process .join ()
92100
93101
94102@pytest .mark .anyio
95- async def test_sse_security_invalid_host_header (server_port : int ):
103+ async def test_sse_security_invalid_host_header ():
96104 """Test SSE with invalid Host header."""
97105 # Enable security by providing settings with an empty allowed_hosts list
98106 security_settings = TransportSecuritySettings (enable_dns_rebinding_protection = True , allowed_hosts = ["example.com" ])
99- process = start_server_process (server_port , security_settings )
107+ process , port = start_server_process (security_settings )
100108
101109 try :
102110 # Test with invalid host header
103111 headers = {"Host" : "evil.com" }
104112
105113 async with httpx .AsyncClient () as client :
106- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
114+ response = await client .get (f"http://127.0.0.1:{ port } /sse" , headers = headers )
107115 assert response .status_code == 421
108116 assert response .text == "Invalid Host header"
109117
@@ -113,20 +121,20 @@ async def test_sse_security_invalid_host_header(server_port: int):
113121
114122
115123@pytest .mark .anyio
116- async def test_sse_security_invalid_origin_header (server_port : int ):
124+ async def test_sse_security_invalid_origin_header ():
117125 """Test SSE with invalid Origin header."""
118126 # Configure security to allow the host but restrict origins
119127 security_settings = TransportSecuritySettings (
120128 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://localhost:*" ]
121129 )
122- process = start_server_process (server_port , security_settings )
130+ process , port = start_server_process (security_settings )
123131
124132 try :
125133 # Test with invalid origin header
126134 headers = {"Origin" : "http://evil.com" }
127135
128136 async with httpx .AsyncClient () as client :
129- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
137+ response = await client .get (f"http://127.0.0.1:{ port } /sse" , headers = headers )
130138 assert response .status_code == 403
131139 assert response .text == "Invalid Origin header"
132140
@@ -136,20 +144,20 @@ async def test_sse_security_invalid_origin_header(server_port: int):
136144
137145
138146@pytest .mark .anyio
139- async def test_sse_security_post_invalid_content_type (server_port : int ):
147+ async def test_sse_security_post_invalid_content_type ():
140148 """Test POST endpoint with invalid Content-Type header."""
141149 # Configure security to allow the host
142150 security_settings = TransportSecuritySettings (
143151 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
144152 )
145- process = start_server_process (server_port , security_settings )
153+ process , port = start_server_process (security_settings )
146154
147155 try :
148156 async with httpx .AsyncClient (timeout = 5.0 ) as client :
149157 # Test POST with invalid content type
150158 fake_session_id = "12345678123456781234567812345678"
151159 response = await client .post (
152- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " ,
160+ f"http://127.0.0.1:{ port } /messages/?session_id={ fake_session_id } " ,
153161 headers = {"Content-Type" : "text/plain" },
154162 content = "test" ,
155163 )
@@ -158,7 +166,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
158166
159167 # Test POST with missing content type
160168 response = await client .post (
161- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " , content = "test"
169+ f"http://127.0.0.1:{ port } /messages/?session_id={ fake_session_id } " , content = "test"
162170 )
163171 assert response .status_code == 400
164172 assert response .text == "Invalid Content-Type header"
@@ -169,18 +177,18 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
169177
170178
171179@pytest .mark .anyio
172- async def test_sse_security_disabled (server_port : int ):
180+ async def test_sse_security_disabled ():
173181 """Test SSE with security disabled."""
174182 settings = TransportSecuritySettings (enable_dns_rebinding_protection = False )
175- process = start_server_process (server_port , settings )
183+ process , port = start_server_process (settings )
176184
177185 try :
178186 # Test with invalid host header - should still work
179187 headers = {"Host" : "evil.com" }
180188
181189 async with httpx .AsyncClient (timeout = 5.0 ) as client :
182190 # For SSE endpoints, we need to use stream to avoid timeout
183- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
191+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
184192 # Should connect successfully even with invalid host
185193 assert response .status_code == 200
186194
@@ -190,30 +198,30 @@ async def test_sse_security_disabled(server_port: int):
190198
191199
192200@pytest .mark .anyio
193- async def test_sse_security_custom_allowed_hosts (server_port : int ):
201+ async def test_sse_security_custom_allowed_hosts ():
194202 """Test SSE with custom allowed hosts."""
195203 settings = TransportSecuritySettings (
196204 enable_dns_rebinding_protection = True ,
197205 allowed_hosts = ["localhost" , "127.0.0.1" , "custom.host" ],
198206 allowed_origins = ["http://localhost" , "http://127.0.0.1" , "http://custom.host" ],
199207 )
200- process = start_server_process (server_port , settings )
208+ process , port = start_server_process (settings )
201209
202210 try :
203211 # Test with custom allowed host
204212 headers = {"Host" : "custom.host" }
205213
206214 async with httpx .AsyncClient (timeout = 5.0 ) as client :
207215 # For SSE endpoints, we need to use stream to avoid timeout
208- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
216+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
209217 # Should connect successfully with custom host
210218 assert response .status_code == 200
211219
212220 # Test with non-allowed host
213221 headers = {"Host" : "evil.com" }
214222
215223 async with httpx .AsyncClient () as client :
216- response = await client .get (f"http://127.0.0.1:{ server_port } /sse" , headers = headers )
224+ response = await client .get (f"http://127.0.0.1:{ port } /sse" , headers = headers )
217225 assert response .status_code == 421
218226 assert response .text == "Invalid Host header"
219227
@@ -223,14 +231,14 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
223231
224232
225233@pytest .mark .anyio
226- async def test_sse_security_wildcard_ports (server_port : int ):
234+ async def test_sse_security_wildcard_ports ():
227235 """Test SSE with wildcard port patterns."""
228236 settings = TransportSecuritySettings (
229237 enable_dns_rebinding_protection = True ,
230238 allowed_hosts = ["localhost:*" , "127.0.0.1:*" ],
231239 allowed_origins = ["http://localhost:*" , "http://127.0.0.1:*" ],
232240 )
233- process = start_server_process (server_port , settings )
241+ process , port = start_server_process (settings )
234242
235243 try :
236244 # Test with various port numbers
@@ -239,15 +247,15 @@ async def test_sse_security_wildcard_ports(server_port: int):
239247
240248 async with httpx .AsyncClient (timeout = 5.0 ) as client :
241249 # For SSE endpoints, we need to use stream to avoid timeout
242- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
250+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
243251 # Should connect successfully with any port
244252 assert response .status_code == 200
245253
246254 headers = {"Origin" : f"http://localhost:{ test_port } " }
247255
248256 async with httpx .AsyncClient (timeout = 5.0 ) as client :
249257 # For SSE endpoints, we need to use stream to avoid timeout
250- async with client .stream ("GET" , f"http://127.0.0.1:{ server_port } /sse" , headers = headers ) as response :
258+ async with client .stream ("GET" , f"http://127.0.0.1:{ port } /sse" , headers = headers ) as response :
251259 # Should connect successfully with any port
252260 assert response .status_code == 200
253261
@@ -257,13 +265,13 @@ async def test_sse_security_wildcard_ports(server_port: int):
257265
258266
259267@pytest .mark .anyio
260- async def test_sse_security_post_valid_content_type (server_port : int ):
268+ async def test_sse_security_post_valid_content_type ():
261269 """Test POST endpoint with valid Content-Type headers."""
262270 # Configure security to allow the host
263271 security_settings = TransportSecuritySettings (
264272 enable_dns_rebinding_protection = True , allowed_hosts = ["127.0.0.1:*" ], allowed_origins = ["http://127.0.0.1:*" ]
265273 )
266- process = start_server_process (server_port , security_settings )
274+ process , port = start_server_process (security_settings )
267275
268276 try :
269277 async with httpx .AsyncClient () as client :
@@ -279,7 +287,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
279287 # Use a valid UUID format (even though session won't exist)
280288 fake_session_id = "12345678123456781234567812345678"
281289 response = await client .post (
282- f"http://127.0.0.1:{ server_port } /messages/?session_id={ fake_session_id } " ,
290+ f"http://127.0.0.1:{ port } /messages/?session_id={ fake_session_id } " ,
283291 headers = {"Content-Type" : content_type },
284292 json = {"test" : "data" },
285293 )
0 commit comments