From 35a3722fd7ebfecf1b7c5dec1fee55f35168bddc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 14:38:14 +0000 Subject: [PATCH 01/24] Add courtesy-cancel controls and bound shielded dispatcher writes - New CallOptions key cancel_on_abandon (default true): abandoning a request (timeout or caller cancellation) sends notifications/cancelled unless the caller opted out or the request carries resumption hints - Bound the two shielded cancellation-path writes with a 5s deadline so a wedged transport write cannot hang shutdown or a cancelled caller - Capitalize the connection-closed fan-out message ("Connection closed") - Pin the server-seat timeout contract in the interaction suite: a timed-out server-initiated request is followed by notifications/cancelled --- src/mcp/shared/dispatcher.py | 14 +- src/mcp/shared/jsonrpc_dispatcher.py | 38 +++- tests/interaction/_requirements.py | 4 +- tests/interaction/lowlevel/test_timeouts.py | 68 ++++++- tests/shared/test_jsonrpc_dispatcher.py | 197 ++++++++++++++++++++ 5 files changed, 309 insertions(+), 12 deletions(-) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index cffdfd22f..c6e421651 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -55,6 +55,16 @@ class CallOptions(TypedDict, total=False): timeout: float """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" + cancel_on_abandon: bool + """Whether abandoning this request sends `notifications/cancelled` to the peer. + + A request is abandoned when its `timeout` elapses or the caller's scope is + cancelled while awaiting the response. Defaults to `True`. Set `False` for + requests the protocol forbids cancelling, such as `initialize`. The + notification is also suppressed when resumption hints are present: the + caller intends to resume the request, so the peer's work must keep running. + """ + on_progress: ProgressFnT """Receive `notifications/progress` updates for this request.""" @@ -97,8 +107,8 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a request and await its raw result dict. - `opts` carries per-call `timeout` / `on_progress` / resumption hints; - see `CallOptions`. + `opts` carries per-call `timeout` / `on_progress` / abandon-cancellation + / resumption hints; see `CallOptions`. Raises: MCPError: If the peer responded with an error, or the handler diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 457e6b6f7..ded8bcb3d 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -63,6 +63,13 @@ logger = logging.getLogger(__name__) +_SHIELDED_WRITE_TIMEOUT: float = 5 +"""Bound for the shielded courtesy writes on the cancellation paths. + +Those writes run inside a shield because the surrounding scope is already +cancelled; without a bound, a wedged transport write would turn the shield +into an uncancellable hang (and block shutdown indefinitely).""" + TransportT = TypeVar("TransportT", bound=TransportContext) PeerCancelMode = Literal["interrupt", "signal"] @@ -323,6 +330,18 @@ async def send_raw_request( pending = _Pending(send=send, receive=receive, on_progress=on_progress) self._pending[request_id] = pending + # An abandoned request (timeout elapsed, or the caller's scope was + # cancelled while awaiting the response) sends a courtesy + # `notifications/cancelled` so the peer can stop work - unless the + # caller opted out (`initialize`, which the spec forbids cancelling), + # or the request carries resumption hints (the caller intends to + # resume it, so the peer's work must keep running). + cancel_on_abandon = ( + opts.get("cancel_on_abandon", True) + and opts.get("resumption_token") is None + and opts.get("on_resumption_token") is None + ) + metadata = _outbound_metadata(_related_request_id, opts) target = out_params.get("name") span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}" @@ -348,14 +367,16 @@ async def send_raw_request( # Spec-recommended courtesy: tell the peer we've given up so it can # stop work and free resources. v1's BaseSession.send_request does # NOT do this; it's new behaviour. - await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id) + if cancel_on_abandon: + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): # Our caller's scope was cancelled. We're already inside a cancelled - # scope, so any bare `await` here re-raises immediately - shield to - # let the courtesy cancel notification go out before we propagate. - with anyio.CancelScope(shield=True): - await self._cancel_outbound(request_id, "caller cancelled", _related_request_id) + # scope, so any bare `await` here re-raises immediately - shield + # (bounded) to let the courtesy cancel go out before we propagate. + if cancel_on_abandon: + with anyio.move_on_after(_SHIELDED_WRITE_TIMEOUT, shield=True): + await self._cancel_outbound(request_id, "caller cancelled", _related_request_id) raise finally: # Always remove the waiter, even on cancel/timeout, so a late @@ -635,7 +656,7 @@ def _fan_out_closed(self) -> None: Synchronous (uses `send_nowait`) because it's called from `finally` which may be inside a cancelled scope. Idempotent. """ - closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + closed = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") for pending in self._pending.values(): try: pending.send.send_nowait(closed) @@ -681,8 +702,9 @@ async def _handle_request( await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): # Outer-cancel: run()'s task group is shutting down. Any bare - # `await` here re-raises immediately, so shield the courtesy write. - with anyio.CancelScope(shield=True): + # `await` here re-raises immediately, so shield (bounded) the + # courtesy write. + with anyio.move_on_after(_SHIELDED_WRITE_TIMEOUT, shield=True): await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) raise except MCPError as e: diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index caed8905d..e37940b8c 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -468,7 +468,9 @@ def __post_init__(self) -> None: ), divergence=Divergence( note=( - "The client only raises locally and sends nothing on timeout, so the server keeps running the handler." + "Client seat only: the client raises locally and sends nothing on timeout, so the server keeps " + "running the handler. The server seat conforms: a timed-out server-initiated request is followed " + "by notifications/cancelled on the wire." ), ), ), diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index b440f3210..7c12e1c65 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -13,9 +13,13 @@ from trio.testing import MockClock from mcp import MCPError, types +from mcp.client import ClientRequestContext +from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext -from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, TextContent +from mcp.shared.message import SessionMessage +from mcp.types import REQUEST_TIMEOUT, CallToolResult, ErrorData, JSONRPCNotification, TextContent +from tests.interaction._helpers import RecordingTransport from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -56,6 +60,68 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara ) +@requirement("protocol:timeout:basic") +@requirement("protocol:timeout:sends-cancellation") +async def test_server_request_timeout_sends_cancellation_to_the_client() -> None: + """A server-initiated request that times out fails server-side and cancels the client's work. + + The server seat conforms to the spec's timeout guidance: the handler's timed-out sampling + request is followed by notifications/cancelled on the wire. The client's sampling callback + blocks until the server has already given up, then answers; the late response is discarded + and the tool call still completes. + """ + release = anyio.Event() + callback_started = anyio.Event() + errors: list[ErrorData] = [] + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="impatient", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "impatient" + request = types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=8, + ) + ) + with pytest.raises(MCPError) as exc_info: + await ctx.session.send_request(request, types.CreateMessageResult, request_read_timeout_seconds=0.000001) + errors.append(exc_info.value.error) + release.set() + return CallToolResult(content=[TextContent(text="gave up")]) + + server = Server("impatient", on_list_tools=list_tools, on_call_tool=call_tool) + recording = RecordingTransport(InMemoryTransport(server)) + + async def sampling_callback( + context: ClientRequestContext, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + callback_started.set() + await release.wait() + return types.CreateMessageResult(role="assistant", content=TextContent(text="too late"), model="test-model") + + async with Client(recording, sampling_callback=sampling_callback) as client: + result = await client.call_tool("impatient", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="gave up")])) + assert callback_started.is_set() + assert errors == snapshot([ErrorData(code=REQUEST_TIMEOUT, message="Request 'sampling/createMessage' timed out")]) + cancellations = [ + item.message + for item in recording.received + if isinstance(item, SessionMessage) + and isinstance(item.message, JSONRPCNotification) + and item.message.method == "notifications/cancelled" + ] + # The cancel names the sampling request (the server's first outbound request) and the reason. + assert [notification.params for notification in cancellations] == snapshot( + [{"requestId": 1, "reason": "timed out after 1e-06s"}] + ) + + @requirement("protocol:timeout:session-survives") async def test_session_serves_requests_after_timeout() -> None: """A timed-out request does not poison the session: the next request succeeds.""" diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index b2a24c87d..2df4c66fc 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -15,6 +15,7 @@ import anyio import anyio.lowlevel import pytest +from trio.testing import MockClock from mcp import Client from mcp.server import Server, ServerRequestContext @@ -33,6 +34,7 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, + REQUEST_TIMEOUT, CallToolRequest, CallToolRequestParams, CallToolResult, @@ -413,6 +415,201 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() +@pytest.mark.anyio +async def test_caller_cancel_sends_courtesy_cancellation_on_the_wire(): + """Cancelling the scope around send_raw_request emits notifications/cancelled by default.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + scopes[0].cancel() + with anyio.fail_after(5): + await gave_up.wait() + cancel = await c2s_recv.receive() + assert isinstance(cancel, SessionMessage) + assert isinstance(cancel.message, JSONRPCNotification) + assert cancel.message.method == "notifications/cancelled" + assert cancel.message.params == {"requestId": request.message.id, "reason": "caller cancelled"} + assert cancel.metadata is None + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert scopes[0].cancelled_caught + + +@pytest.mark.anyio +async def test_caller_cancel_with_resumption_hints_suppresses_the_courtesy_cancellation(): + """A request sent with resumption hints is meant to be resumed; abandoning it must not stop the peer's work.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + async def on_token(token: str) -> None: + raise NotImplementedError + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None, {"on_resumption_token": on_token}) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + scopes[0].cancel() + with anyio.fail_after(5): + await gave_up.wait() + # The next write proves nothing was sent in between: a courtesy + # cancel would have to precede it on the ordered stream. + await client.notify("marker", None) + with anyio.fail_after(5): + nxt = await c2s_recv.receive() + assert isinstance(nxt, SessionMessage) + assert isinstance(nxt.message, JSONRPCNotification) + assert nxt.message.method == "marker" + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_timeout_with_resumption_hints_suppresses_the_courtesy_cancellation(): + """A timed-out request that carries resumption hints stays resumable: no cancellation is sent.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0, "resumption_token": "tok"}) + assert exc.value.error.code == REQUEST_TIMEOUT + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + await client.notify("marker", None) + with anyio.fail_after(5): + nxt = await c2s_recv.receive() + assert isinstance(nxt, SessionMessage) + assert isinstance(nxt.message, JSONRPCNotification) + assert nxt.message.method == "marker" + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_on_abandon_false_suppresses_the_courtesy_cancellation_on_timeout(): + """Callers opt out per call for requests the protocol forbids cancelling (initialize).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0, "cancel_on_abandon": False}) + assert exc.value.error.code == REQUEST_TIMEOUT + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + await client.notify("marker", None) + with anyio.fail_after(5): + nxt = await c2s_recv.receive() + assert isinstance(nxt, SessionMessage) + assert isinstance(nxt.message, JSONRPCNotification) + assert nxt.message.method == "marker" + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +@pytest.mark.anyio +async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wedged(): + """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang. + + The write stream has no buffer and no reader, so the request write blocks; cancelling the + caller then routes into the shielded courtesy-cancel write, which blocks on the same wedged + stream. The bound abandons it after _SHIELDED_WRITE_TIMEOUT; trio's virtual clock makes the + wait instant. On regression (unbounded shield) the test hangs rather than failing fast: the + outer fail_after cannot cancel through the shield - that is the bug. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + with anyio.fail_after(30): + async with anyio.create_task_group() as tg: # pragma: no branch + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + await anyio.wait_all_tasks_blocked() # the caller is parked in the request write + scopes[0].cancel() + with anyio.fail_after(20): + await gave_up.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert scopes[0].cancelled_caught + + @pytest.mark.anyio async def test_ctx_message_metadata_carries_inbound_request_metadata(): """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" From 0ac793731c0275243e5fa987c99008c67aef90ab Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:11:00 +0000 Subject: [PATCH 02/24] Contain notification-handler exceptions in the dispatcher A raising notification handler ran as a bare task in the dispatcher's task group, so its exception cancelled the read loop and every in-flight request. Wrap spawned handlers in the same containment boundary progress callbacks already have: log the failure and keep the connection serving. --- src/mcp/shared/jsonrpc_dispatcher.py | 16 +++++++- tests/shared/test_jsonrpc_dispatcher.py | 50 +++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index ded8bcb3d..589d0dd31 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -34,7 +34,7 @@ from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, OnNotify, OnRequest, ProgressFnT from mcp.shared.exceptions import MCPError, NoBackChannelError from mcp.shared.message import ( ClientMessageMetadata, @@ -190,6 +190,18 @@ async def _wrapped(progress: float, total: float | None, message: str | None) -> return _wrapped +def _contained_notify(fn: OnNotify) -> OnNotify: + """Wrap a notification handler so it can't crash the dispatcher (same boundary as `_shielded_progress`).""" + + async def _wrapped(dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + try: + await fn(dctx, method, params) + except Exception: + logger.exception("notification handler for %r raised", method) + + return _wrapped + + def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: """Choose the `SessionMessage.metadata` for an outgoing request/notification. @@ -619,7 +631,7 @@ def _dispatch_notification( dctx = _JSONRPCDispatchContext( transport=transport_ctx, _dispatcher=self, _request_id=None, message_metadata=metadata ) - self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + self._spawn(_contained_notify(on_notify), dctx, msg.method, msg.params, sender_ctx=sender_ctx) def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 2df4c66fc..7f0774d2b 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -610,6 +610,56 @@ async def caller() -> None: assert scopes[0].cancelled_caught +@pytest.mark.anyio +async def test_notification_handler_exception_is_contained(caplog: pytest.LogCaptureFixture): + """A raising notification handler costs only that notification, never the connection. + + The handler runs as a bare task in the dispatcher's task group; without containment its + exception would cancel the read loop and every in-flight request. The TypeScript, C#, and + Go engines all contain notification-handler failures the same way. + """ + + async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise RuntimeError("notify boom") + + async with running_pair(jsonrpc_pair, server_on_notify=server_on_notify) as (client, *_): + with anyio.fail_after(5): + await client.notify("boom", None) + # The connection survived: a full round-trip still works. + result = await client.send_raw_request("ping", None) + assert result == {"echoed": "ping", "params": {}} + assert "notification handler for 'boom' raised" in caplog.text + + +@pytest.mark.anyio +async def test_spawned_notification_handlers_run_concurrently(): + """Notification handlers are spawned, not serialized: a parked one does not block the next. + + The first handler waits for the second to have started - serialized dispatch would deadlock + here. This matches the TypeScript and C# engines (fire-and-forget); handlers needing + mutual ordering must coordinate themselves. + """ + second_started = anyio.Event() + completed: list[str] = [] + done = anyio.Event() + + async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + if method == "first": + await second_started.wait() + else: + second_started.set() + completed.append(method) + if len(completed) == 2: + done.set() + + async with running_pair(jsonrpc_pair, server_on_notify=server_on_notify) as (client, *_): + with anyio.fail_after(5): + await client.notify("first", None) + await client.notify("second", None) + await done.wait() + assert completed == ["second", "first"] + + @pytest.mark.anyio async def test_ctx_message_metadata_carries_inbound_request_metadata(): """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" From aaacafc5b4874a51860ea688a7274d4f9f5e54fb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:20:52 +0000 Subject: [PATCH 03/24] Add an on_stream_exception observer to the dispatcher Transports yield Exception items on the read stream for connection faults and parse errors; the dispatcher debug-logged and dropped them. An optional observer now receives them (awaited in the read loop, contained so a raising observer costs the item, not the connection). Unset keeps the old behavior. --- src/mcp/shared/jsonrpc_dispatcher.py | 24 ++++++++--- tests/shared/test_jsonrpc_dispatcher.py | 57 +++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 589d0dd31..aec3ec45c 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -245,6 +245,7 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, ) -> None: ... @overload def __init__( @@ -256,6 +257,7 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, ) -> None: ... def __init__( self, @@ -266,6 +268,7 @@ def __init__( peer_cancel_mode: PeerCancelMode = "interrupt", raise_handler_exceptions: bool = False, inline_methods: frozenset[str] = frozenset(), + on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, ) -> None: self._read_stream = read_stream self._write_stream = write_stream @@ -287,6 +290,11 @@ def __init__( # while inline will deadlock because the parked read loop cannot dequeue # the response. self._inline_methods = inline_methods + # Observer for Exception items the transport yields on the read stream + # (SSE/streamable-HTTP connection faults, stdio parse errors). Without + # it they are debug-logged and dropped. Awaited in the read loop and + # contained: a raising observer costs the item, not the connection. + self._on_stream_exception = on_stream_exception self._next_id = 0 self._pending: dict[RequestId, _Pending] = {} @@ -482,13 +490,19 @@ async def _dispatch( ) -> None: """Route one inbound item. - Everything here is `send_nowait` or `_spawn`; the only `await` is for - `inline_methods` requests, which deliberately block dequeuing until - handled. Any other `await` would let one slow message head-of-line - block the entire read loop. + Everything here is `send_nowait` or `_spawn`; the only `await`s are + `inline_methods` requests and the `on_stream_exception` observer, + which deliberately block dequeuing until handled. Any other `await` + would let one slow message head-of-line block the entire read loop. """ if isinstance(item, Exception): - logger.debug("transport yielded exception: %r", item) + if self._on_stream_exception is None: + logger.debug("transport yielded exception: %r", item) + return + try: + await self._on_stream_exception(item) + except Exception: + logger.exception("on_stream_exception observer raised") return metadata = item.metadata msg = item.message diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 7f0774d2b..8633840ec 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -912,6 +912,63 @@ async def test_transport_exception_in_read_stream_is_logged_and_dropped(): s.close() +@pytest.mark.anyio +async def test_on_stream_exception_observes_transport_exceptions(): + """With an observer set, Exception items reach it instead of being dropped; the loop stays healthy.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + + seen: list[Exception] = [] + + async def observe(exc: Exception) -> None: + seen.append(exc) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, on_stream_exception=observe) + on_request, on_notify = echo_handlers(Recorder()) + hiccup = ValueError("transport hiccup") + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(hiccup) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [hiccup] + + +@pytest.mark.anyio +async def test_on_stream_exception_observer_raising_is_contained(caplog: pytest.LogCaptureFixture): + """A raising observer costs the item, not the connection: it runs in the read loop itself.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + + async def observe(exc: Exception) -> None: + raise RuntimeError("observer boom") + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, on_stream_exception=observe) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert "on_stream_exception observer raised" in caplog.text + + @pytest.mark.anyio async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): From 83c4ba734658ee9d17f344f83a03567343a59d61 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:23:24 +0000 Subject: [PATCH 04/24] Collapse the dispatcher constructor overloads with a defaulted TypeVar TransportT now defaults to TransportContext (PEP 696, same pattern as shared/context.py), so omitting transport_builder no longer needs a dedicated overload to pin the type parameter. --- src/mcp/shared/jsonrpc_dispatcher.py | 34 +++++----------------------- 1 file changed, 6 insertions(+), 28 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index aec3ec45c..63c431950 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -24,13 +24,14 @@ import logging from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field -from typing import Any, Generic, Literal, TypeVar, cast, overload +from typing import Any, Generic, Literal, cast import anyio import anyio.abc from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from opentelemetry.trace import SpanKind from pydantic import ValidationError +from typing_extensions import TypeVar from mcp.shared._otel import inject_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream @@ -70,7 +71,7 @@ cancelled; without a bound, a wedged transport write would turn the shield into an uncancellable hang (and block shutdown indefinitely).""" -TransportT = TypeVar("TransportT", bound=TransportContext) +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) PeerCancelMode = Literal["interrupt", "signal"] """How inbound `notifications/cancelled` is applied to a running handler. @@ -236,29 +237,6 @@ class JSONRPCDispatcher(Dispatcher[TransportT]): conformance at the class definition rather than at first use. """ - @overload - def __init__( - self: JSONRPCDispatcher[TransportContext], - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - *, - peer_cancel_mode: PeerCancelMode = "interrupt", - raise_handler_exceptions: bool = False, - inline_methods: frozenset[str] = frozenset(), - on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, - ) -> None: ... - @overload - def __init__( - self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - *, - transport_builder: Callable[[MessageMetadata], TransportT], - peer_cancel_mode: PeerCancelMode = "interrupt", - raise_handler_exceptions: bool = False, - inline_methods: frozenset[str] = frozenset(), - on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, - ) -> None: ... def __init__( self, read_stream: ReadStream[SessionMessage | Exception], @@ -272,9 +250,9 @@ def __init__( ) -> None: self._read_stream = read_stream self._write_stream = write_stream - # The overloads guarantee that when `transport_builder` is omitted, - # `TransportT` is `TransportContext`, so the default is type-correct; - # pyright can't see across overloads, hence the cast. + # When `transport_builder` is omitted, `TransportT` falls back to its + # default (`TransportContext`), so the default builder is type-correct; + # pyright can't connect the two, hence the cast. self._transport_builder = cast( "Callable[[MessageMetadata], TransportT]", transport_builder or _default_transport_builder, From 12c8ef310aa99667428e57be2952c90c8d67a94c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:39:56 +0000 Subject: [PATCH 05/24] Move ClientSession onto JSONRPCDispatcher and delete BaseSession ClientSession keeps its public surface (constructor, typed methods, manual initialize, context-manager lifecycle) but now owns a JSONRPCDispatcher instead of inheriting the v1 BaseSession receive loop. Server-initiated requests are answered through the existing callbacks via the closed-union parse; notifications validate-or-drop and tee to message_handler; transport exceptions reach message_handler through the dispatcher's stream-exception observer. A from_dispatcher constructor accepts a pre-built dispatcher for in-process embedding. mcp.shared.session shrinks to the surviving names: the ProgressFnT re-export and a typing-only RequestResponder stub for MessageHandlerFnT annotations. Behavior changes (deliberate, to be covered in the migration guide): - request ids count from 1; the progress token follows - timeouts use the dispatcher error text and send notifications/cancelled, so a timed-out server handler is interrupted instead of running on - responses with unknown ids are ignored per spec instead of surfacing a RuntimeError to message_handler - a raising request callback is answered with code 0 and the exception text - notification callbacks run concurrently (no completion-before-response) Three interaction-suite divergence entries are resolved and deleted, and the server-to-client cancellation requirement is now pinned by a passing test. --- src/mcp/client/session.py | 292 ++++++++--- src/mcp/shared/_context.py | 5 +- src/mcp/shared/session.py | 495 +----------------- tests/client/test_resource_cleanup.py | 66 --- tests/client/test_session.py | 104 +++- tests/interaction/_requirements.py | 28 - .../interaction/lowlevel/test_cancellation.py | 73 ++- tests/interaction/lowlevel/test_logging.py | 15 +- tests/interaction/lowlevel/test_progress.py | 4 +- tests/interaction/lowlevel/test_timeouts.py | 21 +- tests/interaction/lowlevel/test_wire.py | 4 +- tests/issues/test_88_random_error.py | 8 +- tests/server/mcpserver/test_server.py | 8 +- tests/shared/test_session.py | 447 ---------------- 14 files changed, 438 insertions(+), 1132 deletions(-) delete mode 100644 tests/client/test_resource_cleanup.py delete mode 100644 tests/shared/test_session.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 3a0485649..9975018fc 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -1,16 +1,26 @@ from __future__ import annotations import logging +from collections.abc import Mapping +from types import TracebackType from typing import Any, Protocol, cast, get_args +import anyio +import anyio.abc import anyio.lowlevel -from pydantic import BaseModel, TypeAdapter +from pydantic import BaseModel, TypeAdapter, ValidationError +from typing_extensions import Self, TypeVar from mcp import types from mcp.client._transport import ReadStream, WriteStream +from mcp.shared._compat import resync_tracer from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder +from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.session import ProgressFnT, RequestResponder +from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types._types import RequestParamsMeta @@ -18,6 +28,8 @@ logger = logging.getLogger("client") +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + class SamplingFnT(Protocol): async def __call__( @@ -104,15 +116,16 @@ async def _default_logging_callback( answered with METHOD_NOT_FOUND instead of failing union validation.""" -class ClientSession( - BaseSession[ - types.ClientRequest, - types.ClientNotification, - types.ClientResult, - types.ServerRequest, - types.ServerNotification, - ] -): +class ClientSession: + """Client half of an MCP connection, running on `JSONRPCDispatcher`. + + Construct it over a transport's stream pair, enter it as an async context + manager, then call `initialize()`. The receive loop, request correlation, + and per-request concurrency live in the dispatcher; this class owns the + MCP type layer: typed requests, the initialize handshake, and routing + server-initiated traffic to the constructor callbacks. + """ + def __init__( self, read_stream: ReadStream[SessionMessage | Exception], @@ -127,7 +140,70 @@ def __init__( *, sampling_capabilities: types.SamplingCapability | None = None, ) -> None: - super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds) + self._init_state( + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + elicitation_callback=elicitation_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + sampling_capabilities=sampling_capabilities, + ) + # Built here (inert until run() starts in __aenter__) so notifications + # can be sent before entering the context manager, as before. + self._dispatcher: Dispatcher[Any] = JSONRPCDispatcher( + read_stream, write_stream, on_stream_exception=self._on_stream_exception + ) + + @classmethod + def from_dispatcher( + cls, + dispatcher: Dispatcher[Any], + *, + read_timeout_seconds: float | None = None, + sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, + list_roots_callback: ListRootsFnT | None = None, + logging_callback: LoggingFnT | None = None, + message_handler: MessageHandlerFnT | None = None, + client_info: types.Implementation | None = None, + sampling_capabilities: types.SamplingCapability | None = None, + ) -> Self: + """Build a session over a pre-built dispatcher instead of a stream pair. + + For embedding a server in-process (`DirectDispatcher`) or transports + that construct their own dispatcher. Transport-level `Exception` items + reach `message_handler` only on the stream constructor, where the + session wires the dispatcher's `on_stream_exception` itself. + """ + self = cls.__new__(cls) + self._init_state( + read_timeout_seconds=read_timeout_seconds, + sampling_callback=sampling_callback, + elicitation_callback=elicitation_callback, + list_roots_callback=list_roots_callback, + logging_callback=logging_callback, + message_handler=message_handler, + client_info=client_info, + sampling_capabilities=sampling_capabilities, + ) + self._dispatcher = dispatcher + return self + + def _init_state( + self, + *, + read_timeout_seconds: float | None, + sampling_callback: SamplingFnT | None, + elicitation_callback: ElicitationFnT | None, + list_roots_callback: ListRootsFnT | None, + logging_callback: LoggingFnT | None, + message_handler: MessageHandlerFnT | None, + client_info: types.Implementation | None, + sampling_capabilities: types.SamplingCapability | None, + ) -> None: + self._session_read_timeout_seconds = read_timeout_seconds self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback self._sampling_capabilities = sampling_capabilities @@ -137,18 +213,90 @@ def __init__( self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None + self._task_group: anyio.abc.TaskGroup | None = None - @property - def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]: - return types.server_request_adapter + async def __aenter__(self) -> Self: + self._task_group = anyio.create_task_group() + await self._task_group.__aenter__() + await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + return self - @property - def _receive_request_methods(self) -> frozenset[str]: - return _SERVER_REQUEST_METHODS + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + # Exit must not block: cancel the dispatcher and any in-flight + # callbacks rather than waiting for them. + assert self._task_group is not None + self._task_group.cancel_scope.cancel() + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + await resync_tracer() + return result - @property - def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]: - return types.server_notification_adapter + async def send_request( + self, + request: types.ClientRequest, + result_type: type[ReceiveResultT], + request_read_timeout_seconds: float | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ReceiveResultT: + """Send a request and wait for its typed result. + + A per-request read timeout takes precedence over the session-level + one. `metadata` carries transport hints: `ClientMessageMetadata` + resumption fields (streamable HTTP), or a + `ServerMessageMetadata.related_request_id` to route the message onto + an originating request's stream. + + Raises: + MCPError: The server responded with an error, or the read timeout + elapsed, or the connection closed while waiting. + RuntimeError: Called before entering the context manager. + """ + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + method: str = data["method"] + opts: CallOptions = {} + timeout = request_read_timeout_seconds or self._session_read_timeout_seconds + if timeout is not None: + opts["timeout"] = timeout + if progress_callback is not None: + opts["on_progress"] = progress_callback + related_request_id: types.RequestId | None = None + if isinstance(metadata, ClientMessageMetadata): + if metadata.resumption_token is not None: + opts["resumption_token"] = metadata.resumption_token + if metadata.on_resumption_token_update is not None: + opts["on_resumption_token"] = metadata.on_resumption_token_update + elif isinstance(metadata, ServerMessageMetadata): + related_request_id = metadata.related_request_id + if method == "initialize": + # The spec forbids cancelling initialize; opt out of the + # dispatcher's courtesy cancel-on-abandon. + opts["cancel_on_abandon"] = False + if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher): + # Related-request routing is JSON-RPC stream plumbing; other + # dispatchers have no per-request streams to route onto. + raw = await self._dispatcher.send_raw_request( + method, data.get("params"), opts, _related_request_id=related_request_id + ) + else: + raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) + return result_type.model_validate(raw, by_name=False) + + async def send_notification( + self, + notification: types.ClientNotification, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a one-way notification. Usable before entering the context manager.""" + data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) + if related_request_id and isinstance(self._dispatcher, JSONRPCDispatcher): + await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) + else: + await self._dispatcher.notify(data["method"], data.get("params")) async def initialize(self) -> types.InitializeResult: sampling = ( @@ -397,49 +545,65 @@ async def send_roots_list_changed(self) -> None: """Send a roots/list_changed notification.""" await self.send_notification(types.RootsListChangedNotification()) - async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None: - ctx = RequestContext[ClientSession](request_id=responder.request_id, meta=responder.request_meta, session=self) - - match responder.request: - case types.CreateMessageRequest(params=params): - with responder: - response = await self._sampling_callback(ctx, params) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - - case types.ElicitRequest(params=params): - with responder: - response = await self._elicitation_callback(ctx, params) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) + async def _on_request( + self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + """Answer a server-initiated request via the registered callbacks. + An unknown method raises `MCPError` (METHOD_NOT_FOUND), which the + dispatcher puts on the wire as-is; malformed params for a known method + raise `ValidationError`, which the dispatcher answers with + INVALID_PARAMS; an `ErrorData` returned by a callback becomes the + error response. + """ + if method not in _SERVER_REQUEST_METHODS: + # Unknown methods are METHOD_NOT_FOUND (-32601) per JSON-RPC 2.0, + # not validation failures (-32602). + raise MCPError(code=types.METHOD_NOT_FOUND, message="Method not found", data=method) + payload: dict[str, Any] = {"method": method} + if params is not None: + payload["params"] = dict(params) + request = types.server_request_adapter.validate_python(payload, by_name=False) + + ctx = RequestContext[ClientSession]( + request_id=dctx.request_id, meta=request.params.meta if request.params else None, session=self + ) + response: types.ClientResult | types.ErrorData + match request: + case types.CreateMessageRequest(params=sampling_params): + response = await self._sampling_callback(ctx, sampling_params) + case types.ElicitRequest(params=elicit_params): + response = await self._elicitation_callback(ctx, elicit_params) case types.ListRootsRequest(): - with responder: - response = await self._list_roots_callback(ctx) - client_response = ClientResponse.validate_python(response) - await responder.respond(client_response) - + response = await self._list_roots_callback(ctx) case types.PingRequest(): # pragma: no branch - with responder: - await responder.respond(types.EmptyResult()) - - async def _handle_incoming( - self, - req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + response = types.EmptyResult() + client_response = ClientResponse.validate_python(response) + if isinstance(client_response, types.ErrorData): + raise MCPError.from_error_data(client_response) + return client_response.model_dump(by_alias=True, mode="json", exclude_none=True) + + async def _on_notify( + self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> None: - """Handle incoming messages by forwarding to the message handler.""" - await self._message_handler(req) - - async def _received_notification(self, notification: types.ServerNotification) -> None: - """Handle notifications from the server.""" - # Process specific notification types - match notification: - case types.LoggingMessageNotification(params=params): - await self._logging_callback(params) - case types.ElicitCompleteNotification(params=params): - # Handle elicitation completion notification - # Clients MAY use this to retry requests or update UI - # The notification contains the elicitationId of the completed elicitation - pass - case _: - pass + """Route a server notification: validate, run the typed callback, tee to message_handler.""" + payload: dict[str, Any] = {"method": method} + if params is not None: + payload["params"] = dict(params) + try: + notification = types.server_notification_adapter.validate_python(payload, by_name=False) + except ValidationError: + logger.warning("Failed to validate notification: %s", payload, exc_info=True) + return + if isinstance(notification, types.CancelledNotification): + # The dispatcher already applied the cancellation to the in-flight + # request; message_handler never sees it, so handlers matching + # exhaustively over ServerNotification need no arm for it. + return + if isinstance(notification, types.LoggingMessageNotification): + await self._logging_callback(notification.params) + await self._message_handler(notification) + + async def _on_stream_exception(self, exc: Exception) -> None: + """Forward transport-level faults (connection errors, parse errors) to message_handler.""" + await self._message_handler(exc) diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py index bbcee2d02..8ad4ca918 100644 --- a/src/mcp/shared/_context.py +++ b/src/mcp/shared/_context.py @@ -1,14 +1,13 @@ -"""Request context for MCP handlers.""" +"""Request context for MCP client handlers.""" from dataclasses import dataclass from typing import Any, Generic from typing_extensions import TypeVar -from mcp.shared.session import BaseSession from mcp.types import RequestId, RequestParamsMeta -SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any]) +SessionT = TypeVar("SessionT", default=Any) @dataclass(kw_only=True) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 61279ad8b..55710fa98 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,487 +1,32 @@ -from __future__ import annotations +"""Compatibility surface for the removed v1 session layer. -import contextvars -import logging -from contextlib import AsyncExitStack -from types import TracebackType -from typing import Any, Generic, Protocol, TypeVar +`BaseSession` (the v1 receive loop) is gone: `ClientSession` runs on +`JSONRPCDispatcher` and the server side on `ServerRunner`. This module keeps +the names that outlived it. +""" -import anyio -from anyio.streams.memory import MemoryObjectSendStream -from opentelemetry.trace import SpanKind -from pydantic import BaseModel, TypeAdapter -from typing_extensions import Self +from typing import Generic, TypeVar -from mcp.shared._compat import resync_tracer -from mcp.shared._otel import inject_trace_context, otel_span -from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.exceptions import MCPError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata, SessionMessage -from mcp.types import ( - CONNECTION_CLOSED, - INVALID_PARAMS, - METHOD_NOT_FOUND, - REQUEST_TIMEOUT, - CancelledNotification, - ClientNotification, - ClientRequest, - ClientResult, - ErrorData, - JSONRPCError, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - ProgressNotification, - ProgressToken, - RequestParamsMeta, - ServerNotification, - ServerRequest, - ServerResult, -) - -SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) -SendResultT = TypeVar("SendResultT", ClientResult, ServerResult) -SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification) -ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest) -ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) -ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification) +from mcp.shared.dispatcher import ProgressFnT as ProgressFnT +from mcp.shared.message import MessageMetadata +from mcp.types import RequestParamsMeta RequestId = str | int - -class ProgressFnT(Protocol): - """Protocol for progress notification callbacks.""" - - async def __call__( - self, progress: float, total: float | None, message: str | None - ) -> None: ... # pragma: no branch +ReceiveRequestT = TypeVar("ReceiveRequestT") +SendResultT = TypeVar("SendResultT") class RequestResponder(Generic[ReceiveRequestT, SendResultT]): - """Handles responding to MCP requests and manages request lifecycle. - - This class MUST be used as a context manager to ensure proper cleanup and - cancellation handling: + """Typing stub for the v1 responder. - Example: - ```python - with request_responder as resp: - await resp.respond(result) - ``` - - The context manager ensures: - 1. Proper cancellation scope setup and cleanup - 2. Request completion tracking - 3. Cleanup of in-flight requests + Never instantiated by the SDK: the client answers every server request + itself, so the `RequestResponder` arm of `MessageHandlerFnT` is + unreachable. The class remains so existing annotations and imports keep + working. """ - def __init__( - self, - request_id: RequestId, - request_meta: RequestParamsMeta | None, - request: ReceiveRequestT, - session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - message_metadata: MessageMetadata = None, - context: contextvars.Context | None = None, - ) -> None: - self.request_id = request_id - self.request_meta = request_meta - self.request = request - self.message_metadata = message_metadata - self.context = context - self._session = session - self._completed = False - self._entered = False # Track if we're in a context manager - - def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]: - self._entered = True - return self - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - self._entered = False - - async def respond(self, response: SendResultT | ErrorData) -> None: - """Send a response for this request. - - Must be called within a context manager block. - - Raises: - RuntimeError: If not used within a context manager - AssertionError: If request was already responded to - """ - if not self._entered: # pragma: no cover - raise RuntimeError("RequestResponder must be used as a context manager") - assert not self._completed, "Request already responded to" - self._completed = True - await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, response=response - ) - - -class BaseSession( - Generic[ - SendRequestT, - SendNotificationT, - SendResultT, - ReceiveRequestT, - ReceiveNotificationT, - ], -): - """Implements an MCP "session" on top of read/write streams, including features - like request/response linking, notifications, and progress. - - This class is an async context manager that automatically starts processing - messages when entered. - """ - - _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] - _request_id: int - _progress_callbacks: dict[RequestId, ProgressFnT] - - def __init__( - self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - # If none, reading will never time out - read_timeout_seconds: float | None = None, - ) -> None: - self._read_stream = read_stream - self._write_stream = write_stream - self._response_streams = {} - self._request_id = 0 - self._session_read_timeout_seconds = read_timeout_seconds - self._progress_callbacks = {} - self._exit_stack = AsyncExitStack() - - async def __aenter__(self) -> Self: - self._task_group = anyio.create_task_group() - await self._task_group.__aenter__() - self._task_group.start_soon(self._receive_loop) - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - await self._exit_stack.aclose() - # Using BaseSession as a context manager should not block on exit (this - # would be very surprising behavior), so make sure to cancel the tasks - # in the task group. - self._task_group.cancel_scope.cancel() - result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) - await resync_tracer() - return result - - async def send_request( - self, - request: SendRequestT, - result_type: type[ReceiveResultT], - request_read_timeout_seconds: float | None = None, - metadata: MessageMetadata = None, - progress_callback: ProgressFnT | None = None, - ) -> ReceiveResultT: - """Sends a request and waits for a response. - - Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take - precedence over the session read timeout. - - Do not use this method to emit notifications! Use send_notification() instead. - """ - request_id = self._request_id - self._request_id = request_id + 1 - - response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) - self._response_streams[request_id] = response_stream - - # Set up progress token if progress callback is provided - request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token - if "params" not in request_data: # pragma: lax no cover - request_data["params"] = {} - if "_meta" not in request_data["params"]: # pragma: lax no cover - request_data["params"]["_meta"] = {} - request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback - - try: - target = request_data.get("params", {}).get("name") - span_name = f"MCP send {request.method} {target}" if target else f"MCP send {request.method}" - - with otel_span( - span_name, - kind=SpanKind.CLIENT, - attributes={"mcp.method.name": request.method, "jsonrpc.request.id": str(request_id)}, - ): - # Inject W3C trace context into _meta (SEP-414). - meta: dict[str, Any] = request_data.setdefault("params", {}).setdefault("_meta", {}) - inject_trace_context(meta) - - jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data) - await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) - - # request read timeout takes precedence over session read timeout - timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) - - if isinstance(response_or_error, JSONRPCError): - raise MCPError.from_jsonrpc_error(response_or_error) - else: - return result_type.model_validate(response_or_error.result, by_name=False) - - finally: - self._response_streams.pop(request_id, None) - self._progress_callbacks.pop(request_id, None) - await response_stream.aclose() - await response_stream_reader.aclose() - - async def send_notification( - self, - notification: SendNotificationT, - related_request_id: RequestId | None = None, - ) -> None: - """Emits a notification, which is a one-way message that does not expect a response.""" - # Some transport implementations may need to set the related_request_id - # to attribute to the notifications to the request that triggered them. - jsonrpc_notification = JSONRPCNotification( - jsonrpc="2.0", - **notification.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage( - message=jsonrpc_notification, - metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, - ) - await self._write_stream.send(session_message) - - async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: - if isinstance(response, ErrorData): - jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=jsonrpc_error) - await self._write_stream.send(session_message) - else: - jsonrpc_response = JSONRPCResponse( - jsonrpc="2.0", - id=request_id, - result=response.model_dump(by_alias=True, mode="json", exclude_none=True), - ) - session_message = SessionMessage(message=jsonrpc_response) - await self._write_stream.send(session_message) - - @property - def _receive_request_adapter(self) -> TypeAdapter[ReceiveRequestT]: - """Each subclass must provide its own request adapter.""" - raise NotImplementedError - - @property - def _receive_request_methods(self) -> frozenset[str]: - """Method names in the receive-request union; anything else is - answered with METHOD_NOT_FOUND before validation is attempted.""" - raise NotImplementedError - - @property - def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]: - raise NotImplementedError - - async def _receive_loop(self) -> None: - async with self._read_stream, self._write_stream: - try: - - async def _handle_session_message(message: SessionMessage) -> None: - sender_context: contextvars.Context | None = getattr(self._read_stream, "last_context", None) - if isinstance(message.message, JSONRPCRequest): - if message.message.method not in self._receive_request_methods: - # Unknown methods are METHOD_NOT_FOUND (-32601) per - # JSON-RPC 2.0, not validation failures (-32602). - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.id, - error=ErrorData( - code=METHOD_NOT_FOUND, message="Method not found", data=message.message.method - ), - ) - await self._write_stream.send(SessionMessage(message=error_response)) - return - try: - validated_request = self._receive_request_adapter.validate_python( - message.message.model_dump(by_alias=True, mode="json", exclude_none=True), - by_name=False, - ) - responder = RequestResponder( - request_id=message.message.id, - request_meta=validated_request.params.meta if validated_request.params else None, - request=validated_request, - session=self, - message_metadata=message.metadata, - context=sender_context, - ) - await self._received_request(responder) - except Exception: - # For request validation errors, send a proper JSON-RPC error - # response instead of crashing the server - logging.warning("Failed to validate request", exc_info=True) - logging.debug(f"Message that failed validation: {message.message}") - error_response = JSONRPCError( - jsonrpc="2.0", - id=message.message.id, - error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), - ) - session_message = SessionMessage(message=error_response) - await self._write_stream.send(session_message) - - elif isinstance(message.message, JSONRPCNotification): - try: - notification = self._receive_notification_adapter.validate_python( - message.message.model_dump(by_alias=True, mode="json", exclude_none=True), - by_name=False, - ) - if isinstance(notification, CancelledNotification): - # ClientSession runs server-initiated requests - # inline in this loop, so by the time a peer - # cancellation is read there is nothing left to - # cancel. Consume it here so message_handler - # keeps the contract it had before the - # dispatcher swap removed _in_flight. - return - # Handle progress notifications callback - if isinstance(notification, ProgressNotification): - progress_token = notification.params.progress_token - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - try: - await callback( - notification.params.progress, - notification.params.total, - notification.params.message, - ) - except Exception: - logging.exception("Progress callback raised an exception") - await self._received_notification(notification) - await self._handle_incoming(notification) - except Exception: - # For other validation errors, log and continue - logging.warning( - "Failed to validate notification: %s", - message.message, - exc_info=True, - ) - else: # Response or error - await self._handle_response(message) - - async for message in self._read_stream: - if isinstance(message, Exception): - await self._handle_incoming(message) - continue - - await _handle_session_message(message) - - except anyio.ClosedResourceError: - # This is expected when the client disconnects abruptly. - # Without this handler, the exception would propagate up and - # crash the server's task group. - logging.debug("Read stream closed by client") - except Exception as e: - # Other exceptions are not expected and should be logged. We purposefully - # catch all exceptions here to avoid crashing the server. - logging.exception(f"Unhandled exception in receive loop: {e}") # pragma: no cover - finally: - # after the read stream is closed, we need to send errors - # to any pending requests - # Snapshot: stream.send() wakes the waiter, whose finally pops - # from _response_streams before the next __next__() call. - for id, stream in list(self._response_streams.items()): - error = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") - try: - await stream.send(JSONRPCError(jsonrpc="2.0", id=id, error=error)) - await stream.aclose() - except Exception: # pragma: lax no cover - # Stream might already be closed - pass - self._response_streams.clear() - - def _normalize_request_id(self, response_id: RequestId) -> RequestId: - """Normalize a response ID to match how request IDs are stored. - - Since the client always sends integer IDs, we normalize string IDs - to integers when possible. This matches the TypeScript SDK approach: - https://github.com/modelcontextprotocol/typescript-sdk/blob/a606fb17909ea454e83aab14c73f14ea45c04448/src/shared/protocol.ts#L861 - - Args: - response_id: The response ID from the incoming message. - - Returns: - The normalized ID (int if possible, otherwise original value). - """ - if isinstance(response_id, str): - try: - return int(response_id) - except ValueError: - logging.warning(f"Response ID {response_id!r} cannot be normalized to match pending requests") - return response_id - - async def _handle_response(self, message: SessionMessage) -> None: - """Handle an incoming response or error message.""" - # This check is always true at runtime: the caller (_receive_loop) only invokes - # this method in the else branch after checking for JSONRPCRequest and - # JSONRPCNotification. However, the type checker can't infer this from the - # method signature, so we need this guard for type narrowing. - if not isinstance(message.message, JSONRPCResponse | JSONRPCError): - return # pragma: no cover - - if message.message.id is None: - # Narrows to JSONRPCError since JSONRPCResponse.id is always RequestId - error = message.message.error - logging.warning(f"Received error with null ID: {error.message}") - await self._handle_incoming(MCPError(error.code, error.message, error.data)) - return - # Normalize response ID to handle type mismatches (e.g., "0" vs 0) - response_id = self._normalize_request_id(message.message.id) - - stream = self._response_streams.pop(response_id, None) - if stream: - await stream.send(message.message) - else: - await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) - - async def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None: - """Can be overridden by subclasses to handle a request without needing to - listen on the message stream. - - If the request is responded to within this method, it will not be - forwarded on to the message stream. - """ - - async def _received_notification(self, notification: ReceiveNotificationT) -> None: - """Can be overridden by subclasses to handle a notification without needing - to listen on the message stream. - """ - - async def send_progress_notification( - self, - progress_token: ProgressToken, - progress: float, - total: float | None = None, - message: str | None = None, - ) -> None: - """Sends a progress notification for a request that is currently being processed.""" - - async def _handle_incoming( - self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception - ) -> None: - """A generic handler for incoming messages. Overridden by subclasses.""" + request_id: RequestId + request_meta: RequestParamsMeta | None + request: ReceiveRequestT + message_metadata: MessageMetadata diff --git a/tests/client/test_resource_cleanup.py b/tests/client/test_resource_cleanup.py deleted file mode 100644 index c7bf8fafa..000000000 --- a/tests/client/test_resource_cleanup.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Any -from unittest.mock import patch - -import anyio -import pytest -from pydantic import TypeAdapter - -from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, RequestId, SendResultT -from mcp.types import ClientNotification, ClientRequest, ClientResult, EmptyResult, ErrorData, PingRequest - - -@pytest.mark.anyio -async def test_send_request_stream_cleanup(): - """Test that send_request properly cleans up streams when an exception occurs. - - This test mocks out most of the session functionality to focus on stream cleanup. - """ - - # Create a mock session with the minimal required functionality - class TestSession(BaseSession[ClientRequest, ClientNotification, ClientResult, Any, Any]): - async def _send_response( - self, request_id: RequestId, response: SendResultT | ErrorData - ) -> None: # pragma: no cover - pass - - @property - def _receive_request_adapter(self) -> TypeAdapter[Any]: - return TypeAdapter(object) # pragma: no cover - - @property - def _receive_notification_adapter(self) -> TypeAdapter[Any]: - return TypeAdapter(object) # pragma: no cover - - # Create streams - write_stream_send, write_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) - read_stream_send, read_stream_receive = anyio.create_memory_object_stream[SessionMessage](1) - - # Create the session - session = TestSession(read_stream_receive, write_stream_send) - - # Create a test request - request = PingRequest() - - # Patch the _write_stream.send method to raise an exception - async def mock_send(*args: Any, **kwargs: Any): - raise RuntimeError("Simulated network error") - - # Record the response streams before the test - initial_stream_count = len(session._response_streams) - - # Run the test with the patched method - with patch.object(session._write_stream, "send", mock_send): - with pytest.raises(RuntimeError): - await session.send_request(request, EmptyResult) - - # Verify that no response streams were leaked - assert len(session._response_streams) == initial_stream_count, ( - f"Expected {initial_stream_count} response streams after request, but found {len(session._response_streams)}" - ) - - # Clean up - await write_stream_send.aclose() - await write_stream_receive.aclose() - await read_stream_send.aclose() - await read_stream_receive.aclose() diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 28d212d00..9b4f20d83 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -17,6 +17,7 @@ from mcp.types import ( INVALID_PARAMS, LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, CallToolResult, Implementation, InitializedNotification, @@ -751,8 +752,34 @@ async def test_receive_loop_answers_malformed_inbound_request_with_invalid_param @pytest.mark.anyio -async def test_receive_loop_answers_invalid_params_when_sampling_callback_raises(): - """Same boundary catches exceptions from the request handler itself.""" +async def test_receive_loop_answers_unknown_request_method_with_method_not_found(): + """A server request whose method is not in the ServerRequest union gets -32601 + (METHOD_NOT_FOUND) on the wire, not a validation failure (-32602).""" + async with raw_client_session() as (_session, to_client, from_client): + await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown"))) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCError) + assert out.message.id == 7 + assert out.message.error == types.ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="x/unknown") + + +@pytest.mark.anyio +async def test_receive_loop_drops_unknown_notification_method_without_response(): + """An unknown notification method is dropped silently: JSON-RPC forbids + responses to notifications, and the receive loop keeps serving.""" + async with raw_client_session() as (_session, to_client, from_client): + await to_client.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="x/unknown"))) + # The next wire output must be the answer to this follow-up ping, + # proving the notification produced no response and the loop survived. + await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCResponse) + assert out.message.id == 1 + + +@pytest.mark.anyio +async def test_raising_sampling_callback_answers_with_code_zero(): + """A raising request callback is answered through the dispatcher's exception boundary.""" async def boom(ctx: object, params: object) -> types.CreateMessageResult: raise RuntimeError("sampling boom") @@ -767,7 +794,7 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult: ) out = await from_client.receive() assert isinstance(out.message, JSONRPCError) - assert out.message.error.code == INVALID_PARAMS + assert out.message.error == types.ErrorData(code=0, message="sampling boom") @pytest.mark.anyio @@ -841,23 +868,68 @@ async def handler(msg: object) -> None: @pytest.mark.anyio -async def test_receive_loop_swallows_progress_callback_exception(caplog: pytest.LogCaptureFixture): +async def test_progress_callback_exception_is_swallowed(caplog: pytest.LogCaptureFixture): delivered = anyio.Event() async def boom(progress: float, total: float | None, message: str | None) -> None: raise RuntimeError("progress boom") async def handler(msg: object) -> None: - delivered.set() + if isinstance(msg, types.ProgressNotification): + delivered.set() + + async with raw_client_session(message_handler=handler) as (session, to_client, from_client): + async with anyio.create_task_group() as tg: + + async def call() -> None: + await session.send_request(types.PingRequest(), types.EmptyResult, progress_callback=boom) + + tg.start_soon(call) + request = await from_client.receive() + assert isinstance(request.message, JSONRPCRequest) + # The request id doubles as the progress token. + params = {"progressToken": request.message.id, "progress": 0.5} + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) + ) + # The progress notification also reaches the message handler; the + # raising callback was swallowed and logged. + await delivered.wait() + await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request.message.id, result={}))) + assert "progress callback raised" in caplog.text - async with raw_client_session(message_handler=handler) as (session, to_client, _): - # Register the callback under a known token without sending a request. - session._progress_callbacks[42] = boom # pyright: ignore[reportPrivateUsage] - params = {"progressToken": 42, "progress": 0.5} - await to_client.send( - SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) - ) - # The progress notification also reaches the message handler after the - # callback runs, so this fires once the callback's exception is handled. - await delivered.wait() - assert "Progress callback raised an exception" in caplog.text + +@pytest.mark.anyio +async def test_from_dispatcher_runs_over_direct_dispatch(): + """A session built with from_dispatcher works without a stream pair (in-process embedding).""" + from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair + from mcp.shared.dispatcher import DispatchContext + from mcp.shared.transport_context import TransportContext + + client_side, server_side = create_direct_dispatcher_pair() + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> dict[str, object]: + assert method == "ping" + return {} + + notified: list[str] = [] + + async def server_on_notify( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> None: + notified.append(method) + + session = ClientSession.from_dispatcher(client_side) + results: list[types.EmptyResult] = [] + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, server_on_request, server_on_notify) + async with session: + results.append(await session.send_ping(meta=None)) + # related_request_id routing is JSON-RPC plumbing; on other + # dispatchers the notification is sent without it. + await session.send_notification(types.RootsListChangedNotification(), related_request_id=7) + server_side.close() + assert results == [types.EmptyResult()] + assert notified == ["notifications/roots/list_changed"] diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index e37940b8c..6bfa1dcbe 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -287,14 +287,6 @@ def __post_init__(self) -> None: "A response that arrives after the sender issued notifications/cancelled is ignored; the " "request stays failed and no error is raised." ), - divergence=Divergence( - note=( - "A response whose id matches no in-flight request is delivered to the message handler " - "as a RuntimeError rather than being silently ignored. The post-cancellation case is the " - "same code path; tested in its unknown-id form because that is deterministic without the " - "client-side cancellation API the SDK does not yet provide." - ), - ), ), "protocol:cancel:server-survives": Requirement( source="sdk", @@ -306,19 +298,6 @@ def __post_init__(self) -> None: "A server that abandons an in-flight server-initiated request (sampling, elicitation, roots) " "cancels it, and the client stops processing the cancelled request." ), - divergence=Divergence( - note=( - "Abandoning a server-side send_request emits no cancellation notification, and the client " - "could not act on one anyway: client callbacks run inline in the receive loop, so a " - "cancellation is not even read until the callback has finished." - ), - ), - deferred=( - "Not implemented in the SDK: abandoning a server-side send_request emits no cancellation " - "notification (the same sender-side gap recorded on protocol:timeout:sends-cancellation), and " - "the client could not act on one anyway because client callbacks run inline in the receive " - "loop, so a cancellation would not even be read until the callback had already finished." - ), ), "protocol:cancel:unknown-id-ignored": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#error-handling", @@ -466,13 +445,6 @@ def __post_init__(self) -> None: "When a request times out, the sender issues notifications/cancelled for that request before " "failing the local call." ), - divergence=Divergence( - note=( - "Client seat only: the client raises locally and sends nothing on timeout, so the server keeps " - "running the handler. The server seat conforms: a timed-out server-initiated request is followed " - "by notifications/cancelled on the wire." - ), - ), ), "protocol:timeout:session-survives": Requirement( source=f"{SPEC_BASE_URL}/basic/lifecycle#timeouts", diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 6f1454e58..22a4c546b 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -11,7 +11,7 @@ from inline_snapshot import snapshot from mcp import MCPError, types -from mcp.client import ClientSession +from mcp.client import ClientRequestContext, ClientSession from mcp.server import Server, ServerRequestContext from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage @@ -155,14 +155,70 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert result == snapshot(CallToolResult(content=[TextContent(text="unbothered")])) +@requirement("protocol:cancel:server-to-client") +async def test_abandoned_server_request_cancels_the_client_callback(connect: Connect) -> None: + """A server that abandons a sampling request cancels it, interrupting the client's callback. + + The handler gives up on its sampling request by cancelling the scope around it; the courtesy + notifications/cancelled that follows interrupts the client's sampling callback mid-await. + """ + callback_started = anyio.Event() + callback_cancelled = anyio.Event() + + async def sampling_callback( + context: ClientRequestContext, params: types.CreateMessageRequestParams + ) -> types.CreateMessageResult: + callback_started.set() + try: + await anyio.Event().wait() # blocks until the cancellation interrupts it + except anyio.get_cancelled_exc_class(): + callback_cancelled.set() + raise + raise NotImplementedError # unreachable + + async def list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="impatient", input_schema={"type": "object"})]) + + async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: + assert params.name == "impatient" + request = types.CreateMessageRequest( + params=types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=TextContent(text="Say hello."))], + max_tokens=8, + ) + ) + async with anyio.create_task_group() as abandon_scope: + + async def sample() -> None: + await ctx.session.send_request(request, types.CreateMessageResult) + raise NotImplementedError # unreachable: the scope is cancelled + + abandon_scope.start_soon(sample) + await callback_started.wait() + abandon_scope.cancel_scope.cancel() + with anyio.fail_after(5): + await callback_cancelled.wait() + return CallToolResult(content=[TextContent(text="abandoned")]) + + server = Server("abandoner", on_list_tools=list_tools, on_call_tool=call_tool) + + async with connect(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("impatient", {}) + + assert result == snapshot(CallToolResult(content=[TextContent(text="abandoned")])) + assert callback_cancelled.is_set() + + @requirement("protocol:cancel:late-response-ignored") -async def test_a_response_for_an_unknown_request_id_surfaces_to_the_message_handler() -> None: - """A response whose id matches no in-flight request is surfaced to the message handler as a RuntimeError. +async def test_a_response_for_an_unknown_request_id_is_ignored() -> None: + """A response whose id matches no in-flight request is ignored, as the spec asks. The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; that is the same client-side code path as any response with an unknown id, and that form is - deterministic to test without depending on the cancellation API the SDK does not yet provide. - See the divergence note on the requirement. + deterministic to test without depending on a client-side cancellation API. Nothing reaches + the message handler and the session keeps serving. A real Server cannot be made to answer with a fabricated id, so the test plays the server's side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The @@ -228,7 +284,6 @@ async def message_handler(message: IncomingMessage) -> None: pong = await session.send_request(PingRequest(), EmptyResult) assert pong == snapshot(EmptyResult()) - assert len(incoming) == 1 - assert isinstance(incoming[0], RuntimeError) - # The full message embeds the response object's repr; only the prefix is stable. - assert str(incoming[0]).startswith("Received response with an unknown request ID:") + # The fabricated response was dropped silently: the ping after it still + # round-tripped, and nothing was surfaced to the message handler. + assert incoming == [] diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index fba632ef4..d945c8e76 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -1,12 +1,15 @@ """Logging interactions against the low-level Server, driven through the public Client API. Notification ordering: the in-memory transport delivers every server-to-client message on one -ordered stream, and the client's receive loop dispatches each incoming message to completion -before reading the next one. Over streamable HTTP that ordered single-stream guarantee holds -only for messages that carry a ``related_request_id`` (they ride the originating request's POST -stream); without it the message routes to the standalone GET stream and may arrive after the -response. These tests pass ``related_request_id`` so they can collect into a plain list and -assert after the request completes on every transport leg -- no events, no waiting. +ordered stream, and the client starts notification callbacks in arrival order. Callbacks run +concurrently with the rest of the session (no completion-before-response guarantee), but a +callback with no internal awaits runs to completion as soon as it starts, which keeps +plain-list collection deterministic here. Over streamable HTTP the ordered single-stream +guarantee holds only for messages that carry a ``related_request_id`` (they ride the +originating request's POST stream); without it the message routes to the standalone GET stream +and may arrive after the response. These tests pass ``related_request_id`` and use await-free +callbacks so they can collect into a plain list and assert after the request completes on +every transport leg -- no events, no waiting. """ import pytest diff --git a/tests/interaction/lowlevel/test_progress.py b/tests/interaction/lowlevel/test_progress.py index 6350c33a3..a89039b99 100644 --- a/tests/interaction/lowlevel/test_progress.py +++ b/tests/interaction/lowlevel/test_progress.py @@ -87,8 +87,8 @@ async def ignore(progress: float, total: float | None, message: str | None) -> N async with connect(server) as client: result = await client.call_tool("inspect", {}, progress_callback=ignore) - # The token is the request id of the tools/call request itself (initialize is request 0). - assert result == snapshot(CallToolResult(content=[TextContent(text="1")])) + # The token is the request id of the tools/call request itself (initialize is request 1). + assert result == snapshot(CallToolResult(content=[TextContent(text="2")])) @requirement("protocol:progress:no-token") diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index 7c12e1c65..62caf7e81 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -30,16 +30,21 @@ async def test_request_timeout_fails_the_pending_call() -> None: """A request whose response does not arrive within its read timeout fails with a timeout error. - No cancellation is sent to the server (see the divergence note on the requirement): the handler - starts and is still running after the caller has already given up. The test waits for the - handler to have started only after the timeout has fired, so the timeout itself races nothing. + The timeout is followed by notifications/cancelled on the wire, so the server's handler is + interrupted instead of running to completion. The test waits for the handler to have started + only after the timeout has fired, so the timeout itself races nothing. """ handler_started = anyio.Event() + handler_cancelled = anyio.Event() async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "block" handler_started.set() - await anyio.Event().wait() # blocks until the session is torn down + try: + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it + except anyio.get_cancelled_exc_class(): + handler_cancelled.set() + raise raise NotImplementedError # unreachable server = Server("blocker", on_call_tool=call_tool) @@ -48,14 +53,16 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}, read_timeout_seconds=0.000001) - # The request was already on the wire: the handler still runs even though the caller gave up. + # The request was already on the wire, so the handler started; the courtesy + # cancellation that followed the timeout then interrupted it. with anyio.fail_after(5): await handler_started.wait() + await handler_cancelled.wait() assert exc_info.value.error == snapshot( ErrorData( code=REQUEST_TIMEOUT, - message="Timed out while waiting for response to CallToolRequest. Waited 1e-06 seconds.", + message="Request 'tools/call' timed out", ) ) @@ -183,6 +190,6 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara assert exc_info.value.error == snapshot( ErrorData( code=REQUEST_TIMEOUT, - message="Timed out while waiting for response to CallToolRequest. Waited 0.05 seconds.", + message="Request 'tools/call' timed out", ) ) diff --git a/tests/interaction/lowlevel/test_wire.py b/tests/interaction/lowlevel/test_wire.py index 0f9c58aa7..178c2c1c3 100644 --- a/tests/interaction/lowlevel/test_wire.py +++ b/tests/interaction/lowlevel/test_wire.py @@ -61,7 +61,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_request_ids_are_unique_and_never_null() -> None: """Every request the client sends carries a distinct, non-null id. - The id sequence is pinned: sequential integers from zero, in send order. + The id sequence is pinned: sequential integers from one, in send order. """ recording = RecordingTransport(InMemoryTransport(_echo_server())) @@ -77,7 +77,7 @@ async def test_request_ids_are_unique_and_never_null() -> None: assert len(request_ids) == len(set(request_ids)) # initialize, tools/list, tools/call, tools/call, ping -- the client does not issue a # schema-cache refresh here because the explicit tools/list already populated the cache. - assert request_ids == snapshot([0, 1, 2, 3, 4]) + assert request_ids == snapshot([1, 2, 3, 4, 5]) @requirement("protocol:notifications:no-response") diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 6b593d2a5..84c16430f 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -55,7 +55,9 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar assert params.name in ("slow", "fast"), f"Unknown tool: {params.name}" if params.name == "slow": - await slow_request_lock.wait() # it should timeout here + # The client's timeout fires while this waits; the courtesy + # cancellation then interrupts the wait. + await slow_request_lock.wait() text = f"slow {request_count}" else: text = f"fast {request_count}" @@ -95,9 +97,9 @@ async def client( # Use very small timeout to trigger quickly without waiting with pytest.raises(MCPError) as exc_info: await session.call_tool("slow", read_timeout_seconds=0.000001) # artificial timeout that always fails - assert "Timed out while waiting" in str(exc_info.value) + assert "timed out" in str(exc_info.value) - # release the slow request not to have hanging process + # No-op if the courtesy cancellation already interrupted the handler. slow_request_lock.set() # Third call should work (fast operation, no timeout), diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index 21352b5f2..60d30342c 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1098,10 +1098,10 @@ async def logging_tool(msg: str, ctx: Context) -> str: assert "Logged messages for test" in content.text assert mock_log.call_count == 4 - mock_log.assert_any_call(level="debug", data="Debug message", logger=None, related_request_id="1") - mock_log.assert_any_call(level="info", data="Info message", logger=None, related_request_id="1") - mock_log.assert_any_call(level="warning", data="Warning message", logger=None, related_request_id="1") - mock_log.assert_any_call(level="error", data="Error message", logger=None, related_request_id="1") + mock_log.assert_any_call(level="debug", data="Debug message", logger=None, related_request_id="2") + mock_log.assert_any_call(level="info", data="Info message", logger=None, related_request_id="2") + mock_log.assert_any_call(level="warning", data="Warning message", logger=None, related_request_id="2") + mock_log.assert_any_call(level="error", data="Error message", logger=None, related_request_id="2") async def test_optional_context(self): """Test that context is optional.""" diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py deleted file mode 100644 index 38f36d82c..000000000 --- a/tests/shared/test_session.py +++ /dev/null @@ -1,447 +0,0 @@ -import anyio -import pytest - -from mcp import Client, types -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.shared.exceptions import MCPError -from mcp.shared.memory import create_client_server_memory_streams -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ( - METHOD_NOT_FOUND, - PARSE_ERROR, - CancelledNotification, - CancelledNotificationParams, - ClientResult, - EmptyResult, - ErrorData, - JSONRPCError, - JSONRPCNotification, - JSONRPCRequest, - JSONRPCResponse, - ServerNotification, - ServerRequest, -) - - -@pytest.mark.anyio -async def test_request_cancellation(): - """Test that requests can be cancelled while in-flight.""" - ev_tool_called = anyio.Event() - ev_cancelled = anyio.Event() - request_id = None - - # Create a server with a slow tool - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - nonlocal request_id, ev_tool_called - if params.name == "slow_tool": - request_id = ctx.request_id - ev_tool_called.set() - await anyio.sleep(10) # Long enough to ensure we can cancel - return types.CallToolResult(content=[]) # pragma: no cover - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - raise NotImplementedError - - server = Server( - name="TestSessionServer", - on_call_tool=handle_call_tool, - on_list_tools=handle_list_tools, - ) - - async def make_request(client: Client): - nonlocal ev_cancelled - try: - await client.session.send_request( - types.CallToolRequest( - params=types.CallToolRequestParams(name="slow_tool", arguments={}), - ), - types.CallToolResult, - ) - pytest.fail("Request should have been cancelled") # pragma: no cover - except MCPError as e: - # Expected - request was cancelled - assert "Request cancelled" in str(e) - ev_cancelled.set() - - async with Client(server) as client: - async with anyio.create_task_group() as tg: # pragma: no branch - tg.start_soon(make_request, client) - - # Wait for the request to be in-flight - with anyio.fail_after(1): # Timeout after 1 second - await ev_tool_called.wait() - - # Send cancellation notification - assert request_id is not None - await client.session.send_notification( - CancelledNotification(params=CancelledNotificationParams(request_id=request_id)) - ) - - # Give cancellation time to process - with anyio.fail_after(1): # pragma: no branch - await ev_cancelled.wait() - - -@pytest.mark.anyio -async def test_response_id_type_mismatch_string_to_int(): - """Test that responses with string IDs are correctly matched to requests sent with - integer IDs. - - This handles the case where a server returns "id": "0" (string) but the client - sent "id": 0 (integer). Without ID type normalization, this would cause a timeout. - """ - ev_response_received = anyio.Event() - result_holder: list[types.EmptyResult] = [] - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Receive a request and respond with a string ID instead of integer.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - root = message.message - assert isinstance(root, JSONRPCRequest) - # Get the original request ID (which is an integer) - request_id = root.id - assert isinstance(request_id, int), f"Expected int, got {type(request_id)}" - - # Respond with the ID as a string (simulating a buggy server) - response = JSONRPCResponse( - jsonrpc="2.0", - id=str(request_id), # Convert to string to simulate mismatch - result={}, - ) - await server_write.send(SessionMessage(message=response)) - - async def make_request(client_session: ClientSession): - nonlocal result_holder - # Send a ping request (uses integer ID internally) - result = await client_session.send_ping() - result_holder.append(result) - ev_response_received.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_response_received.wait() - - assert len(result_holder) == 1 - assert isinstance(result_holder[0], EmptyResult) - - -@pytest.mark.anyio -async def test_error_response_id_type_mismatch_string_to_int(): - """Test that error responses with string IDs are correctly matched to requests - sent with integer IDs. - - This handles the case where a server returns an error with "id": "0" (string) - but the client sent "id": 0 (integer). - """ - ev_error_received = anyio.Event() - error_holder: list[MCPError | Exception] = [] - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Receive a request and respond with an error using a string ID.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - root = message.message - assert isinstance(root, JSONRPCRequest) - request_id = root.id - assert isinstance(request_id, int) - - # Respond with an error, using the ID as a string - error_response = JSONRPCError( - jsonrpc="2.0", - id=str(request_id), # Convert to string to simulate mismatch - error=ErrorData(code=-32600, message="Test error"), - ) - await server_write.send(SessionMessage(message=error_response)) - - async def make_request(client_session: ClientSession): - nonlocal error_holder - try: - await client_session.send_ping() - pytest.fail("Expected MCPError to be raised") # pragma: no cover - except MCPError as e: - error_holder.append(e) - ev_error_received.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - - assert len(error_holder) == 1 - assert "Test error" in str(error_holder[0]) - - -@pytest.mark.anyio -async def test_response_id_non_numeric_string_no_match(): - """Test that responses with non-numeric string IDs don't incorrectly match - integer request IDs. - - If a server returns "id": "abc" (non-numeric string), it should not match - a request sent with "id": 0 (integer). - """ - ev_timeout = anyio.Event() - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Receive a request and respond with a non-numeric string ID.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - - # Respond with a non-numeric string ID (should not match) - response = JSONRPCResponse( - jsonrpc="2.0", - id="not_a_number", # Non-numeric string - result={}, - ) - await server_write.send(SessionMessage(message=response)) - - async def make_request(client_session: ClientSession): - try: - # Use a short timeout since we expect this to fail - await client_session.send_request( - types.PingRequest(), - types.EmptyResult, - request_read_timeout_seconds=0.5, - ) - pytest.fail("Expected timeout") # pragma: no cover - except MCPError as e: - assert "Timed out" in str(e) - ev_timeout.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_timeout.wait() - - -@pytest.mark.anyio -async def test_connection_closed(): - """Test that pending requests are cancelled when the connection is closed remotely.""" - - ev_closed = anyio.Event() - ev_response = anyio.Event() - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def make_request(client_session: ClientSession): - """Send a request in a separate task""" - nonlocal ev_response - try: - # any request will do - await client_session.initialize() - pytest.fail("Request should have errored") # pragma: no cover - except MCPError as e: - # Expected - request errored - assert "Connection closed" in str(e) - ev_response.set() - - async def mock_server(): - """Wait for a request, then close the connection""" - nonlocal ev_closed - # Wait for a request - await server_read.receive() - # Close the connection, as if the server exited - server_write.close() - server_read.close() - ev_closed.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession(read_stream=client_read, write_stream=client_write) as client_session, - ): - tg.start_soon(make_request, client_session) - tg.start_soon(mock_server) - - with anyio.fail_after(1): - await ev_closed.wait() - with anyio.fail_after(1): # pragma: no branch - await ev_response.wait() - - -@pytest.mark.anyio -async def test_null_id_error_surfaced_via_message_handler(): - """Test that a JSONRPCError with id=None is surfaced to the message handler. - - Per JSON-RPC 2.0, error responses use id=null when the request id could not - be determined (e.g., parse errors). These cannot be correlated to any pending - request, so they are forwarded to the message handler as MCPError. - """ - ev_error_received = anyio.Event() - error_holder: list[MCPError] = [] - - async def capture_errors( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - assert isinstance(message, MCPError) - error_holder.append(message) - ev_error_received.set() - - sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - _server_read, server_write = server_streams - - async def mock_server(): - """Send a null-id error (simulating a parse error).""" - error_response = JSONRPCError(jsonrpc="2.0", id=None, error=sent_error) - await server_write.send(SessionMessage(message=error_response)) - - async with ( - anyio.create_task_group() as tg, - ClientSession( - read_stream=client_read, - write_stream=client_write, - message_handler=capture_errors, - ) as _client_session, - ): - tg.start_soon(mock_server) - - with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - - assert len(error_holder) == 1 - assert error_holder[0].error == sent_error - - -@pytest.mark.anyio -async def test_null_id_error_does_not_affect_pending_request(): - """Test that a null-id error doesn't interfere with an in-flight request. - - When a null-id error arrives while a request is pending, the error should - go to the message handler and the pending request should still complete - normally with its own response. - """ - ev_error_received = anyio.Event() - ev_response_received = anyio.Event() - error_holder: list[MCPError] = [] - result_holder: list[EmptyResult] = [] - - async def capture_errors( - message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception, - ) -> None: - assert isinstance(message, MCPError) - error_holder.append(message) - ev_error_received.set() - - sent_error = ErrorData(code=PARSE_ERROR, message="Parse error") - - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async def mock_server(): - """Read a request, inject a null-id error, then respond normally.""" - message = await server_read.receive() - assert isinstance(message, SessionMessage) - assert isinstance(message.message, JSONRPCRequest) - request_id = message.message.id - - # First, send a null-id error (should go to message handler) - await server_write.send(SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=sent_error))) - - # Then, respond normally to the pending request - await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) - - async def make_request(client_session: ClientSession): - result = await client_session.send_ping() - result_holder.append(result) - ev_response_received.set() - - async with ( - anyio.create_task_group() as tg, - ClientSession( - read_stream=client_read, - write_stream=client_write, - message_handler=capture_errors, - ) as client_session, - ): - tg.start_soon(mock_server) - tg.start_soon(make_request, client_session) - - with anyio.fail_after(2): # pragma: no branch - await ev_error_received.wait() - await ev_response_received.wait() - - # Null-id error reached the message handler - assert len(error_holder) == 1 - assert error_holder[0].error == sent_error - - # Pending request completed successfully - assert len(result_holder) == 1 - assert isinstance(result_holder[0], EmptyResult) - - -@pytest.mark.anyio -async def test_receive_loop_answers_unknown_request_method_with_method_not_found(): - """A peer request whose method is not in the receive union gets -32601 - (METHOD_NOT_FOUND) on the wire, not a validation failure (-32602).""" - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async with ClientSession(read_stream=client_read, write_stream=client_write): - await server_write.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown"))) - with anyio.fail_after(5): # pragma: no branch - out = await server_read.receive() - - assert isinstance(out, SessionMessage) - assert isinstance(out.message, JSONRPCError) - assert out.message.id == 7 - assert out.message.error == ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="x/unknown") - - -@pytest.mark.anyio -async def test_receive_loop_drops_unknown_notification_method_without_response(): - """An unknown notification method is dropped silently: JSON-RPC forbids - responses to notifications, and the receive loop keeps serving.""" - async with create_client_server_memory_streams() as (client_streams, server_streams): - client_read, client_write = client_streams - server_read, server_write = server_streams - - async with ClientSession(read_stream=client_read, write_stream=client_write): - await server_write.send(SessionMessage(message=JSONRPCNotification(jsonrpc="2.0", method="x/unknown"))) - # The next wire output must be the answer to this follow-up ping, - # proving the notification produced no response and the loop survived. - await server_write.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) - with anyio.fail_after(5): # pragma: no branch - out = await server_read.receive() - - assert isinstance(out, SessionMessage) - assert isinstance(out.message, JSONRPCResponse) - assert out.message.id == 1 From ad471bed43a8ccfdbb7ca94ee6983b103d260f52 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:57:42 +0000 Subject: [PATCH 06/24] Cover the remaining client-session branches and document the migration Adds tests for ServerMessageMetadata routing, related-request-id notifications, and params-absent inbound requests over direct dispatch, plus the migration-guide entry for the ClientSession dispatcher swap. --- docs/migration.md | 18 +++++++++++++++- tests/client/test_session.py | 41 ++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/docs/migration.md b/docs/migration.md index 850e05255..d690d6a1b 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1164,7 +1164,23 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr `BaseSession._in_flight` and the `RequestResponder` members that supported it (`cancel()`, the `cancelled` and `in_flight` properties, the `on_complete` constructor argument, and the internal `CancelScope`) have been removed. These existed to let `ServerSession` cancel a handler when a `CancelledNotification` arrived; `ServerSession` no longer drives a receive loop, so they were dead code. Inbound-cancellation handling for the server now lives in `JSONRPCDispatcher`. -`BaseSession` is still used by `ClientSession`, which never relied on these members. `RequestResponder.respond()` is unchanged. +`BaseSession` itself has since been removed entirely; see the next section. + +### `ClientSession` now runs on `JSONRPCDispatcher`; `BaseSession` removed + +`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new `ClientSession.from_dispatcher(dispatcher, ...)` constructor accepts a pre-built dispatcher (for example a `DirectDispatcher` for in-process embedding). + +Behavior changes: + +- **Request ids count from 1** (previously 0). Progress tokens, which reuse the request id, shift the same way. Ids are opaque per JSON-RPC; do not assign meaning to them. +- **Timeouts**: the error message is now `Request 'tools/call' timed out` (previously `Timed out while waiting for response to CallToolRequest. Waited N seconds.`), and a timed-out or abandoned request is followed by `notifications/cancelled` on the wire, so the server stops the handler instead of leaving it running. The `initialize` request is never cancelled this way, and requests sent with resumption metadata are also exempt so they stay resumable. +- **Server-initiated requests run concurrently.** Sampling/elicitation/roots callbacks no longer serialize the receive loop: a slow callback does not block other traffic, a callback may itself send requests without deadlocking, and a server's `notifications/cancelled` now actually interrupts the callback (the request is then answered with an error response). +- **Notification callbacks are concurrent.** `logging_callback` and `message_handler` start in arrival order, but there is no completion-before-response guarantee (matching the TypeScript, C#, and Go SDKs). Callbacks that need strict sequencing must coordinate themselves. +- **Unknown-id responses are ignored**, as the spec asks. v1 surfaced them to `message_handler` as a `RuntimeError`; nothing is surfaced now. +- **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. +- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. `send_notification` before entry still works. + +`mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists. ### Experimental Tasks support removed diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 9b4f20d83..bc9d7e585 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -927,9 +927,50 @@ async def server_on_notify( await tg.start(server_side.run, server_on_request, server_on_notify) async with session: results.append(await session.send_ping(meta=None)) + # Server-to-client direction: direct dispatch delivers ping with no + # params member at all (no _meta injection outside JSON-RPC). + assert await server_side.send_raw_request("ping", None) == {} # related_request_id routing is JSON-RPC plumbing; on other # dispatchers the notification is sent without it. await session.send_notification(types.RootsListChangedNotification(), related_request_id=7) server_side.close() assert results == [types.EmptyResult()] assert notified == ["notifications/roots/list_changed"] + + +@pytest.mark.anyio +async def test_send_request_with_server_metadata_routes_related_request_id(): + """ServerMessageMetadata.related_request_id is threaded onto the outgoing message.""" + from mcp.shared.message import ServerMessageMetadata + + async with raw_client_session() as (session, to_client, from_client): + async with anyio.create_task_group() as tg: + + async def call() -> None: + await session.send_request( + types.PingRequest(), types.EmptyResult, metadata=ServerMessageMetadata(related_request_id=3) + ) + + tg.start_soon(call) + out = await from_client.receive() + assert isinstance(out.metadata, ServerMessageMetadata) + assert out.metadata.related_request_id == 3 + assert isinstance(out.message, JSONRPCRequest) + await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={}))) + + +@pytest.mark.anyio +async def test_send_notification_with_related_request_id_attaches_metadata(): + """A related_request_id on a notification rides the originating request's stream.""" + from mcp.shared.message import ServerMessageMetadata + + async with raw_client_session() as (session, _to_client, from_client): + await session.send_notification( + types.ProgressNotification( + params=types.ProgressNotificationParams(progress_token=1, progress=0.5), + ), + related_request_id=4, + ) + out = await from_client.receive() + assert isinstance(out.metadata, ServerMessageMetadata) + assert out.metadata.related_request_id == 4 From 2d4c5cd06fef2defe475841b7b2e2d0a3e015fc6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 15:57:52 +0000 Subject: [PATCH 07/24] Stop logging an error when the standalone SSE stream closes mid-listen Transport teardown closes the standalone stream's send side first, so a writer parked in receive() ends on a clean end-of-stream; but when teardown lands while the writer is between dequeues, the next receive() raises ClosedResourceError, which fell into the catch-all and logged a traceback at ERROR level for a routine disconnect. Catch it and end quietly. A new test pins the close ordering that keeps the parked path clean. --- src/mcp/server/streamable_http.py | 6 ++++ tests/shared/test_streamable_http.py | 42 ++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 220d46f9a..f269fc6c4 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -717,6 +717,12 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) + except anyio.ClosedResourceError: # pragma: lax no cover + # Teardown completed while the writer was between dequeues: + # the next receive() hits the closed stream. A writer parked + # in receive() instead sees a clean end-of-stream (cleanup + # closes the send side first), so this arm is timing-dependent. + pass except Exception: logger.exception("Error in standalone SSE writer") # pragma: no cover finally: diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 7db7e68fb..9bfe6d2d0 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -28,6 +28,7 @@ from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( + GET_STREAM_KEY, MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, @@ -2224,3 +2225,44 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_ assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_standalone_stream_teardown_mid_listen_is_not_an_error(caplog: pytest.LogCaptureFixture) -> None: + """Tearing down the standalone stream under its parked writer produces no error log. + + Cleanup closes the send side first, so a writer parked in receive() ends on a clean + end-of-stream. This pins that close ordering: reversing it would wake the parked writer + with ClosedResourceError on every disconnect. (The timing window where teardown lands + between dequeues is handled by the writer's ClosedResourceError arm, which cannot be + forced deterministically from the public surface.) + """ + session_manager = StreamableHTTPSessionManager( + app=_create_server(), + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + app = Starlette(routes=[Mount("/mcp", app=session_manager.handle_request)]) + notified = anyio.Event() + + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, + ) -> None: + if isinstance(message, types.ResourceUpdatedNotification): + notified.set() + + async with session_manager.run(): + async with ( + make_client(app) as http_client, + streamable_http_client(f"{BASE_URL}/mcp", http_client=http_client) as (read_stream, write_stream), + ClientSession(read_stream, write_stream, message_handler=message_handler) as session, + ): + await session.initialize() + # Prove the standalone GET writer is live: a notification with no + # related request rides the GET stream to the client. + await session.call_tool("test_tool_with_standalone_notification", {}) + with anyio.fail_after(5): + await notified.wait() + # Tear the standalone stream down while the writer is parked on it. + (transport,) = session_manager._server_instances.values() # pyright: ignore[reportPrivateUsage] + await transport._clean_up_memory_streams(GET_STREAM_KEY) # pyright: ignore[reportPrivateUsage] + assert "Error in standalone SSE writer" not in caplog.text From 36a091d863ace10f7785e52c790a71756c1c8295 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:11:24 +0000 Subject: [PATCH 08/24] Accept a pre-built dispatcher via a constructor keyword Replaces the from_dispatcher classmethod: read_stream/write_stream become optional and dispatcher is a keyword-only alternative, with mutual exclusion validated at construction. Drops the __new__-based alternate constructor and its shared state-init helper. --- docs/migration.md | 2 +- src/mcp/client/session.py | 100 ++++++++++------------------------- tests/client/test_session.py | 27 ++++++++-- 3 files changed, 53 insertions(+), 76 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index d690d6a1b..1b09b27c7 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1168,7 +1168,7 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr ### `ClientSession` now runs on `JSONRPCDispatcher`; `BaseSession` removed -`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new `ClientSession.from_dispatcher(dispatcher, ...)` constructor accepts a pre-built dispatcher (for example a `DirectDispatcher` for in-process embedding). +`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new keyword-only `dispatcher=` constructor argument accepts a pre-built dispatcher instead of the stream pair (for example a `DirectDispatcher` for in-process embedding). Behavior changes: diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9975018fc..d79afa7e1 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -117,19 +117,25 @@ async def _default_logging_callback( class ClientSession: - """Client half of an MCP connection, running on `JSONRPCDispatcher`. - - Construct it over a transport's stream pair, enter it as an async context - manager, then call `initialize()`. The receive loop, request correlation, - and per-request concurrency live in the dispatcher; this class owns the - MCP type layer: typed requests, the initialize handshake, and routing - server-initiated traffic to the constructor callbacks. + """Client half of an MCP connection, running on a `Dispatcher`. + + Construct it over a transport's stream pair (or pass a pre-built + `dispatcher=` instead, e.g. a `DirectDispatcher` for in-process + embedding), enter it as an async context manager, then call + `initialize()`. The receive loop, request correlation, and per-request + concurrency live in the dispatcher; this class owns the MCP type layer: + typed requests, the initialize handshake, and routing server-initiated + traffic to the constructor callbacks. + + Transport-level `Exception` items reach `message_handler` only when the + session builds its own dispatcher from streams, where it wires the + dispatcher's `on_stream_exception` itself. """ def __init__( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], + read_stream: ReadStream[SessionMessage | Exception] | None = None, + write_stream: WriteStream[SessionMessage] | None = None, read_timeout_seconds: float | None = None, sampling_callback: SamplingFnT | None = None, elicitation_callback: ElicitationFnT | None = None, @@ -139,69 +145,7 @@ def __init__( client_info: types.Implementation | None = None, *, sampling_capabilities: types.SamplingCapability | None = None, - ) -> None: - self._init_state( - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - elicitation_callback=elicitation_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - sampling_capabilities=sampling_capabilities, - ) - # Built here (inert until run() starts in __aenter__) so notifications - # can be sent before entering the context manager, as before. - self._dispatcher: Dispatcher[Any] = JSONRPCDispatcher( - read_stream, write_stream, on_stream_exception=self._on_stream_exception - ) - - @classmethod - def from_dispatcher( - cls, - dispatcher: Dispatcher[Any], - *, - read_timeout_seconds: float | None = None, - sampling_callback: SamplingFnT | None = None, - elicitation_callback: ElicitationFnT | None = None, - list_roots_callback: ListRootsFnT | None = None, - logging_callback: LoggingFnT | None = None, - message_handler: MessageHandlerFnT | None = None, - client_info: types.Implementation | None = None, - sampling_capabilities: types.SamplingCapability | None = None, - ) -> Self: - """Build a session over a pre-built dispatcher instead of a stream pair. - - For embedding a server in-process (`DirectDispatcher`) or transports - that construct their own dispatcher. Transport-level `Exception` items - reach `message_handler` only on the stream constructor, where the - session wires the dispatcher's `on_stream_exception` itself. - """ - self = cls.__new__(cls) - self._init_state( - read_timeout_seconds=read_timeout_seconds, - sampling_callback=sampling_callback, - elicitation_callback=elicitation_callback, - list_roots_callback=list_roots_callback, - logging_callback=logging_callback, - message_handler=message_handler, - client_info=client_info, - sampling_capabilities=sampling_capabilities, - ) - self._dispatcher = dispatcher - return self - - def _init_state( - self, - *, - read_timeout_seconds: float | None, - sampling_callback: SamplingFnT | None, - elicitation_callback: ElicitationFnT | None, - list_roots_callback: ListRootsFnT | None, - logging_callback: LoggingFnT | None, - message_handler: MessageHandlerFnT | None, - client_info: types.Implementation | None, - sampling_capabilities: types.SamplingCapability | None, + dispatcher: Dispatcher[Any] | None = None, ) -> None: self._session_read_timeout_seconds = read_timeout_seconds self._client_info = client_info or DEFAULT_CLIENT_INFO @@ -214,6 +158,18 @@ def _init_state( self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} self._initialize_result: types.InitializeResult | None = None self._task_group: anyio.abc.TaskGroup | None = None + if dispatcher is not None: + if read_stream is not None or write_stream is not None: + raise ValueError("pass read_stream/write_stream or dispatcher, not both") + self._dispatcher: Dispatcher[Any] = dispatcher + else: + if read_stream is None or write_stream is None: + raise ValueError("read_stream and write_stream are required when no dispatcher is given") + # Built here (inert until run() starts in __aenter__) so notifications + # can be sent before entering the context manager, as before. + self._dispatcher = JSONRPCDispatcher( + read_stream, write_stream, on_stream_exception=self._on_stream_exception + ) async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() diff --git a/tests/client/test_session.py b/tests/client/test_session.py index bc9d7e585..aaaa20375 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -900,8 +900,8 @@ async def call() -> None: @pytest.mark.anyio -async def test_from_dispatcher_runs_over_direct_dispatch(): - """A session built with from_dispatcher works without a stream pair (in-process embedding).""" +async def test_dispatcher_keyword_runs_over_direct_dispatch(): + """A session built with dispatcher= works without a stream pair (in-process embedding).""" from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import DispatchContext from mcp.shared.transport_context import TransportContext @@ -921,7 +921,7 @@ async def server_on_notify( ) -> None: notified.append(method) - session = ClientSession.from_dispatcher(client_side) + session = ClientSession(dispatcher=client_side) results: list[types.EmptyResult] = [] async with anyio.create_task_group() as tg: await tg.start(server_side.run, server_on_request, server_on_notify) @@ -938,6 +938,27 @@ async def server_on_notify( assert notified == ["notifications/roots/list_changed"] +def test_constructor_rejects_streams_and_dispatcher_together(): + from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair + + client_side, _server_side = create_direct_dispatcher_pair() + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + with pytest.raises(ValueError, match="not both"): + ClientSession(s2c_recv, dispatcher=client_side) + s2c_send.close() + s2c_recv.close() + + +def test_constructor_requires_both_streams_without_dispatcher(): + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + with pytest.raises(ValueError, match="read_stream and write_stream are required"): + ClientSession(s2c_recv) + with pytest.raises(ValueError, match="read_stream and write_stream are required"): + ClientSession() + s2c_send.close() + s2c_recv.close() + + @pytest.mark.anyio async def test_send_request_with_server_metadata_routes_related_request_id(): """ServerMessageMetadata.related_request_id is threaded onto the outgoing message.""" From dbd1693994f1aa5218ab69fd7fb40ab977cca24b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 17:13:11 +0000 Subject: [PATCH 09/24] Restore full coverage after the BaseSession deletion Covers MCPError.from_jsonrpc_error and the context-stream sync close() methods, whose only exercisers died with BaseSession and its tests, and restructures three test handler arms that could never take their false branch. --- tests/client/test_session.py | 5 +++-- .../interaction/lowlevel/test_cancellation.py | 7 ++----- tests/shared/test_context_streams.py | 20 +++++++++++++++++++ tests/shared/test_exceptions.py | 13 +++++++++++- tests/shared/test_streamable_http.py | 5 +++-- 5 files changed, 40 insertions(+), 10 deletions(-) create mode 100644 tests/shared/test_context_streams.py diff --git a/tests/client/test_session.py b/tests/client/test_session.py index aaaa20375..62f9eaa6f 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -875,8 +875,9 @@ async def boom(progress: float, total: float | None, message: str | None) -> Non raise RuntimeError("progress boom") async def handler(msg: object) -> None: - if isinstance(msg, types.ProgressNotification): - delivered.set() + # Only the progress notification is teed to the message handler here. + assert isinstance(msg, types.ProgressNotification) + delivered.set() async with raw_client_session(message_handler=handler) as (session, to_client, from_client): async with anyio.create_task_group() as tg: diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 22a4c546b..60ca80b41 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -268,10 +268,8 @@ def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage await server_write.send(respond(9999, EmptyResult())) await server_write.send(respond(ping.message.id, EmptyResult())) - incoming: list[IncomingMessage] = [] - async def message_handler(message: IncomingMessage) -> None: - incoming.append(message) + raise NotImplementedError # unreachable: nothing is surfaced for an unknown-id response async with ( create_client_server_memory_streams() as ((client_read, client_write), server_streams), @@ -285,5 +283,4 @@ async def message_handler(message: IncomingMessage) -> None: assert pong == snapshot(EmptyResult()) # The fabricated response was dropped silently: the ping after it still - # round-tripped, and nothing was surfaced to the message handler. - assert incoming == [] + # round-tripped, and the message handler (a tripwire) was never invoked. diff --git a/tests/shared/test_context_streams.py b/tests/shared/test_context_streams.py new file mode 100644 index 000000000..b03589230 --- /dev/null +++ b/tests/shared/test_context_streams.py @@ -0,0 +1,20 @@ +"""Tests for the contextvars-carrying memory-stream wrappers.""" + +import anyio +import pytest + +from mcp.shared._context_streams import create_context_streams + +pytestmark = pytest.mark.anyio + + +async def test_sync_close_closes_the_underlying_streams() -> None: + """The wrappers mirror anyio's memory streams: close() is the sync form of aclose().""" + send, receive = create_context_streams[str](1) + await send.send("queued") + send.close() + receive.close() + with pytest.raises(anyio.ClosedResourceError): + await send.send("after close") + with pytest.raises(anyio.ClosedResourceError): + await receive.receive() diff --git a/tests/shared/test_exceptions.py b/tests/shared/test_exceptions.py index 9a7466264..c6b575092 100644 --- a/tests/shared/test_exceptions.py +++ b/tests/shared/test_exceptions.py @@ -3,7 +3,7 @@ import pytest from mcp.shared.exceptions import MCPError, UrlElicitationRequiredError -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData +from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError def test_url_elicitation_required_error_create_with_single_elicitation() -> None: @@ -162,3 +162,14 @@ def test_url_elicitation_required_error_exception_message() -> None: # The exception's string representation should match the message assert str(error) == "URL elicitation required" + + +def test_from_jsonrpc_error_preserves_code_message_and_data() -> None: + """Building an MCPError from a wire JSONRPCError keeps every error field.""" + wire = JSONRPCError( + jsonrpc="2.0", + id=3, + error=ErrorData(code=URL_ELICITATION_REQUIRED, message="go elsewhere", data={"hint": "y"}), + ) + error = MCPError.from_jsonrpc_error(wire) + assert error.error == ErrorData(code=URL_ELICITATION_REQUIRED, message="go elsewhere", data={"hint": "y"}) diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 9bfe6d2d0..c23ade097 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -2247,8 +2247,9 @@ async def test_standalone_stream_teardown_mid_listen_is_not_an_error(caplog: pyt async def message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: - if isinstance(message, types.ResourceUpdatedNotification): - notified.set() + # Only the standalone-stream notification is teed to the handler here. + assert isinstance(message, types.ResourceUpdatedNotification) + notified.set() async with session_manager.run(): async with ( From c4720d315b7acb3e4b4fd5616dcc808445d96d68 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 22:29:55 +0000 Subject: [PATCH 10/24] Tighten dispatcher abandon-path writes and response invariants - Bound the timeout-path courtesy cancel like the other abandon writes and extract a single _final_write policy: 5s for courtesy cancels, 1s for shutdown responses so closing a session stays fast on a wedged transport - Answer shutdown-interrupted requests with CONNECTION_CLOSED and retire the REQUEST_CANCELLED constant (-32002 collides with resource-not-found) - Key the peer-cancel error response on cancelled_caught so a cancel landing after the handler finished cannot produce a second answer for the same id - Decide outbound metadata and cancel-on-abandon suppression in one place: only resumption hints that actually reach the transport suppress the courtesy cancel - Never send notifications/cancelled for a request whose write never completed - Identity-guard the in-flight pop so a finished handler cannot evict a newer entry that reused its request id - Map request-write stream failures to MCPError(CONNECTION_CLOSED); warn when a bounded final write gives up; mark the Dispatcher lifecycle provisional --- src/mcp/shared/dispatcher.py | 11 +- src/mcp/shared/jsonrpc_dispatcher.py | 243 +++++++--- src/mcp/types/__init__.py | 2 - src/mcp/types/jsonrpc.py | 1 - tests/shared/test_jsonrpc_dispatcher.py | 621 +++++++++++++++++++++++- 5 files changed, 784 insertions(+), 94 deletions(-) diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index c6e421651..820422e6d 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -61,8 +61,11 @@ class CallOptions(TypedDict, total=False): A request is abandoned when its `timeout` elapses or the caller's scope is cancelled while awaiting the response. Defaults to `True`. Set `False` for requests the protocol forbids cancelling, such as `initialize`. The - notification is also suppressed when resumption hints are present: the - caller intends to resume the request, so the peer's work must keep running. + notification is also suppressed when resumption hints actually reach the + transport (the caller intends to resume the request, so the peer's work + must keep running); hints ignored in favor of dispatch-context routing do + not suppress it. No notification is sent for a request that was never + written to the transport. """ on_progress: ProgressFnT @@ -197,6 +200,10 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): Implementations own correlation of outbound requests to inbound results, the receive loop, per-request concurrency, and cancellation/progress wiring. + + The protocol's lifecycle surface is provisional and expected to change + before v2 stable (`run()` may be superseded by an `open()`/`wait_closed()` + pair). """ async def run( diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 63c431950..ef09b870d 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -24,6 +24,7 @@ import logging from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass, field +from functools import partial from typing import Any, Generic, Literal, cast import anyio @@ -48,7 +49,6 @@ CONNECTION_CLOSED, INTERNAL_ERROR, INVALID_PARAMS, - REQUEST_CANCELLED, REQUEST_TIMEOUT, ErrorData, JSONRPCError, @@ -65,11 +65,22 @@ logger = logging.getLogger(__name__) _SHIELDED_WRITE_TIMEOUT: float = 5 -"""Bound for the shielded courtesy writes on the cancellation paths. +"""Bound for the courtesy writes on the timeout and cancellation paths. -Those writes run inside a shield because the surrounding scope is already -cancelled; without a bound, a wedged transport write would turn the shield -into an uncancellable hang (and block shutdown indefinitely).""" +The cancellation-path writes run inside a shield because the surrounding +scope is already cancelled; without a bound, a wedged transport write would +turn the shield into an uncancellable hang (and block shutdown indefinitely). +The timeout-path courtesy cancel is unshielded (its scope is not cancelled) +but shares the bound so a wedged transport can't delay the timeout error +indefinitely.""" + +_SHUTDOWN_WRITE_TIMEOUT: float = 1 +"""Bound for the shutdown-arm error response write in `_handle_request`. + +Tighter than `_SHIELDED_WRITE_TIMEOUT` because session close must be quick: +the write is a courtesy answer to a request the shutdown is abandoning, so a +wedged transport may delay close by at most ~1s rather than holding teardown +for the full courtesy-cancel bound.""" TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) @@ -85,8 +96,9 @@ def _coerce_id(request_id: RequestId) -> RequestId: """Coerce a string request ID to int when it's a valid int literal. `_allocate_id` only ever produces `int` keys for `_pending`, but a peer - may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` - both perform this coercion at lookup time so the response still correlates. + may echo the ID back as a JSON string. The TypeScript SDK performs this + coercion at lookup time (as v1's `BaseSession` did) so the response still + correlates. """ if isinstance(request_id, str): try: @@ -203,8 +215,18 @@ async def _wrapped(dctx: DispatchContext[TransportContext], method: str, params: return _wrapped -def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: - """Choose the `SessionMessage.metadata` for an outgoing request/notification. +@dataclass(slots=True, frozen=True) +class _OutboundPlan: + """One decision about an outgoing message: what reaches the transport, and + whether abandoning the request sends a courtesy `notifications/cancelled`.""" + + metadata: MessageMetadata + cancel_on_abandon: bool + + +def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | None) -> _OutboundPlan: + """Choose the `SessionMessage.metadata` for an outgoing request/notification + and the matching abandon-cancellation policy. `ServerMessageMetadata` tags a server-to-client message with the inbound request it belongs to (so streamable-HTTP can route it onto that request's @@ -215,19 +237,30 @@ def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | `related_request_id` is set it takes precedence and any resumption hints in `opts` are dropped (with a debug log): requests made from a dispatch context are routed onto the inbound request's stream, not resumed. + + The same decision fixes `cancel_on_abandon`: an abandoned request sends a + courtesy `notifications/cancelled` unless the caller opted out, or the + resumption hints actually reach the transport (the caller intends to + resume, so the peer's work must keep running). Hints dropped here do NOT + suppress the cancel - a request that is neither resumable nor cancelled + would leak the peer's work. """ + opts = opts or {} + cancel_on_abandon = opts.get("cancel_on_abandon", True) + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") if related_request_id is not None: - if opts and (opts.get("resumption_token") is not None or opts.get("on_resumption_token") is not None): + if token is not None or on_token is not None: logger.debug( "dropping resumption hints: related_request_id %r takes precedence on metadata", related_request_id ) - return ServerMessageMetadata(related_request_id=related_request_id) - if opts: - token = opts.get("resumption_token") - on_token = opts.get("on_resumption_token") - if token is not None or on_token is not None: - return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) - return None + return _OutboundPlan(ServerMessageMetadata(related_request_id=related_request_id), cancel_on_abandon) + if token is not None or on_token is not None: + return _OutboundPlan( + ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token), + cancel_on_abandon=False, + ) + return _OutboundPlan(None, cancel_on_abandon) class JSONRPCDispatcher(Dispatcher[TransportT]): @@ -248,6 +281,35 @@ def __init__( inline_methods: frozenset[str] = frozenset(), on_stream_exception: Callable[[Exception], Awaitable[None]] | None = None, ) -> None: + """Wire a dispatcher over a transport's `SessionMessage` stream pair. + + Args: + read_stream: Inbound messages from the peer; `Exception` items are + transport-level read faults (see `on_stream_exception`). + write_stream: Outbound messages to the peer. + transport_builder: Builds the per-message `TransportContext` from + the inbound `SessionMessage.metadata`. Defaults to a plain + always-routable JSON-RPC context. + peer_cancel_mode: How inbound `notifications/cancelled` is applied + to a running handler; see `PeerCancelMode`. + raise_handler_exceptions: Re-raise handler exceptions out of + `run()` after the error response is written, instead of + containing them at the exception-to-wire boundary. + inline_methods: Request methods handled inline in the read loop + (awaited before the next message is dequeued) instead of + spawned concurrently. Use for methods whose side effects must + be observable to the next message, e.g. `initialize`, so a + pipelined follow-up sees the initialized state. Only suitable + for handlers that complete quickly, since inline handling + blocks dequeuing; a handler that awaits the peer + (`send_raw_request`) while inline will deadlock because the + parked read loop cannot dequeue the response. + on_stream_exception: Observer for `Exception` items the transport + yields on the read stream (SSE/streamable-HTTP connection + faults, stdio parse errors). Without it they are debug-logged + and dropped. Awaited in the read loop and contained: a raising + observer costs the item, not the connection. + """ self._read_stream = read_stream self._write_stream = write_stream # When `transport_builder` is omitted, `TransportT` falls back to its @@ -259,19 +321,7 @@ def __init__( ) self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode self._raise_handler_exceptions = raise_handler_exceptions - # Request methods handled inline in the read loop (awaited before the - # next message is dequeued) instead of spawned concurrently. Use for - # methods whose side effects must be observable to the next message, - # e.g. `initialize`, so a pipelined follow-up sees the initialized state. - # Only suitable for handlers that complete quickly, since inline handling - # blocks dequeuing; a handler that awaits the peer (`send_raw_request`) - # while inline will deadlock because the parked read loop cannot dequeue - # the response. self._inline_methods = inline_methods - # Observer for Exception items the transport yields on the read stream - # (SSE/streamable-HTTP connection faults, stdio parse errors). Without - # it they are debug-logged and dropped. Awaited in the read loop and - # contained: a raising observer costs the item, not the connection. self._on_stream_exception = on_stream_exception self._next_id = 0 @@ -298,8 +348,9 @@ async def send_raw_request( Raises: MCPError: The peer responded with a JSON-RPC error; or `REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; or - `CONNECTION_CLOSED` if the dispatcher shut down while - awaiting the response. + `CONNECTION_CLOSED` if the transport closed before the request + could be written, or the dispatcher shut down while awaiting + the response. RuntimeError: Called before `run()` has started or after it has finished. """ @@ -328,19 +379,17 @@ async def send_raw_request( pending = _Pending(send=send, receive=receive, on_progress=on_progress) self._pending[request_id] = pending - # An abandoned request (timeout elapsed, or the caller's scope was - # cancelled while awaiting the response) sends a courtesy - # `notifications/cancelled` so the peer can stop work - unless the - # caller opted out (`initialize`, which the spec forbids cancelling), - # or the request carries resumption hints (the caller intends to - # resume it, so the peer's work must keep running). - cancel_on_abandon = ( - opts.get("cancel_on_abandon", True) - and opts.get("resumption_token") is None - and opts.get("on_resumption_token") is None - ) + # One decision covers both what metadata reaches the transport and + # whether abandoning this request (timeout elapsed, or the caller's + # scope cancelled while awaiting the response) sends a courtesy + # `notifications/cancelled`; see `_plan_outbound`. + plan = _plan_outbound(_related_request_id, opts) + # Spec MUST: only previously-issued requests may be cancelled, so the + # courtesy cancel is armed only once the request write completes - a + # caller cancelled mid-write must not announce a cancel for a request + # the peer never received. + request_written = False - metadata = _outbound_metadata(_related_request_id, opts) target = out_params.get("name") span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}" # TODO(maxisbey): the otel span + inject below mirror @@ -358,23 +407,46 @@ async def send_raw_request( # present on the wire (and the interaction suite pins that). inject_trace_context(out_meta) msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) - await self._write(msg, metadata) + try: + await self._write(msg, plan.metadata) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + # The transport tore down before run() noticed EOF; surface + # the documented contract, not the raw stream error. + raise MCPError(code=CONNECTION_CLOSED, message="Connection closed") from None + request_written = True with anyio.fail_after(opts.get("timeout")): outcome = await receive.receive() except TimeoutError: # Spec-recommended courtesy: tell the peer we've given up so it can # stop work and free resources. v1's BaseSession.send_request does - # NOT do this; it's new behaviour. - if cancel_on_abandon: - await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id) + # NOT do this; it's new behaviour. Unshielded: this scope is not + # cancelled, so an outer caller cancellation must still be able to + # interrupt the write. + if plan.cancel_on_abandon and request_written: + await self._final_write( + partial( + self._cancel_outbound, + request_id, + f"timed out after {opts.get('timeout')}s", + _related_request_id, + ), + shield=False, + timeout=_SHIELDED_WRITE_TIMEOUT, + describe=f"courtesy cancel for timed-out request {request_id!r}", + ) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): # Our caller's scope was cancelled. We're already inside a cancelled - # scope, so any bare `await` here re-raises immediately - shield - # (bounded) to let the courtesy cancel go out before we propagate. - if cancel_on_abandon: - with anyio.move_on_after(_SHIELDED_WRITE_TIMEOUT, shield=True): - await self._cancel_outbound(request_id, "caller cancelled", _related_request_id) + # scope, so any bare `await` here re-raises immediately - the + # shielded (bounded) helper lets the courtesy cancel go out before + # we propagate. + if plan.cancel_on_abandon and request_written: + await self._final_write( + partial(self._cancel_outbound, request_id, "caller cancelled", _related_request_id), + shield=True, + timeout=_SHIELDED_WRITE_TIMEOUT, + describe=f"courtesy cancel for caller-cancelled request {request_id!r}", + ) raise finally: # Always remove the waiter, even on cancel/timeout, so a late @@ -403,7 +475,7 @@ async def notify( msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params)) else: msg = JSONRPCNotification(jsonrpc="2.0", method=method) - await self._write(msg, _outbound_metadata(_related_request_id, None)) + await self._write(msg, _plan_outbound(_related_request_id, None).metadata) async def run( self, @@ -690,26 +762,40 @@ async def _handle_request( # later calls `dctx.send_raw_request()` should see # `NoBackChannelError`) and drop from `_in_flight` so a # late `notifications/cancelled` is a no-op rather than - # racing the result write below. No checkpoint between - # handler return and the pop, so the cancel can't - # interleave there. + # racing the result write below. Identity-guarded: a + # duplicate inbound id blind-overwrites the table entry + # (see `_dispatch_request`), and this pop must not evict + # the newer request's entry - that would leave it + # peer-uncancellable. No checkpoint between handler return + # and the pop, so the cancel can't interleave there. dctx.close() - self._in_flight.pop(_coerce_id(req.id), None) + key = _coerce_id(req.id) + if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: + del self._in_flight[key] await self._write_result(req.id, result) - if scope.cancel_called: + if scope.cancelled_caught: # Peer-cancel: `_dispatch_notification` cancelled this scope - # while the handler was running. anyio swallows a scope's *own* - # cancel at __exit__, so execution lands here rather than the - # `except cancelled` arm below. + # and the cancellation was actually absorbed at __exit__, i.e. + # the result write above did not happen. (`cancel_called` + # alone is not enough: a cancel that lands after the handler's + # last checkpoint is never delivered, the handler completes + # and the result write can succeed - answering again here + # would put a second response for an already-answered id on + # the wire when the write stream doesn't checkpoint.) # TODO(maxisbey): spec says SHOULD NOT respond after cancel. # The existing server always has, so match that for now. await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): - # Outer-cancel: run()'s task group is shutting down. Any bare - # `await` here re-raises immediately, so shield (bounded) the - # courtesy write. - with anyio.move_on_after(_SHIELDED_WRITE_TIMEOUT, shield=True): - await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + # Outer-cancel: run()'s task group is shutting down. Answer the + # request so the peer is not left waiting on a connection that is + # going away; the helper shields (bounded) the write because any + # bare `await` here re-raises immediately. + await self._final_write( + partial(self._write_error, req.id, ErrorData(code=CONNECTION_CLOSED, message="Connection closed")), + shield=True, + timeout=_SHUTDOWN_WRITE_TIMEOUT, + describe=f"shutdown error response for request {req.id!r}", + ) raise except MCPError as e: await self._write_error(req.id, e.error) @@ -753,6 +839,31 @@ async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped error for %r: write stream closed", request_id) + async def _final_write( + self, + write: Callable[[], Awaitable[None]], + *, + shield: bool, + timeout: float, + describe: str, + ) -> None: + """Attempt one last write under the shared abandon/teardown policy. + + Every arm that writes to the transport after giving up on a request + (timeout courtesy cancel, caller-cancel courtesy cancel, shutdown + error response) goes through here so the bound+shield+warning policy + cannot diverge between them. `shield=True` is for arms already inside + a cancelled scope (a bare `await` would re-raise immediately); the + bound keeps a wedged transport write from turning the shield into an + uncancellable hang. An unshielded arm shares the bound so a wedged + transport can't delay its caller's error indefinitely, while staying + interruptible from outside. + """ + with anyio.move_on_after(timeout, shield=shield) as scope: + await write() + if scope.cancelled_caught: + logger.warning("%s gave up: transport write blocked", describe) + async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: # Thread `related_request_id` so streamable-HTTP routes the cancel onto # the same per-request SSE stream as the request it cancels; without it diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index cb49ff29d..b2d537fb7 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -152,7 +152,6 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, - REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -320,7 +319,6 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", - "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 14743c33b..84304a37c 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,7 +43,6 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 -REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 8633840ec..0b4ae48b2 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -10,6 +10,7 @@ import json import logging from collections.abc import Mapping +from types import TracebackType from typing import Any import anyio @@ -25,8 +26,9 @@ from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] JSONRPCDispatcher, _coerce_id, - _outbound_metadata, + _OutboundPlan, _Pending, + _plan_outbound, ) from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext @@ -55,6 +57,36 @@ DCtx = DispatchContext[TransportContext] +class RecordingWriteStream: + """Write stream that records sends synchronously, without a checkpoint. + + Models a transport write that can complete without yielding (a memory + stream's `send` checkpoints first, which would let a pending cancellation + interrupt the write and mask the behavior under test). `__aexit__` + releases nothing, so writes during run() teardown still land. + """ + + def __init__(self) -> None: + self.sent: list[SessionMessage] = [] + + async def send(self, item: SessionMessage) -> None: + self.sent.append(item) + + async def aclose(self) -> None: + raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + + async def __aenter__(self) -> "RecordingWriteStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + @pytest.mark.anyio async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): release_first = anyio.Event() @@ -131,6 +163,58 @@ async def call_then_record() -> None: assert seen_error == [ErrorData(code=0, message="Request cancelled")] +@pytest.mark.anyio +async def test_peer_cancel_landing_after_handlers_last_checkpoint_writes_only_the_result(): + """A peer cancel that fails to interrupt the handler must not add a code-0 + error after the result: exactly one answer for that id goes on the wire. + + SDK-defined: the cancelled-error response belongs only when the + cancellation was actually absorbed, i.e. the result write did not happen. + The schedule is deterministic because the cancel notification itself is + the handler's wakeup: the read loop sets `ctx.cancel_requested` and then + cancels the scope in the same synchronous block, so anyio defers the + cancellation (the wakeup future is already done) and the handler runs to + completion. The recording write stream is needed because a memory + stream's `send` checkpoints, which would let the deferred cancellation + land mid-write and hide the double answer. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + recording = RecordingWriteStream() + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) + handler_started = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + return {"completed": "after-cancel"} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + await handler_started.wait() + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1} + ) + ) + ) + # Quiesce: the handler has resumed, completed, and exited its scope. + await anyio.wait_all_tasks_blocked() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + assert [m.message for m in recording.sent] == [ + JSONRPCResponse(jsonrpc="2.0", id=1, result={"completed": "after-cancel"}) + ] + + @pytest.mark.anyio async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): handler_started = anyio.Event() @@ -415,6 +499,57 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() +@pytest.mark.anyio +async def test_dispatch_context_request_with_dropped_resumption_hints_still_sends_courtesy_cancel(): + """Resumption hints that never reach the transport must not suppress the abandon cancel. + + For a dispatch-context request, `related_request_id` takes metadata + precedence and the hints are dropped - so the request is not resumable, + and abandoning it without a courtesy cancel would leak the peer's work + forever. One decision (`_plan_outbound`) now produces both the metadata + and the cancel policy, so they cannot disagree. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await ctx.send_raw_request("sampling/createMessage", None, {"timeout": 0, "resumption_token": "tok"}) + return {"gave_up": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + # The hints were dropped: dispatch-context routing won the metadata. + assert isinstance(outbound.metadata, ServerMessageMetadata) + sampling_id = outbound.message.id + # Don't respond; let the timeout fire. Next on the wire must be the courtesy cancel. + with anyio.fail_after(5): + cancel = await s2c_recv.receive() + assert isinstance(cancel, SessionMessage) + assert isinstance(cancel.message, JSONRPCNotification) + assert cancel.message.method == "notifications/cancelled" + assert cancel.message.params == {"requestId": sampling_id, "reason": "timed out after 0s"} + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.result == {"gave_up": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_caller_cancel_sends_courtesy_cancellation_on_the_wire(): """Cancelling the scope around send_raw_request emits notifications/cancelled by default.""" @@ -457,6 +592,79 @@ async def caller() -> None: assert scopes[0].cancelled_caught +@pytest.mark.anyio +async def test_caller_cancel_during_blocked_request_write_sends_no_cancelled_notification(): + """A caller cancelled while the request write is still blocked must not emit + `notifications/cancelled`: the spec only allows cancelling previously-issued + requests, and this one never reached the peer. + + The fake stream wedges only the first write (the request) and records any + later one synchronously, so a courtesy cancel - which would be the bug - + is captured even though it runs inside the bounded shield. + """ + + class FirstWriteWedgedStream: + def __init__(self) -> None: + self.sent: list[SessionMessage] = [] + self.first_write_started = anyio.Event() + + async def send(self, item: SessionMessage) -> None: + if not self.first_write_started.is_set(): + self.first_write_started.set() + await anyio.sleep_forever() # the request write wedges until the caller is cancelled + self.sent.append(item) + + async def aclose(self) -> None: + raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + + async def __aenter__(self) -> "FirstWriteWedgedStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + wedged = FirstWriteWedgedStream() + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, wedged) + on_request, on_notify = echo_handlers(Recorder()) + + scopes: list[anyio.CancelScope] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with anyio.CancelScope() as scope: + scopes.append(scope) + await client.send_raw_request("slow", None) + raise NotImplementedError # unreachable: the scope is cancelled + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + with anyio.fail_after(5): + await wedged.first_write_started.wait() # the caller is parked in the request write + scopes[0].cancel() + with anyio.fail_after(5): + await gave_up.wait() + # Prove the recorder is live: a marker write after the wedge IS + # captured, so a courtesy cancel would have been too. + await client.notify("notifications/marker", None) + tg.cancel_scope.cancel() + finally: + s2c_send.close() + s2c_recv.close() + assert scopes[0].cancelled_caught + # The marker is the only post-wedge write: no cancel notification went out + # for a request the peer never received. + assert [m.message for m in wedged.sent] == [JSONRPCNotification(jsonrpc="2.0", method="notifications/marker")] + + @pytest.mark.anyio async def test_caller_cancel_with_resumption_hints_suppresses_the_courtesy_cancellation(): """A request sent with resumption hints is meant to be resumed; abandoning it must not stop the peer's work.""" @@ -570,14 +778,17 @@ async def test_cancel_on_abandon_false_suppresses_the_courtesy_cancellation_on_t [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], ) @pytest.mark.anyio -async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wedged(): +async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wedged( + caplog: pytest.LogCaptureFixture, +): """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang. - The write stream has no buffer and no reader, so the request write blocks; cancelling the - caller then routes into the shielded courtesy-cancel write, which blocks on the same wedged - stream. The bound abandons it after _SHIELDED_WRITE_TIMEOUT; trio's virtual clock makes the - wait instant. On regression (unbounded shield) the test hangs rather than failing fast: the - outer fail_after cannot cancel through the shield - that is the bug. + The peer consumes exactly the request (arming the courtesy cancel) and never responds; + cancelling the caller then routes into the shielded courtesy-cancel write, which blocks on + the unbuffered, unread write stream. The bound abandons it after _SHIELDED_WRITE_TIMEOUT + (with a warning); trio's virtual clock makes the wait instant. On regression (unbounded + shield) the test hangs rather than failing fast: the outer fail_after cannot cancel through + the shield - that is the bug. """ c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) @@ -595,11 +806,17 @@ async def caller() -> None: gave_up.set() try: + # Both bounds must exceed the in-loop _SHIELDED_WRITE_TIMEOUT (5s) they + # wait out; the virtual clock means no wall-time cost. with anyio.fail_after(30): async with anyio.create_task_group() as tg: # pragma: no branch await tg.start(client.run, on_request, on_notify) tg.start_soon(caller) - await anyio.wait_all_tasks_blocked() # the caller is parked in the request write + # Consume exactly the request so its write completes; the later + # courtesy cancel finds no reader and wedges. + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) scopes[0].cancel() with anyio.fail_after(20): await gave_up.wait() @@ -608,6 +825,236 @@ async def caller() -> None: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): s.close() assert scopes[0].cancelled_caught + # The warning proves it was the bound (not a completed write) that released the shield. + assert "courtesy cancel for caller-cancelled request" in caplog.text + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +@pytest.mark.anyio +async def test_timeout_courtesy_cancel_write_is_bounded_when_the_transport_is_wedged( + caplog: pytest.LogCaptureFixture, +): + """A wedged transport write cannot delay the REQUEST_TIMEOUT error indefinitely (SDK-defined bound). + + The peer consumes exactly the request and never responds; when the timeout + elapses, the courtesy cancel blocks on the unbuffered, unread write stream. + The bound abandons it after _SHIELDED_WRITE_TIMEOUT (with a warning) so the + timeout error still surfaces; trio's virtual clock makes the waits instant. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + + errors: list[MCPError] = [] + gave_up = anyio.Event() + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 1}) + errors.append(exc.value) + gave_up.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + tg.start_soon(caller) + # Consume exactly the request so its write completes; the later + # courtesy cancel finds no reader and wedges. + with anyio.fail_after(5): + request = await c2s_recv.receive() + assert isinstance(request, SessionMessage) + assert isinstance(request.message, JSONRPCRequest) + # Must exceed the request timeout (1s) plus the in-loop + # _SHIELDED_WRITE_TIMEOUT (5s); the virtual clock means no wall-time cost. + with anyio.fail_after(10): + await gave_up.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert errors[0].error.code == REQUEST_TIMEOUT + assert "courtesy cancel for timed-out request" in caplog.text + + +@pytest.mark.parametrize( + "anyio_backend", + [pytest.param(("trio", {"clock": MockClock(autojump_threshold=0)}), id="trio-mockclock")], +) +@pytest.mark.anyio +async def test_shutdown_error_response_write_is_bounded_when_the_transport_is_wedged( + caplog: pytest.LogCaptureFixture, +): + """Cancelling the task group hosting run() completes even when the shutdown error write wedges (SDK-defined bound). + + The in-flight handler is parked when run() is cancelled; its shielded + connection-closed-error write blocks on a wedged transport, and only the + _SHUTDOWN_WRITE_TIMEOUT bound lets the join complete. A fake write stream is + needed because a memory stream can't express the wedge: run()'s teardown + closes its own write stream, which would wake the blocked send. On + regression (unbounded shield) the test hangs rather than failing fast: the + outer fail_after cannot cancel through the shield - that is the bug. + """ + + class WedgedWriteStream: + async def send(self, item: SessionMessage) -> None: + await anyio.sleep_forever() + + async def aclose(self) -> None: + raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + + async def __aenter__(self) -> "WedgedWriteStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, WedgedWriteStream()) + handler_started = anyio.Event() + + async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await anyio.sleep_forever() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + # Sits above the in-loop _SHUTDOWN_WRITE_TIMEOUT (1s) it waits out but + # below _SHIELDED_WRITE_TIMEOUT (5s), so this also pins that the + # shutdown arm uses the tighter shutdown bound (session close must be + # quick); the virtual clock means no wall-time cost. + with anyio.fail_after(3): + async with anyio.create_task_group() as tg: # pragma: no branch + await tg.start(server.run, park, on_notify) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + await handler_started.wait() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + # Reaching here proves the join completed; the warning proves it was the + # bound (not a completed write) that released it. + assert "shutdown error response for request" in caplog.text + + +@pytest.mark.anyio +async def test_shutdown_answers_in_flight_request_with_connection_closed(): + """Cancelling run() answers a still-running request with CONNECTION_CLOSED. + + SDK-defined contract: the peer learns its request died with the connection + (not a request-specific cancellation - -32002 belongs to the spec's + resource-not-found). The recording write stream keeps the teardown write + observable: run()'s exit would close a memory stream before the shielded + write lands. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + recording = RecordingWriteStream() + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) + handler_started = anyio.Event() + + async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await anyio.sleep_forever() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, park, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + await handler_started.wait() + tg.cancel_scope.cancel() + finally: + c2s_send.close() + c2s_recv.close() + assert [m.message for m in recording.sent] == [ + JSONRPCError(jsonrpc="2.0", id=1, error=ErrorData(code=CONNECTION_CLOSED, message="Connection closed")) + ] + + +@pytest.mark.anyio +async def test_request_write_failure_propagates_and_leaves_no_pending_entry(): + """A request whose transport write raises must not leak its `_pending` entry. + + SDK-defined: regression cover for v1's `test_send_request_stream_cleanup` + (response streams were cleaned up when the write failed). + """ + boom = RuntimeError("write failed") + + class RaisingWriteStream: + async def send(self, item: SessionMessage) -> None: + raise boom + + async def aclose(self) -> None: + raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + + async def __aenter__(self) -> "RaisingWriteStream": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> bool | None: + return None + + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, RaisingWriteStream()) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5), pytest.raises(RuntimeError) as exc: + await client.send_raw_request("ping", None) + assert exc.value is boom + assert client._pending == {} # pyright: ignore[reportPrivateUsage] + tg.cancel_scope.cancel() + finally: + s2c_send.close() + s2c_recv.close() + + +@pytest.mark.anyio +async def test_request_write_on_torn_down_transport_raises_connection_closed(): + """The transport tearing down before run() notices EOF surfaces as MCPError, not a raw stream error. + + SDK-defined: `send_raw_request` documents MCPError(CONNECTION_CLOSED) for a + closed connection; the raw `BrokenResourceError` from the write must not leak. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + # Tear down the peer's receive end only: the client's read stream + # stays open, so run() has not observed EOF when the write fails. + c2s_recv.close() + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() @pytest.mark.anyio @@ -1170,12 +1617,14 @@ def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): s.close() -def test_outbound_metadata_with_resumption_token_returns_client_metadata(): - md = _outbound_metadata(None, {"resumption_token": "abc"}) - assert isinstance(md, ClientMessageMetadata) - assert md.resumption_token == "abc" - assert _outbound_metadata(None, None) is None - assert _outbound_metadata(None, {}) is None +def test_plan_outbound_with_resumption_token_returns_client_metadata_and_suppresses_abandon_cancel(): + """Hints that reach the transport make the request resumable, so abandoning it must not cancel the peer's work.""" + plan = _plan_outbound(None, {"resumption_token": "abc"}) + assert isinstance(plan.metadata, ClientMessageMetadata) + assert plan.metadata.resumption_token == "abc" + assert plan.cancel_on_abandon is False + assert _plan_outbound(None, None) == _OutboundPlan(metadata=None, cancel_on_abandon=True) + assert _plan_outbound(None, {}) == _OutboundPlan(metadata=None, cancel_on_abandon=True) @pytest.mark.anyio @@ -1209,6 +1658,47 @@ async def respond_stringly() -> None: s.close() +@pytest.mark.anyio +async def test_error_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A peer that echoes the request ID as a JSON string on a JSONRPCError still resolves the waiter. + + Same `_coerce_id` treatment as the success-response path: the peer's error + surfaces as MCPError instead of the request hanging until the connection closes. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def reject_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCError( + jsonrpc="2.0", id=str(rid), error=ErrorData(code=INVALID_PARAMS, message="bad cursor") + ) + ) + ) + + tg.start_soon(reject_stringly) + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" # the peer's error, passed through + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) @@ -1261,7 +1751,7 @@ def test_coerce_id_passes_through_non_numeric_string_and_int(): @pytest.mark.anyio async def test_jsonrpc_error_response_with_null_id_is_dropped(): - """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + """Parse-error responses (id=null) have no waiter; they're dropped and the read loop stays healthy.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1272,7 +1762,19 @@ async def test_jsonrpc_error_response_with_null_id_is_dropped(): await s2c_send.send( SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) ) - await anyio.sleep(0) + with anyio.fail_after(5): + # The read stream is ordered: this round-trip completing proves + # the null-id error was consumed without killing the loop. + async def respond() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={"ok": True})) + ) + + tg.start_soon(respond) + assert await client.send_raw_request("ping", None) == {"ok": True} tg.cancel_scope.cancel() finally: for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): @@ -1604,20 +2106,93 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> s.close() -def test_outbound_metadata_with_related_request_id_drops_resumption_hints_with_debug_log( +@pytest.mark.anyio +async def test_duplicate_request_id_completion_of_first_handler_keeps_second_cancellable(): + """When a duplicate inbound id overwrites `_in_flight` while the first + handler is still running, the first handler's completion must not evict + the second's entry - that would leave the second request immune to + `notifications/cancelled`. + + SDK-defined: the spec puts id uniqueness on the sender and the dispatcher + blind-overwrites on duplicates (parity with v1/TS); the in-table pop is + identity-guarded so a stale handler only removes its own entry. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + first_started = anyio.Event() + release_first = anyio.Event() + second_started = anyio.Event() + second_exited = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + first_started.set() + await release_first.wait() + return {"first": True} + second_started.set() + try: + await anyio.sleep_forever() + finally: + second_exited.set() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + pass # the cancelled notification is teed here; nothing to observe + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + with anyio.fail_after(5): + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="first"))) + await first_started.wait() + # Duplicate id while the first handler is still running: the + # table entry now belongs to the second request. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="second"))) + await second_started.wait() + release_first.set() + resp1 = await s2c_recv.receive() + assert isinstance(resp1, SessionMessage) + assert isinstance(resp1.message, JSONRPCResponse) + assert resp1.message.result == {"first": True} + # Let the first handler task run past its pop entirely. + await anyio.wait_all_tasks_blocked() + assert 7 in server._in_flight # pyright: ignore[reportPrivateUsage] + # The surviving entry must still be cancellable by the peer. + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 7} + ) + ) + ) + resp2 = await s2c_recv.receive() + assert isinstance(resp2, SessionMessage) + assert isinstance(resp2.message, JSONRPCError) + assert resp2.message.error == ErrorData(code=0, message="Request cancelled") + assert second_exited.is_set() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_plan_outbound_with_related_request_id_drops_resumption_hints_but_keeps_abandon_cancel( caplog: pytest.LogCaptureFixture, ): """`SessionMessage.metadata` carries one object; `related_request_id` wins - and resumption hints are dropped observably (debug log).""" + and resumption hints are dropped observably (debug log). Dropped hints do + not suppress the abandon cancel: the request is not resumable.""" with caplog.at_level(logging.DEBUG, logger="mcp.shared.jsonrpc_dispatcher"): - md = _outbound_metadata(7, {"resumption_token": "abc"}) - assert isinstance(md, ServerMessageMetadata) - assert md.related_request_id == 7 + plan = _plan_outbound(7, {"resumption_token": "abc"}) + assert isinstance(plan.metadata, ServerMessageMetadata) + assert plan.metadata.related_request_id == 7 + assert plan.cancel_on_abandon is True assert "dropping resumption hints" in caplog.text caplog.clear() with caplog.at_level(logging.DEBUG, logger="mcp.shared.jsonrpc_dispatcher"): - md = _outbound_metadata(7, {"timeout": 1.0}) - assert isinstance(md, ServerMessageMetadata) + plan = _plan_outbound(7, {"timeout": 1.0}) + assert isinstance(plan.metadata, ServerMessageMetadata) assert "dropping resumption hints" not in caplog.text From c67b8d17b54ac82b82e3028626a7337d82a2d66c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 22:30:04 +0000 Subject: [PATCH 11/24] Harden ClientSession enter/exit and fault delivery - Unwind the entered task group when __aenter__ is cancelled while the dispatcher is starting, instead of abandoning its cancel scope - Deliver transport-level exceptions to message_handler concurrently and contained, like notifications, so a handler that awaits session I/O no longer deadlocks the read loop - Route related_request_id=0 correctly in send_notification (ids are opaque) - Document the dispatcher= constructor path in send_request's contract --- src/mcp/client/session.py | 55 +++++++-- tests/client/test_session.py | 227 +++++++++++++++++++++++++++++++---- 2 files changed, 249 insertions(+), 33 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index d79afa7e1..9b4494997 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -129,7 +129,10 @@ class ClientSession: Transport-level `Exception` items reach `message_handler` only when the session builds its own dispatcher from streams, where it wires the - dispatcher's `on_stream_exception` itself. + dispatcher's `on_stream_exception` itself. Faults are delivered + concurrently in the session's task group, like notifications — never + inline in the read loop — so the handler may await session I/O, and one + that raises costs that delivery, not the connection. """ def __init__( @@ -174,7 +177,26 @@ def __init__( async def __aenter__(self) -> Self: self._task_group = anyio.create_task_group() await self._task_group.__aenter__() - await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + try: + await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) + except BaseException: + # A cancellation landing here (e.g. the caller wrapped connect in + # `move_on_after`) would abandon the entered task group, and anyio + # later raises "exited non-innermost cancel scope" instead of a + # clean timeout. Unwind the group before propagating; cancelling + # its scope first keeps __aexit__ from blocking under the + # still-active cancellation. + task_group = self._task_group + self._task_group = None + task_group.cancel_scope.cancel() + # Shield the group's own scope (not a new one: scope exits must + # stay LIFO) so a pending outer cancellation cannot re-fire + # inside __aexit__; the join is prompt because the scope is + # cancelled. The original exception then propagates from the + # `raise`; a child error supersedes it, raised by __aexit__. + task_group.cancel_scope.shield = True + await task_group.__aexit__(None, None, None) + raise return self async def __aexit__( @@ -209,8 +231,10 @@ async def send_request( Raises: MCPError: The server responded with an error, or the read timeout - elapsed, or the connection closed while waiting. - RuntimeError: Called before entering the context manager. + elapsed, or the connection closed while sending or waiting. + RuntimeError: Called before entering the context manager. Raised + by the stream-built dispatcher; a user-supplied `dispatcher=` + may not enforce this. """ data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] @@ -249,7 +273,8 @@ async def send_notification( ) -> None: """Send a one-way notification. Usable before entering the context manager.""" data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) - if related_request_id and isinstance(self._dispatcher, JSONRPCDispatcher): + # `is not None`, not truthiness: request ids are opaque and 0 is valid. + if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher): await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) else: await self._dispatcher.notify(data["method"], data.get("params")) @@ -561,5 +586,21 @@ async def _on_notify( await self._message_handler(notification) async def _on_stream_exception(self, exc: Exception) -> None: - """Forward transport-level faults (connection errors, parse errors) to message_handler.""" - await self._message_handler(exc) + """Spawn delivery of a transport-level fault (connection error, parse error) to message_handler. + + The dispatcher awaits this observer inline in its read loop, so the + handler must not run here: a slow handler would head-of-line block the + session, and one that awaits session I/O (e.g. sends a ping) would + deadlock against the parked loop. Spawn it instead, with the same + containment notification deliveries get. + """ + # The dispatcher only runs inside the task group entered in + # __aenter__, so the group is always live when it calls back here. + assert self._task_group is not None + self._task_group.start_soon(self._deliver_stream_exception, exc) + + async def _deliver_stream_exception(self, exc: Exception) -> None: + try: + await self._message_handler(exc) + except Exception: + logger.exception("message_handler raised on transport exception") diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 62f9eaa6f..8de41978e 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,18 +1,22 @@ from __future__ import annotations -from collections.abc import AsyncIterator -from contextlib import asynccontextmanager +from collections.abc import AsyncIterator, Mapping +from contextlib import AsyncExitStack, asynccontextmanager from typing import Any import anyio +import anyio.abc import anyio.streams.memory import pytest from mcp import types from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession from mcp.shared._context import RequestContext -from mcp.shared.message import SessionMessage +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder +from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( INVALID_PARAMS, @@ -779,7 +783,11 @@ async def test_receive_loop_drops_unknown_notification_method_without_response() @pytest.mark.anyio async def test_raising_sampling_callback_answers_with_code_zero(): - """A raising request callback is answered through the dispatcher's exception boundary.""" + """A raising sampling callback propagates out of the session's request router + into the dispatcher's exception boundary, which answers with code 0 and + `str(exc)` (SDK-defined shape, pinned at the dispatcher in + tests/shared/test_jsonrpc_dispatcher.py). Raw streams because the assertion + is the outbound `JSONRPCError` envelope itself.""" async def boom(ctx: object, params: object) -> types.CreateMessageResult: raise RuntimeError("sampling boom") @@ -799,7 +807,10 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult: @pytest.mark.anyio async def test_receive_loop_logs_and_drops_malformed_notification(caplog: pytest.LogCaptureFixture): - """A notification that fails ServerNotification validation is logged and dropped.""" + """A notification that fails `ServerNotification` validation is logged and + dropped without reaching `message_handler`, and the loop keeps serving + (SDK-defined). Scripted peer: the typed API cannot emit a method outside + the spec's notification union.""" seen: list[object] = [] delivered = anyio.Event() @@ -819,19 +830,66 @@ async def handler(msg: object) -> None: @pytest.mark.anyio -async def test_receive_loop_forwards_transport_exception_to_message_handler(): +async def test_raising_message_handler_on_transport_exception_costs_the_delivery_not_the_connection( + caplog: pytest.LogCaptureFixture, +): + """A transport-level `Exception` item on the read stream reaches + `message_handler` (SDK-defined: the stream-built session wires the + dispatcher's `on_stream_exception` to spawn handler deliveries), and a + handler that raises is contained by the session — the failure is logged + and the receive loop keeps serving, proven by a follow-up ping + round-trip. Raw streams because only a transport can put an `Exception` + item on the read stream.""" seen: list[object] = [] delivered = anyio.Event() async def handler(msg: object) -> None: seen.append(msg) delivered.set() + # No checkpoint between set() and the session's containment logging + # the raise, so once wait() resumes the log entry exists. + raise RuntimeError("handler boom") - async with raw_client_session(message_handler=handler) as (_session, to_client, _): + async with raw_client_session(message_handler=handler) as (_session, to_client, from_client): exc = ValueError("bad bytes") await to_client.send(exc) await delivered.wait() + # Loop health: a follow-up inbound ping is still answered. + await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=9, method="ping"))) + out = await from_client.receive() assert seen == [exc] + assert isinstance(out.message, JSONRPCResponse) + assert out.message.id == 9 + assert "message_handler raised on transport exception" in caplog.text + + +@pytest.mark.anyio +async def test_message_handler_awaiting_session_traffic_on_transport_exception_completes(): + """A message_handler that reacts to a transport-level `Exception` item by + awaiting session traffic (a ping round-trip) completes instead of + deadlocking: the session spawns fault deliveries into its task group + rather than running them inline in the dispatcher's read loop + (SDK-defined). Raw streams because only a transport can put an + `Exception` item on the read stream.""" + ponged = anyio.Event() + + # The constructor takes the handler, so it is defined before the session + # exists; `session` resolves at call time, after the `as` clause binds it. + async def handler(msg: object) -> None: + assert isinstance(msg, Exception) + await session.send_ping() + ponged.set() + + async with raw_client_session(message_handler=handler) as (session, to_client, from_client): + await to_client.send(ValueError("bad bytes")) + # Serve the handler's ping like a transport would. Pre-spawn this + # deadlocked: the read loop was parked inside the handler, so the + # response below could never be dequeued. + out = await from_client.receive() + assert isinstance(out.message, JSONRPCRequest) + assert out.message.method == "ping" + await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={}))) + await ponged.wait() @pytest.mark.anyio @@ -841,6 +899,8 @@ async def test_receive_loop_consumes_server_cancelled_without_reaching_message_h The server dispatcher now emits this on sampling/elicitation timeout, but ClientSession has no in-flight tracking to act on it, so surfacing it would only break user handlers that exhaustively match ServerNotification. + Scripted peer: the typed server API has no way to emit a bare + `notifications/cancelled`. """ seen: list[object] = [] delivered = anyio.Event() @@ -868,45 +928,59 @@ async def handler(msg: object) -> None: @pytest.mark.anyio -async def test_progress_callback_exception_is_swallowed(caplog: pytest.LogCaptureFixture): +async def test_progress_notification_reaches_request_callback_and_message_handler(): + """A `notifications/progress` for an in-flight request reaches the + `progress_callback` passed to `send_request` and still tees to + `message_handler` as a typed `ProgressNotification` — the wiring this layer + adds over the dispatcher (callback dispatch and exception containment are + pinned in tests/shared/test_jsonrpc_dispatcher.py). Scripted peer: the + progress token must echo the wire-level request id, which only raw-stream + observation exposes.""" + updates: list[tuple[float, float | None, str | None]] = [] + teed: list[types.ProgressNotification] = [] + request_id: types.RequestId | None = None + progressed = anyio.Event() delivered = anyio.Event() - async def boom(progress: float, total: float | None, message: str | None) -> None: - raise RuntimeError("progress boom") + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + updates.append((progress, total, message)) + progressed.set() async def handler(msg: object) -> None: # Only the progress notification is teed to the message handler here. assert isinstance(msg, types.ProgressNotification) + teed.append(msg) delivered.set() async with raw_client_session(message_handler=handler) as (session, to_client, from_client): async with anyio.create_task_group() as tg: async def call() -> None: - await session.send_request(types.PingRequest(), types.EmptyResult, progress_callback=boom) + await session.send_request(types.PingRequest(), types.EmptyResult, progress_callback=on_progress) tg.start_soon(call) request = await from_client.receive() assert isinstance(request.message, JSONRPCRequest) + request_id = request.message.id # The request id doubles as the progress token. - params = {"progressToken": request.message.id, "progress": 0.5} + params = {"progressToken": request_id, "progress": 0.5, "total": 1.0, "message": "halfway"} await to_client.send( SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) ) - # The progress notification also reaches the message handler; the - # raising callback was swallowed and logged. + await progressed.wait() await delivered.wait() - await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request.message.id, result={}))) - assert "progress callback raised" in caplog.text + await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + assert updates == [(0.5, 1.0, "halfway")] + assert request_id is not None + assert len(teed) == 1 + assert teed[0].params == types.ProgressNotificationParams( + progress_token=request_id, progress=0.5, total=1.0, message="halfway" + ) @pytest.mark.anyio async def test_dispatcher_keyword_runs_over_direct_dispatch(): """A session built with dispatcher= works without a stream pair (in-process embedding).""" - from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair - from mcp.shared.dispatcher import DispatchContext - from mcp.shared.transport_context import TransportContext - client_side, server_side = create_direct_dispatcher_pair() async def server_on_request( @@ -939,9 +1013,55 @@ async def server_on_notify( assert notified == ["notifications/roots/list_changed"] -def test_constructor_rejects_streams_and_dispatcher_together(): - from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +@pytest.mark.anyio +async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): + """`send_request` passes `cancel_on_abandon=False` to the dispatcher for + `initialize` — the spec forbids cancelling it — and leaves the option unset + for every other method. Both dispatcher abandon arms (timeout and caller + cancel) key off this one option; the suppression mechanics are pinned in + tests/shared/test_jsonrpc_dispatcher.py.""" + + class RecordingDispatcher: + """Records `send_raw_request` opts and answers with canned results.""" + + def __init__(self) -> None: + self.calls: list[tuple[str, CallOptions]] = [] + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + task_status.started() + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.calls.append((method, opts or {})) + if method == "initialize": + return InitializeResult( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities(), + server_info=Implementation(name="mock-server", version="0.1.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True) + return {} + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + pass + + dispatcher = RecordingDispatcher() + async with ClientSession(dispatcher=dispatcher) as session: + await session.initialize() + await session.send_ping() + opts_by_method = dict(dispatcher.calls) + assert opts_by_method["initialize"].get("cancel_on_abandon") is False + assert "cancel_on_abandon" not in opts_by_method["ping"] + +def test_constructor_rejects_streams_and_dispatcher_together(): client_side, _server_side = create_direct_dispatcher_pair() s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) with pytest.raises(ValueError, match="not both"): @@ -960,11 +1080,49 @@ def test_constructor_requires_both_streams_without_dispatcher(): s2c_recv.close() +@pytest.mark.anyio +async def test_aenter_cancelled_while_dispatcher_starts_unwinds_cleanly(): + """Cancellation landing while `__aenter__` waits for the dispatcher to + start (e.g. the caller wrapped connect in `move_on_after`) unwinds the + half-entered task group: the enclosing scope exits via its own deadline + instead of anyio's "exited non-innermost cancel scope" RuntimeError, and + the session is left un-entered (SDK-defined unwind path).""" + + class NeverStartsDispatcher: + """`run()` parks without ever signalling `task_status.started()`.""" + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + await anyio.sleep_forever() + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + session = ClientSession(dispatcher=NeverStartsDispatcher()) + async with AsyncExitStack() as stack: + # The deadline only ends the wait: `start()` is parked forever (and + # `TaskGroup.__aenter__` has no checkpoint to absorb the cancellation + # earlier), so any duration is non-racy. + with anyio.move_on_after(0.01) as scope: + await stack.enter_async_context(session) + assert scope.cancelled_caught + # The failed enter must not leave the session half-entered. + assert session._task_group is None + + @pytest.mark.anyio async def test_send_request_with_server_metadata_routes_related_request_id(): """ServerMessageMetadata.related_request_id is threaded onto the outgoing message.""" - from mcp.shared.message import ServerMessageMetadata - async with raw_client_session() as (session, to_client, from_client): async with anyio.create_task_group() as tg: @@ -984,8 +1142,6 @@ async def call() -> None: @pytest.mark.anyio async def test_send_notification_with_related_request_id_attaches_metadata(): """A related_request_id on a notification rides the originating request's stream.""" - from mcp.shared.message import ServerMessageMetadata - async with raw_client_session() as (session, _to_client, from_client): await session.send_notification( types.ProgressNotification( @@ -996,3 +1152,22 @@ async def test_send_notification_with_related_request_id_attaches_metadata(): out = await from_client.receive() assert isinstance(out.metadata, ServerMessageMetadata) assert out.metadata.related_request_id == 4 + + +@pytest.mark.anyio +async def test_send_notification_with_related_request_id_zero_attaches_metadata(): + """`related_request_id=0` still routes onto the originating request's + stream: request ids are opaque and 0 is a valid one, so the session must + check `is not None` rather than truthiness (regression pin for the + falsy-zero gap). Wire-level because the metadata attachment is only + observable on the sent `SessionMessage`.""" + async with raw_client_session() as (session, _to_client, from_client): + await session.send_notification( + types.ProgressNotification( + params=types.ProgressNotificationParams(progress_token=1, progress=0.5), + ), + related_request_id=0, + ) + out = await from_client.receive() + assert isinstance(out.metadata, ServerMessageMetadata) + assert out.metadata.related_request_id == 0 From 33a5482526ef87d41e2b9ef326adbaa0093838ed Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 22:30:25 +0000 Subject: [PATCH 12/24] Pin initialize-abandon suppression and tidy the interaction suite - Add an interaction test proving a timed-out initialize sends no notifications/cancelled, and drop the stale deferred rationale for the initialize-not-cancellable requirement - Record null-id error-response drops in the requirements manifest and describe the cancelled-request error response as applying to both seats - Record wire messages only after the inner send succeeds, so a failed or cancelled write is not logged as sent - Bound the remaining indefinite waits, refresh comments stale since the courtesy-cancel change, and assert the issue-88 timeout by error code - Rewrite the suite README's concurrency section for the dispatcher model --- tests/interaction/README.md | 23 +++---- tests/interaction/_helpers.py | 5 +- tests/interaction/_requirements.py | 29 +++++++-- .../interaction/lowlevel/test_cancellation.py | 61 ++++++++++++++++++- tests/interaction/lowlevel/test_timeouts.py | 7 ++- tests/issues/test_88_random_error.py | 11 +++- 6 files changed, 112 insertions(+), 24 deletions(-) diff --git a/tests/interaction/README.md b/tests/interaction/README.md index be68c3b0f..473e79c83 100644 --- a/tests/interaction/README.md +++ b/tests/interaction/README.md @@ -193,11 +193,13 @@ many requirements at once; if the assertions would be separate, write separate t ### Notifications and concurrency -The client's receive loop dispatches each incoming message to completion before reading the next, -and the in-memory transport delivers everything on one ordered stream. Together these guarantee -that every notification a server handler emits before its response reaches the client callback -before the originating request returns — so tests collect notifications into a plain list and -assert after the call, with no synchronisation. The exceptions: +The client's dispatcher starts a task per incoming notification in arrival order but does not +await it before reading the next message, so completion order is not structural. What still +holds: the in-memory transport delivers everything on one ordered stream, and a callback that +records synchronously (no `await` before the append) finishes its scheduling slice before the +awaited request's waiter — woken strictly later — resumes. So tests whose callbacks are plain +appends may still collect into a list and assert after the call. A callback that awaits before +recording loses that ordering and must synchronise. The other exceptions: - a notification not triggered by a request the test is awaiting needs an `anyio.Event` set in the receiving handler and awaited under `anyio.fail_after(5)`; @@ -220,9 +222,8 @@ but still inside an outer `async with`, and no restructure can avoid it. A handful of `# pragma: lax no cover` markers in `src/` cover teardown exception handlers whose execution is timing-dependent under the in-process HTTP bridge — the POST-stream and -stateless-session `except Exception` handlers in `server/streamable_http*.py`, the `_terminated` -check in `message_router`, and the response-stream double-close guard in -`BaseSession._receive_loop`. `strict-no-cover` does not check `lax` lines; do not promote them to -strict `no cover` without first making the teardown ordering deterministic. The suite also relies -on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that closes a stream the -SSE leg would otherwise leak. +stateless-session `except Exception` handlers in `server/streamable_http*.py` and the +`_terminated` check in `message_router`. `strict-no-cover` does not check `lax` lines; do not +promote them to strict `no cover` without first making the teardown ordering deterministic. The +suite also relies on a one-line `src/mcp/server/sse.py` fix (`sse_stream_reader.aclose()`) that +closes a stream the SSE leg would otherwise leak. diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py index 25833b0ca..0710ae937 100644 --- a/tests/interaction/_helpers.py +++ b/tests/interaction/_helpers.py @@ -67,8 +67,11 @@ def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage] self._log = log async def send(self, item: SessionMessage, /) -> None: - self._log.append(item) + # Record only after the inner send returns: a send that raises (or is cancelled + # mid-write) never reached the transport, so logging it first would fabricate a + # "sent" message in wire-level assertions. await self._inner.send(item) + self._log.append(item) async def aclose(self) -> None: await self._inner.aclose() diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 6bfa1dcbe..acaef072c 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -268,18 +268,15 @@ def __post_init__(self) -> None: divergence=Divergence( note=( "The spec says receivers of a cancellation SHOULD NOT send a response for the cancelled " - "request; the server sends an error response (code 0, 'Request cancelled'), which is what " - "unblocks the SDK client's pending call." + "request; both seats send an error response (code 0, 'Request cancelled') instead — the " + "server for cancelled client requests, and the client for cancelled server-initiated " + "requests — which is what unblocks the sender's pending call." ), ), ), "protocol:cancel:initialize-not-cancellable": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", behavior="The client never sends notifications/cancelled for the initialize request.", - deferred=( - "Not implemented in the SDK: the client has no public cancellation API at all, so no pathway " - "exists that could cancel initialize; there is no distinct behaviour to pin beyond that absence." - ), ), "protocol:cancel:late-response-ignored": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/cancellation#behavior-requirements", @@ -342,6 +339,26 @@ def __post_init__(self) -> None: source=f"{SPEC_BASE_URL}/basic#responses", behavior="A request whose method has no registered handler is answered with a METHOD_NOT_FOUND error.", ), + "protocol:error:null-id": Requirement( + source="sdk", + behavior=( + "An error response carrying a null id — the JSON-RPC shape for a peer reporting a failure it " + "could not attribute to a request, such as a parse error — is surfaced to the application " + "rather than silently discarded." + ), + divergence=Divergence( + note=( + "The dispatcher drops null-id error responses with a debug log; v1 surfaced them to " + "message_handler as an MCPError. A typed fault channel restoring visibility is planned " + "before v2 stable." + ), + ), + deferred=( + "Not yet covered here: the current drop is pinned at the dispatcher level by " + "tests/shared/test_jsonrpc_dispatcher.py; an interaction-level test waits on the planned " + "fault channel." + ), + ), "protocol:meta:related-task": Requirement( source=f"{SPEC_BASE_URL}/basic/utilities/tasks#related-task-metadata", behavior="Messages may carry related-task _meta associating them with a task.", diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 60ca80b41..03cf20a84 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -16,6 +16,7 @@ from mcp.shared.memory import MessageStream, create_client_server_memory_streams from mcp.shared.message import SessionMessage from mcp.types import ( + REQUEST_TIMEOUT, CallToolResult, EmptyResult, ErrorData, @@ -196,7 +197,8 @@ async def sample() -> None: raise NotImplementedError # unreachable: the scope is cancelled abandon_scope.start_soon(sample) - await callback_started.wait() + with anyio.fail_after(5): + await callback_started.wait() abandon_scope.cancel_scope.cancel() with anyio.fail_after(5): await callback_cancelled.wait() @@ -284,3 +286,60 @@ async def message_handler(message: IncomingMessage) -> None: assert pong == snapshot(EmptyResult()) # The fabricated response was dropped silently: the ping after it still # round-tripped, and the message handler (a tripwire) was never invoked. + + +@requirement("protocol:cancel:initialize-not-cancellable") +async def test_timed_out_initialize_sends_no_cancellation() -> None: + """An abandoned initialize is not followed by notifications/cancelled on the wire. + + Spec-mandated: the initialize request MUST NOT be cancelled. Abandoning any other request + sends a courtesy notifications/cancelled (see protocol:timeout:sends-cancellation); this + test pins that initialize opts out. A real Server always answers initialize, so the test + plays a stalling server by hand: it never answers initialize, the client's read timeout + fires, and the ping the test sends next is the marker — the in-memory stream is ordered and + a courtesy cancel goes out before the timeout error reaches the caller, so a regression + would put notifications/cancelled ahead of the ping in the recorded sequence. + """ + received_methods: list[str] = [] + + async def scripted_server(streams: MessageStream) -> None: + server_read, server_write = streams + + # Hold the initialize request unanswered until the client's read timeout fires. + init = await server_read.receive() + assert isinstance(init, SessionMessage) + assert isinstance(init.message, JSONRPCRequest) + received_methods.append(init.message.method) + + follow_up = await server_read.receive() + assert isinstance(follow_up, SessionMessage) + assert isinstance(follow_up.message, JSONRPCRequest) + received_methods.append(follow_up.message.method) + await server_write.send( + SessionMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=follow_up.message.id, + # Serialized exactly as a real server serializes results onto the wire. + result=EmptyResult().model_dump(by_alias=True, mode="json", exclude_none=True), + ) + ) + ) + + async with ( + create_client_server_memory_streams() as ((client_read, client_write), server_streams), + anyio.create_task_group() as task_group, + # The session-level read timeout is the only public pathway that abandons initialize; + # the response never arrives, so any positive value fires on the next event-loop pass. + ClientSession(client_read, client_write, read_timeout_seconds=0.000001) as session, + ): + task_group.start_soon(scripted_server, server_streams) + with anyio.fail_after(5): + with pytest.raises(MCPError) as exc_info: + await session.initialize() + assert exc_info.value.error.code == REQUEST_TIMEOUT + # Override the session-level timeout: this ping must round-trip normally. + pong = await session.send_request(PingRequest(), EmptyResult, request_read_timeout_seconds=5) + + assert pong == snapshot(EmptyResult()) + assert received_methods == snapshot(["initialize", "ping"]) diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index 62caf7e81..e79d3303b 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -107,7 +107,8 @@ async def sampling_callback( context: ClientRequestContext, params: types.CreateMessageRequestParams ) -> types.CreateMessageResult: callback_started.set() - await release.wait() + with anyio.fail_after(5): + await release.wait() return types.CreateMessageResult(role="assistant", content=TextContent(text="too late"), model="test-model") async with Client(recording, sampling_callback=sampling_callback) as client: @@ -146,7 +147,7 @@ async def list_tools( async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: if params.name == "echo": return CallToolResult(content=[TextContent(text="still alive")]) - await anyio.Event().wait() # blocks until the session is torn down + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it raise NotImplementedError # unreachable server = Server("blocker", on_list_tools=list_tools, on_call_tool=call_tool) @@ -178,7 +179,7 @@ async def test_session_level_timeout_applies_to_every_request() -> None: async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> CallToolResult: assert params.name == "block" - await anyio.Event().wait() # blocks until the session is torn down + await anyio.Event().wait() # blocks until the courtesy cancellation interrupts it raise NotImplementedError # unreachable server = Server("blocker", on_call_tool=call_tool) diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 84c16430f..c3f38b708 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -12,7 +12,14 @@ from mcp.server import Server, ServerRequestContext from mcp.shared.exceptions import MCPError from mcp.shared.message import SessionMessage -from mcp.types import CallToolRequestParams, CallToolResult, ListToolsResult, PaginatedRequestParams, TextContent +from mcp.types import ( + REQUEST_TIMEOUT, + CallToolRequestParams, + CallToolResult, + ListToolsResult, + PaginatedRequestParams, + TextContent, +) @pytest.mark.anyio @@ -97,7 +104,7 @@ async def client( # Use very small timeout to trigger quickly without waiting with pytest.raises(MCPError) as exc_info: await session.call_tool("slow", read_timeout_seconds=0.000001) # artificial timeout that always fails - assert "timed out" in str(exc_info.value) + assert exc_info.value.error.code == REQUEST_TIMEOUT # No-op if the courtesy cancellation already interrupted the handler. slow_request_lock.set() From 8786a52efe2b75104520a595ca434710d54e5d0e Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 22:30:35 +0000 Subject: [PATCH 13/24] Exercise the standalone SSE teardown window deterministically Drive the between-dequeues teardown path directly through the transport's ASGI entry point with a gated send, so the ClosedResourceError arm is covered by a real test and no longer needs its coverage pragma. The e2e teardown test's docstring now claims only what its assertion proves. --- src/mcp/server/streamable_http.py | 4 +- tests/shared/test_streamable_http.py | 108 +++++++++++++++++++++++++-- 2 files changed, 105 insertions(+), 7 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f269fc6c4..33570c02f 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -717,11 +717,11 @@ async def standalone_sse_writer(): # Send the message via SSE event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) - except anyio.ClosedResourceError: # pragma: lax no cover + except anyio.ClosedResourceError: # Teardown completed while the writer was between dequeues: # the next receive() hits the closed stream. A writer parked # in receive() instead sees a clean end-of-stream (cleanup - # closes the send side first), so this arm is timing-dependent. + # closes the send side first). pass except Exception: logger.exception("Error in standalone SSE writer") # pragma: no cover diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index c23ade097..fe6b98711 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -18,10 +18,12 @@ import anyio import httpx import pytest +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from httpx_sse import ServerSentEvent from starlette.applications import Starlette from starlette.requests import Request from starlette.routing import Mount +from starlette.types import Message, Scope from mcp import MCPError, types from mcp.client.session import ClientSession @@ -2231,11 +2233,10 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_ async def test_standalone_stream_teardown_mid_listen_is_not_an_error(caplog: pytest.LogCaptureFixture) -> None: """Tearing down the standalone stream under its parked writer produces no error log. - Cleanup closes the send side first, so a writer parked in receive() ends on a clean - end-of-stream. This pins that close ordering: reversing it would wake the parked writer - with ClosedResourceError on every disconnect. (The timing window where teardown lands - between dequeues is handled by the writer's ClosedResourceError arm, which cannot be - forced deterministically from the public surface.) + SDK-defined teardown behavior, driven through the full client/server path: the writer + is parked in receive() when teardown lands, and ends quietly. The companion test + test_standalone_stream_teardown_between_dequeues_is_not_an_error forces the other + teardown window, which this path cannot reach deterministically. """ session_manager = StreamableHTTPSessionManager( app=_create_server(), @@ -2267,3 +2268,100 @@ async def message_handler( (transport,) = session_manager._server_instances.values() # pyright: ignore[reportPrivateUsage] await transport._clean_up_memory_streams(GET_STREAM_KEY) # pyright: ignore[reportPrivateUsage] assert "Error in standalone SSE writer" not in caplog.text + + +@pytest.mark.anyio +async def test_standalone_stream_teardown_between_dequeues_is_not_an_error( + caplog: pytest.LogCaptureFixture, +) -> None: + """Teardown landing while the standalone writer is between dequeues produces no error log. + + SDK-defined: after teardown, the writer's next dequeue hits its own closed stream + (ClosedResourceError), which is expected disconnect noise, not an error. The public + surface cannot force this window (the in-process client consumes SSE without + backpressure, so the writer is always parked in receive() when teardown runs), so this + drives the transport's ASGI entry point directly with a gated `send`. + + Steps: + 1. A GET establishes the standalone SSE stream; the gated ASGI send keeps the + response from consuming any SSE data. + 2. An event sent into the standalone stream rendezvouses with the writer's receive(), + which then blocks forwarding it to the un-consumed SSE stream -- the + between-dequeues window. + 3. Stream cleanup runs inside that window, closing both standalone stream ends. + 4. The gate opens: the event reaches the wire, the writer's next dequeue hits the + closed stream, and the response completes cleanly with nothing logged as an error. + """ + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), + ) + # The GET handler only checks that a read-stream writer exists; the standalone + # writer never touches it. + read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) + transport._read_stream_writer = read_stream_writer # pyright: ignore[reportPrivateUsage] + + stream_registered = anyio.Event() + + class SignalingStreams( + dict[types.RequestId, tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]]] + ): + # Only the GET handler inserts here, so any insert is the standalone stream + # registration the test is waiting on. + def __setitem__( + self, + key: types.RequestId, + value: tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]], + ) -> None: + super().__setitem__(key, value) + stream_registered.set() + + transport._request_streams = SignalingStreams() # pyright: ignore[reportPrivateUsage] + + gate = anyio.Event() + sent: list[Message] = [] + + async def asgi_send(message: Message) -> None: + sent.append(message) + await gate.wait() + + # Never delivers anything: parks the response's disconnect listener until the + # completed response cancels it. + disconnect_send, disconnect_receive = anyio.create_memory_object_stream[Message](0) + + async def asgi_receive() -> Message: + return await disconnect_receive.receive() + + scope: Scope = { + "type": "http", + "method": "GET", + "path": "/mcp", + "query_string": b"", + "headers": [(b"accept", b"text/event-stream")], + } + notification = types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") + + async with read_stream_writer, read_stream, disconnect_send, disconnect_receive: + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(transport.handle_request, scope, asgi_receive, asgi_send) + await stream_registered.wait() + standalone_send = transport._request_streams[GET_STREAM_KEY][0] # pyright: ignore[reportPrivateUsage] + # Zero-buffer rendezvous: send() returns only once the writer's receive() + # has taken the event, so the writer is now between dequeues, blocked + # forwarding to the SSE stream nothing consumes while the gate is closed. + await standalone_send.send(EventMessage(notification)) + await transport._clean_up_memory_streams(GET_STREAM_KEY) # pyright: ignore[reportPrivateUsage] + # Unblock the response: it consumes the forwarded event, and the writer's + # next dequeue hits its closed stream. + gate.set() + + # The event dequeued before teardown still reached the wire, and the response + # ended with a normal completion rather than an exception. + assert sent[0]["type"] == "http.response.start" + assert sent[0]["status"] == 200 + body_chunks = [message for message in sent if message["type"] == "http.response.body"] + assert b"notifications/initialized" in body_chunks[0]["body"] + assert body_chunks[-1] == {"type": "http.response.body", "body": b"", "more_body": False} + assert "Error in standalone SSE writer" not in caplog.text + assert "Error in standalone SSE response" not in caplog.text From 0d4762f442b374d55839583a3c99cc7e14625ce8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 11 Jun 2026 22:30:45 +0000 Subject: [PATCH 14/24] Document the dispatcher behavior changes in the migration guide --- docs/migration.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 1b09b27c7..9dba4bcc8 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1170,14 +1170,22 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr `ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new keyword-only `dispatcher=` constructor argument accepts a pre-built dispatcher instead of the stream pair (for example a `DirectDispatcher` for in-process embedding). +Code that imported or subclassed `BaseSession` directly has no shim — the class is removed outright. The receive-loop engine it implemented now lives in `JSONRPCDispatcher` (`mcp.shared.jsonrpc_dispatcher`); to customize client behavior, use the `ClientSession` constructor callbacks, or supply your own engine through the `dispatcher=` keyword. + Behavior changes: - **Request ids count from 1** (previously 0). Progress tokens, which reuse the request id, shift the same way. Ids are opaque per JSON-RPC; do not assign meaning to them. - **Timeouts**: the error message is now `Request 'tools/call' timed out` (previously `Timed out while waiting for response to CallToolRequest. Waited N seconds.`), and a timed-out or abandoned request is followed by `notifications/cancelled` on the wire, so the server stops the handler instead of leaving it running. The `initialize` request is never cancelled this way, and requests sent with resumption metadata are also exempt so they stay resumable. +- **No cancellation for requests that never reached the wire.** A timed-out or caller-cancelled request whose initial write never completed is failed locally without `notifications/cancelled` — the peer never saw the id, so there is nothing to cancel. +- **The resumption exemption applies only when the hints reach the transport.** A request sent from inside a request callback carries stream-routing metadata that takes precedence, so its resumption hints are dropped — and an abandoned one gets the courtesy `notifications/cancelled` like any other request. - **Server-initiated requests run concurrently.** Sampling/elicitation/roots callbacks no longer serialize the receive loop: a slow callback does not block other traffic, a callback may itself send requests without deadlocking, and a server's `notifications/cancelled` now actually interrupts the callback (the request is then answered with an error response). -- **Notification callbacks are concurrent.** `logging_callback` and `message_handler` start in arrival order, but there is no completion-before-response guarantee (matching the TypeScript, C#, and Go SDKs). Callbacks that need strict sequencing must coordinate themselves. +- **Session shutdown answers in-flight server-initiated requests with `CONNECTION_CLOSED`** (-32000, `Connection closed`) instead of -32002. The write is bounded (about one second), so closing a session stays fast even when the transport has stopped accepting writes. +- **The `REQUEST_CANCELLED` constant is removed from `mcp.types`.** Its value (-32002) collided with the spec's resource-not-found error code, and the shutdown response above was its only use. +- **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` start in arrival order, but each delivery runs as its own task with no completion-before-response guarantee (matching the TypeScript, C#, and Go SDKs): deliveries may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves. +- **Transport-level `Exception` items are delivered concurrently too.** An `Exception` the transport places on the read stream is dispatched to `message_handler` as its own task, like notification callbacks, instead of blocking the receive loop — and a `message_handler` that raises on it is logged, not fatal to the session. - **Unknown-id responses are ignored**, as the spec asks. v1 surfaced them to `message_handler` as a `RuntimeError`; nothing is surfaced now. -- **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. +- **Error responses with a null `id`** — the JSON-RPC shape for a peer reporting a parse error — are now dropped with a debug log. v1 surfaced them to `message_handler` as an `MCPError`. +- **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending. - **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. `send_notification` before entry still works. `mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists. From 9d4e02e98dc2b0a71d25a8c29710e1442ce8a6a6 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 10:22:26 +0000 Subject: [PATCH 15/24] Tighten comments and docstrings across the dispatcher swap Cut the comment and docstring volume roughly in half: single-sentence docstring summaries, Raises sections kept but shortened, inline narration replaced by one-line statements of the non-inferable constraint, and development-artifact comments removed. No code changes. --- src/mcp/client/session.py | 90 ++--- src/mcp/server/streamable_http.py | 5 +- src/mcp/shared/dispatcher.py | 21 +- src/mcp/shared/jsonrpc_dispatcher.py | 374 +++++------------- src/mcp/shared/session.py | 15 +- tests/client/test_session.py | 95 ++--- tests/interaction/_helpers.py | 4 +- .../interaction/lowlevel/test_cancellation.py | 28 +- tests/interaction/lowlevel/test_logging.py | 13 +- tests/interaction/lowlevel/test_timeouts.py | 14 +- tests/issues/test_88_random_error.py | 3 +- tests/shared/test_jsonrpc_dispatcher.py | 282 ++++--------- tests/shared/test_streamable_http.py | 52 +-- 13 files changed, 257 insertions(+), 739 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 9b4494997..30fba3a92 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -120,19 +120,11 @@ class ClientSession: """Client half of an MCP connection, running on a `Dispatcher`. Construct it over a transport's stream pair (or pass a pre-built - `dispatcher=` instead, e.g. a `DirectDispatcher` for in-process - embedding), enter it as an async context manager, then call - `initialize()`. The receive loop, request correlation, and per-request - concurrency live in the dispatcher; this class owns the MCP type layer: - typed requests, the initialize handshake, and routing server-initiated - traffic to the constructor callbacks. - - Transport-level `Exception` items reach `message_handler` only when the - session builds its own dispatcher from streams, where it wires the - dispatcher's `on_stream_exception` itself. Faults are delivered - concurrently in the session's task group, like notifications — never - inline in the read loop — so the handler may await session I/O, and one - that raises costs that delivery, not the connection. + `dispatcher=`), enter as an async context manager, then call + `initialize()`. The dispatcher owns the receive loop and request + correlation; this class owns the typed MCP layer and the constructor + callbacks. Transport `Exception` items reach `message_handler` only when + the session builds its own dispatcher from a stream pair. """ def __init__( @@ -168,8 +160,7 @@ def __init__( else: if read_stream is None or write_stream is None: raise ValueError("read_stream and write_stream are required when no dispatcher is given") - # Built here (inert until run() starts in __aenter__) so notifications - # can be sent before entering the context manager, as before. + # Built eagerly so notifications can be sent before entering the context manager. self._dispatcher = JSONRPCDispatcher( read_stream, write_stream, on_stream_exception=self._on_stream_exception ) @@ -180,20 +171,14 @@ async def __aenter__(self) -> Self: try: await self._task_group.start(self._dispatcher.run, self._on_request, self._on_notify) except BaseException: - # A cancellation landing here (e.g. the caller wrapped connect in - # `move_on_after`) would abandon the entered task group, and anyio - # later raises "exited non-innermost cancel scope" instead of a - # clean timeout. Unwind the group before propagating; cancelling - # its scope first keeps __aexit__ from blocking under the - # still-active cancellation. + # Unwind the entered task group before propagating: a cancellation + # landing here (e.g. `move_on_after` around connect) would abandon + # it and anyio would later raise "exited non-innermost cancel scope". task_group = self._task_group self._task_group = None task_group.cancel_scope.cancel() - # Shield the group's own scope (not a new one: scope exits must - # stay LIFO) so a pending outer cancellation cannot re-fire - # inside __aexit__; the join is prompt because the scope is - # cancelled. The original exception then propagates from the - # `raise`; a child error supersedes it, raised by __aexit__. + # Shield the group's own scope (a new one would break LIFO exit) + # so a pending outer cancellation cannot re-fire inside __aexit__. task_group.cancel_scope.shield = True await task_group.__aexit__(None, None, None) raise @@ -205,8 +190,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: - # Exit must not block: cancel the dispatcher and any in-flight - # callbacks rather than waiting for them. + # Exit must not block: cancel the dispatcher and in-flight callbacks. assert self._task_group is not None self._task_group.cancel_scope.cancel() result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) @@ -223,18 +207,14 @@ async def send_request( ) -> ReceiveResultT: """Send a request and wait for its typed result. - A per-request read timeout takes precedence over the session-level - one. `metadata` carries transport hints: `ClientMessageMetadata` - resumption fields (streamable HTTP), or a - `ServerMessageMetadata.related_request_id` to route the message onto - an originating request's stream. + Args: + metadata: Transport hints: `ClientMessageMetadata` resumption fields + (streamable HTTP), or a `ServerMessageMetadata.related_request_id` + routing the message onto the originating request's stream. Raises: - MCPError: The server responded with an error, or the read timeout - elapsed, or the connection closed while sending or waiting. - RuntimeError: Called before entering the context manager. Raised - by the stream-built dispatcher; a user-supplied `dispatcher=` - may not enforce this. + MCPError: Error response, read timeout, or connection closed. + RuntimeError: Called before entering the context manager. """ data = request.model_dump(by_alias=True, mode="json", exclude_none=True) method: str = data["method"] @@ -253,12 +233,10 @@ async def send_request( elif isinstance(metadata, ServerMessageMetadata): related_request_id = metadata.related_request_id if method == "initialize": - # The spec forbids cancelling initialize; opt out of the - # dispatcher's courtesy cancel-on-abandon. + # The spec forbids cancelling initialize. opts["cancel_on_abandon"] = False if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher): - # Related-request routing is JSON-RPC stream plumbing; other - # dispatchers have no per-request streams to route onto. + # Only JSON-RPC dispatchers have per-request streams to route onto. raw = await self._dispatcher.send_raw_request( method, data.get("params"), opts, _related_request_id=related_request_id ) @@ -273,7 +251,7 @@ async def send_notification( ) -> None: """Send a one-way notification. Usable before entering the context manager.""" data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) - # `is not None`, not truthiness: request ids are opaque and 0 is valid. + # `is not None`: request ids are opaque and 0 is valid. if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher): await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) else: @@ -529,17 +507,8 @@ async def send_roots_list_changed(self) -> None: async def _on_request( self, dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None ) -> dict[str, Any]: - """Answer a server-initiated request via the registered callbacks. - - An unknown method raises `MCPError` (METHOD_NOT_FOUND), which the - dispatcher puts on the wire as-is; malformed params for a known method - raise `ValidationError`, which the dispatcher answers with - INVALID_PARAMS; an `ErrorData` returned by a callback becomes the - error response. - """ + """Answer a server-initiated request via the registered callbacks.""" if method not in _SERVER_REQUEST_METHODS: - # Unknown methods are METHOD_NOT_FOUND (-32601) per JSON-RPC 2.0, - # not validation failures (-32602). raise MCPError(code=types.METHOD_NOT_FOUND, message="Method not found", data=method) payload: dict[str, Any] = {"method": method} if params is not None: @@ -577,25 +546,18 @@ async def _on_notify( logger.warning("Failed to validate notification: %s", payload, exc_info=True) return if isinstance(notification, types.CancelledNotification): - # The dispatcher already applied the cancellation to the in-flight - # request; message_handler never sees it, so handlers matching - # exhaustively over ServerNotification need no arm for it. + # The dispatcher already applied the cancellation; not surfaced to message_handler. return if isinstance(notification, types.LoggingMessageNotification): await self._logging_callback(notification.params) await self._message_handler(notification) async def _on_stream_exception(self, exc: Exception) -> None: - """Spawn delivery of a transport-level fault (connection error, parse error) to message_handler. + """Deliver a transport-level fault to message_handler via a spawned task. - The dispatcher awaits this observer inline in its read loop, so the - handler must not run here: a slow handler would head-of-line block the - session, and one that awaits session I/O (e.g. sends a ping) would - deadlock against the parked loop. Spawn it instead, with the same - containment notification deliveries get. + Running the handler inline would park the dispatcher's read loop and + deadlock handlers that await session I/O. """ - # The dispatcher only runs inside the task group entered in - # __aenter__, so the group is always live when it calls back here. assert self._task_group is not None self._task_group.start_soon(self._deliver_stream_exception, exc) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 33570c02f..93904d6cc 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -718,10 +718,7 @@ async def standalone_sse_writer(): event_data = self._create_event_data(event_message) await sse_stream_writer.send(event_data) except anyio.ClosedResourceError: - # Teardown completed while the writer was between dequeues: - # the next receive() hits the closed stream. A writer parked - # in receive() instead sees a clean end-of-stream (cleanup - # closes the send side first). + # Session teardown can close the stream while the writer is between dequeues. pass except Exception: logger.exception("Error in standalone SSE writer") # pragma: no cover diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py index 820422e6d..888e55ba3 100644 --- a/src/mcp/shared/dispatcher.py +++ b/src/mcp/shared/dispatcher.py @@ -56,16 +56,10 @@ class CallOptions(TypedDict, total=False): """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" cancel_on_abandon: bool - """Whether abandoning this request sends `notifications/cancelled` to the peer. - - A request is abandoned when its `timeout` elapses or the caller's scope is - cancelled while awaiting the response. Defaults to `True`. Set `False` for - requests the protocol forbids cancelling, such as `initialize`. The - notification is also suppressed when resumption hints actually reach the - transport (the caller intends to resume the request, so the peer's work - must keep running); hints ignored in favor of dispatch-context routing do - not suppress it. No notification is sent for a request that was never - written to the transport. + """Whether abandoning this request (timeout or caller cancellation) sends `notifications/cancelled`. + + Defaults to `True`. Set `False` for requests the protocol forbids cancelling, such as `initialize`. + Also suppressed when resumption hints reach the transport, or when the request was never written. """ on_progress: ProgressFnT @@ -110,9 +104,6 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a request and await its raw result dict. - `opts` carries per-call `timeout` / `on_progress` / abandon-cancellation - / resumption hints; see `CallOptions`. - Raises: MCPError: If the peer responded with an error, or the handler raised. Implementations normalize all handler exceptions to @@ -201,9 +192,7 @@ class Dispatcher(Outbound, Protocol[TransportT_co]): Implementations own correlation of outbound requests to inbound results, the receive loop, per-request concurrency, and cancellation/progress wiring. - The protocol's lifecycle surface is provisional and expected to change - before v2 stable (`run()` may be superseded by an `open()`/`wait_closed()` - pair). + The lifecycle surface is provisional; `run()` may change before v2 stable. """ async def run( diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index ef09b870d..00eb50108 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -1,21 +1,8 @@ -"""JSON-RPC `Dispatcher` implementation. - -Consumes the existing `SessionMessage`-based stream contract that all current -transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, -the receive loop, per-request task isolation, cancellation/progress wiring, and -the single exception-to-wire boundary. - -The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and -sees only `(ctx, method, params) -> dict`. Transports sit below and see only -`SessionMessage` reads/writes. - -The dispatcher is *mostly* MCP-agnostic - methods/params are opaque strings and -dicts - but it intercepts `notifications/cancelled` and -`notifications/progress` because request correlation, cancellation and -progress are exactly the wiring this layer exists to provide. Those few wire -shapes are extracted with structural `match` patterns (no casts, no -`mcp.types` model coupling); a malformed payload simply fails to match and -the correlation is skipped. +"""JSON-RPC `Dispatcher` over the `SessionMessage` stream contract all transports speak. + +Owns request-id correlation, the receive loop, per-request task isolation, +cancellation/progress wiring, and the single exception-to-wire boundary; +methods and params are otherwise opaque strings and dicts. """ from __future__ import annotations @@ -65,41 +52,21 @@ logger = logging.getLogger(__name__) _SHIELDED_WRITE_TIMEOUT: float = 5 -"""Bound for the courtesy writes on the timeout and cancellation paths. - -The cancellation-path writes run inside a shield because the surrounding -scope is already cancelled; without a bound, a wedged transport write would -turn the shield into an uncancellable hang (and block shutdown indefinitely). -The timeout-path courtesy cancel is unshielded (its scope is not cancelled) -but shares the bound so a wedged transport can't delay the timeout error -indefinitely.""" +"""Bound for courtesy abandon-path writes; without it a wedged transport +would turn the shielded write into an uncancellable hang.""" _SHUTDOWN_WRITE_TIMEOUT: float = 1 -"""Bound for the shutdown-arm error response write in `_handle_request`. - -Tighter than `_SHIELDED_WRITE_TIMEOUT` because session close must be quick: -the write is a courtesy answer to a request the shutdown is abandoning, so a -wedged transport may delay close by at most ~1s rather than holding teardown -for the full courtesy-cancel bound.""" +"""Tighter bound for the shutdown-arm error write so a wedged transport can't hold session close.""" TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext) PeerCancelMode = Literal["interrupt", "signal"] -"""How inbound `notifications/cancelled` is applied to a running handler. - -`"interrupt"` (default) cancels the handler's scope. `"signal"` only sets -`ctx.cancel_requested` and lets the handler observe it cooperatively. -""" +"""How `notifications/cancelled` is applied: `"interrupt"` (default) cancels +the handler's scope; `"signal"` only sets `ctx.cancel_requested`.""" def _coerce_id(request_id: RequestId) -> RequestId: - """Coerce a string request ID to int when it's a valid int literal. - - `_allocate_id` only ever produces `int` keys for `_pending`, but a peer - may echo the ID back as a JSON string. The TypeScript SDK performs this - coercion at lookup time (as v1's `BaseSession` did) so the response still - correlates. - """ + """Coerce a stringified int request ID back to int so a peer-echoed ID still correlates (matches the TS SDK).""" if isinstance(request_id, str): try: return int(request_id) @@ -133,12 +100,7 @@ class _JSONRPCDispatchContext(Generic[TransportT]): _dispatcher: JSONRPCDispatcher[TransportT] _request_id: RequestId | None message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework - """The transport-attached `SessionMessage.metadata` for this inbound message. - - Carries `ServerMessageMetadata` (HTTP request, SSE stream-close callbacks) - that the server lifts onto its request context. `None` for transports - that attach nothing. - """ + """Transport-attached `SessionMessage.metadata` that the server lifts onto its request context.""" _progress_token: ProgressToken | None = None _closed: bool = False cancel_requested: anyio.Event = field(default_factory=anyio.Event) @@ -186,13 +148,7 @@ def _default_transport_builder(_meta: MessageMetadata) -> TransportContext: def _shielded_progress(fn: ProgressFnT) -> ProgressFnT: - """Wrap a user progress callback so it can't crash the dispatcher. - - The callback runs as a bare task in the dispatcher's task group; an - uncaught exception would cancel every sibling (the read loop and all - in-flight requests). Swallow and log instead, matching the previous - receive-loop's behavior. - """ + """Wrap a user progress callback so an exception can't cancel the dispatcher's task group.""" async def _wrapped(progress: float, total: float | None, message: str | None) -> None: try: @@ -217,33 +173,18 @@ async def _wrapped(dctx: DispatchContext[TransportContext], method: str, params: @dataclass(slots=True, frozen=True) class _OutboundPlan: - """One decision about an outgoing message: what reaches the transport, and - whether abandoning the request sends a courtesy `notifications/cancelled`.""" + """Outbound metadata plus whether abandoning the request sends a courtesy `notifications/cancelled`.""" metadata: MessageMetadata cancel_on_abandon: bool def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | None) -> _OutboundPlan: - """Choose the `SessionMessage.metadata` for an outgoing request/notification - and the matching abandon-cancellation policy. - - `ServerMessageMetadata` tags a server-to-client message with the inbound - request it belongs to (so streamable-HTTP can route it onto that request's - SSE stream). `ClientMessageMetadata` carries resumption hints to the - client transport. `None` is the common case. - - `SessionMessage.metadata` carries exactly one of these, so when - `related_request_id` is set it takes precedence and any resumption hints - in `opts` are dropped (with a debug log): requests made from a dispatch - context are routed onto the inbound request's stream, not resumed. - - The same decision fixes `cancel_on_abandon`: an abandoned request sends a - courtesy `notifications/cancelled` unless the caller opted out, or the - resumption hints actually reach the transport (the caller intends to - resume, so the peer's work must keep running). Hints dropped here do NOT - suppress the cancel - a request that is neither resumable nor cancelled - would leak the peer's work. + """Choose the outbound `SessionMessage.metadata` and the abandon-cancellation policy. + + `related_request_id` wins over resumption hints (they are dropped). Only + hints that actually reach the transport suppress the courtesy cancel - a + request that is neither resumable nor cancelled would leak the peer's work. """ opts = opts or {} cancel_on_abandon = opts.get("cancel_on_abandon", True) @@ -264,10 +205,9 @@ def _plan_outbound(related_request_id: RequestId | None, opts: CallOptions | Non class JSONRPCDispatcher(Dispatcher[TransportT]): - """`Dispatcher` over the existing `SessionMessage` stream contract. + """`Dispatcher` over the `SessionMessage` stream contract. - Inherits the `Dispatcher` Protocol explicitly so pyright checks - conformance at the class definition rather than at first use. + Explicit Protocol base so pyright checks conformance at the class definition. """ def __init__( @@ -284,37 +224,20 @@ def __init__( """Wire a dispatcher over a transport's `SessionMessage` stream pair. Args: - read_stream: Inbound messages from the peer; `Exception` items are - transport-level read faults (see `on_stream_exception`). - write_stream: Outbound messages to the peer. - transport_builder: Builds the per-message `TransportContext` from - the inbound `SessionMessage.metadata`. Defaults to a plain - always-routable JSON-RPC context. - peer_cancel_mode: How inbound `notifications/cancelled` is applied - to a running handler; see `PeerCancelMode`. + transport_builder: Builds each message's `TransportContext` from + its `SessionMessage.metadata`. raise_handler_exceptions: Re-raise handler exceptions out of - `run()` after the error response is written, instead of - containing them at the exception-to-wire boundary. - inline_methods: Request methods handled inline in the read loop - (awaited before the next message is dequeued) instead of - spawned concurrently. Use for methods whose side effects must - be observable to the next message, e.g. `initialize`, so a - pipelined follow-up sees the initialized state. Only suitable - for handlers that complete quickly, since inline handling - blocks dequeuing; a handler that awaits the peer - (`send_raw_request`) while inline will deadlock because the - parked read loop cannot dequeue the response. - on_stream_exception: Observer for `Exception` items the transport - yields on the read stream (SSE/streamable-HTTP connection - faults, stdio parse errors). Without it they are debug-logged - and dropped. Awaited in the read loop and contained: a raising - observer costs the item, not the connection. + `run()` after the error response is written. + inline_methods: Methods awaited in the read loop before the next + message is dequeued (e.g. `initialize`); an inline handler + that awaits the peer deadlocks the parked loop. + on_stream_exception: Observer for `Exception` items on the read + stream; without it they are debug-logged and dropped. """ self._read_stream = read_stream self._write_stream = write_stream - # When `transport_builder` is omitted, `TransportT` falls back to its - # default (`TransportContext`), so the default builder is type-correct; - # pyright can't connect the two, hence the cast. + # With transport_builder omitted, TransportT defaults to + # TransportContext; pyright can't connect the two, hence the cast. self._transport_builder = cast( "Callable[[MessageMetadata], TransportT]", transport_builder or _default_transport_builder, @@ -340,19 +263,14 @@ async def send_raw_request( ) -> dict[str, Any]: """Send a JSON-RPC request and await its response. - `_related_request_id` is set only by `_JSONRPCDispatchContext` when a - handler makes a server-to-client request mid-flight; it routes the - outgoing message onto the correct per-request SSE stream (SHTTP) via - `ServerMessageMetadata`. Top-level callers leave it `None`. + `_related_request_id` is set only by `_JSONRPCDispatchContext` so that + mid-handler requests route onto the inbound request's SSE stream. Raises: - MCPError: The peer responded with a JSON-RPC error; or - `REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; or - `CONNECTION_CLOSED` if the transport closed before the request - could be written, or the dispatcher shut down while awaiting - the response. - RuntimeError: Called before `run()` has started or after it has - finished. + MCPError: Peer error response; `REQUEST_TIMEOUT` if + `opts["timeout"]` elapsed; `CONNECTION_CLOSED` if the + transport closed or the dispatcher shut down. + RuntimeError: Called outside `run()`. """ if not self._running: raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") @@ -362,66 +280,45 @@ async def send_raw_request( out_meta = dict(out_params.get("_meta") or {}) on_progress = opts.get("on_progress") if on_progress is not None: - # The caller wants progress updates. The spec mechanism is: include - # `_meta.progressToken` on the request; the peer echoes that token on - # any `notifications/progress` it sends. We use the request id as the - # token so the receive loop can find this `_Pending.on_progress` by - # `_pending[token]` without a second lookup table. + # The request id doubles as the progress token, so `_pending[token]` finds `on_progress` directly. out_meta["progressToken"] = request_id out_params["_meta"] = out_meta - # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from - # `_resolve_pending`/`_fan_out_closed` means the waiter already has an - # outcome and dropping the late/redundant signal is correct. buffer=0 - # is unsafe - there's a window between registering `_pending[id]` and - # parking in `receive()` where a close signal would be lost. + # buffer=1: a close signal can arrive before the waiter parks in receive(); + # a WouldBlock later just means the waiter already has its one outcome. send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) pending = _Pending(send=send, receive=receive, on_progress=on_progress) self._pending[request_id] = pending - # One decision covers both what metadata reaches the transport and - # whether abandoning this request (timeout elapsed, or the caller's - # scope cancelled while awaiting the response) sends a courtesy - # `notifications/cancelled`; see `_plan_outbound`. plan = _plan_outbound(_related_request_id, opts) # Spec MUST: only previously-issued requests may be cancelled, so the - # courtesy cancel is armed only once the request write completes - a - # caller cancelled mid-write must not announce a cancel for a request - # the peer never received. + # courtesy cancel arms only once the request write completes. request_written = False target = out_params.get("name") span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}" - # TODO(maxisbey): the otel span + inject below mirror - # BaseSession.send_request for parity. They belong in an outbound - # middleware (symmetric with otel_middleware on the inbound side) once - # that seam exists; the dispatcher should not own otel. + # TODO(maxisbey): move the otel span + inject into an outbound + # middleware once that seam exists; the dispatcher should not own otel. try: with otel_span( span_name, kind=SpanKind.CLIENT, attributes={"mcp.method.name": method, "jsonrpc.request.id": str(request_id)}, ): - # Inject W3C trace context into _meta (SEP-414). With a no-op - # tracer this writes nothing, but `_meta` itself is still - # present on the wire (and the interaction suite pins that). + # SEP-414: inject W3C trace context; `_meta` stays on the wire even with a no-op tracer. inject_trace_context(out_meta) msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) try: await self._write(msg, plan.metadata) except (anyio.BrokenResourceError, anyio.ClosedResourceError): - # The transport tore down before run() noticed EOF; surface - # the documented contract, not the raw stream error. + # Transport tore down before run() noticed EOF; surface the documented contract. raise MCPError(code=CONNECTION_CLOSED, message="Connection closed") from None request_written = True with anyio.fail_after(opts.get("timeout")): outcome = await receive.receive() except TimeoutError: - # Spec-recommended courtesy: tell the peer we've given up so it can - # stop work and free resources. v1's BaseSession.send_request does - # NOT do this; it's new behaviour. Unshielded: this scope is not - # cancelled, so an outer caller cancellation must still be able to - # interrupt the write. + # Courtesy cancel (spec-recommended, new vs v1) so the peer stops work; + # unshielded so an outer caller cancellation can still interrupt the write. if plan.cancel_on_abandon and request_written: await self._final_write( partial( @@ -436,10 +333,8 @@ async def send_raw_request( ) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None except anyio.get_cancelled_exc_class(): - # Our caller's scope was cancelled. We're already inside a cancelled - # scope, so any bare `await` here re-raises immediately - the - # shielded (bounded) helper lets the courtesy cancel go out before - # we propagate. + # Caller cancelled: bare awaits re-raise here, so the shielded helper + # lets the courtesy cancel go out before we propagate. if plan.cancel_on_abandon and request_written: await self._final_write( partial(self._cancel_outbound, request_id, "caller cancelled", _related_request_id), @@ -449,9 +344,7 @@ async def send_raw_request( ) raise finally: - # Always remove the waiter, even on cancel/timeout, so a late - # response from the peer (race) hits a closed stream and is dropped - # in `_dispatch` rather than leaking. + # Remove the waiter on every path so a late response is dropped, not leaked. self._pending.pop(request_id, None) send.close() receive.close() @@ -467,10 +360,8 @@ async def notify( *, _related_request_id: RequestId | None = None, ) -> None: - # Leave `params` unset (not explicitly None) when there are none: - # transports serialize with `exclude_unset=True`, and an explicit None - # would survive as `"params": null`, which JSON-RPC 2.0 forbids and - # strict peers (e.g. the TypeScript SDK's zod schemas) reject. + # Leave `params` unset when None: with `exclude_unset=True` an explicit + # None would serialize as `"params": null`, which JSON-RPC 2.0 forbids. if params is not None: msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params)) else: @@ -486,10 +377,7 @@ async def run( ) -> None: """Drive the receive loop until the read stream closes. - Each inbound request is handled in its own task in an internal task - group; `task_status.started()` fires once that group is open, so - `await tg.start(dispatcher.run, ...)` resumes when `send_raw_request` - is usable. + `task_status.started()` fires once `send_raw_request` is usable. """ try: async with anyio.create_task_group() as tg: @@ -500,33 +388,24 @@ async def run( async with self._read_stream, self._write_stream: try: async for item in self._read_stream: - # Duck-typed: `_context_streams.ContextReceiveStream` - # exposes `.last_context` (the sender's contextvars - # snapshot per message). Plain memory streams don't. + # Duck-typed: only `ContextReceiveStream` carries the + # sender's per-message contextvars snapshot. sender_ctx: contextvars.Context | None = getattr( self._read_stream, "last_context", None ) await self._dispatch(item, on_request, on_notify, sender_ctx) except anyio.ClosedResourceError: - # The transport closed our receive end and we looped - # back to `__anext__` on the now-closed stream - # (stateless SHTTP teardown). Same as EOF. + # Receive end closed under us (stateless SHTTP teardown); same as EOF. logger.debug("read stream closed by transport; treating as EOF") - # Read stream EOF: wake any blocked `send_raw_request` waiters - # (callers outside this task group) with CONNECTION_CLOSED. + # EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED. self._running = False self._fan_out_closed() finally: - # Transport closed: cancel in-flight handlers. Without this - # the task-group join waits for them, and a handler that - # outlives its caller (its request timed out client-side, or - # the client disconnected mid-call) would keep `run()` from - # returning forever. Same behaviour as `Server.run()` before - # the dispatcher rework. + # Cancel in-flight handlers; otherwise the task-group join + # waits on handlers whose callers are already gone. tg.cancel_scope.cancel() finally: - # Covers the cancel/crash paths where the inline fan-out above is - # never reached. Idempotent. + # Covers cancel/crash paths that skip the inline fan-out; idempotent. self._running = False self._tg = None self._fan_out_closed() @@ -540,10 +419,8 @@ async def _dispatch( ) -> None: """Route one inbound item. - Everything here is `send_nowait` or `_spawn`; the only `await`s are - `inline_methods` requests and the `on_stream_exception` observer, - which deliberately block dequeuing until handled. Any other `await` - would let one slow message head-of-line block the entire read loop. + Only `inline_methods` requests and the `on_stream_exception` observer + are awaited; any other `await` would head-of-line block the read loop. """ if isinstance(item, Exception): if self._on_stream_exception is None: @@ -564,9 +441,7 @@ async def _dispatch( case JSONRPCResponse(): self._resolve_pending(msg.id, msg.result) case JSONRPCError(): # pragma: no branch - # `id` may be None per JSON-RPC (parse error before id known). - # The match is exhaustive over JSONRPCMessage; the no-match arc - # on this final case is unreachable. + # Exhaustive over JSONRPCMessage, so the no-match arc is unreachable. self._resolve_pending(msg.id, msg.error) async def _dispatch_request( @@ -578,8 +453,7 @@ async def _dispatch_request( ) -> None: progress_token: ProgressToken | None match req.params: - # The bool guard matters: `int()` patterns match bool (a subclass), - # and `True == 1` would alias dict lookups to request id 1. + # bool subclasses int: without the guard True would alias request id 1. case {"_meta": {"progressToken": str() | int() as progress_token}} if not isinstance(progress_token, bool): pass case _: @@ -587,9 +461,7 @@ async def _dispatch_request( try: transport_ctx = self._transport_builder(metadata) except Exception: - # Containment boundary for the user-supplied builder: a raising - # builder must cost only this message, not the whole connection - # (the exception would otherwise escape into run()'s read loop). + # A raising builder must cost only this message, not the connection. logger.exception("transport_builder raised; rejecting request %r", req.id) self._spawn( self._write_error, @@ -606,20 +478,13 @@ async def _dispatch_request( _progress_token=progress_token, ) scope = anyio.CancelScope() - # TODO(maxisbey): the spec puts request-id uniqueness on the sender; - # neither v1 nor the TS SDK guards a duplicate id here, so for now we - # blind-overwrite (parity). Revisit rejecting with INVALID_REQUEST. - # Coerced key so `notifications/cancelled` correlates regardless of - # whether the peer stringifies the id between request and cancel - # (`_dispatch_notification` coerces at lookup; responses still echo - # `req.id` verbatim). + # TODO(maxisbey): duplicate ids blind-overwrite (v1/TS parity); revisit + # rejecting with INVALID_REQUEST. Key coerced so a stringified + # `notifications/cancelled` id still correlates. self._in_flight[_coerce_id(req.id)] = _InFlight(scope=scope, dctx=dctx) if req.method in self._inline_methods: - # Spawn (so `sender_ctx` applies, matching the concurrent path) but - # park the read loop until the handler returns; that's the inline - # ordering guarantee. Because the read loop is parked, a handler - # that awaits the peer here (e.g. `dctx.send_raw_request`) will - # deadlock: the response can never be dequeued. + # Spawn so `sender_ctx` applies, but park the read loop until the + # handler returns - that's the inline ordering guarantee. done = anyio.Event() async def _run_inline() -> None: @@ -643,17 +508,12 @@ def _dispatch_notification( """Route one inbound notification. `notifications/cancelled` and `notifications/progress` are intercepted - here because they correlate against JSON-RPC request IDs - the - `_in_flight` / `_pending` tables this layer owns - so no higher layer - can act on them. Both are still teed to `on_notify` afterwards, so - middleware and registered notification handlers observe every inbound - notification. See the module docstring for the design rationale. + here (they correlate against the `_in_flight`/`_pending` tables this + layer owns) and still teed to `on_notify` afterwards. """ if msg.method == "notifications/cancelled": match msg.params: - # The bool guards here and below matter: `int()` patterns match - # bool (a subclass), and `True == 1` would alias the dict lookup - # to the entry keyed by request id 1. + # bool subclasses int: the guards keep True from aliasing request id 1. case {"requestId": str() | int() as rid} if ( not isinstance(rid, bool) and (in_flight := self._in_flight.get(_coerce_id(rid))) is not None ): @@ -662,9 +522,6 @@ def _dispatch_notification( in_flight.scope.cancel() case _: pass - # fall through: cancelled is also teed to on_notify so middleware - # and registered handlers can observe it (matches DirectDispatcher, - # which forwards every notification). elif msg.method == "notifications/progress": match msg.params: case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( @@ -684,12 +541,10 @@ def _dispatch_notification( ) case _: pass - # fall through: progress is also teed to on_notify try: transport_ctx = self._transport_builder(metadata) except Exception: - # Same containment boundary as `_dispatch_request`: a raising - # builder drops this notification instead of killing the read loop. + # Same containment as `_dispatch_request`: drop the notification, keep the loop. logger.exception("transport_builder raised; dropping notification %r", msg.method) return dctx = _JSONRPCDispatchContext( @@ -715,10 +570,8 @@ def _spawn( ) -> None: """Schedule `fn(*args)` in the run() task group, propagating the sender's contextvars. - ASGI middleware (auth, OTel) sets contextvars on the request task that - wrote into the read stream. `Context.run(tg.start_soon, ...)` makes - the spawned handler inherit *that* context instead of the receive - loop's, so `auth_context_var` and OTel spans survive. + ASGI middleware (auth, OTel) sets contextvars on the task that wrote the + message; `Context.run` makes the spawned handler inherit that context. """ assert self._tg is not None if sender_ctx is not None: @@ -729,8 +582,7 @@ def _spawn( def _fan_out_closed(self) -> None: """Wake every pending `send_raw_request` waiter with `CONNECTION_CLOSED`. - Synchronous (uses `send_nowait`) because it's called from `finally` - which may be inside a cancelled scope. Idempotent. + Synchronous: callers may be inside a cancelled scope. Idempotent. """ closed = ErrorData(code=CONNECTION_CLOSED, message="Connection closed") for pending in self._pending.values(): @@ -749,47 +601,31 @@ async def _handle_request( ) -> None: """Run `on_request` for one inbound request and write its response. - This is the single exception-to-wire boundary: handler exceptions are - caught here and serialized to `JSONRPCError`. Nothing above this in - the stack constructs wire errors. + The single exception-to-wire boundary: handler exceptions become `JSONRPCError` here. """ try: with scope: try: result = await on_request(dctx, req.method, req.params) finally: - # Handler done: close the back-channel (detached work that - # later calls `dctx.send_raw_request()` should see - # `NoBackChannelError`) and drop from `_in_flight` so a - # late `notifications/cancelled` is a no-op rather than - # racing the result write below. Identity-guarded: a - # duplicate inbound id blind-overwrites the table entry - # (see `_dispatch_request`), and this pop must not evict - # the newer request's entry - that would leave it - # peer-uncancellable. No checkpoint between handler return - # and the pop, so the cancel can't interleave there. + # Close the back-channel and drop from `_in_flight`; no checkpoint + # since handler return, so a peer cancel can't interleave. + # Identity guard: don't evict a duplicate id's newer entry. dctx.close() key = _coerce_id(req.id) if (entry := self._in_flight.get(key)) is not None and entry.dctx is dctx: del self._in_flight[key] await self._write_result(req.id, result) if scope.cancelled_caught: - # Peer-cancel: `_dispatch_notification` cancelled this scope - # and the cancellation was actually absorbed at __exit__, i.e. - # the result write above did not happen. (`cancel_called` - # alone is not enough: a cancel that lands after the handler's - # last checkpoint is never delivered, the handler completes - # and the result write can succeed - answering again here - # would put a second response for an already-answered id on - # the wire when the write stream doesn't checkpoint.) - # TODO(maxisbey): spec says SHOULD NOT respond after cancel. - # The existing server always has, so match that for now. + # anyio absorbs the scope's own cancel at __exit__, and + # `cancelled_caught` (unlike `cancel_called`) guarantees the + # result write above did not happen - no double response. + # TODO(maxisbey): spec says SHOULD NOT respond after cancel; + # the existing server always has, so match that for now. await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) except anyio.get_cancelled_exc_class(): - # Outer-cancel: run()'s task group is shutting down. Answer the - # request so the peer is not left waiting on a connection that is - # going away; the helper shields (bounded) the write because any - # bare `await` here re-raises immediately. + # Shutdown: answer the request so the peer isn't left waiting; the + # shielded helper is needed because bare awaits re-raise here. await self._final_write( partial(self._write_error, req.id, ErrorData(code=CONNECTION_CLOSED, message="Connection closed")), shield=True, @@ -800,25 +636,19 @@ async def _handle_request( except MCPError as e: await self._write_error(req.id, e.error) except ValidationError: - # TODO(maxisbey): data="" is pinned compat with the existing - # server (which never leaked pydantic error text onto the wire). - # Consider putting the validation detail in `data` once the - # interaction suite's divergence entry is resolved. + # TODO(maxisbey): data="" pins existing-server compat (no pydantic + # text on the wire); revisit per the suite's divergence entry. await self._write_error( req.id, ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") ) except Exception as e: logger.exception("handler for %r raised", req.method) - # TODO(maxisbey): code=0 is pinned compat with the existing - # server's `_handle_request`. JSON-RPC says INTERNAL_ERROR - # (-32603); revisit once the suite's divergence entry is resolved. + # TODO(maxisbey): code=0 pins existing-server compat; JSON-RPC says + # INTERNAL_ERROR. Revisit per the suite's divergence entry. await self._write_error(req.id, ErrorData(code=0, message=str(e))) if self._raise_handler_exceptions: raise - # No outer `_in_flight` pop here: the inner `finally` above already - # removes the entry on every path out of the handler, and a second - # pop after the awaited response writes could evict a newer request - # that reused the id during that window. + # No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id. def _allocate_id(self) -> int: self._next_id += 1 @@ -849,15 +679,9 @@ async def _final_write( ) -> None: """Attempt one last write under the shared abandon/teardown policy. - Every arm that writes to the transport after giving up on a request - (timeout courtesy cancel, caller-cancel courtesy cancel, shutdown - error response) goes through here so the bound+shield+warning policy - cannot diverge between them. `shield=True` is for arms already inside - a cancelled scope (a bare `await` would re-raise immediately); the - bound keeps a wedged transport write from turning the shield into an - uncancellable hang. An unshielded arm shares the bound so a wedged - transport can't delay its caller's error indefinitely, while staying - interruptible from outside. + `shield=True` is for arms already inside a cancelled scope (a bare + `await` would re-raise); the bound keeps a wedged transport write + from becoming an uncancellable hang. """ with anyio.move_on_after(timeout, shield=shield) as scope: await write() @@ -865,10 +689,8 @@ async def _final_write( logger.warning("%s gave up: transport write blocked", describe) async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: - # Thread `related_request_id` so streamable-HTTP routes the cancel onto - # the same per-request SSE stream as the request it cancels; without it - # the notification falls through to the standalone GET stream and is - # dropped when no GET stream is open. + # Thread `related_request_id` so streamable HTTP routes the cancel onto + # the request's own SSE stream instead of a possibly-absent GET stream. try: await self.notify( "notifications/cancelled", diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 55710fa98..b4f0beedf 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,9 +1,4 @@ -"""Compatibility surface for the removed v1 session layer. - -`BaseSession` (the v1 receive loop) is gone: `ClientSession` runs on -`JSONRPCDispatcher` and the server side on `ServerRunner`. This module keeps -the names that outlived it. -""" +"""Compatibility names that outlived the removed v1 session layer (`BaseSession`).""" from typing import Generic, TypeVar @@ -18,13 +13,7 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): - """Typing stub for the v1 responder. - - Never instantiated by the SDK: the client answers every server request - itself, so the `RequestResponder` arm of `MessageHandlerFnT` is - unreachable. The class remains so existing annotations and imports keep - working. - """ + """Typing stub for the v1 responder; the SDK never instantiates it.""" request_id: RequestId request_meta: RequestParamsMeta | None diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8de41978e..6d7000f62 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -757,8 +757,7 @@ async def test_receive_loop_answers_malformed_inbound_request_with_invalid_param @pytest.mark.anyio async def test_receive_loop_answers_unknown_request_method_with_method_not_found(): - """A server request whose method is not in the ServerRequest union gets -32601 - (METHOD_NOT_FOUND) on the wire, not a validation failure (-32602).""" + """An unknown request method is answered with METHOD_NOT_FOUND, not INVALID_PARAMS (spec-mandated).""" async with raw_client_session() as (_session, to_client, from_client): await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="x/unknown"))) out = await from_client.receive() @@ -769,12 +768,10 @@ async def test_receive_loop_answers_unknown_request_method_with_method_not_found @pytest.mark.anyio async def test_receive_loop_drops_unknown_notification_method_without_response(): - """An unknown notification method is dropped silently: JSON-RPC forbids - responses to notifications, and the receive loop keeps serving.""" + """An unknown notification method is dropped silently: JSON-RPC forbids responses to notifications.""" async with raw_client_session() as (_session, to_client, from_client): await to_client.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="x/unknown"))) - # The next wire output must be the answer to this follow-up ping, - # proving the notification produced no response and the loop survived. + # The answered follow-up ping proves no response was emitted and the loop survived. await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"))) out = await from_client.receive() assert isinstance(out.message, JSONRPCResponse) @@ -783,11 +780,8 @@ async def test_receive_loop_drops_unknown_notification_method_without_response() @pytest.mark.anyio async def test_raising_sampling_callback_answers_with_code_zero(): - """A raising sampling callback propagates out of the session's request router - into the dispatcher's exception boundary, which answers with code 0 and - `str(exc)` (SDK-defined shape, pinned at the dispatcher in - tests/shared/test_jsonrpc_dispatcher.py). Raw streams because the assertion - is the outbound `JSONRPCError` envelope itself.""" + """A raising sampling callback is answered with code 0 and `str(exc)` (SDK-defined). + Raw streams because the assertion is the outbound `JSONRPCError` envelope itself.""" async def boom(ctx: object, params: object) -> types.CreateMessageResult: raise RuntimeError("sampling boom") @@ -807,10 +801,8 @@ async def boom(ctx: object, params: object) -> types.CreateMessageResult: @pytest.mark.anyio async def test_receive_loop_logs_and_drops_malformed_notification(caplog: pytest.LogCaptureFixture): - """A notification that fails `ServerNotification` validation is logged and - dropped without reaching `message_handler`, and the loop keeps serving - (SDK-defined). Scripted peer: the typed API cannot emit a method outside - the spec's notification union.""" + """A malformed notification is logged and dropped without reaching `message_handler` (SDK-defined). + Scripted peer: the typed API cannot emit a method outside the spec's notification union.""" seen: list[object] = [] delivered = anyio.Event() @@ -833,28 +825,22 @@ async def handler(msg: object) -> None: async def test_raising_message_handler_on_transport_exception_costs_the_delivery_not_the_connection( caplog: pytest.LogCaptureFixture, ): - """A transport-level `Exception` item on the read stream reaches - `message_handler` (SDK-defined: the stream-built session wires the - dispatcher's `on_stream_exception` to spawn handler deliveries), and a - handler that raises is contained by the session — the failure is logged - and the receive loop keeps serving, proven by a follow-up ping - round-trip. Raw streams because only a transport can put an `Exception` - item on the read stream.""" + """A `message_handler` that raises on a transport-level `Exception` item is contained: the + failure is logged and the receive loop keeps serving (SDK-defined). Raw streams because + only a transport can put an `Exception` item on the read stream.""" seen: list[object] = [] delivered = anyio.Event() async def handler(msg: object) -> None: seen.append(msg) delivered.set() - # No checkpoint between set() and the session's containment logging - # the raise, so once wait() resumes the log entry exists. + # No checkpoint between set() and the containment log, so after wait() the log entry exists. raise RuntimeError("handler boom") async with raw_client_session(message_handler=handler) as (_session, to_client, from_client): exc = ValueError("bad bytes") await to_client.send(exc) await delivered.wait() - # Loop health: a follow-up inbound ping is still answered. await to_client.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=9, method="ping"))) out = await from_client.receive() assert seen == [exc] @@ -865,16 +851,12 @@ async def handler(msg: object) -> None: @pytest.mark.anyio async def test_message_handler_awaiting_session_traffic_on_transport_exception_completes(): - """A message_handler that reacts to a transport-level `Exception` item by - awaiting session traffic (a ping round-trip) completes instead of - deadlocking: the session spawns fault deliveries into its task group - rather than running them inline in the dispatcher's read loop - (SDK-defined). Raw streams because only a transport can put an - `Exception` item on the read stream.""" + """A `message_handler` that awaits session traffic on a transport `Exception` item completes: + fault deliveries are spawned into the task group, not run inline in the read loop (SDK-defined). + Raw streams because only a transport can put an `Exception` item on the read stream.""" ponged = anyio.Event() - # The constructor takes the handler, so it is defined before the session - # exists; `session` resolves at call time, after the `as` clause binds it. + # `session` resolves at call time, after the `as` clause binds it. async def handler(msg: object) -> None: assert isinstance(msg, Exception) await session.send_ping() @@ -882,9 +864,7 @@ async def handler(msg: object) -> None: async with raw_client_session(message_handler=handler) as (session, to_client, from_client): await to_client.send(ValueError("bad bytes")) - # Serve the handler's ping like a transport would. Pre-spawn this - # deadlocked: the read loop was parked inside the handler, so the - # response below could never be dequeued. + # Serve the handler's ping like a transport would; inline delivery would deadlock here. out = await from_client.receive() assert isinstance(out.message, JSONRPCRequest) assert out.message.method == "ping" @@ -899,8 +879,7 @@ async def test_receive_loop_consumes_server_cancelled_without_reaching_message_h The server dispatcher now emits this on sampling/elicitation timeout, but ClientSession has no in-flight tracking to act on it, so surfacing it would only break user handlers that exhaustively match ServerNotification. - Scripted peer: the typed server API has no way to emit a bare - `notifications/cancelled`. + Scripted peer: the typed server API cannot emit a bare `notifications/cancelled`. """ seen: list[object] = [] delivered = anyio.Event() @@ -929,13 +908,8 @@ async def handler(msg: object) -> None: @pytest.mark.anyio async def test_progress_notification_reaches_request_callback_and_message_handler(): - """A `notifications/progress` for an in-flight request reaches the - `progress_callback` passed to `send_request` and still tees to - `message_handler` as a typed `ProgressNotification` — the wiring this layer - adds over the dispatcher (callback dispatch and exception containment are - pinned in tests/shared/test_jsonrpc_dispatcher.py). Scripted peer: the - progress token must echo the wire-level request id, which only raw-stream - observation exposes.""" + """A `notifications/progress` for an in-flight request reaches both the `progress_callback` and + `message_handler` (SDK-defined). Scripted peer: the progress token must echo the wire request id.""" updates: list[tuple[float, float | None, str | None]] = [] teed: list[types.ProgressNotification] = [] request_id: types.RequestId | None = None @@ -1002,11 +976,9 @@ async def server_on_notify( await tg.start(server_side.run, server_on_request, server_on_notify) async with session: results.append(await session.send_ping(meta=None)) - # Server-to-client direction: direct dispatch delivers ping with no - # params member at all (no _meta injection outside JSON-RPC). + # Server-to-client: direct dispatch delivers ping with no params member (no _meta injection). assert await server_side.send_raw_request("ping", None) == {} - # related_request_id routing is JSON-RPC plumbing; on other - # dispatchers the notification is sent without it. + # related_request_id is JSON-RPC plumbing; other dispatchers send the notification without it. await session.send_notification(types.RootsListChangedNotification(), related_request_id=7) server_side.close() assert results == [types.EmptyResult()] @@ -1015,11 +987,8 @@ async def server_on_notify( @pytest.mark.anyio async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): - """`send_request` passes `cancel_on_abandon=False` to the dispatcher for - `initialize` — the spec forbids cancelling it — and leaves the option unset - for every other method. Both dispatcher abandon arms (timeout and caller - cancel) key off this one option; the suppression mechanics are pinned in - tests/shared/test_jsonrpc_dispatcher.py.""" + """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids + cancelling it — and leaves the option unset for every other method.""" class RecordingDispatcher: """Records `send_raw_request` opts and answers with canned results.""" @@ -1082,11 +1051,8 @@ def test_constructor_requires_both_streams_without_dispatcher(): @pytest.mark.anyio async def test_aenter_cancelled_while_dispatcher_starts_unwinds_cleanly(): - """Cancellation landing while `__aenter__` waits for the dispatcher to - start (e.g. the caller wrapped connect in `move_on_after`) unwinds the - half-entered task group: the enclosing scope exits via its own deadline - instead of anyio's "exited non-innermost cancel scope" RuntimeError, and - the session is left un-entered (SDK-defined unwind path).""" + """Cancellation while `__aenter__` waits for the dispatcher to start unwinds the half-entered + task group cleanly, not via anyio's "exited non-innermost cancel scope" RuntimeError (SDK-defined).""" class NeverStartsDispatcher: """`run()` parks without ever signalling `task_status.started()`.""" @@ -1110,9 +1076,7 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: session = ClientSession(dispatcher=NeverStartsDispatcher()) async with AsyncExitStack() as stack: - # The deadline only ends the wait: `start()` is parked forever (and - # `TaskGroup.__aenter__` has no checkpoint to absorb the cancellation - # earlier), so any duration is non-racy. + # `start()` is parked forever, so the deadline only ends the wait — any duration is non-racy. with anyio.move_on_after(0.01) as scope: await stack.enter_async_context(session) assert scope.cancelled_caught @@ -1156,11 +1120,8 @@ async def test_send_notification_with_related_request_id_attaches_metadata(): @pytest.mark.anyio async def test_send_notification_with_related_request_id_zero_attaches_metadata(): - """`related_request_id=0` still routes onto the originating request's - stream: request ids are opaque and 0 is a valid one, so the session must - check `is not None` rather than truthiness (regression pin for the - falsy-zero gap). Wire-level because the metadata attachment is only - observable on the sent `SessionMessage`.""" + """`related_request_id=0` still attaches metadata: 0 is a valid request id, so the session checks + `is not None`, not truthiness (regression pin). Wire-level: only the sent `SessionMessage` shows it.""" async with raw_client_session() as (session, _to_client, from_client): await session.send_notification( types.ProgressNotification( diff --git a/tests/interaction/_helpers.py b/tests/interaction/_helpers.py index 0710ae937..54d41e1e7 100644 --- a/tests/interaction/_helpers.py +++ b/tests/interaction/_helpers.py @@ -67,9 +67,7 @@ def __init__(self, inner: WriteStream[SessionMessage], log: list[SessionMessage] self._log = log async def send(self, item: SessionMessage, /) -> None: - # Record only after the inner send returns: a send that raises (or is cancelled - # mid-write) never reached the transport, so logging it first would fabricate a - # "sent" message in wire-level assertions. + # Record only after the inner send returns: a failed or cancelled send never reached the transport. await self._inner.send(item) self._log.append(item) diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 03cf20a84..4fab2f650 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -158,11 +158,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara @requirement("protocol:cancel:server-to-client") async def test_abandoned_server_request_cancels_the_client_callback(connect: Connect) -> None: - """A server that abandons a sampling request cancels it, interrupting the client's callback. - - The handler gives up on its sampling request by cancelling the scope around it; the courtesy - notifications/cancelled that follows interrupts the client's sampling callback mid-await. - """ + """A server that abandons a sampling request cancels it, interrupting the client's callback mid-await.""" callback_started = anyio.Event() callback_cancelled = anyio.Event() @@ -219,8 +215,7 @@ async def test_a_response_for_an_unknown_request_id_is_ignored() -> None: The spec says a sender SHOULD ignore a response that arrives after it issued a cancellation; that is the same client-side code path as any response with an unknown id, and that form is - deterministic to test without depending on a client-side cancellation API. Nothing reaches - the message handler and the session keeps serving. + deterministic to test without a client-side cancellation API. A real Server cannot be made to answer with a fabricated id, so the test plays the server's side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The @@ -284,21 +279,13 @@ async def message_handler(message: IncomingMessage) -> None: pong = await session.send_request(PingRequest(), EmptyResult) assert pong == snapshot(EmptyResult()) - # The fabricated response was dropped silently: the ping after it still - # round-tripped, and the message handler (a tripwire) was never invoked. @requirement("protocol:cancel:initialize-not-cancellable") async def test_timed_out_initialize_sends_no_cancellation() -> None: - """An abandoned initialize is not followed by notifications/cancelled on the wire. - - Spec-mandated: the initialize request MUST NOT be cancelled. Abandoning any other request - sends a courtesy notifications/cancelled (see protocol:timeout:sends-cancellation); this - test pins that initialize opts out. A real Server always answers initialize, so the test - plays a stalling server by hand: it never answers initialize, the client's read timeout - fires, and the ping the test sends next is the marker — the in-memory stream is ordered and - a courtesy cancel goes out before the timeout error reaches the caller, so a regression - would put notifications/cancelled ahead of the ping in the recorded sequence. + """An abandoned initialize is not followed by notifications/cancelled on the wire (spec-mandated). + + A real Server always answers initialize, so the test plays a stalling server by hand. """ received_methods: list[str] = [] @@ -320,7 +307,6 @@ async def scripted_server(streams: MessageStream) -> None: JSONRPCResponse( jsonrpc="2.0", id=follow_up.message.id, - # Serialized exactly as a real server serializes results onto the wire. result=EmptyResult().model_dump(by_alias=True, mode="json", exclude_none=True), ) ) @@ -329,8 +315,7 @@ async def scripted_server(streams: MessageStream) -> None: async with ( create_client_server_memory_streams() as ((client_read, client_write), server_streams), anyio.create_task_group() as task_group, - # The session-level read timeout is the only public pathway that abandons initialize; - # the response never arrives, so any positive value fires on the next event-loop pass. + # The session-level read timeout is the only public pathway that abandons initialize. ClientSession(client_read, client_write, read_timeout_seconds=0.000001) as session, ): task_group.start_soon(scripted_server, server_streams) @@ -342,4 +327,5 @@ async def scripted_server(streams: MessageStream) -> None: pong = await session.send_request(PingRequest(), EmptyResult, request_read_timeout_seconds=5) assert pong == snapshot(EmptyResult()) + # The stream is ordered, so a courtesy cancel would have arrived ahead of the ping. assert received_methods == snapshot(["initialize", "ping"]) diff --git a/tests/interaction/lowlevel/test_logging.py b/tests/interaction/lowlevel/test_logging.py index d945c8e76..b8f9d3d77 100644 --- a/tests/interaction/lowlevel/test_logging.py +++ b/tests/interaction/lowlevel/test_logging.py @@ -1,15 +1,8 @@ """Logging interactions against the low-level Server, driven through the public Client API. -Notification ordering: the in-memory transport delivers every server-to-client message on one -ordered stream, and the client starts notification callbacks in arrival order. Callbacks run -concurrently with the rest of the session (no completion-before-response guarantee), but a -callback with no internal awaits runs to completion as soon as it starts, which keeps -plain-list collection deterministic here. Over streamable HTTP the ordered single-stream -guarantee holds only for messages that carry a ``related_request_id`` (they ride the -originating request's POST stream); without it the message routes to the standalone GET stream -and may arrive after the response. These tests pass ``related_request_id`` and use await-free -callbacks so they can collect into a plain list and assert after the request completes on -every transport leg -- no events, no waiting. +Notification ordering: await-free callbacks finish in arrival order, and passing +``related_request_id`` keeps each notification on the originating request's POST stream over +streamable HTTP, so plain-list collection is deterministic on every transport leg. """ import pytest diff --git a/tests/interaction/lowlevel/test_timeouts.py b/tests/interaction/lowlevel/test_timeouts.py index e79d3303b..903829845 100644 --- a/tests/interaction/lowlevel/test_timeouts.py +++ b/tests/interaction/lowlevel/test_timeouts.py @@ -30,9 +30,7 @@ async def test_request_timeout_fails_the_pending_call() -> None: """A request whose response does not arrive within its read timeout fails with a timeout error. - The timeout is followed by notifications/cancelled on the wire, so the server's handler is - interrupted instead of running to completion. The test waits for the handler to have started - only after the timeout has fired, so the timeout itself races nothing. + The timeout is followed by notifications/cancelled, which interrupts the server's handler. """ handler_started = anyio.Event() handler_cancelled = anyio.Event() @@ -53,8 +51,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara with pytest.raises(MCPError) as exc_info: await client.call_tool("block", {}, read_timeout_seconds=0.000001) - # The request was already on the wire, so the handler started; the courtesy - # cancellation that followed the timeout then interrupted it. + # The request was already on the wire: the handler started and was then cancelled. with anyio.fail_after(5): await handler_started.wait() await handler_cancelled.wait() @@ -72,10 +69,7 @@ async def call_tool(ctx: ServerRequestContext, params: types.CallToolRequestPara async def test_server_request_timeout_sends_cancellation_to_the_client() -> None: """A server-initiated request that times out fails server-side and cancels the client's work. - The server seat conforms to the spec's timeout guidance: the handler's timed-out sampling - request is followed by notifications/cancelled on the wire. The client's sampling callback - blocks until the server has already given up, then answers; the late response is discarded - and the tool call still completes. + The sampling callback answers only after the server gave up; the late response is discarded. """ release = anyio.Event() callback_started = anyio.Event() @@ -124,7 +118,7 @@ async def sampling_callback( and isinstance(item.message, JSONRPCNotification) and item.message.method == "notifications/cancelled" ] - # The cancel names the sampling request (the server's first outbound request) and the reason. + # requestId 1 is the sampling request, the server's first outbound request. assert [notification.params for notification in cancellations] == snapshot( [{"requestId": 1, "reason": "timed out after 1e-06s"}] ) diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index c3f38b708..b1c6a4f70 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -62,8 +62,7 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar assert params.name in ("slow", "fast"), f"Unknown tool: {params.name}" if params.name == "slow": - # The client's timeout fires while this waits; the courtesy - # cancellation then interrupts the wait. + # The client's timeout fires during this wait; the courtesy cancellation then interrupts it. await slow_request_lock.wait() text = f"slow {request_count}" else: diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 0b4ae48b2..e14e51d93 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -1,10 +1,4 @@ -"""JSON-RPC-specific Dispatcher tests. - -Behaviors with no `DirectDispatcher` analog: request-id correlation, the -exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. -The contract tests shared with `DirectDispatcher` live in -`test_dispatcher.py`. -""" +"""JSON-RPC-specific dispatcher tests; contract tests shared with `DirectDispatcher` live in `test_dispatcher.py`.""" import contextvars import json @@ -58,13 +52,7 @@ class RecordingWriteStream: - """Write stream that records sends synchronously, without a checkpoint. - - Models a transport write that can complete without yielding (a memory - stream's `send` checkpoints first, which would let a pending cancellation - interrupt the write and mask the behavior under test). `__aexit__` - releases nothing, so writes during run() teardown still land. - """ + """Records sends without a checkpoint, so a pending cancellation cannot interrupt the write or mask it.""" def __init__(self) -> None: self.sent: list[SessionMessage] = [] @@ -165,19 +153,9 @@ async def call_then_record() -> None: @pytest.mark.anyio async def test_peer_cancel_landing_after_handlers_last_checkpoint_writes_only_the_result(): - """A peer cancel that fails to interrupt the handler must not add a code-0 - error after the result: exactly one answer for that id goes on the wire. - - SDK-defined: the cancelled-error response belongs only when the - cancellation was actually absorbed, i.e. the result write did not happen. - The schedule is deterministic because the cancel notification itself is - the handler's wakeup: the read loop sets `ctx.cancel_requested` and then - cancels the scope in the same synchronous block, so anyio defers the - cancellation (the wakeup future is already done) and the handler runs to - completion. The recording write stream is needed because a memory - stream's `send` checkpoints, which would let the deferred cancellation - land mid-write and hide the double answer. - """ + """A peer cancel that fails to interrupt the handler writes only the result: one answer per + id goes on the wire (SDK-defined). The recording stream is needed because a memory stream's + `send` checkpoints, letting the deferred cancellation land mid-write and hide a double answer.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) recording = RecordingWriteStream() server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) @@ -197,6 +175,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) with anyio.fail_after(5): await handler_started.wait() + # The cancel is also the handler's wakeup, so anyio defers it and the handler completes. await c2s_send.send( SessionMessage( message=JSONRPCNotification( @@ -228,7 +207,6 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | def factory(*, can_send_request: bool = True): client, server, close = jsonrpc_pair(can_send_request=can_send_request) - # Reach in to set signal mode on the server side. assert isinstance(server, JSONRPCDispatcher) server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] return client, server, close @@ -275,17 +253,12 @@ async def caller() -> None: @pytest.mark.anyio async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): - """Iterating a closed receive end raises ClosedResourceError; run() treats it as EOF. - - Stateless SHTTP teardown closes the dispatcher's receive end after the - request is handled; the next loop iteration must not surface as a crash. - """ + """Iterating a closed receive end is EOF, not a crash (stateless SHTTP closes it during teardown).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) on_request, on_notify = echo_handlers(Recorder()) - # Close the dispatcher's own receive end (not the send end) before run() - # iterates it: __anext__ on a closed stream raises ClosedResourceError. + # Close the receive end itself (not the send end): __anext__ then raises ClosedResourceError. c2s_recv.close() with anyio.fail_after(5): await server.run(on_request, on_notify) @@ -295,12 +268,8 @@ async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): @pytest.mark.anyio async def test_run_cancels_in_flight_handlers_when_read_stream_eofs(): - """A handler that outlives its caller must not keep run() from returning. - - Without the cancel-at-EOF, the task-group join would wait on this handler - forever (over SSE that leaks the handler task and the GET request hosting - the session). - """ + """run() cancels still-running handlers at read-stream EOF; otherwise its join waits forever + (over SSE, leaking the handler and the GET request hosting the session).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -346,7 +315,6 @@ async def test_run_closes_write_stream_on_exit(): await tg.start(server.run, on_request, on_notify) c2s_send.close() # EOF the read side; run() exits with anyio.fail_after(5), pytest.raises(anyio.EndOfStream): # pragma: no branch - # Write end was entered and released by run(); peer's receive sees EOF. await s2c_recv.receive() s2c_recv.close() @@ -365,8 +333,7 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | with anyio.fail_after(5): with pytest.raises(MCPError): # REQUEST_TIMEOUT await client.send_raw_request("slow", None, {"timeout": 0}) - # The server handler is still running; let it finish and write a - # response for an id the client has already discarded. + # Let the parked handler respond to an id the client has already discarded. await handler_started.wait() proceed.set() # One more round-trip proves the dispatcher is still healthy. @@ -395,7 +362,6 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> with pytest.raises(BaseException) as exc: async with anyio.create_task_group() as tg: await tg.start(server.run, boom, on_notify) - # Inject a request directly onto the server's read stream. await c2s_send.send( SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) ) @@ -426,7 +392,6 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> try: async with anyio.create_task_group() as tg: await tg.start(server.run, server_on_request, on_notify) - # Kick the server with an inbound request id=7. await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) with anyio.fail_after(5): outbound = await s2c_recv.receive() @@ -434,7 +399,6 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> assert isinstance(outbound.message, JSONRPCRequest) assert isinstance(outbound.metadata, ServerMessageMetadata) assert outbound.metadata.related_request_id == 7 - # Reply so the handler completes cleanly. await c2s_send.send( SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) ) @@ -451,12 +415,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_courtesy_cancel_on_timeout_tags_outbound_with_server_message_metadata(): - """The timeout-path `notifications/cancelled` carries the originating request id. - - Streamable-HTTP's `message_router` keys on `ServerMessageMetadata.related_request_id`; - a cancel without it would fall through to the standalone GET stream and be dropped - when no GET stream is open, so the client never learns to stop work. - """ + """The timeout-path `notifications/cancelled` carries the originating request id: SHTTP's + `message_router` keys on `related_request_id`; without it the cancel would be dropped.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -501,14 +461,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_dispatch_context_request_with_dropped_resumption_hints_still_sends_courtesy_cancel(): - """Resumption hints that never reach the transport must not suppress the abandon cancel. - - For a dispatch-context request, `related_request_id` takes metadata - precedence and the hints are dropped - so the request is not resumable, - and abandoning it without a courtesy cancel would leak the peer's work - forever. One decision (`_plan_outbound`) now produces both the metadata - and the cancel policy, so they cannot disagree. - """ + """Resumption hints that never reach the transport must not suppress the abandon cancel: + `related_request_id` takes metadata precedence and drops the hints, so the request is not resumable.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -594,14 +548,9 @@ async def caller() -> None: @pytest.mark.anyio async def test_caller_cancel_during_blocked_request_write_sends_no_cancelled_notification(): - """A caller cancelled while the request write is still blocked must not emit - `notifications/cancelled`: the spec only allows cancelling previously-issued - requests, and this one never reached the peer. - - The fake stream wedges only the first write (the request) and records any - later one synchronously, so a courtesy cancel - which would be the bug - - is captured even though it runs inside the bounded shield. - """ + """A caller cancelled mid-request-write must not emit `notifications/cancelled` (the spec only + allows cancelling issued requests). The fake stream wedges only the first write, so a later + courtesy cancel - the bug - would still be captured.""" class FirstWriteWedgedStream: def __init__(self) -> None: @@ -615,7 +564,7 @@ async def send(self, item: SessionMessage) -> None: self.sent.append(item) async def aclose(self) -> None: - raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + raise NotImplementedError async def __aenter__(self) -> "FirstWriteWedgedStream": return self @@ -652,16 +601,13 @@ async def caller() -> None: scopes[0].cancel() with anyio.fail_after(5): await gave_up.wait() - # Prove the recorder is live: a marker write after the wedge IS - # captured, so a courtesy cancel would have been too. await client.notify("notifications/marker", None) tg.cancel_scope.cancel() finally: s2c_send.close() s2c_recv.close() assert scopes[0].cancelled_caught - # The marker is the only post-wedge write: no cancel notification went out - # for a request the peer never received. + # Only the marker went out post-wedge: no cancel for a request the peer never received. assert [m.message for m in wedged.sent] == [JSONRPCNotification(jsonrpc="2.0", method="notifications/marker")] @@ -697,8 +643,7 @@ async def caller() -> None: scopes[0].cancel() with anyio.fail_after(5): await gave_up.wait() - # The next write proves nothing was sent in between: a courtesy - # cancel would have to precede it on the ordered stream. + # A courtesy cancel would have to precede the marker on the ordered stream. await client.notify("marker", None) with anyio.fail_after(5): nxt = await c2s_recv.receive() @@ -781,15 +726,9 @@ async def test_cancel_on_abandon_false_suppresses_the_courtesy_cancellation_on_t async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wedged( caplog: pytest.LogCaptureFixture, ): - """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang. - - The peer consumes exactly the request (arming the courtesy cancel) and never responds; - cancelling the caller then routes into the shielded courtesy-cancel write, which blocks on - the unbuffered, unread write stream. The bound abandons it after _SHIELDED_WRITE_TIMEOUT - (with a warning); trio's virtual clock makes the wait instant. On regression (unbounded - shield) the test hangs rather than failing fast: the outer fail_after cannot cancel through - the shield - that is the bug. - """ + """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang: + `_SHIELDED_WRITE_TIMEOUT` abandons the courtesy-cancel write (SDK-defined bound). On regression + the test hangs rather than failing fast - fail_after cannot cancel through the shield.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -806,14 +745,12 @@ async def caller() -> None: gave_up.set() try: - # Both bounds must exceed the in-loop _SHIELDED_WRITE_TIMEOUT (5s) they - # wait out; the virtual clock means no wall-time cost. + # Both bounds exceed the in-loop _SHIELDED_WRITE_TIMEOUT (5s); the virtual clock makes them instant. with anyio.fail_after(30): async with anyio.create_task_group() as tg: # pragma: no branch await tg.start(client.run, on_request, on_notify) tg.start_soon(caller) - # Consume exactly the request so its write completes; the later - # courtesy cancel finds no reader and wedges. + # Consume only the request; the later courtesy cancel finds no reader and wedges. request = await c2s_recv.receive() assert isinstance(request, SessionMessage) assert isinstance(request.message, JSONRPCRequest) @@ -837,13 +774,8 @@ async def caller() -> None: async def test_timeout_courtesy_cancel_write_is_bounded_when_the_transport_is_wedged( caplog: pytest.LogCaptureFixture, ): - """A wedged transport write cannot delay the REQUEST_TIMEOUT error indefinitely (SDK-defined bound). - - The peer consumes exactly the request and never responds; when the timeout - elapses, the courtesy cancel blocks on the unbuffered, unread write stream. - The bound abandons it after _SHIELDED_WRITE_TIMEOUT (with a warning) so the - timeout error still surfaces; trio's virtual clock makes the waits instant. - """ + """A wedged transport write cannot delay the REQUEST_TIMEOUT error indefinitely (SDK-defined + bound): `_SHIELDED_WRITE_TIMEOUT` abandons the courtesy cancel so the error still surfaces.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -862,14 +794,12 @@ async def caller() -> None: async with anyio.create_task_group() as tg: await tg.start(client.run, on_request, on_notify) tg.start_soon(caller) - # Consume exactly the request so its write completes; the later - # courtesy cancel finds no reader and wedges. + # Consume only the request; the later courtesy cancel finds no reader and wedges. with anyio.fail_after(5): request = await c2s_recv.receive() assert isinstance(request, SessionMessage) assert isinstance(request.message, JSONRPCRequest) - # Must exceed the request timeout (1s) plus the in-loop - # _SHIELDED_WRITE_TIMEOUT (5s); the virtual clock means no wall-time cost. + # Exceeds the request timeout (1s) plus _SHIELDED_WRITE_TIMEOUT (5s); virtual clock, no wall time. with anyio.fail_after(10): await gave_up.wait() tg.cancel_scope.cancel() @@ -888,23 +818,16 @@ async def caller() -> None: async def test_shutdown_error_response_write_is_bounded_when_the_transport_is_wedged( caplog: pytest.LogCaptureFixture, ): - """Cancelling the task group hosting run() completes even when the shutdown error write wedges (SDK-defined bound). - - The in-flight handler is parked when run() is cancelled; its shielded - connection-closed-error write blocks on a wedged transport, and only the - _SHUTDOWN_WRITE_TIMEOUT bound lets the join complete. A fake write stream is - needed because a memory stream can't express the wedge: run()'s teardown - closes its own write stream, which would wake the blocked send. On - regression (unbounded shield) the test hangs rather than failing fast: the - outer fail_after cannot cancel through the shield - that is the bug. - """ + """Cancelling the task group hosting run() completes even when the shutdown error write wedges: + only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). The fake stream is needed + because run()'s teardown closes a memory stream, which would wake the blocked send.""" class WedgedWriteStream: async def send(self, item: SessionMessage) -> None: await anyio.sleep_forever() async def aclose(self) -> None: - raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + raise NotImplementedError async def __aenter__(self) -> "WedgedWriteStream": return self @@ -930,10 +853,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> raise NotImplementedError try: - # Sits above the in-loop _SHUTDOWN_WRITE_TIMEOUT (1s) it waits out but - # below _SHIELDED_WRITE_TIMEOUT (5s), so this also pins that the - # shutdown arm uses the tighter shutdown bound (session close must be - # quick); the virtual clock means no wall-time cost. + # 3s sits between _SHUTDOWN_WRITE_TIMEOUT (1s) and _SHIELDED_WRITE_TIMEOUT (5s): pins the tighter bound. with anyio.fail_after(3): async with anyio.create_task_group() as tg: # pragma: no branch await tg.start(server.run, park, on_notify) @@ -945,21 +865,14 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> finally: c2s_send.close() c2s_recv.close() - # Reaching here proves the join completed; the warning proves it was the - # bound (not a completed write) that released it. + # The warning proves the bound (not a completed write) released the join. assert "shutdown error response for request" in caplog.text @pytest.mark.anyio async def test_shutdown_answers_in_flight_request_with_connection_closed(): - """Cancelling run() answers a still-running request with CONNECTION_CLOSED. - - SDK-defined contract: the peer learns its request died with the connection - (not a request-specific cancellation - -32002 belongs to the spec's - resource-not-found). The recording write stream keeps the teardown write - observable: run()'s exit would close a memory stream before the shielded - write lands. - """ + """Cancelling run() answers a still-running request with CONNECTION_CLOSED (SDK-defined). The + recording stream is needed because run()'s exit would close a memory stream before the shielded write lands.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) recording = RecordingWriteStream() server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) @@ -990,11 +903,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_request_write_failure_propagates_and_leaves_no_pending_entry(): - """A request whose transport write raises must not leak its `_pending` entry. - - SDK-defined: regression cover for v1's `test_send_request_stream_cleanup` - (response streams were cleaned up when the write failed). - """ + """A request whose transport write raises must not leak its `_pending` entry (v1 regression cover).""" boom = RuntimeError("write failed") class RaisingWriteStream: @@ -1002,7 +911,7 @@ async def send(self, item: SessionMessage) -> None: raise boom async def aclose(self) -> None: - raise NotImplementedError # the dispatcher releases streams via __aexit__, never aclose + raise NotImplementedError async def __aenter__(self) -> "RaisingWriteStream": return self @@ -1033,11 +942,7 @@ async def __aexit__( @pytest.mark.anyio async def test_request_write_on_torn_down_transport_raises_connection_closed(): - """The transport tearing down before run() notices EOF surfaces as MCPError, not a raw stream error. - - SDK-defined: `send_raw_request` documents MCPError(CONNECTION_CLOSED) for a - closed connection; the raw `BrokenResourceError` from the write must not leak. - """ + """A write onto a torn-down transport surfaces as MCPError(CONNECTION_CLOSED), not a raw `BrokenResourceError`.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1045,8 +950,7 @@ async def test_request_write_on_torn_down_transport_raises_connection_closed(): try: async with anyio.create_task_group() as tg: await tg.start(client.run, on_request, on_notify) - # Tear down the peer's receive end only: the client's read stream - # stays open, so run() has not observed EOF when the write fails. + # Close only the peer's receive end, so run() has not observed EOF when the write fails. c2s_recv.close() with anyio.fail_after(5), pytest.raises(MCPError) as exc: await client.send_raw_request("ping", None) @@ -1059,12 +963,7 @@ async def test_request_write_on_torn_down_transport_raises_connection_closed(): @pytest.mark.anyio async def test_notification_handler_exception_is_contained(caplog: pytest.LogCaptureFixture): - """A raising notification handler costs only that notification, never the connection. - - The handler runs as a bare task in the dispatcher's task group; without containment its - exception would cancel the read loop and every in-flight request. The TypeScript, C#, and - Go engines all contain notification-handler failures the same way. - """ + """A raising notification handler costs only that notification, never the connection (parity with TS/C#/Go).""" async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: raise RuntimeError("notify boom") @@ -1080,12 +979,8 @@ async def server_on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | N @pytest.mark.anyio async def test_spawned_notification_handlers_run_concurrently(): - """Notification handlers are spawned, not serialized: a parked one does not block the next. - - The first handler waits for the second to have started - serialized dispatch would deadlock - here. This matches the TypeScript and C# engines (fire-and-forget); handlers needing - mutual ordering must coordinate themselves. - """ + """Notification handlers are spawned, not serialized (parity with TS/C#): the first handler + waits for the second to start, so serialized dispatch would deadlock here.""" second_started = anyio.Event() completed: list[str] = [] done = anyio.Event() @@ -1197,12 +1092,8 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | @pytest.mark.anyio async def test_ctx_after_handler_return_reports_closed_and_drops_backchannel_traffic(): - """Once `_handle_request` closes the dctx, the back-channel guard and ops agree. - - Detached work that outlives the handler must see `can_send_request == False`, - get `NoBackChannelError` from `send_raw_request`, and have `notify`/`progress` - silently dropped rather than emitted with a stale `related_request_id`. - """ + """After `_handle_request` closes the dctx, `can_send_request` is False, `send_raw_request` raises + NoBackChannelError, and `notify`/`progress` are dropped rather than sent with a stale `related_request_id`.""" captured: list[DCtx] = [] async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -1222,8 +1113,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) await dctx.send_raw_request("sampling/createMessage", None) await dctx.notify("notifications/message", {"level": "info"}) await dctx.progress(0.9) - # A second round-trip flushes any notification the server might have - # written, so an empty client recorder afterwards proves the drop. + # A second round-trip flushes any server write; an empty recorder then proves the drop. await client.send_raw_request("ping", None) assert crec.notifications == [] @@ -1243,15 +1133,13 @@ async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): with anyio.fail_after(5): result = await client.send_raw_request("t", None, opts) - # Request still completes; the callback's crash was swallowed. assert result == {"ok": True} assert "progress callback raised" in caplog.text @pytest.mark.anyio async def test_inline_methods_are_handled_before_next_message_is_dequeued(): - """A method in `inline_methods` runs to completion before subsequent - messages are dispatched, so its side effects are visible to them.""" + """An `inline_methods` method runs to completion before the next message is dispatched.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( @@ -1283,10 +1171,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_send_raw_request_always_carries_meta_on_the_wire(): - """Outbound requests always include `params._meta` (otel injection per SEP-414). - - Caller-supplied `_meta` keys are preserved; the progress token is merged in. - """ + """Outbound requests always carry `params._meta` (otel injection per SEP-414); caller-supplied + keys are preserved and the progress token is merged in.""" seen: list[Mapping[str, Any] | None] = [] async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -1301,9 +1187,7 @@ async def noop_progress(progress: float, total: float | None, message: str | Non with anyio.fail_after(5): await client.send_raw_request("a", None) await client.send_raw_request("b", {"x": 1, "_meta": {"k": "v"}}, opts) - # `_meta` is always present. Its contents depend on the active otel - # tracer (traceparent/tracestate may be injected), so assert presence - # and that anything beyond W3C keys is exactly what we expect. + # `_meta` contents depend on the active otel tracer, so pin only what sits beyond the W3C keys. w3c = {"traceparent", "tracestate"} assert seen[0] is not None and seen[0].keys() == {"_meta"} assert set(seen[0]["_meta"].keys()) <= w3c @@ -1473,8 +1357,7 @@ async def call() -> None: @pytest.mark.anyio @pytest.mark.parametrize("inline", [frozenset[str](), frozenset({"t"})], ids=["spawned", "inline"]) async def test_handler_inherits_sender_contextvars(inline: frozenset[str]): - """The handler task sees contextvars set by the task that wrote into the - read stream, on both the spawned and the inline-method dispatch paths.""" + """The handler sees the sender's contextvars on both the spawned and the inline-method dispatch paths.""" raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) write_send = ContextSendStream[SessionMessage | Exception](raw_send) @@ -1574,9 +1457,7 @@ async def caller() -> None: sent = await c2s_recv.receive() assert isinstance(sent, SessionMessage) assert isinstance(sent.message, JSONRPCRequest) - # Now safe: close the client's write end, then cancel the caller. The - # shielded `_cancel_outbound` write hits ClosedResourceError and is - # swallowed; cancellation propagates cleanly. + # The shielded `_cancel_outbound` write now hits ClosedResourceError and is swallowed. c2s_send.close() caller_scope.cancel() with anyio.fail_after(5): @@ -1606,7 +1487,6 @@ def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) - # Register a fake pending and pre-fill its single buffer slot. d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] send.send_nowait({"real": "result"}) d._fan_out_closed() # pyright: ignore[reportPrivateUsage] @@ -1660,11 +1540,7 @@ async def respond_stringly() -> None: @pytest.mark.anyio async def test_error_response_with_string_id_correlates_to_int_keyed_pending_request(): - """A peer that echoes the request ID as a JSON string on a JSONRPCError still resolves the waiter. - - Same `_coerce_id` treatment as the success-response path: the peer's error - surfaces as MCPError instead of the request hanging until the connection closes. - """ + """A JSONRPCError echoing the request ID as a JSON string still resolves the waiter (same `_coerce_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1763,8 +1639,7 @@ async def test_jsonrpc_error_response_with_null_id_is_dropped(): SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) ) with anyio.fail_after(5): - # The read stream is ordered: this round-trip completing proves - # the null-id error was consumed without killing the loop. + # Ordered stream: this round-trip completing proves the null-id error was consumed. async def respond() -> None: out = await c2s_recv.receive() assert isinstance(out, SessionMessage) @@ -1783,12 +1658,7 @@ async def respond() -> None: @pytest.mark.anyio async def test_notify_without_params_omits_params_key_on_the_wire(): - """JSON-RPC 2.0 forbids `params: null`; the member must be absent. - - Transports serialize with `exclude_unset=True`, so `notify` must leave - `params` unset on the model rather than passing an explicit None (strict - peers like the TypeScript SDK reject `"params": null`). - """ + """JSON-RPC 2.0 forbids `params: null`: `notify` leaves `params` unset (transports use `exclude_unset=True`).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -1908,9 +1778,7 @@ async def call() -> None: tg.start_soon(call) await handler_started.wait() await client.notify("notifications/cancelled", {"requestId": True}) - # The malformed cancel is teed to on_notify; once observed, the - # correlation arm has already run - and must not have cancelled - # the request keyed by id 1. + # Once the teed notification is observed, the correlation arm has already run. await srec.notified.wait() assert not handler_exited.is_set() await client.notify("notifications/cancelled", {"requestId": 1}) @@ -1967,8 +1835,7 @@ async def on_progress(progress: float, total: float | None, message: str | None) @pytest.mark.anyio async def test_request_with_bool_meta_progress_token_is_not_adopted(): - """A bool `_meta.progressToken` is malformed; `ctx.progress()` must be a no-op - instead of emitting `progressToken: true` on the wire.""" + """A bool `_meta.progressToken` is malformed: `ctx.progress()` must be a no-op, not emit `progressToken: true`.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -2006,8 +1873,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> ids=["string-cancel-for-int-request", "int-cancel-for-string-request"], ) async def test_cancelled_correlates_across_string_and_int_request_id_forms(request_id: RequestId, cancel_id: object): - """A peer that stringifies the id between request and cancel still cancels - (same `_coerce_id` treatment as the response path).""" + """A peer that stringifies the id between request and cancel still cancels (same `_coerce_id` path).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -2046,9 +1912,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_completed_handler_does_not_evict_reused_request_id_from_in_flight(): - """The awaited response write sits after the `_in_flight` pop; a second - request reusing the id during that window must keep its own entry (a - second, post-write pop would evict it and break its peer-cancellation).""" + """A second request reusing an id while the first handler is parked in its response write + keeps its own `_in_flight` entry (a post-write pop would evict it and break peer-cancellation).""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) # buffer=0: the first handler's response write parks until the test receives. s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) @@ -2108,15 +1973,8 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> @pytest.mark.anyio async def test_duplicate_request_id_completion_of_first_handler_keeps_second_cancellable(): - """When a duplicate inbound id overwrites `_in_flight` while the first - handler is still running, the first handler's completion must not evict - the second's entry - that would leave the second request immune to - `notifications/cancelled`. - - SDK-defined: the spec puts id uniqueness on the sender and the dispatcher - blind-overwrites on duplicates (parity with v1/TS); the in-table pop is - identity-guarded so a stale handler only removes its own entry. - """ + """A duplicate inbound id overwrites `_in_flight` (parity with v1/TS); the identity-guarded pop + keeps the first handler's completion from evicting the second's entry and breaking its cancellation.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -2146,8 +2004,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> with anyio.fail_after(5): await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="first"))) await first_started.wait() - # Duplicate id while the first handler is still running: the - # table entry now belongs to the second request. + # Duplicate id: the table entry now belongs to the second request. await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="second"))) await second_started.wait() release_first.set() @@ -2180,9 +2037,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> def test_plan_outbound_with_related_request_id_drops_resumption_hints_but_keeps_abandon_cancel( caplog: pytest.LogCaptureFixture, ): - """`SessionMessage.metadata` carries one object; `related_request_id` wins - and resumption hints are dropped observably (debug log). Dropped hints do - not suppress the abandon cancel: the request is not resumable.""" + """`related_request_id` wins the metadata slot; dropped hints don't suppress the abandon cancel.""" with caplog.at_level(logging.DEBUG, logger="mcp.shared.jsonrpc_dispatcher"): plan = _plan_outbound(7, {"resumption_token": "abc"}) assert isinstance(plan.metadata, ServerMessageMetadata) @@ -2198,9 +2053,8 @@ def test_plan_outbound_with_related_request_id_drops_resumption_hints_but_keeps_ @pytest.mark.anyio async def test_server_middleware_observes_cancelled_notification(): - """End-to-end over the JSON-RPC path: `Server.middleware` wraps every inbound - notification, including `notifications/cancelled` (the dispatcher applies - the cancellation itself, then forwards the notification).""" + """`Server.middleware` wraps every inbound notification, including `notifications/cancelled` + (the dispatcher applies the cancellation itself, then forwards the notification).""" handler_started = anyio.Event() cancel_observed = anyio.Event() observed: list[tuple[str, dict[str, Any]]] = [] diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index fe6b98711..3613df18f 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -2231,13 +2231,7 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers(context_ @pytest.mark.anyio async def test_standalone_stream_teardown_mid_listen_is_not_an_error(caplog: pytest.LogCaptureFixture) -> None: - """Tearing down the standalone stream under its parked writer produces no error log. - - SDK-defined teardown behavior, driven through the full client/server path: the writer - is parked in receive() when teardown lands, and ends quietly. The companion test - test_standalone_stream_teardown_between_dequeues_is_not_an_error forces the other - teardown window, which this path cannot reach deterministically. - """ + """Standalone-stream teardown while the writer is parked in receive() logs no error (SDK-defined).""" session_manager = StreamableHTTPSessionManager( app=_create_server(), security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), @@ -2259,8 +2253,7 @@ async def message_handler( ClientSession(read_stream, write_stream, message_handler=message_handler) as session, ): await session.initialize() - # Prove the standalone GET writer is live: a notification with no - # related request rides the GET stream to the client. + # A notification with no related request rides the GET stream, proving the writer is live. await session.call_tool("test_tool_with_standalone_notification", {}) with anyio.fail_after(5): await notified.wait() @@ -2274,30 +2267,17 @@ async def message_handler( async def test_standalone_stream_teardown_between_dequeues_is_not_an_error( caplog: pytest.LogCaptureFixture, ) -> None: - """Teardown landing while the standalone writer is between dequeues produces no error log. - - SDK-defined: after teardown, the writer's next dequeue hits its own closed stream - (ClosedResourceError), which is expected disconnect noise, not an error. The public - surface cannot force this window (the in-process client consumes SSE without - backpressure, so the writer is always parked in receive() when teardown runs), so this - drives the transport's ASGI entry point directly with a gated `send`. - - Steps: - 1. A GET establishes the standalone SSE stream; the gated ASGI send keeps the - response from consuming any SSE data. - 2. An event sent into the standalone stream rendezvouses with the writer's receive(), - which then blocks forwarding it to the un-consumed SSE stream -- the - between-dequeues window. - 3. Stream cleanup runs inside that window, closing both standalone stream ends. - 4. The gate opens: the event reaches the wire, the writer's next dequeue hits the - closed stream, and the response completes cleanly with nothing logged as an error. + """Teardown landing while the standalone writer is between dequeues logs no error. + + SDK-defined: after teardown the writer's next dequeue hits its own closed stream — expected + disconnect noise. The public surface cannot force this window (the in-process client consumes + SSE without backpressure), so the test drives the transport's ASGI entry point with a gated `send`. """ transport = StreamableHTTPServerTransport( mcp_session_id=None, security_settings=TransportSecuritySettings(enable_dns_rebinding_protection=False), ) - # The GET handler only checks that a read-stream writer exists; the standalone - # writer never touches it. + # The GET handler only checks that a read-stream writer exists; it is never written to. read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0) transport._read_stream_writer = read_stream_writer # pyright: ignore[reportPrivateUsage] @@ -2306,8 +2286,7 @@ async def test_standalone_stream_teardown_between_dequeues_is_not_an_error( class SignalingStreams( dict[types.RequestId, tuple[MemoryObjectSendStream[EventMessage], MemoryObjectReceiveStream[EventMessage]]] ): - # Only the GET handler inserts here, so any insert is the standalone stream - # registration the test is waiting on. + # Only the GET handler inserts here, so any insert is the standalone stream registration. def __setitem__( self, key: types.RequestId, @@ -2325,8 +2304,7 @@ async def asgi_send(message: Message) -> None: sent.append(message) await gate.wait() - # Never delivers anything: parks the response's disconnect listener until the - # completed response cancels it. + # Never delivers anything, parking the response's disconnect listener. disconnect_send, disconnect_receive = anyio.create_memory_object_stream[Message](0) async def asgi_receive() -> Message: @@ -2347,17 +2325,13 @@ async def asgi_receive() -> Message: tg.start_soon(transport.handle_request, scope, asgi_receive, asgi_send) await stream_registered.wait() standalone_send = transport._request_streams[GET_STREAM_KEY][0] # pyright: ignore[reportPrivateUsage] - # Zero-buffer rendezvous: send() returns only once the writer's receive() - # has taken the event, so the writer is now between dequeues, blocked - # forwarding to the SSE stream nothing consumes while the gate is closed. + # Zero-buffer rendezvous: once send() returns, the writer has dequeued the event + # and is blocked forwarding it past the closed gate — the between-dequeues window. await standalone_send.send(EventMessage(notification)) await transport._clean_up_memory_streams(GET_STREAM_KEY) # pyright: ignore[reportPrivateUsage] - # Unblock the response: it consumes the forwarded event, and the writer's - # next dequeue hits its closed stream. + # Unblock the response; the writer's next dequeue hits its closed stream. gate.set() - # The event dequeued before teardown still reached the wire, and the response - # ended with a normal completion rather than an exception. assert sent[0]["type"] == "http.response.start" assert sent[0]["status"] == 200 body_chunks = [message for message in sent if message["type"] == "http.response.body"] From 07f1b6d195da62a24492a0bdc43c12c5a4c06025 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:14:49 +0000 Subject: [PATCH 16/24] Remove the client-side related_request_id surface ServerMessageMetadata.related_request_id exists for the server's streamable-HTTP transport to route outbound messages onto the originating request's SSE stream. No client transport has ever serialized it, so ClientSession's related_request_id parameter and ServerMessageMetadata acceptance were dead inheritance from the shared v1 BaseSession. - send_notification loses its related_request_id parameter - send_request's metadata narrows to ClientMessageMetadata | None (resumption hints, the live part) - the isinstance(dispatcher) downcasts those parameters forced are gone Progress and response correlation (progressToken in params, JSON-RPC id) are payload-level mechanisms and are unaffected. --- docs/migration.md | 1 + src/mcp/client/session.py | 33 +++++----------------- tests/client/test_session.py | 55 ++---------------------------------- 3 files changed, 10 insertions(+), 79 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 9dba4bcc8..509a381a5 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1187,6 +1187,7 @@ Behavior changes: - **Error responses with a null `id`** — the JSON-RPC shape for a peer reporting a parse error — are now dropped with a debug log. v1 surfaced them to `message_handler` as an `MCPError`. - **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending. - **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. `send_notification` before entry still works. +- **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** The hint was never serialized by any client transport in v1 or v2 — it exists for the server's streamable-HTTP stream routing. Progress and response correlation via `progressToken` and the request id is unaffected. `mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists. diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 30fba3a92..796a56ea6 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -18,7 +18,7 @@ from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.message import ClientMessageMetadata, SessionMessage from mcp.shared.session import ProgressFnT, RequestResponder from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -202,15 +202,13 @@ async def send_request( request: types.ClientRequest, result_type: type[ReceiveResultT], request_read_timeout_seconds: float | None = None, - metadata: MessageMetadata = None, + metadata: ClientMessageMetadata | None = None, progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: """Send a request and wait for its typed result. Args: - metadata: Transport hints: `ClientMessageMetadata` resumption fields - (streamable HTTP), or a `ServerMessageMetadata.related_request_id` - routing the message onto the originating request's stream. + metadata: Streamable HTTP resumption hints. Raises: MCPError: Error response, read timeout, or connection closed. @@ -224,38 +222,21 @@ async def send_request( opts["timeout"] = timeout if progress_callback is not None: opts["on_progress"] = progress_callback - related_request_id: types.RequestId | None = None - if isinstance(metadata, ClientMessageMetadata): + if metadata is not None: if metadata.resumption_token is not None: opts["resumption_token"] = metadata.resumption_token if metadata.on_resumption_token_update is not None: opts["on_resumption_token"] = metadata.on_resumption_token_update - elif isinstance(metadata, ServerMessageMetadata): - related_request_id = metadata.related_request_id if method == "initialize": # The spec forbids cancelling initialize. opts["cancel_on_abandon"] = False - if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher): - # Only JSON-RPC dispatchers have per-request streams to route onto. - raw = await self._dispatcher.send_raw_request( - method, data.get("params"), opts, _related_request_id=related_request_id - ) - else: - raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) + raw = await self._dispatcher.send_raw_request(method, data.get("params"), opts) return result_type.model_validate(raw, by_name=False) - async def send_notification( - self, - notification: types.ClientNotification, - related_request_id: types.RequestId | None = None, - ) -> None: + async def send_notification(self, notification: types.ClientNotification) -> None: """Send a one-way notification. Usable before entering the context manager.""" data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) - # `is not None`: request ids are opaque and 0 is valid. - if related_request_id is not None and isinstance(self._dispatcher, JSONRPCDispatcher): - await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) - else: - await self._dispatcher.notify(data["method"], data.get("params")) + await self._dispatcher.notify(data["method"], data.get("params")) async def initialize(self) -> types.InitializeResult: sampling = ( diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 6d7000f62..8a871e0aa 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -14,7 +14,7 @@ from mcp.shared._context import RequestContext from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest -from mcp.shared.message import ServerMessageMetadata, SessionMessage +from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -978,8 +978,7 @@ async def server_on_notify( results.append(await session.send_ping(meta=None)) # Server-to-client: direct dispatch delivers ping with no params member (no _meta injection). assert await server_side.send_raw_request("ping", None) == {} - # related_request_id is JSON-RPC plumbing; other dispatchers send the notification without it. - await session.send_notification(types.RootsListChangedNotification(), related_request_id=7) + await session.send_notification(types.RootsListChangedNotification()) server_side.close() assert results == [types.EmptyResult()] assert notified == ["notifications/roots/list_changed"] @@ -1082,53 +1081,3 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: assert scope.cancelled_caught # The failed enter must not leave the session half-entered. assert session._task_group is None - - -@pytest.mark.anyio -async def test_send_request_with_server_metadata_routes_related_request_id(): - """ServerMessageMetadata.related_request_id is threaded onto the outgoing message.""" - async with raw_client_session() as (session, to_client, from_client): - async with anyio.create_task_group() as tg: - - async def call() -> None: - await session.send_request( - types.PingRequest(), types.EmptyResult, metadata=ServerMessageMetadata(related_request_id=3) - ) - - tg.start_soon(call) - out = await from_client.receive() - assert isinstance(out.metadata, ServerMessageMetadata) - assert out.metadata.related_request_id == 3 - assert isinstance(out.message, JSONRPCRequest) - await to_client.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id=out.message.id, result={}))) - - -@pytest.mark.anyio -async def test_send_notification_with_related_request_id_attaches_metadata(): - """A related_request_id on a notification rides the originating request's stream.""" - async with raw_client_session() as (session, _to_client, from_client): - await session.send_notification( - types.ProgressNotification( - params=types.ProgressNotificationParams(progress_token=1, progress=0.5), - ), - related_request_id=4, - ) - out = await from_client.receive() - assert isinstance(out.metadata, ServerMessageMetadata) - assert out.metadata.related_request_id == 4 - - -@pytest.mark.anyio -async def test_send_notification_with_related_request_id_zero_attaches_metadata(): - """`related_request_id=0` still attaches metadata: 0 is a valid request id, so the session checks - `is not None`, not truthiness (regression pin). Wire-level: only the sent `SessionMessage` shows it.""" - async with raw_client_session() as (session, _to_client, from_client): - await session.send_notification( - types.ProgressNotification( - params=types.ProgressNotificationParams(progress_token=1, progress=0.5), - ), - related_request_id=0, - ) - out = await from_client.receive() - assert isinstance(out.metadata, ServerMessageMetadata) - assert out.metadata.related_request_id == 0 From 48f2b012ea6a7a1dfcc287dd927e8800aee6d388 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 14:23:05 +0000 Subject: [PATCH 17/24] Make ClientRequestContext a concrete class RequestContext[ClientSession] was the only instantiation of the generic left in the tree (the server seat has ServerRequestContext), so the public ClientRequestContext alias becomes the real dataclass and the private mcp.shared._context module is deleted. request_id is now always populated: the client only builds a context for inbound requests, and ping is answered before any context exists. --- docs/migration.md | 5 +- src/mcp/client/context.py | 15 +----- src/mcp/client/session.py | 52 ++++++++++++------- src/mcp/shared/_context.py | 23 -------- tests/client/test_list_roots_callback.py | 5 +- tests/client/test_sampling_callback.py | 7 ++- tests/client/test_session.py | 8 +-- tests/server/mcpserver/test_elicitation.py | 28 +++++----- tests/server/mcpserver/test_integration.py | 10 ++-- .../server/mcpserver/test_url_elicitation.py | 25 +++++---- tests/shared/test_streamable_http.py | 4 +- 11 files changed, 75 insertions(+), 107 deletions(-) delete mode 100644 src/mcp/shared/_context.py diff --git a/docs/migration.md b/docs/migration.md index 509a381a5..1b17fc5e3 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -634,11 +634,11 @@ server = Server("my-server", on_call_tool=handle_call_tool) The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`). -The `RequestContext` class has been split to separate shared fields from server-specific fields. The shared `RequestContext` now only takes 1 type parameter (the session type) instead of 3. +The split separates shared fields from server-specific fields. There is no shared `RequestContext` generic anymore — each concrete class fixes its session type. **`RequestContext` changes:** -- Type parameters reduced from `RequestContext[SessionT, LifespanContextT, RequestT]` to `RequestContext[SessionT]` +- The `RequestContext[SessionT, LifespanContextT, RequestT]` generic no longer exists; use `ClientRequestContext` or `ServerRequestContext[LifespanContextT, RequestT]` - Server-specific fields (`lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) moved to new `ServerRequestContext` class in `mcp.server.context` **Before (v1):** @@ -1188,6 +1188,7 @@ Behavior changes: - **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending. - **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. `send_notification` before entry still works. - **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** The hint was never serialized by any client transport in v1 or v2 — it exists for the server's streamable-HTTP stream routing. Progress and response correlation via `progressToken` and the request id is unaffected. +- **The private `mcp.shared._context.RequestContext` generic is deleted.** Client callbacks now receive the concrete `mcp.client.ClientRequestContext`, whose `request_id` is always populated (the client only builds a context for inbound requests). Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`. `mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists. diff --git a/src/mcp/client/context.py b/src/mcp/client/context.py index 2f4404e00..aecd29527 100644 --- a/src/mcp/client/context.py +++ b/src/mcp/client/context.py @@ -1,16 +1,5 @@ """Request context for MCP client handlers.""" -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext +from mcp.client.session import ClientRequestContext -ClientRequestContext = RequestContext[ClientSession] -"""Context for handling incoming requests in a client session. - -This context is passed to client-side callbacks (sampling, elicitation, list_roots) when the server sends requests -to the client. - -Attributes: - request_id: The unique identifier for this request. - meta: Optional metadata associated with the request. - session: The client session handling this request. -""" +__all__ = ["ClientRequestContext"] diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 796a56ea6..7392df525 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -2,6 +2,7 @@ import logging from collections.abc import Mapping +from dataclasses import dataclass from types import TracebackType from typing import Any, Protocol, cast, get_args @@ -14,7 +15,6 @@ from mcp import types from mcp.client._transport import ReadStream, WriteStream from mcp.shared._compat import resync_tracer -from mcp.shared._context import RequestContext from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher @@ -22,7 +22,7 @@ from mcp.shared.session import ProgressFnT, RequestResponder from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS -from mcp.types._types import RequestParamsMeta +from mcp.types import RequestId, RequestParamsMeta DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -31,10 +31,19 @@ ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) +@dataclass(kw_only=True) +class ClientRequestContext: + """Context for a server-initiated request, passed to the sampling/elicitation/list-roots callbacks.""" + + session: ClientSession + request_id: RequestId + meta: RequestParamsMeta | None = None + + class SamplingFnT(Protocol): async def __call__( self, - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch @@ -42,14 +51,14 @@ async def __call__( class ElicitationFnT(Protocol): async def __call__( self, - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch class ListRootsFnT(Protocol): async def __call__( - self, context: RequestContext[ClientSession] + self, context: ClientRequestContext ) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch @@ -71,7 +80,7 @@ async def _default_message_handler( async def _default_sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: return types.ErrorData( @@ -81,7 +90,7 @@ async def _default_sampling_callback( async def _default_elicitation_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.ElicitRequestParams, ) -> types.ElicitResult | types.ErrorData: return types.ErrorData( @@ -91,7 +100,7 @@ async def _default_elicitation_callback( async def _default_list_roots_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, ) -> types.ListRootsResult | types.ErrorData: return types.ErrorData( code=types.INVALID_REQUEST, @@ -496,19 +505,22 @@ async def _on_request( payload["params"] = dict(params) request = types.server_request_adapter.validate_python(payload, by_name=False) - ctx = RequestContext[ClientSession]( - request_id=dctx.request_id, meta=request.params.meta if request.params else None, session=self - ) response: types.ClientResult | types.ErrorData - match request: - case types.CreateMessageRequest(params=sampling_params): - response = await self._sampling_callback(ctx, sampling_params) - case types.ElicitRequest(params=elicit_params): - response = await self._elicitation_callback(ctx, elicit_params) - case types.ListRootsRequest(): - response = await self._list_roots_callback(ctx) - case types.PingRequest(): # pragma: no branch - response = types.EmptyResult() + if isinstance(request, types.PingRequest): + # Answered without a context: direct dispatch carries no request id. + response = types.EmptyResult() + else: + assert dctx.request_id is not None # the callback-driving dispatchers always assign ids + ctx = ClientRequestContext( + session=self, request_id=dctx.request_id, meta=request.params.meta if request.params else None + ) + match request: + case types.CreateMessageRequest(params=sampling_params): + response = await self._sampling_callback(ctx, sampling_params) + case types.ElicitRequest(params=elicit_params): + response = await self._elicitation_callback(ctx, elicit_params) + case types.ListRootsRequest(): # pragma: no branch + response = await self._list_roots_callback(ctx) client_response = ClientResponse.validate_python(response) if isinstance(client_response, types.ErrorData): raise MCPError.from_error_data(client_response) diff --git a/src/mcp/shared/_context.py b/src/mcp/shared/_context.py deleted file mode 100644 index 8ad4ca918..000000000 --- a/src/mcp/shared/_context.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Request context for MCP client handlers.""" - -from dataclasses import dataclass -from typing import Any, Generic - -from typing_extensions import TypeVar - -from mcp.types import RequestId, RequestParamsMeta - -SessionT = TypeVar("SessionT", default=Any) - - -@dataclass(kw_only=True) -class RequestContext(Generic[SessionT]): - """Common context for handling incoming requests. - - For request handlers, request_id is always populated. - For notification handlers, request_id is None. - """ - - session: SessionT - request_id: RequestId | None = None - meta: RequestParamsMeta | None = None diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index be4b9a97b..a26ef45b2 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -2,9 +2,8 @@ from pydantic import FileUrl from mcp import Client -from mcp.client.session import ClientSession +from mcp.client import ClientRequestContext from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ListRootsResult, Root, TextContent @@ -20,7 +19,7 @@ async def test_list_roots_callback(): ) async def list_roots_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, ) -> ListRootsResult: return callback_return diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 6efcac0a5..2b90b00af 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -1,9 +1,8 @@ import pytest from mcp import Client -from mcp.client.session import ClientSession +from mcp.client import ClientRequestContext from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ( CreateMessageRequestParams, CreateMessageResult, @@ -26,7 +25,7 @@ async def test_sampling_callback(): ) async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return @@ -71,7 +70,7 @@ async def test_create_message_backwards_compat_single_content(): ) async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: CreateMessageRequestParams, ) -> CreateMessageResult: return callback_return diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 8a871e0aa..5e64a9250 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -10,8 +10,8 @@ import pytest from mcp import types +from mcp.client import ClientRequestContext from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession -from mcp.shared._context import RequestContext from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair from mcp.shared.dispatcher import CallOptions, DispatchContext, OnNotify, OnRequest from mcp.shared.message import SessionMessage @@ -425,7 +425,7 @@ async def test_client_capabilities_with_custom_callbacks(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( @@ -435,7 +435,7 @@ async def custom_sampling_callback( # pragma: no cover ) async def custom_list_roots_callback( # pragma: no cover - context: RequestContext[ClientSession], + context: ClientRequestContext, ) -> types.ListRootsResult | types.ErrorData: return types.ListRootsResult(roots=[]) @@ -509,7 +509,7 @@ async def test_client_capabilities_with_sampling_tools(): received_capabilities = None async def custom_sampling_callback( # pragma: no cover - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult | types.ErrorData: return types.CreateMessageResult( diff --git a/tests/server/mcpserver/test_elicitation.py b/tests/server/mcpserver/test_elicitation.py index 9292586b3..26908ed16 100644 --- a/tests/server/mcpserver/test_elicitation.py +++ b/tests/server/mcpserver/test_elicitation.py @@ -6,9 +6,9 @@ from pydantic import BaseModel, Field from mcp import Client, types -from mcp.client.session import ClientSession, ElicitationFnT +from mcp.client import ClientRequestContext +from mcp.client.session import ElicitationFnT from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -64,7 +64,7 @@ async def test_elicitation_accept_returns_the_users_answer_to_the_tool(): create_ask_user_tool(mcp) # Create a custom handler for elicitation requests - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): if params.message == "Tool wants to ask: What is your name?": return ElicitResult(action="accept", content={"answer": "Test User"}) else: # pragma: no cover @@ -81,7 +81,7 @@ async def test_elicitation_decline_reaches_the_tool_without_content(): mcp = MCPServer(name="ElicitationDeclineServer") create_ask_user_tool(mcp) - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="decline") await call_tool_and_assert( @@ -119,9 +119,7 @@ class InvalidNestedSchema(BaseModel): create_validation_tool("nested_model", InvalidNestedSchema) # Dummy callback (won't be called due to validation failure) - async def elicitation_callback( - context: RequestContext[ClientSession], params: ElicitRequestParams - ): # pragma: no cover + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover return ElicitResult(action="accept", content={}) async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -176,7 +174,7 @@ async def optional_tool(ctx: Context) -> str: for content, expected in test_cases: - async def callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept", content=content) await call_tool_and_assert(mcp, callback, "optional_tool", {}, expected) @@ -194,9 +192,7 @@ async def invalid_optional_tool(ctx: Context) -> str: except TypeError as e: return f"Validation failed: {str(e)}" - async def elicitation_callback( - context: RequestContext[ClientSession], params: ElicitRequestParams - ): # pragma: no cover + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # pragma: no cover return ElicitResult(action="accept", content={}) await call_tool_and_assert( @@ -219,7 +215,7 @@ async def valid_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {', '.join(result.data.tags)}" return f"User {result.action}" # pragma: no cover - async def multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def multiselect_callback(context: ClientRequestContext, params: ElicitRequestParams): if "Please provide tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -239,7 +235,7 @@ async def optional_multiselect_tool(ctx: Context) -> str: return f"Name: {result.data.name}, Tags: {tags_str}" return f"User {result.action}" # pragma: no cover - async def optional_multiselect_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def optional_multiselect_callback(context: ClientRequestContext, params: ElicitRequestParams): if "Please provide optional tags" in params.message: return ElicitResult(action="accept", content={"name": "Test", "tags": ["tag1", "tag2"]}) return ElicitResult(action="decline") # pragma: no cover @@ -273,7 +269,7 @@ async def defaults_tool(ctx: Context) -> str: return f"User {result.action}" # First verify that defaults are present in the JSON schema sent to clients - async def callback_schema_verify(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback_schema_verify(context: ClientRequestContext, params: ElicitRequestParams): # Verify the schema includes defaults assert isinstance(params, types.ElicitRequestFormParams), "Expected form mode elicitation" schema = params.requested_schema @@ -295,7 +291,7 @@ async def callback_schema_verify(context: RequestContext[ClientSession], params: ) # Test overriding defaults - async def callback_override(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def callback_override(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult( action="accept", content={"email": "john@example.com", "name": "John", "age": 25, "subscribe": False} ) @@ -371,7 +367,7 @@ async def select_color_legacy(ctx: Context) -> str: return f"User: {result.data.user_name}, Color: {result.data.color}" return f"User {result.action}" # pragma: no cover - async def enum_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def enum_callback(context: ClientRequestContext, params: ElicitRequestParams): if "colors" in params.message and "legacy" not in params.message: return ElicitResult(action="accept", content={"user_name": "Bob", "favorite_colors": ["red", "green"]}) elif "color" in params.message: diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index f71c0574c..5bac39dfe 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -26,9 +26,7 @@ structured_output, tool_progress, ) -from mcp.client import Client -from mcp.client.session import ClientSession -from mcp.shared._context import RequestContext +from mcp.client import Client, ClientRequestContext from mcp.shared.session import RequestResponder from mcp.types import ( ClientResult, @@ -80,9 +78,7 @@ async def handle_generic_notification( self.tool_notifications.append(message.params) -async def sampling_callback( - context: RequestContext[ClientSession], params: CreateMessageRequestParams -) -> CreateMessageResult: +async def sampling_callback(context: ClientRequestContext, params: CreateMessageRequestParams) -> CreateMessageResult: """Sampling callback for tests.""" return CreateMessageResult( role="assistant", @@ -94,7 +90,7 @@ async def sampling_callback( ) -async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): +async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): """Elicitation callback for tests.""" # For restaurant booking test if "No tables available" in params.message: diff --git a/tests/server/mcpserver/test_url_elicitation.py b/tests/server/mcpserver/test_url_elicitation.py index af90dc208..9ab03fcda 100644 --- a/tests/server/mcpserver/test_url_elicitation.py +++ b/tests/server/mcpserver/test_url_elicitation.py @@ -5,10 +5,9 @@ from pydantic import BaseModel, Field from mcp import Client, types -from mcp.client.session import ClientSession +from mcp.client import ClientRequestContext from mcp.server.elicitation import CancelledElicitation, DeclinedElicitation, elicit_url from mcp.server.mcpserver import Context, MCPServer -from mcp.shared._context import RequestContext from mcp.types import ElicitRequestParams, ElicitResult, TextContent @@ -28,7 +27,7 @@ async def request_api_key(ctx: Context) -> str: return f"User {result.action}" # Create elicitation callback that accepts URL mode - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" assert params.url == "https://example.com/api_key_setup" assert params.elicitation_id == "test-elicitation-001" @@ -57,7 +56,7 @@ async def oauth_flow(ctx: Context) -> str: # Test only checks decline path return f"User {result.action} authorization" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="decline") @@ -83,7 +82,7 @@ async def payment_flow(ctx: Context) -> str: # Test only checks cancel path return f"User {result.action} payment" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" return ElicitResult(action="cancel") @@ -110,7 +109,7 @@ async def setup_credentials(ctx: Context) -> str: # Test only checks accept path - return the type name return type(result).__name__ - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept") async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -137,7 +136,7 @@ async def check_url_response(ctx: Context) -> str: assert result.content is None return f"Action: {result.action}, Content: {result.content}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # Verify that this is URL mode assert params.mode == "url" assert isinstance(params, types.ElicitRequestURLParams) @@ -170,7 +169,7 @@ async def ask_name(ctx: Context) -> str: assert result.data is not None return f"Hello, {result.data.name}!" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # Verify form mode parameters assert params.mode == "form" assert isinstance(params, types.ElicitRequestFormParams) @@ -206,7 +205,7 @@ async def trigger_elicitation(ctx: Context) -> str: return "Elicitation completed" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="accept") # pragma: no cover async with Client(mcp, elicitation_callback=elicitation_callback) as client: @@ -263,7 +262,7 @@ async def test_cancel(ctx: Context) -> str: return "Not cancelled" # pragma: no cover # Test declined result - async def decline_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def decline_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="decline") async with Client(mcp, elicitation_callback=decline_callback) as client: @@ -273,7 +272,7 @@ async def decline_callback(context: RequestContext[ClientSession], params: Elici assert result.content[0].text == "Declined" # Test cancelled result - async def cancel_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def cancel_callback(context: ClientRequestContext, params: ElicitRequestParams): return ElicitResult(action="cancel") async with Client(mcp, elicitation_callback=cancel_callback) as client: @@ -303,7 +302,7 @@ async def use_deprecated_elicit(ctx: Context) -> str: return f"Email: {result.content.get('email', 'none')}" return "No email provided" # pragma: no cover - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): # Verify this is form mode assert params.mode == "form" assert params.requested_schema is not None @@ -331,7 +330,7 @@ async def direct_elicit_url(ctx: Context) -> str: ) return f"Result: {result.action}" - async def elicitation_callback(context: RequestContext[ClientSession], params: ElicitRequestParams): + async def elicitation_callback(context: ClientRequestContext, params: ElicitRequestParams): assert params.mode == "url" assert params.elicitation_id == "ctx-test-001" return ElicitResult(action="accept") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3613df18f..02976656e 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -26,6 +26,7 @@ from starlette.types import Message, Scope from mcp import MCPError, types +from mcp.client import ClientRequestContext from mcp.client.session import ClientSession from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server, ServerRequestContext @@ -44,7 +45,6 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._compat import resync_tracer -from mcp.shared._context import RequestContext from mcp.shared._context_streams import create_context_streams from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder @@ -1235,7 +1235,7 @@ async def test_streamablehttp_server_sampling(basic_app: Starlette) -> None: # Define sampling callback that returns a mock response async def sampling_callback( - context: RequestContext[ClientSession], + context: ClientRequestContext, params: types.CreateMessageRequestParams, ) -> types.CreateMessageResult: nonlocal sampling_callback_invoked, captured_message_params From 421e65ba13c12c109315de00b3631953f0addce0 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:15:17 +0000 Subject: [PATCH 18/24] Synthesize request ids in DirectDispatcher Server-initiated sampling/elicitation/roots requests over a ClientSession built with dispatcher=DirectDispatcher failed before the callback ran: the session requires a populated request id and direct dispatch carried none. DirectDispatcher now assigns per-instance monotonic ids to inbound requests (notifications keep None, which is how middleware distinguishes them). Adds a non-ping direct-dispatch test and bounds the indefinite awaits in the existing dispatcher= tests. --- src/mcp/client/session.py | 2 +- src/mcp/shared/direct_dispatcher.py | 12 ++++-- tests/client/test_session.py | 60 +++++++++++++++++++++++------ tests/shared/test_dispatcher.py | 6 ++- 4 files changed, 63 insertions(+), 17 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7392df525..e7dd1291a 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -507,7 +507,7 @@ async def _on_request( response: types.ClientResult | types.ErrorData if isinstance(request, types.PingRequest): - # Answered without a context: direct dispatch carries no request id. + # Answered without a context: ping has no callback that would need one. response = types.EmptyResult() else: assert dctx.request_id is not None # the callback-driving dispatchers always assign ids diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py index 5b3d29c8d..6bba74987 100644 --- a/src/mcp/shared/direct_dispatcher.py +++ b/src/mcp/shared/direct_dispatcher.py @@ -50,7 +50,7 @@ class _DirectDispatchContext: _back_request: _Request _back_notify: _Notify request_id: RequestId | None = None - """Always `None`: direct dispatch has no wire-level request id.""" + """A dispatcher-synthesized id for requests; `None` for notifications.""" message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework """Always `None`: in-memory dispatch attaches no transport metadata.""" _on_progress: ProgressFnT | None = None @@ -91,6 +91,7 @@ def __init__(self, transport_ctx: TransportContext): self._peer: DirectDispatcher | None = None self._on_request: OnRequest | None = None self._on_notify: OnNotify | None = None + self._next_id = 0 self._ready = anyio.Event() self._closed = anyio.Event() @@ -128,13 +129,16 @@ async def run( def close(self) -> None: self._closed.set() - def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + def _make_context( + self, on_progress: ProgressFnT | None = None, request_id: RequestId | None = None + ) -> _DirectDispatchContext: assert self._peer is not None peer = self._peer return _DirectDispatchContext( transport=self._transport_ctx, _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), _back_notify=lambda m, p: peer._dispatch_notify(m, p), + request_id=request_id, _on_progress=on_progress, ) @@ -147,7 +151,9 @@ async def _dispatch_request( await self._ready.wait() assert self._on_request is not None opts = opts or {} - dctx = self._make_context(on_progress=opts.get("on_progress")) + # Synthesize an id: the DispatchContext contract reserves None for notifications. + self._next_id += 1 + dctx = self._make_context(on_progress=opts.get("on_progress"), request_id=self._next_id) try: with anyio.fail_after(opts.get("timeout")): try: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 5e64a9250..48ef5bab7 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -8,6 +8,7 @@ import anyio.abc import anyio.streams.memory import pytest +from pydantic import FileUrl from mcp import types from mcp.client import ClientRequestContext @@ -972,18 +973,54 @@ async def server_on_notify( session = ClientSession(dispatcher=client_side) results: list[types.EmptyResult] = [] - async with anyio.create_task_group() as tg: - await tg.start(server_side.run, server_on_request, server_on_notify) - async with session: - results.append(await session.send_ping(meta=None)) - # Server-to-client: direct dispatch delivers ping with no params member (no _meta injection). - assert await server_side.send_raw_request("ping", None) == {} - await session.send_notification(types.RootsListChangedNotification()) - server_side.close() + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, server_on_request, server_on_notify) + async with session: + results.append(await session.send_ping(meta=None)) + # Server-to-client: direct dispatch delivers ping with no params member (no _meta injection). + assert await server_side.send_raw_request("ping", None) == {} + await session.send_notification(types.RootsListChangedNotification()) + server_side.close() assert results == [types.EmptyResult()] assert notified == ["notifications/roots/list_changed"] +@pytest.mark.anyio +async def test_direct_dispatch_roots_list_reaches_callback_with_synthesized_request_id(): + """A server-initiated roots/list over dispatcher= reaches the registered callback and round-trips + the result; the callback context carries an int request_id (SDK-defined: DirectDispatcher + synthesizes ids).""" + client_side, server_side = create_direct_dispatcher_pair() + contexts: list[ClientRequestContext] = [] + + async def list_roots(context: ClientRequestContext) -> types.ListRootsResult: + contexts.append(context) + return types.ListRootsResult(roots=[types.Root(uri=FileUrl("file:///workspace"))]) + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> dict[str, object]: + raise NotImplementedError + + async def server_on_notify( + ctx: DispatchContext[TransportContext], method: str, params: dict[str, object] | None + ) -> None: + raise NotImplementedError + + session = ClientSession(dispatcher=client_side, list_roots_callback=list_roots) + result: dict[str, Any] | None = None + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + await tg.start(server_side.run, server_on_request, server_on_notify) + async with session: + result = await server_side.send_raw_request("roots/list", None) + server_side.close() + assert result == {"roots": [{"uri": "file:///workspace"}]} + assert len(contexts) == 1 + assert isinstance(contexts[0].request_id, int) + + @pytest.mark.anyio async def test_initialize_opts_out_of_cancel_on_abandon_while_other_requests_leave_it_unset(): """`send_request` passes `cancel_on_abandon=False` for `initialize` — the spec forbids @@ -1021,9 +1058,10 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: pass dispatcher = RecordingDispatcher() - async with ClientSession(dispatcher=dispatcher) as session: - await session.initialize() - await session.send_ping() + with anyio.fail_after(5): + async with ClientSession(dispatcher=dispatcher) as session: + await session.initialize() + await session.send_ping() opts_by_method = dict(dispatcher.calls) assert opts_by_method["initialize"].get("cancel_on_abandon") is False assert "cancel_on_abandon" not in opts_by_method["ping"] diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py index 745f4b387..01150a21c 100644 --- a/tests/shared/test_dispatcher.py +++ b/tests/shared/test_dispatcher.py @@ -228,13 +228,15 @@ async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair @pytest.mark.anyio async def test_ctx_request_id_exposes_inbound_id(pair_factory: PairFactory): - """JSON-RPC carries the wire id through; direct dispatch has none.""" + """Every dispatcher assigns each inbound request a distinct int id; JSON-RPC carries + the wire id through, DirectDispatcher synthesizes one (SDK-defined).""" async with running_pair(pair_factory) as (client, _server, _crec, srec): with anyio.fail_after(5): await client.send_raw_request("tools/call", None) await client.send_raw_request("tools/call", None) a, b = (ctx.request_id for ctx in srec.contexts) - assert (a is None and b is None) or (isinstance(a, int) and isinstance(b, int) and a != b) + assert isinstance(a, int) and isinstance(b, int) + assert a != b @pytest.mark.anyio From 2d33ade749b781e220eb6d1f4ba7c2784a36d803 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:15:27 +0000 Subject: [PATCH 19/24] Raise CONNECTION_CLOSED for requests sent after the connection closed send_raw_request raised RuntimeError both before run() and after the transport closed, contradicting the documented contract that connection loss surfaces as MCPError(CONNECTION_CLOSED). A closed flag now separates the two states: RuntimeError remains for use before run(), and a request after EOF gets the same CONNECTION_CLOSED error the in-flight waiters receive. --- docs/migration.md | 2 +- src/mcp/shared/jsonrpc_dispatcher.py | 10 ++++++++-- tests/shared/test_jsonrpc_dispatcher.py | 20 ++++++++++++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 1b17fc5e3..49dee8f96 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1186,7 +1186,7 @@ Behavior changes: - **Unknown-id responses are ignored**, as the spec asks. v1 surfaced them to `message_handler` as a `RuntimeError`; nothing is surfaced now. - **Error responses with a null `id`** — the JSON-RPC shape for a peer reporting a parse error — are now dropped with a debug log. v1 surfaced them to `message_handler` as an `MCPError`. - **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending. -- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. `send_notification` before entry still works. +- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. After the connection has closed, `send_request` instead raises `MCPError` (`CONNECTION_CLOSED`), matching what an in-flight request receives — `RuntimeError` remains only for calls before entry. `send_notification` before entry still works. - **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** The hint was never serialized by any client transport in v1 or v2 — it exists for the server's streamable-HTTP stream routing. Progress and response correlation via `progressToken` and the request id is unaffected. - **The private `mcp.shared._context.RequestContext` generic is deleted.** Client callbacks now receive the concrete `mcp.client.ClientRequestContext`, whose `request_id` is always populated (the client only builds a context for inbound requests). Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`. diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 00eb50108..2ca08954f 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -252,6 +252,7 @@ def __init__( self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} self._tg: anyio.abc.TaskGroup | None = None self._running = False + self._closed = False async def send_raw_request( self, @@ -270,10 +271,13 @@ async def send_raw_request( MCPError: Peer error response; `REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; `CONNECTION_CLOSED` if the transport closed or the dispatcher shut down. - RuntimeError: Called outside `run()`. + RuntimeError: Called before `run()`. """ + # Post-close sends get the same CONNECTION_CLOSED contract as in-flight waiters. + if self._closed: + raise MCPError(code=CONNECTION_CLOSED, message="Connection closed") if not self._running: - raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run()") opts = opts or {} request_id = self._allocate_id() out_params = dict(params) if params is not None else {} @@ -399,6 +403,7 @@ async def run( logger.debug("read stream closed by transport; treating as EOF") # EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED. self._running = False + self._closed = True self._fan_out_closed() finally: # Cancel in-flight handlers; otherwise the task-group join @@ -407,6 +412,7 @@ async def run( finally: # Covers cancel/crash paths that skip the inline fan-out; idempotent. self._running = False + self._closed = True self._tg = None self._fan_out_closed() diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index e14e51d93..f6b51dd5c 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -1221,6 +1221,26 @@ async def test_send_raw_request_before_run_raises_runtimeerror(): s.close() +@pytest.mark.anyio +async def test_send_raw_request_after_connection_close_raises_connection_closed(): + """Sending after run() saw EOF raises MCPError(CONNECTION_CLOSED) — the same contract + in-flight waiters get — not RuntimeError (SDK-defined).""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + s2c_send.close() # peer drops: run() sees immediate EOF and returns + with anyio.fail_after(5): + await client.run(on_request, on_notify) + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + @pytest.mark.anyio async def test_transport_exception_in_read_stream_is_logged_and_dropped(): c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) From 1703d426649ac754be4ab1cc90b5a6ad697e440f Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:15:37 +0000 Subject: [PATCH 20/24] Make the late-response-ignored pin falsifiable The previous canary (a raising message_handler) could never fire: correct code drops unknown-id responses before any handler, and regressed code's delivery paths contain handler exceptions. Collect surfaced messages instead and assert only a control notification arrives; verified against a simulation of the v1 surface-as-RuntimeError behavior. --- .../interaction/lowlevel/test_cancellation.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/interaction/lowlevel/test_cancellation.py b/tests/interaction/lowlevel/test_cancellation.py index 4fab2f650..6e6c2b6f6 100644 --- a/tests/interaction/lowlevel/test_cancellation.py +++ b/tests/interaction/lowlevel/test_cancellation.py @@ -217,6 +217,11 @@ async def test_a_response_for_an_unknown_request_id_is_ignored() -> None: that is the same client-side code path as any response with an unknown id, and that form is deterministic to test without a client-side cancellation API. + "Ignored" is proved in two halves: the pong round-trip proves the read loop survived the + fabricated response (the ordered in-memory stream routed it first), and `surfaced` holding + only the control notification proves the fabricated response was never delivered to + `message_handler` (v1 surfaced it there as a RuntimeError). + A real Server cannot be made to answer with a fabricated id, so the test plays the server's side of the wire by hand. Reserve this pattern for behaviour no real server can produce. The other tests in this file run over the transport matrix; this one is in-memory only because the @@ -261,12 +266,18 @@ def respond(request_id: types.RequestId, result: types.Result) -> SessionMessage assert isinstance(ping, SessionMessage) assert isinstance(ping.message, JSONRPCRequest) assert ping.message.method == "ping" - # First answer with a fabricated id that matches nothing in flight, then the real id. + # First a fabricated id that matches nothing in flight, then a control notification that + # is surfaced to message_handler (proving the handler is live), then the real id. await server_write.send(respond(9999, EmptyResult())) + await server_write.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) + ) await server_write.send(respond(ping.message.id, EmptyResult())) + surfaced: list[IncomingMessage] = [] + async def message_handler(message: IncomingMessage) -> None: - raise NotImplementedError # unreachable: nothing is surfaced for an unknown-id response + surfaced.append(message) async with ( create_client_server_memory_streams() as ((client_read, client_write), server_streams), @@ -279,6 +290,9 @@ async def message_handler(message: IncomingMessage) -> None: pong = await session.send_request(PingRequest(), EmptyResult) assert pong == snapshot(EmptyResult()) + # The stream is ordered, so the fabricated response was routed before the control + # notification: only the control surfaced, so the unknown-id response was dropped. + assert surfaced == snapshot([types.ToolListChangedNotification()]) @requirement("protocol:cancel:initialize-not-cancellable") From 16727881116be399f1e10ea1e25bba7dff61375c Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 15:52:01 +0000 Subject: [PATCH 21/24] Trim the dispatcher-swap migration notes Cut the ClientSession section by about a third: merged single-concern bullets (stray responses; the timeout exemptions), dropped rationale padding, reframed the shutdown bullet as new-versus-v1 (v1 never answered at shutdown), and removed the REQUEST_CANCELLED bullet entirely - the constant never existed in v1, so it has no place in a v1-to-v2 guide. --- docs/migration.md | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 49dee8f96..01fa83c66 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -634,8 +634,6 @@ server = Server("my-server", on_call_tool=handle_call_tool) The `mcp.shared.context` module has been removed. `RequestContext` is now split into `ClientRequestContext` (in `mcp.client.context`) and `ServerRequestContext` (in `mcp.server.context`). -The split separates shared fields from server-specific fields. There is no shared `RequestContext` generic anymore — each concrete class fixes its session type. - **`RequestContext` changes:** - The `RequestContext[SessionT, LifespanContextT, RequestT]` generic no longer exists; use `ClientRequestContext` or `ServerRequestContext[LifespanContextT, RequestT]` @@ -1168,29 +1166,24 @@ In practice, replace direct `ServerSession` use with `Server.run(read_stream, wr ### `ClientSession` now runs on `JSONRPCDispatcher`; `BaseSession` removed -`ClientSession` keeps its public surface — the `(read_stream, write_stream, ...)` constructor, every typed method, manual `initialize()`, and the async context-manager lifecycle — but the v1 receive loop (`BaseSession`) underneath it is gone. A new keyword-only `dispatcher=` constructor argument accepts a pre-built dispatcher instead of the stream pair (for example a `DirectDispatcher` for in-process embedding). - -Code that imported or subclassed `BaseSession` directly has no shim — the class is removed outright. The receive-loop engine it implemented now lives in `JSONRPCDispatcher` (`mcp.shared.jsonrpc_dispatcher`); to customize client behavior, use the `ClientSession` constructor callbacks, or supply your own engine through the `dispatcher=` keyword. +`ClientSession`'s public surface is unchanged — same constructor, typed methods, manual `initialize()`, and async context-manager lifecycle — but `BaseSession`, the v1 receive loop underneath it, is removed with no shim. The engine now lives in `JSONRPCDispatcher` (`mcp.shared.jsonrpc_dispatcher`). To customize client behavior, use the `ClientSession` constructor callbacks, or pass a pre-built dispatcher via the new keyword-only `dispatcher=` constructor argument (e.g. a `DirectDispatcher` for in-process embedding). Behavior changes: -- **Request ids count from 1** (previously 0). Progress tokens, which reuse the request id, shift the same way. Ids are opaque per JSON-RPC; do not assign meaning to them. -- **Timeouts**: the error message is now `Request 'tools/call' timed out` (previously `Timed out while waiting for response to CallToolRequest. Waited N seconds.`), and a timed-out or abandoned request is followed by `notifications/cancelled` on the wire, so the server stops the handler instead of leaving it running. The `initialize` request is never cancelled this way, and requests sent with resumption metadata are also exempt so they stay resumable. -- **No cancellation for requests that never reached the wire.** A timed-out or caller-cancelled request whose initial write never completed is failed locally without `notifications/cancelled` — the peer never saw the id, so there is nothing to cancel. -- **The resumption exemption applies only when the hints reach the transport.** A request sent from inside a request callback carries stream-routing metadata that takes precedence, so its resumption hints are dropped — and an abandoned one gets the courtesy `notifications/cancelled` like any other request. -- **Server-initiated requests run concurrently.** Sampling/elicitation/roots callbacks no longer serialize the receive loop: a slow callback does not block other traffic, a callback may itself send requests without deadlocking, and a server's `notifications/cancelled` now actually interrupts the callback (the request is then answered with an error response). -- **Session shutdown answers in-flight server-initiated requests with `CONNECTION_CLOSED`** (-32000, `Connection closed`) instead of -32002. The write is bounded (about one second), so closing a session stays fast even when the transport has stopped accepting writes. -- **The `REQUEST_CANCELLED` constant is removed from `mcp.types`.** Its value (-32002) collided with the spec's resource-not-found error code, and the shutdown response above was its only use. -- **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` start in arrival order, but each delivery runs as its own task with no completion-before-response guarantee (matching the TypeScript, C#, and Go SDKs): deliveries may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves. -- **Transport-level `Exception` items are delivered concurrently too.** An `Exception` the transport places on the read stream is dispatched to `message_handler` as its own task, like notification callbacks, instead of blocking the receive loop — and a `message_handler` that raises on it is logged, not fatal to the session. -- **Unknown-id responses are ignored**, as the spec asks. v1 surfaced them to `message_handler` as a `RuntimeError`; nothing is surfaced now. -- **Error responses with a null `id`** — the JSON-RPC shape for a peer reporting a parse error — are now dropped with a debug log. v1 surfaced them to `message_handler` as an `MCPError`. -- **A raising request callback** is answered with `code=0` and the exception text. v1 flattened every callback exception to `INVALID_PARAMS`. Callbacks that want a specific error response should return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: a callback that raises pydantic's `ValidationError` is still answered with `INVALID_PARAMS` (`"Invalid request parameters"`, empty `data`) because the dispatcher cannot distinguish it from inbound-params validation — this conflation is pre-existing v1 behavior, and a revisit is pending. -- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. After the connection has closed, `send_request` instead raises `MCPError` (`CONNECTION_CLOSED`), matching what an in-flight request receives — `RuntimeError` remains only for calls before entry. `send_notification` before entry still works. -- **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** The hint was never serialized by any client transport in v1 or v2 — it exists for the server's streamable-HTTP stream routing. Progress and response correlation via `progressToken` and the request id is unaffected. -- **The private `mcp.shared._context.RequestContext` generic is deleted.** Client callbacks now receive the concrete `mcp.client.ClientRequestContext`, whose `request_id` is always populated (the client only builds a context for inbound requests). Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`. - -`mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing — it has been unreachable at runtime since the server-side swap. `RequestResponder.respond()` no longer exists. +- **Request ids count from 1** (previously 0); progress tokens, which reuse the id, shift too. Ids are opaque per JSON-RPC — do not assign meaning to them. +- **Timeouts**: the error message is now `Request 'tools/call' timed out`, and a timed-out or abandoned request is followed by `notifications/cancelled` so the server stops the handler instead of leaving it running. Exempt: `initialize`, requests sent with resumption metadata (so they stay resumable), and requests whose initial write never completed (the peer never saw the id). +- **Resumption hints sent from inside a request callback are dropped** (stream-routing metadata takes precedence there), so those requests are cancelled like any other. +- **Server-initiated requests run concurrently.** A slow sampling/elicitation/roots callback no longer blocks other traffic, a callback may itself send requests without deadlocking, and a server's `notifications/cancelled` now interrupts the callback (the request is then answered with an error). +- **Session shutdown now answers in-flight server-initiated requests with `CONNECTION_CLOSED` (-32000)**; v1 left them unanswered. The write is bounded (~1s) so closing stays fast. +- **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` deliveries start in arrival order but each runs as its own task: they may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves. +- **Transport-level `Exception` items are delivered to `message_handler` the same way** — as their own task, without blocking the receive loop — and a `message_handler` that raises on one is logged, not fatal to the session. +- **Stray responses are no longer surfaced to `message_handler`.** Responses with an unknown id are ignored (as the spec asks; v1 surfaced a `RuntimeError`), and error responses with a null `id` — a peer reporting a parse error — are dropped with a debug log (v1 surfaced an `MCPError`). +- **A raising request callback** is answered with `code=0` and the exception text; v1 flattened every callback exception to `INVALID_PARAMS`. For a specific error response, return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: pydantic's `ValidationError` is still answered with `INVALID_PARAMS`, as in v1. +- **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. After the connection has closed it raises `MCPError` (`CONNECTION_CLOSED`) instead. `send_notification` before entry still works. +- **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** No client transport ever serialized these hints; progress and response correlation via `progressToken` and the request id is unaffected. +- **Client callbacks now receive `mcp.client.ClientRequestContext`** (its `request_id` is always populated); the private `mcp.shared._context.RequestContext` generic is deleted. Annotations spelled `RequestContext[ClientSession]` become `ClientRequestContext`. + +`mcp.shared.session` is now a compatibility module: `ProgressFnT` is re-exported (its home is `mcp.shared.dispatcher`), and `RequestResponder` remains as a typing-only stub so `MessageHandlerFnT` annotations keep importing. `RequestResponder.respond()` no longer exists. ### Experimental Tasks support removed From 47616ac9b578e4a721069de46cd98f7b24a71428 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:00:12 +0000 Subject: [PATCH 22/24] Close the write stream only after the task-group join run() entered both streams inside its own task group, so at teardown the write stream closed before in-flight handlers sent their final answers: the shutdown CONNECTION_CLOSED response was deterministically dropped on the EOF path and raced the close on the cancel path. The write stream's scope now wraps the task group, so scope exits order the join strictly before the close and teardown writes always land. The shutdown-delivery test becomes a real memory-stream pin, and the wedged-shutdown test's synthetic stream is replaced by a plain unread one. --- src/mcp/shared/jsonrpc_dispatcher.py | 52 ++++++++++++----------- tests/shared/test_jsonrpc_dispatcher.py | 55 +++++++++---------------- 2 files changed, 47 insertions(+), 60 deletions(-) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 2ca08954f..709111c7a 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -384,31 +384,33 @@ async def run( `task_status.started()` fires once `send_raw_request` is usable. """ try: - async with anyio.create_task_group() as tg: - self._tg = tg - self._running = True - task_status.started() - try: - async with self._read_stream, self._write_stream: - try: - async for item in self._read_stream: - # Duck-typed: only `ContextReceiveStream` carries the - # sender's per-message contextvars snapshot. - sender_ctx: contextvars.Context | None = getattr( - self._read_stream, "last_context", None - ) - await self._dispatch(item, on_request, on_notify, sender_ctx) - except anyio.ClosedResourceError: - # Receive end closed under us (stateless SHTTP teardown); same as EOF. - logger.debug("read stream closed by transport; treating as EOF") - # EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED. - self._running = False - self._closed = True - self._fan_out_closed() - finally: - # Cancel in-flight handlers; otherwise the task-group join - # waits on handlers whose callers are already gone. - tg.cancel_scope.cancel() + # LIFO exits: the write stream closes only after the task-group join, so teardown writes still land. + async with self._write_stream: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + try: + async with self._read_stream: + try: + async for item in self._read_stream: + # Duck-typed: only `ContextReceiveStream` carries the + # sender's per-message contextvars snapshot. + sender_ctx: contextvars.Context | None = getattr( + self._read_stream, "last_context", None + ) + await self._dispatch(item, on_request, on_notify, sender_ctx) + except anyio.ClosedResourceError: + # Receive end closed under us (stateless SHTTP teardown); same as EOF. + logger.debug("read stream closed by transport; treating as EOF") + # EOF: wake blocked `send_raw_request` waiters with CONNECTION_CLOSED. + self._running = False + self._closed = True + self._fan_out_closed() + finally: + # Cancel in-flight handlers; otherwise the task-group join + # waits on handlers whose callers are already gone. + tg.cancel_scope.cancel() finally: # Covers cancel/crash paths that skip the inline fan-out; idempotent. self._running = False diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index f6b51dd5c..027585dbe 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -306,7 +306,7 @@ async def drive() -> None: @pytest.mark.anyio async def test_run_closes_write_stream_on_exit(): - """run() enters both streams; the write end is released on EOF.""" + """run() owns both streams; the write end is released once the EOF teardown completes.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) @@ -819,29 +819,11 @@ async def test_shutdown_error_response_write_is_bounded_when_the_transport_is_we caplog: pytest.LogCaptureFixture, ): """Cancelling the task group hosting run() completes even when the shutdown error write wedges: - only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). The fake stream is needed - because run()'s teardown closes a memory stream, which would wake the blocked send.""" - - class WedgedWriteStream: - async def send(self, item: SessionMessage) -> None: - await anyio.sleep_forever() - - async def aclose(self) -> None: - raise NotImplementedError - - async def __aenter__(self) -> "WedgedWriteStream": - return self - - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> bool | None: - return None - + only `_SHUTDOWN_WRITE_TIMEOUT` releases the join (SDK-defined). A 0-buffer stream nobody reads + expresses the wedge: run() closes its write stream only after the join, so the send stays parked.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) - server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, WedgedWriteStream()) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -863,19 +845,19 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> await handler_started.wait() tg.cancel_scope.cancel() finally: - c2s_send.close() - c2s_recv.close() + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() # The warning proves the bound (not a completed write) released the join. assert "shutdown error response for request" in caplog.text @pytest.mark.anyio async def test_shutdown_answers_in_flight_request_with_connection_closed(): - """Cancelling run() answers a still-running request with CONNECTION_CLOSED (SDK-defined). The - recording stream is needed because run()'s exit would close a memory stream before the shielded write lands.""" + """Read-stream EOF answers a still-running request with CONNECTION_CLOSED (SDK-defined): + run() keeps the write stream open until the task-group join, so the shielded teardown write lands.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) - recording = RecordingWriteStream() - server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, recording) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) handler_started = anyio.Event() async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: @@ -892,13 +874,16 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) with anyio.fail_after(5): await handler_started.wait() - tg.cancel_scope.cancel() + c2s_send.close() # EOF: run() cancels the parked handler, which must still answer + with anyio.fail_after(5): + answer = await s2c_recv.receive() + assert isinstance(answer, SessionMessage) + assert answer.message == JSONRPCError( + jsonrpc="2.0", id=1, error=ErrorData(code=CONNECTION_CLOSED, message="Connection closed") + ) finally: - c2s_send.close() - c2s_recv.close() - assert [m.message for m in recording.sent] == [ - JSONRPCError(jsonrpc="2.0", id=1, error=ErrorData(code=CONNECTION_CLOSED, message="Connection closed")) - ] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() @pytest.mark.anyio From 04778b65bc48f0b6c5d1546709cce57d92a9bff8 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:00:22 +0000 Subject: [PATCH 23/24] Pin client request concurrency in both directions Two rendezvous tests: three concurrent outbound tool calls proven simultaneously in flight and resolved out of order to their own callers, and two overlapping server-initiated sampling requests serviced concurrently by the client's callbacks (the v1 receive loop serialized these). Also note in the migration guide that delivery concurrency is unbounded. --- docs/migration.md | 2 +- tests/client/test_session_concurrency.py | 141 +++++++++++++++++++++++ 2 files changed, 142 insertions(+), 1 deletion(-) create mode 100644 tests/client/test_session_concurrency.py diff --git a/docs/migration.md b/docs/migration.md index 01fa83c66..309a4feb2 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1175,7 +1175,7 @@ Behavior changes: - **Resumption hints sent from inside a request callback are dropped** (stream-routing metadata takes precedence there), so those requests are cancelled like any other. - **Server-initiated requests run concurrently.** A slow sampling/elicitation/roots callback no longer blocks other traffic, a callback may itself send requests without deadlocking, and a server's `notifications/cancelled` now interrupts the callback (the request is then answered with an error). - **Session shutdown now answers in-flight server-initiated requests with `CONNECTION_CLOSED` (-32000)**; v1 left them unanswered. The write is bounded (~1s) so closing stays fast. -- **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` deliveries start in arrival order but each runs as its own task: they may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves. +- **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` deliveries start in arrival order but each runs as its own task: they may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves, and there is no built-in bound on concurrent deliveries (v1's inline loop processed one message at a time). - **Transport-level `Exception` items are delivered to `message_handler` the same way** — as their own task, without blocking the receive loop — and a `message_handler` that raises on one is logged, not fatal to the session. - **Stray responses are no longer surfaced to `message_handler`.** Responses with an unknown id are ignored (as the spec asks; v1 surfaced a `RuntimeError`), and error responses with a null `id` — a peer reporting a parse error — are dropped with a debug log (v1 surfaced an `MCPError`). - **A raising request callback** is answered with `code=0` and the exception text; v1 flattened every callback exception to `INVALID_PARAMS`. For a specific error response, return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: pydantic's `ValidationError` is still answered with `INVALID_PARAMS`, as in v1. diff --git a/tests/client/test_session_concurrency.py b/tests/client/test_session_concurrency.py new file mode 100644 index 000000000..dc91bee25 --- /dev/null +++ b/tests/client/test_session_concurrency.py @@ -0,0 +1,141 @@ +"""Concurrency over a single client session: multiple requests in flight at once, in both directions.""" + +import anyio +import pytest +from inline_snapshot import snapshot + +from mcp import Client +from mcp.client import ClientRequestContext +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CallToolResult, + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) + +pytestmark = pytest.mark.anyio + + +async def test_concurrent_tool_calls_resolve_out_of_order_to_their_own_callers() -> None: + """Three tool calls in flight at once on one session each receive their own result, even though + the responses come back in the reverse of the order the requests were sent. + + SDK-defined contract: pins the client request machinery's support for concurrent in-flight + calls with out-of-order response correlation. Each handler parks on its own release event + after signalling it started; a session that serialized requests would never start the later + handlers and the test would time out instead. + """ + send_order = ["a", "b", "c"] + started = {tag: anyio.Event() for tag in send_order} + release = {tag: anyio.Event() for tag in send_order} + done = {tag: anyio.Event() for tag in send_order} + completion_order: list[str] = [] + results: dict[str, CallToolResult] = {} + + server = MCPServer("parking") + + @server.tool() + async def park(tag: str) -> str: + started[tag].set() + await release[tag].wait() + return f"result:{tag}" + + async with Client(server) as client: + + async def call_and_record(tag: str) -> None: + results[tag] = await client.call_tool("park", {"tag": tag}) + completion_order.append(tag) + done[tag].set() + + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + # Waiting for each handler to start before issuing the next call fixes the send + # order, and leaves all three parked in flight together once the loop finishes. + for tag in send_order: + task_group.start_soon(call_and_record, tag) + await started[tag].wait() + + # Nothing completed yet: all three calls are genuinely concurrent. + assert completion_order == [] + + # Release in reverse, awaiting each completion so the finish order is forced. + for tag in reversed(send_order): + release[tag].set() + await done[tag].wait() + + assert completion_order == ["c", "b", "a"] + assert results == snapshot( + { + "c": CallToolResult(content=[TextContent(text="result:c")], structured_content={"result": "result:c"}), + "b": CallToolResult(content=[TextContent(text="result:b")], structured_content={"result": "result:b"}), + "a": CallToolResult(content=[TextContent(text="result:a")], structured_content={"result": "result:a"}), + } + ) + + +async def test_overlapping_sampling_requests_are_serviced_concurrently_by_the_client() -> None: + """A server tool that fans out two sampling requests at once gets both echoes back: the client + runs overlapping inbound `create_message` requests concurrently instead of serializing them in + its receive loop. + + Regression pin for https://github.com/modelcontextprotocol/python-sdk/issues/2489 -- v1's + `BaseSession` awaited each inbound request handler inline, so the second sampling callback + could not start until the first returned; here both rendezvous before either is released. + """ + sampling_started = {"x": anyio.Event(), "y": anyio.Event()} + sampling_release = anyio.Event() + tool_results: list[CallToolResult] = [] + + server = MCPServer("fan_out_server") + + @server.tool() + async def fan_out(ctx: Context) -> str: + echoes: dict[str, str] = {} + + async def sample(tag: str) -> None: + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(text=tag))], + max_tokens=10, + ) + assert isinstance(result.content, TextContent) + echoes[tag] = result.content.text + + async with anyio.create_task_group() as sampler_group: + sampler_group.start_soon(sample, "x") + sampler_group.start_soon(sample, "y") + return f"{echoes['x']} {echoes['y']}" + + async def sampling_callback( + context: ClientRequestContext, params: CreateMessageRequestParams + ) -> CreateMessageResult: + content = params.messages[0].content + assert isinstance(content, TextContent) + sampling_started[content.text].set() + await sampling_release.wait() + return CreateMessageResult( + role="assistant", + content=TextContent(text=f"echo:{content.text}"), + model="test-model", + stop_reason="endTurn", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + with anyio.fail_after(5): + async with anyio.create_task_group() as task_group: # pragma: no branch + + async def invoke_fan_out() -> None: + tool_results.append(await client.call_tool("fan_out", {})) + + task_group.start_soon(invoke_fan_out) + + # Both sampling callbacks are mid-flight before either may answer -- a client that + # serialized inbound requests would never start the second one. + await sampling_started["x"].wait() + await sampling_started["y"].wait() + sampling_release.set() + + assert tool_results == snapshot( + [CallToolResult(content=[TextContent(text="echo:x echo:y")], structured_content={"result": "echo:x echo:y"})] + ) From 0a54692eae3e1e142ea53fa5abebefa327f15e20 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Fri, 12 Jun 2026 16:25:49 +0000 Subject: [PATCH 24/24] Tidy naming and doc nits from review - Rename _SHIELDED_WRITE_TIMEOUT to _ABANDON_WRITE_TIMEOUT; the timed-out arm passes it unshielded, so the old name lied there - Document that on_stream_exception is awaited inline in the read loop - Note in run()'s docstring that the dispatcher is single-shot - Import ProgressFnT from mcp.shared.dispatcher (its home) instead of the mcp.shared.session compat shim - Migration guide: scope the optional-fields section to ServerRequestContext, and correct the null-id bullet (v1 surfaced the transport's ValidationError, not an MCPError) --- docs/migration.md | 6 +++--- src/mcp/client/client.py | 2 +- src/mcp/client/session.py | 4 ++-- src/mcp/shared/jsonrpc_dispatcher.py | 14 ++++++++------ tests/shared/test_jsonrpc_dispatcher.py | 10 +++++----- 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 309a4feb2..f3c79f60d 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -1120,9 +1120,9 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar ) ``` -### `RequestContext`: request-specific fields are now optional +### `ServerRequestContext`: request-specific fields are now optional -The `RequestContext` class now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. +`ServerRequestContext` now uses optional fields for request-specific data (`request_id`, `meta`, etc.) so it can be used for both request and notification handlers. In notification handlers, these fields are `None`. ```python from mcp.server import ServerRequestContext @@ -1177,7 +1177,7 @@ Behavior changes: - **Session shutdown now answers in-flight server-initiated requests with `CONNECTION_CLOSED` (-32000)**; v1 left them unanswered. The write is bounded (~1s) so closing stays fast. - **Notification callbacks are concurrent.** `logging_callback`, `progress_callback`, and `message_handler` deliveries start in arrival order but each runs as its own task: they may interleave, and a `progress_callback` delivery may finish after the request it reports on has returned. Callbacks that need strict sequencing must coordinate themselves, and there is no built-in bound on concurrent deliveries (v1's inline loop processed one message at a time). - **Transport-level `Exception` items are delivered to `message_handler` the same way** — as their own task, without blocking the receive loop — and a `message_handler` that raises on one is logged, not fatal to the session. -- **Stray responses are no longer surfaced to `message_handler`.** Responses with an unknown id are ignored (as the spec asks; v1 surfaced a `RuntimeError`), and error responses with a null `id` — a peer reporting a parse error — are dropped with a debug log (v1 surfaced an `MCPError`). +- **Stray responses are no longer surfaced to `message_handler`.** Responses with an unknown id are ignored (as the spec asks; v1 surfaced a `RuntimeError`), and error responses with a null `id` — a peer reporting a parse error — are dropped with a debug log (v1 surfaced the transport's `ValidationError`). - **A raising request callback** is answered with `code=0` and the exception text; v1 flattened every callback exception to `INVALID_PARAMS`. For a specific error response, return `ErrorData` (unchanged) or raise `MCPError`. One carve-out: pydantic's `ValidationError` is still answered with `INVALID_PARAMS`, as in v1. - **`send_request` before entering the context manager** raises `RuntimeError` immediately; v1 wrote to the transport and hung until the timeout. After the connection has closed it raises `MCPError` (`CONNECTION_CLOSED`) instead. `send_notification` before entry still works. - **`send_notification` no longer takes `related_request_id`, and `send_request` no longer accepts `ServerMessageMetadata`.** No client transport ever serialized these hints; progress and response correlation via `progressToken` and the request id is unaffected. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index b33fea405..3868891f2 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -12,7 +12,7 @@ from mcp.client.streamable_http import streamable_http_client from mcp.server import Server from mcp.server.mcpserver import MCPServer -from mcp.shared.session import ProgressFnT +from mcp.shared.dispatcher import ProgressFnT from mcp.types import ( CallToolResult, CompleteResult, diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index e7dd1291a..6d7472c30 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -15,11 +15,11 @@ from mcp import types from mcp.client._transport import ReadStream, WriteStream from mcp.shared._compat import resync_tracer -from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher +from mcp.shared.dispatcher import CallOptions, DispatchContext, Dispatcher, ProgressFnT from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.message import ClientMessageMetadata, SessionMessage -from mcp.shared.session import ProgressFnT, RequestResponder +from mcp.shared.session import RequestResponder from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import RequestId, RequestParamsMeta diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index 709111c7a..96c51287f 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -51,9 +51,9 @@ logger = logging.getLogger(__name__) -_SHIELDED_WRITE_TIMEOUT: float = 5 -"""Bound for courtesy abandon-path writes; without it a wedged transport -would turn the shielded write into an uncancellable hang.""" +_ABANDON_WRITE_TIMEOUT: float = 5 +"""Bound for courtesy-cancel writes on the abandon paths; the caller-cancel +arm shields its write, so a wedged transport would otherwise hang it uncancellably.""" _SHUTDOWN_WRITE_TIMEOUT: float = 1 """Tighter bound for the shutdown-arm error write so a wedged transport can't hold session close.""" @@ -232,7 +232,8 @@ def __init__( message is dequeued (e.g. `initialize`); an inline handler that awaits the peer deadlocks the parked loop. on_stream_exception: Observer for `Exception` items on the read - stream; without it they are debug-logged and dropped. + stream; without it they are debug-logged and dropped. Awaited + inline in the read loop, so a slow observer stalls dispatch. """ self._read_stream = read_stream self._write_stream = write_stream @@ -332,7 +333,7 @@ async def send_raw_request( _related_request_id, ), shield=False, - timeout=_SHIELDED_WRITE_TIMEOUT, + timeout=_ABANDON_WRITE_TIMEOUT, describe=f"courtesy cancel for timed-out request {request_id!r}", ) raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None @@ -343,7 +344,7 @@ async def send_raw_request( await self._final_write( partial(self._cancel_outbound, request_id, "caller cancelled", _related_request_id), shield=True, - timeout=_SHIELDED_WRITE_TIMEOUT, + timeout=_ABANDON_WRITE_TIMEOUT, describe=f"courtesy cancel for caller-cancelled request {request_id!r}", ) raise @@ -382,6 +383,7 @@ async def run( """Drive the receive loop until the read stream closes. `task_status.started()` fires once `send_raw_request` is usable. + Single-shot: once the loop ends the dispatcher stays closed and cannot be restarted. """ try: # LIFO exits: the write stream closes only after the task-group join, so teardown writes still land. diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 027585dbe..6d9c2e5cd 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -727,7 +727,7 @@ async def test_caller_cancel_courtesy_write_is_bounded_when_the_transport_is_wed caplog: pytest.LogCaptureFixture, ): """A wedged transport write cannot turn caller cancellation into an unbounded shielded hang: - `_SHIELDED_WRITE_TIMEOUT` abandons the courtesy-cancel write (SDK-defined bound). On regression + `_ABANDON_WRITE_TIMEOUT` abandons the courtesy-cancel write (SDK-defined bound). On regression the test hangs rather than failing fast - fail_after cannot cancel through the shield.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) @@ -745,7 +745,7 @@ async def caller() -> None: gave_up.set() try: - # Both bounds exceed the in-loop _SHIELDED_WRITE_TIMEOUT (5s); the virtual clock makes them instant. + # Both bounds exceed the in-loop _ABANDON_WRITE_TIMEOUT (5s); the virtual clock makes them instant. with anyio.fail_after(30): async with anyio.create_task_group() as tg: # pragma: no branch await tg.start(client.run, on_request, on_notify) @@ -775,7 +775,7 @@ async def test_timeout_courtesy_cancel_write_is_bounded_when_the_transport_is_we caplog: pytest.LogCaptureFixture, ): """A wedged transport write cannot delay the REQUEST_TIMEOUT error indefinitely (SDK-defined - bound): `_SHIELDED_WRITE_TIMEOUT` abandons the courtesy cancel so the error still surfaces.""" + bound): `_ABANDON_WRITE_TIMEOUT` abandons the courtesy cancel so the error still surfaces.""" c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](0) client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) @@ -799,7 +799,7 @@ async def caller() -> None: request = await c2s_recv.receive() assert isinstance(request, SessionMessage) assert isinstance(request.message, JSONRPCRequest) - # Exceeds the request timeout (1s) plus _SHIELDED_WRITE_TIMEOUT (5s); virtual clock, no wall time. + # Exceeds the request timeout (1s) plus _ABANDON_WRITE_TIMEOUT (5s); virtual clock, no wall time. with anyio.fail_after(10): await gave_up.wait() tg.cancel_scope.cancel() @@ -835,7 +835,7 @@ async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> raise NotImplementedError try: - # 3s sits between _SHUTDOWN_WRITE_TIMEOUT (1s) and _SHIELDED_WRITE_TIMEOUT (5s): pins the tighter bound. + # 3s sits between _SHUTDOWN_WRITE_TIMEOUT (1s) and _ABANDON_WRITE_TIMEOUT (5s): pins the tighter bound. with anyio.fail_after(3): async with anyio.create_task_group() as tg: # pragma: no branch await tg.start(server.run, park, on_notify)