Skip to content

Commit f7301aa

Browse files
committed
tests: resolve TOCTOU port race conditions for streamable-HTTP/SSE tests
1 parent 3eb5799 commit f7301aa

6 files changed

Lines changed: 270 additions & 358 deletions

File tree

tests/client/test_http_unicode.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
(server→client and client→server) using the streamable HTTP transport.
55
"""
66

7-
import multiprocessing
87
import socket
98
from collections.abc import AsyncGenerator, Generator
109
from contextlib import asynccontextmanager
10+
from multiprocessing.connection import Connection
1111

1212
import pytest
1313
from starlette.applications import Starlette
@@ -19,7 +19,7 @@
1919
from mcp.server import Server, ServerRequestContext
2020
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
2121
from mcp.types import TextContent, Tool
22-
from tests.test_helpers import wait_for_server
22+
from tests.test_helpers import running_server
2323

2424
# Test constants with various Unicode characters
2525
UNICODE_TEST_STRINGS = {
@@ -41,7 +41,7 @@
4141
}
4242

4343

44-
def run_unicode_server(port: int) -> None: # pragma: no cover
44+
def run_unicode_server(port_writer: Connection) -> None: # pragma: no cover
4545
"""Run the Unicode test server in a separate process."""
4646
import uvicorn
4747

@@ -137,43 +137,28 @@ async def lifespan(app: Starlette) -> AsyncGenerator[None, None]:
137137
lifespan=lifespan,
138138
)
139139

140+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
141+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
142+
sock.bind(("127.0.0.1", 0))
143+
sock.listen()
144+
port = sock.getsockname()[1]
145+
port_writer.send(port)
146+
port_writer.close()
147+
140148
# Run the server
141149
config = uvicorn.Config(
142150
app=app,
143-
host="127.0.0.1",
144-
port=port,
145151
log_level="error",
146152
)
147153
uvicorn_server = uvicorn.Server(config)
148-
uvicorn_server.run()
149-
150-
151-
@pytest.fixture
152-
def unicode_server_port() -> int:
153-
"""Find an available port for the Unicode test server."""
154-
with socket.socket() as s:
155-
s.bind(("127.0.0.1", 0))
156-
return s.getsockname()[1]
154+
uvicorn_server.run(sockets=[sock])
157155

158156

159157
@pytest.fixture
160-
def running_unicode_server(unicode_server_port: int) -> Generator[str, None, None]:
158+
def running_unicode_server() -> Generator[str, None, None]:
161159
"""Start a Unicode test server in a separate process."""
162-
proc = multiprocessing.Process(target=run_unicode_server, kwargs={"port": unicode_server_port}, daemon=True)
163-
proc.start()
164-
165-
# Wait for server to be ready
166-
wait_for_server(unicode_server_port)
167-
168-
try:
169-
yield f"http://127.0.0.1:{unicode_server_port}"
170-
finally:
171-
# Clean up - try graceful termination first
172-
proc.terminate()
173-
proc.join(timeout=2)
174-
if proc.is_alive(): # pragma: no cover
175-
proc.kill()
176-
proc.join(timeout=1)
160+
with running_server(run_unicode_server) as url:
161+
yield url
177162

178163

179164
@pytest.mark.anyio

tests/server/test_sse_security.py

Lines changed: 55 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import multiprocessing
55
import socket
6+
from multiprocessing.connection import Connection
67

78
import httpx
89
import pytest
@@ -16,24 +17,11 @@
1617
from mcp.server.sse import SseServerTransport
1718
from mcp.server.transport_security import TransportSecuritySettings
1819
from mcp.types import Tool
19-
from tests.test_helpers import wait_for_server
2020

2121
logger = logging.getLogger(__name__)
2222
SERVER_NAME = "test_sse_security_server"
2323

2424

25-
@pytest.fixture
26-
def server_port() -> int:
27-
with socket.socket() as s:
28-
s.bind(("127.0.0.1", 0))
29-
return s.getsockname()[1]
30-
31-
32-
@pytest.fixture
33-
def server_url(server_port: int) -> str: # pragma: no cover
34-
return f"http://127.0.0.1:{server_port}"
35-
36-
3725
class SecurityTestServer(Server): # pragma: no cover
3826
def __init__(self):
3927
super().__init__(SERVER_NAME)
@@ -42,7 +30,9 @@ async def on_list_tools(self) -> list[Tool]:
4230
return []
4331

4432

45-
def run_server_with_settings(port: int, security_settings: TransportSecuritySettings | None = None): # pragma: no cover
33+
def run_server_with_settings(
34+
port_writer: Connection, security_settings: TransportSecuritySettings | None = None
35+
): # pragma: no cover
4636
"""Run the SSE server with specified security settings."""
4737
app = SecurityTestServer()
4838
sse_transport = SseServerTransport("/messages/", security_settings)
@@ -63,47 +53,65 @@ async def handle_sse(request: Request):
6353
]
6454

6555
starlette_app = Starlette(routes=routes)
66-
uvicorn.run(starlette_app, host="127.0.0.1", port=port, log_level="error")
56+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
57+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
58+
sock.bind(("127.0.0.1", 0))
59+
sock.listen()
60+
port = sock.getsockname()[1]
61+
port_writer.send(port)
62+
port_writer.close()
6763

64+
server = uvicorn.Server(config=uvicorn.Config(app=starlette_app, log_level="error"))
65+
server.run(sockets=[sock])
6866

69-
def start_server_process(port: int, security_settings: TransportSecuritySettings | None = None):
67+
68+
def start_server_process(
69+
security_settings: TransportSecuritySettings | None = None,
70+
) -> tuple[multiprocessing.Process, int]:
7071
"""Start server in a separate process."""
71-
process = multiprocessing.Process(target=run_server_with_settings, args=(port, security_settings))
72+
reader, writer = multiprocessing.Pipe(duplex=False)
73+
process = multiprocessing.Process(
74+
target=run_server_with_settings,
75+
kwargs={"port_writer": writer, "security_settings": security_settings},
76+
)
7277
process.start()
73-
# Wait for server to be ready to accept connections
74-
wait_for_server(port)
75-
return process
78+
writer.close()
79+
try:
80+
port = reader.recv()
81+
finally:
82+
reader.close()
83+
return process, port
7684

7785

7886
@pytest.mark.anyio
79-
async def test_sse_security_default_settings(server_port: int):
87+
async def test_sse_security_default_settings():
8088
"""Test SSE with default security settings (protection disabled)."""
81-
process = start_server_process(server_port)
89+
process, port = start_server_process()
8290

8391
try:
8492
headers = {"Host": "evil.com", "Origin": "http://evil.com"}
8593

8694
async with httpx.AsyncClient(timeout=5.0) as client:
87-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
95+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
8896
assert response.status_code == 200
8997
finally:
9098
process.terminate()
9199
process.join()
92100

93101

94102
@pytest.mark.anyio
95-
async def test_sse_security_invalid_host_header(server_port: int):
103+
async def test_sse_security_invalid_host_header():
96104
"""Test SSE with invalid Host header."""
97105
# Enable security by providing settings with an empty allowed_hosts list
98106
security_settings = TransportSecuritySettings(enable_dns_rebinding_protection=True, allowed_hosts=["example.com"])
99-
process = start_server_process(server_port, security_settings)
107+
process, port = start_server_process(security_settings)
100108

101109
try:
102110
# Test with invalid host header
103111
headers = {"Host": "evil.com"}
104112

105113
async with httpx.AsyncClient() as client:
106-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
114+
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
107115
assert response.status_code == 421
108116
assert response.text == "Invalid Host header"
109117

@@ -113,20 +121,20 @@ async def test_sse_security_invalid_host_header(server_port: int):
113121

114122

115123
@pytest.mark.anyio
116-
async def test_sse_security_invalid_origin_header(server_port: int):
124+
async def test_sse_security_invalid_origin_header():
117125
"""Test SSE with invalid Origin header."""
118126
# Configure security to allow the host but restrict origins
119127
security_settings = TransportSecuritySettings(
120128
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://localhost:*"]
121129
)
122-
process = start_server_process(server_port, security_settings)
130+
process, port = start_server_process(security_settings)
123131

124132
try:
125133
# Test with invalid origin header
126134
headers = {"Origin": "http://evil.com"}
127135

128136
async with httpx.AsyncClient() as client:
129-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
137+
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
130138
assert response.status_code == 403
131139
assert response.text == "Invalid Origin header"
132140

@@ -136,20 +144,20 @@ async def test_sse_security_invalid_origin_header(server_port: int):
136144

137145

138146
@pytest.mark.anyio
139-
async def test_sse_security_post_invalid_content_type(server_port: int):
147+
async def test_sse_security_post_invalid_content_type():
140148
"""Test POST endpoint with invalid Content-Type header."""
141149
# Configure security to allow the host
142150
security_settings = TransportSecuritySettings(
143151
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
144152
)
145-
process = start_server_process(server_port, security_settings)
153+
process, port = start_server_process(security_settings)
146154

147155
try:
148156
async with httpx.AsyncClient(timeout=5.0) as client:
149157
# Test POST with invalid content type
150158
fake_session_id = "12345678123456781234567812345678"
151159
response = await client.post(
152-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
160+
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}",
153161
headers={"Content-Type": "text/plain"},
154162
content="test",
155163
)
@@ -158,7 +166,7 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
158166

159167
# Test POST with missing content type
160168
response = await client.post(
161-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}", content="test"
169+
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}", content="test"
162170
)
163171
assert response.status_code == 400
164172
assert response.text == "Invalid Content-Type header"
@@ -169,18 +177,18 @@ async def test_sse_security_post_invalid_content_type(server_port: int):
169177

170178

171179
@pytest.mark.anyio
172-
async def test_sse_security_disabled(server_port: int):
180+
async def test_sse_security_disabled():
173181
"""Test SSE with security disabled."""
174182
settings = TransportSecuritySettings(enable_dns_rebinding_protection=False)
175-
process = start_server_process(server_port, settings)
183+
process, port = start_server_process(settings)
176184

177185
try:
178186
# Test with invalid host header - should still work
179187
headers = {"Host": "evil.com"}
180188

181189
async with httpx.AsyncClient(timeout=5.0) as client:
182190
# For SSE endpoints, we need to use stream to avoid timeout
183-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
191+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
184192
# Should connect successfully even with invalid host
185193
assert response.status_code == 200
186194

@@ -190,30 +198,30 @@ async def test_sse_security_disabled(server_port: int):
190198

191199

192200
@pytest.mark.anyio
193-
async def test_sse_security_custom_allowed_hosts(server_port: int):
201+
async def test_sse_security_custom_allowed_hosts():
194202
"""Test SSE with custom allowed hosts."""
195203
settings = TransportSecuritySettings(
196204
enable_dns_rebinding_protection=True,
197205
allowed_hosts=["localhost", "127.0.0.1", "custom.host"],
198206
allowed_origins=["http://localhost", "http://127.0.0.1", "http://custom.host"],
199207
)
200-
process = start_server_process(server_port, settings)
208+
process, port = start_server_process(settings)
201209

202210
try:
203211
# Test with custom allowed host
204212
headers = {"Host": "custom.host"}
205213

206214
async with httpx.AsyncClient(timeout=5.0) as client:
207215
# For SSE endpoints, we need to use stream to avoid timeout
208-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
216+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
209217
# Should connect successfully with custom host
210218
assert response.status_code == 200
211219

212220
# Test with non-allowed host
213221
headers = {"Host": "evil.com"}
214222

215223
async with httpx.AsyncClient() as client:
216-
response = await client.get(f"http://127.0.0.1:{server_port}/sse", headers=headers)
224+
response = await client.get(f"http://127.0.0.1:{port}/sse", headers=headers)
217225
assert response.status_code == 421
218226
assert response.text == "Invalid Host header"
219227

@@ -223,14 +231,14 @@ async def test_sse_security_custom_allowed_hosts(server_port: int):
223231

224232

225233
@pytest.mark.anyio
226-
async def test_sse_security_wildcard_ports(server_port: int):
234+
async def test_sse_security_wildcard_ports():
227235
"""Test SSE with wildcard port patterns."""
228236
settings = TransportSecuritySettings(
229237
enable_dns_rebinding_protection=True,
230238
allowed_hosts=["localhost:*", "127.0.0.1:*"],
231239
allowed_origins=["http://localhost:*", "http://127.0.0.1:*"],
232240
)
233-
process = start_server_process(server_port, settings)
241+
process, port = start_server_process(settings)
234242

235243
try:
236244
# Test with various port numbers
@@ -239,15 +247,15 @@ async def test_sse_security_wildcard_ports(server_port: int):
239247

240248
async with httpx.AsyncClient(timeout=5.0) as client:
241249
# For SSE endpoints, we need to use stream to avoid timeout
242-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
250+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
243251
# Should connect successfully with any port
244252
assert response.status_code == 200
245253

246254
headers = {"Origin": f"http://localhost:{test_port}"}
247255

248256
async with httpx.AsyncClient(timeout=5.0) as client:
249257
# For SSE endpoints, we need to use stream to avoid timeout
250-
async with client.stream("GET", f"http://127.0.0.1:{server_port}/sse", headers=headers) as response:
258+
async with client.stream("GET", f"http://127.0.0.1:{port}/sse", headers=headers) as response:
251259
# Should connect successfully with any port
252260
assert response.status_code == 200
253261

@@ -257,13 +265,13 @@ async def test_sse_security_wildcard_ports(server_port: int):
257265

258266

259267
@pytest.mark.anyio
260-
async def test_sse_security_post_valid_content_type(server_port: int):
268+
async def test_sse_security_post_valid_content_type():
261269
"""Test POST endpoint with valid Content-Type headers."""
262270
# Configure security to allow the host
263271
security_settings = TransportSecuritySettings(
264272
enable_dns_rebinding_protection=True, allowed_hosts=["127.0.0.1:*"], allowed_origins=["http://127.0.0.1:*"]
265273
)
266-
process = start_server_process(server_port, security_settings)
274+
process, port = start_server_process(security_settings)
267275

268276
try:
269277
async with httpx.AsyncClient() as client:
@@ -279,7 +287,7 @@ async def test_sse_security_post_valid_content_type(server_port: int):
279287
# Use a valid UUID format (even though session won't exist)
280288
fake_session_id = "12345678123456781234567812345678"
281289
response = await client.post(
282-
f"http://127.0.0.1:{server_port}/messages/?session_id={fake_session_id}",
290+
f"http://127.0.0.1:{port}/messages/?session_id={fake_session_id}",
283291
headers={"Content-Type": content_type},
284292
json={"test": "data"},
285293
)

0 commit comments

Comments
 (0)