Skip to content

Commit 623b3a3

Browse files
committed
test: cover EOF drain branches
1 parent 1380ede commit 623b3a3

3 files changed

Lines changed: 98 additions & 2 deletions

File tree

src/mcp/shared/jsonrpc_dispatcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,9 @@ async def run(
445445
# On normal EOF, let already-received handlers drain
446446
# their responses before the task group exits.
447447
tg.cancel_scope.cancel()
448-
elif self._read_eof_drain_timeout_seconds is not None:
448+
elif self._read_eof_drain_timeout_seconds is None:
449+
pass
450+
else:
449451
tg.cancel_scope.deadline = anyio.current_time() + self._read_eof_drain_timeout_seconds
450452
finally:
451453
# Covers the cancel/crash paths where the inline fan-out above is

tests/shared/test_jsonrpc_dispatcher.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,62 @@ async def drive() -> None:
249249
s2c_recv.close()
250250

251251

252+
@pytest.mark.anyio
253+
async def test_run_closes_write_stream_after_clean_eof_without_drain_timeout():
254+
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
255+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
256+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
257+
c2s_recv,
258+
s2c_send,
259+
close_write_stream_on_read_close=False,
260+
read_eof_drain_timeout_seconds=None,
261+
)
262+
on_request, on_notify = echo_handlers(Recorder())
263+
264+
with anyio.fail_after(5):
265+
async with anyio.create_task_group() as tg, c2s_send, c2s_recv, s2c_send, s2c_recv:
266+
await tg.start(server.run, on_request, on_notify)
267+
c2s_send.close()
268+
with pytest.raises(anyio.EndOfStream):
269+
await s2c_recv.receive()
270+
271+
272+
@pytest.mark.anyio
273+
async def test_run_drains_in_flight_handlers_on_clean_eof_without_timeout():
274+
c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
275+
s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32)
276+
server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(
277+
c2s_recv,
278+
s2c_send,
279+
close_write_stream_on_read_close=False,
280+
read_eof_drain_timeout_seconds=None,
281+
)
282+
handler_started = anyio.Event()
283+
handler_allowed_to_finish = anyio.Event()
284+
285+
async def handle_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]:
286+
handler_started.set()
287+
await handler_allowed_to_finish.wait()
288+
return {"drained": True}
289+
290+
async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None:
291+
raise NotImplementedError
292+
293+
with anyio.fail_after(5):
294+
async with anyio.create_task_group() as tg, c2s_send, c2s_recv, s2c_send, s2c_recv:
295+
await tg.start(server.run, handle_request, on_notify)
296+
await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)))
297+
await handler_started.wait()
298+
c2s_send.close()
299+
handler_allowed_to_finish.set()
300+
301+
response = await s2c_recv.receive()
302+
assert isinstance(response, SessionMessage)
303+
assert isinstance(response.message, JSONRPCResponse)
304+
assert response.message.id == 1
305+
assert response.message.result == {"drained": True}
306+
307+
252308
@pytest.mark.anyio
253309
async def test_run_closes_write_stream_on_exit():
254310
"""run() enters both streams; the write end is released on EOF."""

tests/shared/test_session.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1+
from typing import Any
2+
13
import anyio
24
import pytest
5+
from pydantic import TypeAdapter
36

47
from mcp import Client, types
58
from mcp.client.session import ClientSession
69
from mcp.server import Server, ServerRequestContext
710
from mcp.shared.exceptions import MCPError
811
from mcp.shared.memory import create_client_server_memory_streams
912
from mcp.shared.message import SessionMessage
10-
from mcp.shared.session import RequestResponder
13+
from mcp.shared.session import BaseSession, RequestId, RequestResponder
1114
from mcp.types import (
1215
PARSE_ERROR,
1316
CancelledNotification,
@@ -16,6 +19,7 @@
1619
EmptyResult,
1720
ErrorData,
1821
JSONRPCError,
22+
JSONRPCNotification,
1923
JSONRPCRequest,
2024
JSONRPCResponse,
2125
ServerNotification,
@@ -291,6 +295,40 @@ async def mock_server():
291295
await ev_response.wait()
292296

293297

298+
@pytest.mark.anyio
299+
async def test_receive_loop_can_leave_write_stream_open_on_read_eof():
300+
class TestSession(BaseSession[Any, Any, Any, Any, Any]):
301+
async def _send_response(self, request_id: RequestId, response: Any | ErrorData) -> None:
302+
raise NotImplementedError # pragma: no cover
303+
304+
@property
305+
def _receive_request_adapter(self) -> TypeAdapter[Any]:
306+
return TypeAdapter(object) # pragma: no cover
307+
308+
@property
309+
def _receive_notification_adapter(self) -> TypeAdapter[Any]:
310+
return TypeAdapter(object) # pragma: no cover
311+
312+
read_send, read_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1)
313+
write_send, write_receive = anyio.create_memory_object_stream[SessionMessage](1)
314+
session = TestSession(read_receive, write_send, close_write_stream_on_read_close=False)
315+
receive_loop_returned = anyio.Event()
316+
317+
async def drive_receive_loop() -> None:
318+
await session._receive_loop() # pyright: ignore[reportPrivateUsage]
319+
receive_loop_returned.set()
320+
321+
with anyio.fail_after(5):
322+
async with anyio.create_task_group() as tg, read_send, read_receive, write_send, write_receive:
323+
tg.start_soon(drive_receive_loop)
324+
await read_send.aclose()
325+
await receive_loop_returned.wait()
326+
327+
marker = SessionMessage(message=JSONRPCNotification(jsonrpc="2.0", method="still-open"))
328+
await write_send.send(marker)
329+
assert await write_receive.receive() is marker
330+
331+
294332
@pytest.mark.anyio
295333
async def test_null_id_error_surfaced_via_message_handler():
296334
"""Test that a JSONRPCError with id=None is surfaced to the message handler.

0 commit comments

Comments
 (0)