diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index d34e438fc9..341df0abb8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -5,7 +5,6 @@ on: branches: ["main", "v1.x"] tags: ["v*.*.*"] pull_request: - branches: ["main", "v1.x"] permissions: contents: read diff --git a/docs/migration.md b/docs/migration.md index 9850f74cd4..3ba27cf826 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -442,7 +442,7 @@ async def handle_set_logging_level(level: str) -> None: mcp._mcp_server.subscribe_resource()(handle_subscribe) # pyright: ignore[reportPrivateUsage] ``` -In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `_add_request_handler`: +In v2, the lowlevel `Server` no longer has decorator methods (handlers are constructor-only), so the equivalent workaround is `add_request_handler`: **After (v2):** @@ -461,11 +461,11 @@ async def handle_subscribe(ctx: ServerRequestContext, params: SubscribeRequestPa return EmptyResult() -mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server.add_request_handler("logging/setLevel", SetLevelRequestParams, handle_set_logging_level) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server.add_request_handler("resources/subscribe", SubscribeRequestParams, handle_subscribe) # pyright: ignore[reportPrivateUsage] ``` -This is a private API and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. +`_lowlevel_server` is private and may change. A public way to register these handlers on `MCPServer` is planned; until then, use this workaround or use the lowlevel `Server` directly. ### `MCPServer`'s `Context` logging: `message` renamed to `data`, `extra` removed @@ -620,6 +620,8 @@ ctx: ClientRequestContext server_ctx: ServerRequestContext[LifespanContextT, RequestT] ``` +`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`. + The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: **Before (v1):** @@ -813,6 +815,55 @@ server = Server("my-server", on_list_tools=handle_list_tools) If you need to check whether a handler is registered, track this yourself — there is currently no public introspection API. +### Lowlevel `Server`: `add_request_handler` is now public and takes `params_type` + +The private `_add_request_handler(method, handler)` escape hatch is now the public `add_request_handler(method, params_type, handler)`, alongside a matching `add_notification_handler`. Each takes a `params_type` model that incoming params are validated against before the handler runs. + +```python +# Before (v1 / earlier v2 prereleases) +server._add_request_handler("custom/method", my_handler) + +# After (v2) +server.add_request_handler("custom/method", MyParams, my_handler) +server.add_notification_handler("notifications/custom", MyNotifyParams, my_notify_handler) +``` + +### Lowlevel `Server`: private `_handle_*` dispatch methods removed + +`Server._handle_message`, `_handle_request`, and `_handle_notification` have been removed. The receive loop and per-message dispatch now live in `JSONRPCDispatcher` and `ServerRunner`, which `Server.run()` drives internally. + +These were private, but some users subclassed `Server` and overrode them to intercept requests. Use middleware instead: + +```python +from typing import Any + +from pydantic import BaseModel + +from mcp.server import Server, ServerRequestContext +from mcp.server.context import CallNext, HandlerResult + + +async def logging_middleware( + ctx: ServerRequestContext[Any, Any], method: str, params: BaseModel, call_next: CallNext +) -> HandlerResult: + print(f"handling {method}") + result = await call_next() + print(f"done {method}") + return result + + +server = Server("my-server", on_call_tool=...) +server.middleware.append(logging_middleware) +``` + +For lower-level interception (raw method/params before validation, including unknown methods), use `DispatchMiddleware` from `mcp.shared.dispatcher`. + +### Lowlevel `Server.run(raise_exceptions=True)`: transport errors no longer re-raised + +`raise_exceptions=True` now only governs handler exceptions: an exception raised by an `on_*` handler propagates out of `run()` instead of being converted to a JSON-RPC error response. + +Previously it also re-raised exceptions yielded by the transport onto the read stream (e.g. JSON parse errors). Those are now debug-logged and dropped regardless of `raise_exceptions`. If you relied on `run()` exiting on a transport-level parse error, that no longer happens. + ### Lowlevel `Server`: decorator-based handlers replaced with constructor `on_*` params The lowlevel `Server` class no longer uses decorator methods for handler registration. Instead, handlers are passed as `on_*` keyword arguments to the constructor. @@ -1039,6 +1090,39 @@ from mcp.server import ServerRequestContext # but None in notification handlers ``` +### `ServerSession` is now a thin proxy (no longer a `BaseSession`) + +`ServerSession` no longer subclasses `BaseSession`. It is now a small connection-scoped proxy that exposes `send_request`, `send_notification`, the typed convenience helpers (`create_message`, `elicit_form`, `send_log_message`, `send_tool_list_changed`, ...), `client_params`, and `check_client_capability`. The receive loop, `initialize` handling, and per-request task isolation that previously lived in `ServerSession` have moved to `JSONRPCDispatcher` and `ServerRunner`. + +`ServerSession` is normally constructed for you by `Server.run()` and reached via `ctx.session` in handlers, so most servers are unaffected. If you were constructing or subclassing it directly: + +**Constructor change:** + +```python +# Before (v1) +session = ServerSession(read_stream, write_stream, init_options, stateless=False) + +# After (v2) +session = ServerSession(dispatcher, connection, stateless=False) +# where `dispatcher` is a JSONRPCDispatcher and `connection` is a Connection +``` + +In practice, replace direct `ServerSession` use with `Server.run(read_stream, write_stream, init_options)` and let the framework wire it up. + +**Removed from `mcp.server.session`:** + +- `InitializationState` enum and `ServerSession._initialization_state` — initialization tracking is now on `Connection` (`connection.initialized` is an `anyio.Event`, `connection.client_params` holds the init params). +- `ServerRequestResponder` type alias. +- `ServerSession.incoming_messages` stream — there is no longer a public stream of inbound messages to iterate. Register handlers via the `on_*` constructor params (or `add_request_handler`) and use `Server.middleware` to observe every request. +- `ServerSession.__aenter__` / `__aexit__` — `ServerSession` is no longer an async context manager. +- The private `_receive_loop`, `_received_request`, `_received_notification`, and `_handle_incoming` overrides — there is nothing to override on `ServerSession` anymore. To intercept inbound messages, use `Server.middleware` or `DispatchMiddleware` (see the `_handle_*` removal section above). + +### `BaseSession` / `RequestResponder`: server-side cancellation tracking removed + +`BaseSession._in_flight` and the `RequestResponder` members that supported it (`cancel()`, the `cancelled` and `in_flight` properties, the `on_complete` constructor argument, and the internal `CancelScope`) have been removed. These existed to let `ServerSession` cancel a handler when a `CancelledNotification` arrived; `ServerSession` no longer drives a receive loop, so they were dead code. Inbound-cancellation handling for the server now lives in `JSONRPCDispatcher`. + +`BaseSession` is still used by `ClientSession`, which never relied on these members. `RequestResponder.respond()` is unchanged. + ### Experimental Tasks support removed Tasks (SEP-1686) have been removed from the MCP specification and are no longer part of this SDK. The `mcp.client.experimental`, `mcp.server.experimental`, `mcp.shared.experimental`, and `mcp.server.lowlevel.experimental` modules have been removed, along with all `Task*` types, the `tasks` capability fields, `Tool.execution`, and the `experimental` properties on `ClientSession`, `ServerSession`, `Server`, and `ServerRequestContext`. diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index a0620b9c1d..b37ff3e950 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -417,9 +417,15 @@ async def handle_unsubscribe(ctx: ServerRequestContext, params: UnsubscribeReque return EmptyResult() -mcp._lowlevel_server._add_request_handler("logging/setLevel", handle_set_logging_level) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server._add_request_handler("resources/subscribe", handle_subscribe) # pyright: ignore[reportPrivateUsage] -mcp._lowlevel_server._add_request_handler("resources/unsubscribe", handle_unsubscribe) # pyright: ignore[reportPrivateUsage] +mcp._lowlevel_server.add_request_handler( # pyright: ignore[reportPrivateUsage] + "logging/setLevel", SetLevelRequestParams, handle_set_logging_level +) +mcp._lowlevel_server.add_request_handler( # pyright: ignore[reportPrivateUsage] + "resources/subscribe", SubscribeRequestParams, handle_subscribe +) +mcp._lowlevel_server.add_request_handler( # pyright: ignore[reportPrivateUsage] + "resources/unsubscribe", UnsubscribeRequestParams, handle_unsubscribe +) @mcp.completion() diff --git a/pyproject.toml b/pyproject.toml index 6d2319621a..c3b2bd92b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,11 @@ classifiers = [ "Programming Language :: Python :: 3.14", ] dependencies = [ - "anyio>=4.9", + # anyio < 4.10 triggers a compile-time SyntaxWarning on Python 3.14 (PEP 765, + # "'return' in a 'finally' block"); for stdio servers it lands on the child's + # stderr (agronholm/anyio#816, fixed in 4.10). + "anyio>=4.10; python_version >= '3.14'", + "anyio>=4.9; python_version < '3.14'", "httpx>=0.27.1,<1.0.0", "httpx-sse>=0.4", "pydantic>=2.12.0", @@ -91,6 +95,7 @@ dev = [ "pillow>=12.0", "strict-no-cover", "logfire>=3.0.0", + "opentelemetry-sdk>=1.39.1", ] docs = [ "mkdocs>=1.6.1", diff --git a/src/mcp/client/_memory.py b/src/mcp/client/_memory.py index e6e9386731..187131e380 100644 --- a/src/mcp/client/_memory.py +++ b/src/mcp/client/_memory.py @@ -14,6 +14,9 @@ from mcp.server.mcpserver import MCPServer from mcp.shared.memory import create_client_server_memory_streams +SERVER_SHUTDOWN_GRACE = 2.0 +"""Seconds to wait for the in-process server to exit on EOF before cancelling.""" + class InMemoryTransport: """In-memory transport for testing MCP servers without network overhead. @@ -48,21 +51,49 @@ async def _connect(self) -> AsyncIterator[TransportStreams]: client_read, client_write = client_streams server_read, server_write = server_streams - async with anyio.create_task_group() as tg: - # Start server in background - tg.start_soon( - lambda: actual_server.run( + server_done = anyio.Event() + + async def _run_server() -> None: + try: + await actual_server.run( server_read, server_write, actual_server.create_initialization_options(), raise_exceptions=self._raise_exceptions, ) - ) + finally: + server_done.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(_run_server) try: yield client_read, client_write finally: - tg.cancel_scope.cancel() + # EOF the server (and our own read side) instead of + # cancelling outright. The dispatcher's run() cancels its + # own in-flight handlers on read-stream EOF, so for a + # well-behaved server the task exits naturally and the + # task-group join below is immediate. Cancelling here + # unconditionally would `coro.throw()` into this task, + # which on CPython 3.11 (gh-106749) drops `'call'` trace + # events for the outer await chain and desyncs coverage's + # CTracer past the test frame. + await client_write.aclose() + await server_write.aclose() + # Backstop: the dispatcher exits on EOF, but the server's + # own teardown (lifespan __aexit__, connection.exit_stack + # callbacks) runs after that and is user code. If it never + # completes the join would hang forever, so bound the wait + # and fall back to cancelling. The healthy path returns + # from wait() without the timeout firing, so the cancel is + # never reached and gh-106749 stays avoided. If the cancel + # does fire, the checkpoint at the end of + # `create_client_server_memory_streams` resyncs the tracer. + with anyio.move_on_after(SERVER_SHUTDOWN_GRACE): + await server_done.wait() + if not server_done.is_set(): + tg.cancel_scope.cancel() async def __aenter__(self) -> TransportStreams: """Connect to the server and return streams for communication.""" diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 74e5ba8062..d217ae42ff 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -5,6 +5,7 @@ from urllib.parse import parse_qs, urljoin, urlparse import anyio +import anyio.lowlevel import httpx from anyio.abc import TaskStatus from httpx_sse import SSEError, aconnect_sse @@ -157,3 +158,10 @@ async def _send_message(session_message: SessionMessage) -> None: yield read_stream, write_stream tg.cancel_scope.cancel() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9cdf717c73..78130f2f8e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -9,6 +9,7 @@ from dataclasses import dataclass import anyio +import anyio.lowlevel import httpx from anyio.abc import TaskGroup from httpx_sse import EventSource, ServerSentEvent, aconnect_sse @@ -586,3 +587,10 @@ def start_get_stream() -> None: if transport.session_id and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index de473f36d3..c3423c3c98 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -3,6 +3,7 @@ from contextlib import asynccontextmanager import anyio +import anyio.lowlevel from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError from websockets.asyncio.client import connect as ws_connect @@ -83,3 +84,10 @@ async def ws_writer(): # Once the caller's 'async with' block exits, we shut down tg.cancel_scope.cancel() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/server/__main__.py b/src/mcp/server/__main__.py index dbc50b8a79..4305b87e22 100644 --- a/src/mcp/server/__main__.py +++ b/src/mcp/server/__main__.py @@ -1,14 +1,11 @@ -import importlib.metadata import logging import sys import warnings import anyio -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp.server.lowlevel.server import Server from mcp.server.stdio import stdio_server -from mcp.types import ServerCapabilities if not sys.warnoptions: warnings.simplefilter("ignore") @@ -17,32 +14,10 @@ logger = logging.getLogger("server") -async def receive_loop(session: ServerSession): - logger.info("Starting receive loop") - async for message in session.incoming_messages: - if isinstance(message, Exception): - logger.error("Error: %s", message) - continue - - logger.info("Received message from client: %s", message) - - -async def main(): - version = importlib.metadata.version("mcp") +async def main() -> None: + server: Server[dict[str, object]] = Server("mcp") async with stdio_server() as (read_stream, write_stream): - async with ( - ServerSession( - read_stream, - write_stream, - InitializationOptions( - server_name="mcp", - server_version=version, - capabilities=ServerCapabilities(), - ), - ) as session, - write_stream, - ): - await receive_loop(session) + await server.run(read_stream, write_stream, server.create_initialization_options()) if __name__ == "__main__": diff --git a/src/mcp/server/_typed_request.py b/src/mcp/server/_typed_request.py new file mode 100644 index 0000000000..64b8b8119a --- /dev/null +++ b/src/mcp/server/_typed_request.py @@ -0,0 +1,86 @@ +"""Typed `send_request` for server-to-client requests. + +`TypedServerRequestMixin` provides a typed `send_request(req) -> Result` over +the host's raw `Outbound.send_raw_request`. Spec server-to-client request types +have their result type inferred via per-type overloads; custom requests pass +`result_type=` explicitly. + +If the spec's request set grows substantially, consider declaring the result +mapping on the request types themselves (a `__mcp_result__` ClassVar read via +a structural protocol) so this overload ladder doesn't need maintaining +per-host-class. +""" + +from typing import Any, TypeVar, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.peer import dump_params +from mcp.types import ( + CreateMessageRequest, + CreateMessageResult, + ElicitRequest, + ElicitResult, + EmptyResult, + ListRootsRequest, + ListRootsResult, + PingRequest, + Request, +) + +__all__ = ["TypedServerRequestMixin"] + +ResultT = TypeVar("ResultT", bound=BaseModel) + +_RESULT_FOR: dict[type[Request[Any, Any]], type[BaseModel]] = { + CreateMessageRequest: CreateMessageResult, + ElicitRequest: ElicitResult, + ListRootsRequest: ListRootsResult, + PingRequest: EmptyResult, +} + + +class TypedServerRequestMixin: + """Typed `send_request` for the server-to-client request set. + + Mixed into `Connection` and the server `Context`. Each method constrains + `self` to `Outbound` so any host with `send_raw_request` works. + """ + + @overload + async def send_request( + self: Outbound, req: CreateMessageRequest, *, opts: CallOptions | None = None + ) -> CreateMessageResult: ... + @overload + async def send_request(self: Outbound, req: ElicitRequest, *, opts: CallOptions | None = None) -> ElicitResult: ... + @overload + async def send_request( + self: Outbound, req: ListRootsRequest, *, opts: CallOptions | None = None + ) -> ListRootsResult: ... + @overload + async def send_request(self: Outbound, req: PingRequest, *, opts: CallOptions | None = None) -> EmptyResult: ... + @overload + async def send_request( + self: Outbound, req: Request[Any, Any], *, result_type: type[ResultT], opts: CallOptions | None = None + ) -> ResultT: ... + async def send_request( + self: Outbound, + req: Request[Any, Any], + *, + result_type: type[BaseModel] | None = None, + opts: CallOptions | None = None, + ) -> BaseModel: + """Send a typed server-to-client request and return its typed result. + + For spec request types the result type is inferred. For custom requests + pass `result_type=` explicitly. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + KeyError: `result_type` omitted for a non-spec request type. + """ + raw = await self.send_raw_request(req.method, dump_params(req.params), opts) + cls = result_type if result_type is not None else _RESULT_FOR[type(req)] + return cls.model_validate(raw, by_name=False) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py new file mode 100644 index 0000000000..849a74b28b --- /dev/null +++ b/src/mcp/server/connection.py @@ -0,0 +1,177 @@ +"""`Connection` - per-client connection state and the standalone outbound channel. + +Always present on `Context` (never `None`), even in stateless deployments. +Holds peer info populated at `initialize` time, per-connection scratch +`state` and an `exit_stack` for teardown, and an `Outbound` for the +standalone stream (the SSE GET stream in streamable HTTP, or the single duplex +stream in stdio). + +`notify` is best-effort: it never raises. If there's no standalone channel +(stateless HTTP) or the stream has been dropped, the notification is +debug-logged and silently discarded - server-initiated notifications are +inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when +there's no channel; `ping` is the only spec-sanctioned standalone request. +""" + +import logging +from collections.abc import Mapping +from contextlib import AsyncExitStack +from typing import Any + +import anyio + +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.peer import Meta, dump_params +from mcp.types import ClientCapabilities, Implementation, InitializeRequestParams, LoggingLevel + +__all__ = ["Connection"] + +logger = logging.getLogger(__name__) + + +def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> dict[str, Any] | None: + if not meta: + return payload + out = dict(payload or {}) + out["_meta"] = meta + return out + + +class Connection(TypedServerRequestMixin): + """Per-client connection state and standalone-stream `Outbound`. + + Constructed by `ServerRunner` once per connection. The peer-info fields are + `None` until `initialize` completes; `initialized` is set then. In + stateless deployments the runner sets `initialized` immediately and + peer-info remains `None` (no handshake reaches a stateless connection). + """ + + def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: + self._outbound = outbound + self.has_standalone_channel = has_standalone_channel + self.session_id: str | None = session_id + + self.client_params: InitializeRequestParams | None = None + """The full `initialize` request params; `None` before initialization.""" + self.protocol_version: str | None = None + self.initialized: anyio.Event = anyio.Event() + + self.state: dict[str, Any] = {} + """Per-connection scratch state. Handlers and middleware may read and + write freely; persists across requests on this connection.""" + + self.exit_stack: AsyncExitStack = AsyncExitStack() + """Cleanup stack unwound by `ServerRunner` when the connection closes. + + Push context managers (`await exit_stack.enter_async_context(...)`) + or callbacks (`exit_stack.push_async_callback(...)`) from handlers or + middleware to register per-connection teardown. Unwound LIFO after + `dispatcher.run()` returns, shielded from cancellation. Exceptions + raised by callbacks are logged and swallowed; they never propagate + out of `ServerRunner.run()`.""" + + @property + def client_info(self) -> Implementation | None: + """The client's `Implementation` from `initialize`; `None` before initialization.""" + return self.client_params.client_info if self.client_params is not None else None + + @property + def client_capabilities(self) -> ClientCapabilities | None: + """The client's `ClientCapabilities` from `initialize`; `None` before initialization.""" + return self.client_params.capabilities if self.client_params is not None else None + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a raw request on the standalone stream. + + Low-level `Outbound` channel. Prefer the typed `send_request` (from + `TypedServerRequestMixin`) or the convenience methods below; use this + directly only for off-spec messages. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: `has_standalone_channel` is `False`. + """ + if not self.has_standalone_channel: + raise NoBackChannelError(method) + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a best-effort notification on the standalone stream. + + Never raises. If there's no standalone channel or the stream is broken, + the notification is dropped and debug-logged. + """ + if not self.has_standalone_channel: + logger.debug("dropped %s: no standalone channel", method) + return + try: + await self._outbound.notify(method, params) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped %s: standalone stream closed", method) + + async def ping(self, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a `ping` request on the standalone stream. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: `has_standalone_channel` is `False`. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a `notifications/message` log entry on the standalone stream. Best-effort.""" + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + await self.notify("notifications/message", _notification_params(params, meta)) + + async def send_tool_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/tools/list_changed", _notification_params(None, meta)) + + async def send_prompt_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/prompts/list_changed", _notification_params(None, meta)) + + async def send_resource_list_changed(self, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/list_changed", _notification_params(None, meta)) + + async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> None: + await self.notify("notifications/resources/updated", _notification_params({"uri": uri}, meta)) + + def check_capability(self, capability: ClientCapabilities) -> bool: + """Return whether the connected client declared the given capability. + + Returns `False` if `initialize` hasn't completed yet. + """ + # TODO: redesign - mirrors v1 ServerSession.check_client_capability + # verbatim for parity. + if self.client_capabilities is None: + return False + have = self.client_capabilities + if capability.roots is not None: + if have.roots is None: + return False + if capability.roots.list_changed and not have.roots.list_changed: + return False + if capability.sampling is not None: + if have.sampling is None: + return False + if capability.sampling.context is not None and have.sampling.context is None: + return False + if capability.sampling.tools is not None and have.sampling.tools is None: + return False + if capability.elicitation is not None and have.elicitation is None: + return False + if capability.experimental is not None: + if have.experimental is None: + return False + for k, v in capability.experimental.items(): + if k not in have.experimental or have.experimental[k] != v: + return False + return True diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index bc54c5d2eb..7ca4c63c53 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -1,21 +1,147 @@ from __future__ import annotations +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass -from typing import Any, Generic +from typing import Any, Generic, Protocol +from pydantic import BaseModel from typing_extensions import TypeVar +from mcp.server._typed_request import TypedServerRequestMixin +from mcp.server.connection import Connection from mcp.server.session import ServerSession -from mcp.shared._context import RequestContext +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext from mcp.shared.message import CloseSSEStreamCallback +from mcp.shared.peer import Meta, PeerMixin +from mcp.shared.transport_context import TransportContext +from mcp.types import LoggingLevel, RequestId, RequestParamsMeta LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @dataclass(kw_only=True) -class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContextT, RequestT]): +class ServerRequestContext(Generic[LifespanContextT, RequestT]): + """Per-request context handed to lowlevel request and notification handlers. + + Built by `ServerRunner._make_context` for each inbound message. Carries the + connection-scoped `ServerSession` (server-to-client requests and + notifications), per-request metadata, and any per-message data the + transport attached (the HTTP request, SSE stream-close callbacks). + """ + + session: ServerSession lifespan_context: LifespanContextT + request_id: RequestId | None = None + meta: RequestParamsMeta | None = None request: RequestT | None = None close_sse_stream: CloseSSEStreamCallback | None = None close_standalone_sse_stream: CloseSSEStreamCallback | None = None + + +LifespanT = TypeVar("LifespanT", default=Any, covariant=True) + + +class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]): + """Server-side per-request context. + + Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`), + `PeerMixin` (kwarg-style `sample`/`elicit_*`/`list_roots`/`ping`), + and `TypedServerRequestMixin` (typed `send_request(req) -> Result`). Adds + `lifespan` and `connection`. + + Constructed by `ServerRunner` per inbound request and handed to the user's + handler. + """ + + def __init__( + self, + dctx: DispatchContext[TransportContext], + *, + lifespan: LifespanT, + connection: Connection, + meta: RequestParamsMeta | None = None, + ) -> None: + super().__init__(dctx, meta=meta) + self._lifespan = lifespan + self._connection = connection + + @property + def lifespan(self) -> LifespanT: + """The server-wide lifespan output (what `Server(..., lifespan=...)` yielded).""" + return self._lifespan + + @property + def connection(self) -> Connection: + """The per-client `Connection` for this request's connection.""" + return self._connection + + @property + def session_id(self) -> str | None: + """The transport's session id for this connection, when one exists. + + Convenience for `ctx.connection.session_id`. `None` on stdio and + stateless HTTP. + """ + return self._connection.session_id + + @property + def headers(self) -> Mapping[str, str] | None: + """Request headers carried by this message, when the transport has them. + + Convenience for `ctx.transport.headers`. `None` on stdio. + """ + return self.transport.headers + + async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None: + """Send a request-scoped `notifications/message` log entry. + + Uses this request's back-channel (so the entry rides the request's SSE + stream in streamable HTTP), not the standalone stream - use + `ctx.connection.log(...)` for that. + """ + params: dict[str, Any] = {"level": level, "data": data} + if logger is not None: + params["logger"] = logger + if meta: + params["_meta"] = meta + await self.notify("notifications/message", params) + + +HandlerResult = BaseModel | dict[str, Any] | None +"""What a request handler (or middleware) may return. `ServerRunner` serializes +all three to a result dict.""" + +CallNext = Callable[[], Awaitable[HandlerResult]] + +_MwLifespanT = TypeVar("_MwLifespanT") + + +class ServerMiddleware(Protocol[_MwLifespanT]): + """Context-tier middleware: `(ctx, method, typed_params, call_next) -> result`. + + Runs *inside* `ServerRunner._on_request` after params validation and + context construction. Wraps registered handlers (including `ping`) but + not `initialize`, `METHOD_NOT_FOUND`, or validation failures. Listed + outermost-first on `Server.middleware`. + + `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific + middleware sees `ctx.lifespan_context: L`. While the context is the + mutable `ServerRequestContext` dataclass it is invariant in `L`, so a + reusable middleware should be typed `ServerMiddleware[Any]` to register on + any `Server[L]`. + """ + + # TODO(maxisbey): once `_make_context` returns the (covariant) `Context[L]` + # again, restore `_MwLifespanT` to `contravariant=True` and retype `ctx` + # below to `Context[_MwLifespanT]` so reusable middleware can be + # `ServerMiddleware[object]` instead of `ServerMiddleware[Any]`. + + async def __call__( + self, + ctx: ServerRequestContext[_MwLifespanT, Any], + method: str, + params: BaseModel | None, + call_next: CallNext, + ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 37127c5621..5650db7199 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -36,16 +36,14 @@ async def main(): from __future__ import annotations -import contextvars import logging -import warnings from collections.abc import AsyncIterator, Awaitable, Callable -from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from dataclasses import dataclass from importlib.metadata import version as importlib_version -from typing import Any, Generic, cast +from typing import Any, Generic -import anyio -from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -58,22 +56,45 @@ async def main(): from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes from mcp.server.auth.settings import AuthSettings -from mcp.server.context import ServerRequestContext +from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession +from mcp.server.runner import ServerRunner, otel_middleware from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared._otel import extract_trace_context, otel_span from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.exceptions import MCPError -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext logger = logging.getLogger(__name__) LifespanResultT = TypeVar("LifespanResultT", default=Any) +_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel) + +RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]] +"""A registered request handler: `(ctx, params) -> result`.""" + +NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]] +"""A registered notification handler: `(ctx, params) -> None`.""" + + +@dataclass(frozen=True, slots=True) +class HandlerEntry(Generic[LifespanResultT]): + """A registered handler and the params model to validate incoming params against. + + Stored in `Server._request_handlers` / `_notification_handlers` and consumed + by `ServerRunner` to validate, build `Context`, and invoke. The handler's + second-argument type is erased to `Any` in storage (each entry has a + different concrete params type and `Callable` parameters are contravariant); + the precise type is recoverable via `params_type`. The correlation is + enforced at registration time by `Server.add_request_handler`. + """ + + params_type: type[BaseModel] + handler: RequestHandler[LifespanResultT, Any] + class NotificationOptions: def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False): @@ -83,7 +104,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals @asynccontextmanager -async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]: +async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: """Default lifespan context manager that does nothing. Returns: @@ -191,53 +212,75 @@ def __init__( self.website_url = website_url self.icons = icons self.lifespan = lifespan - self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {} - self._notification_handlers: dict[ - str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]] - ] = {} + self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} + self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {} self._session_manager: StreamableHTTPSessionManager | None = None + # Context-tier middleware consumed by `ServerRunner`. Additive; the + # existing `run()` path ignores it. + self.middleware: list[ServerMiddleware[LifespanResultT]] = [] logger.debug("Initializing server %r", name) - # Populate internal handler dicts from on_* kwargs - self._request_handlers.update( - { - method: handler - for method, handler in { - "ping": on_ping, - "prompts/list": on_list_prompts, - "prompts/get": on_get_prompt, - "resources/list": on_list_resources, - "resources/templates/list": on_list_resource_templates, - "resources/read": on_read_resource, - "resources/subscribe": on_subscribe_resource, - "resources/unsubscribe": on_unsubscribe_resource, - "tools/list": on_list_tools, - "tools/call": on_call_tool, - "logging/setLevel": on_set_logging_level, - "completion/complete": on_completion, - }.items() - if handler is not None - } - ) + _spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [ + ("ping", types.RequestParams, on_ping), + ("prompts/list", types.PaginatedRequestParams, on_list_prompts), + ("prompts/get", types.GetPromptRequestParams, on_get_prompt), + ("resources/list", types.PaginatedRequestParams, on_list_resources), + ("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates), + ("resources/read", types.ReadResourceRequestParams, on_read_resource), + ("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource), + ("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource), + ("tools/list", types.PaginatedRequestParams, on_list_tools), + ("tools/call", types.CallToolRequestParams, on_call_tool), + ("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level), + ("completion/complete", types.CompleteRequestParams, on_completion), + ] + self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None}) + _spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [ + ("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed), + ("notifications/progress", types.ProgressNotificationParams, on_progress), + ] self._notification_handlers.update( - { - method: handler - for method, handler in { - "notifications/roots/list_changed": on_roots_list_changed, - "notifications/progress": on_progress, - }.items() - if handler is not None - } + {m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None} ) - def _add_request_handler( + def add_request_handler( self, method: str, - handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]], + params_type: type[_ParamsT], + handler: RequestHandler[LifespanResultT, _ParamsT], ) -> None: - """Add a request handler, silently replacing any existing handler for the same method.""" - self._request_handlers[method] = handler + """Register a request handler for `method`. + + `params_type` is the model incoming params are validated against + before the handler is invoked. It should subclass `RequestParams` so + `_meta` parses uniformly. Replaces any existing handler for the same + method (no collision guard against spec methods). + """ + self._request_handlers[method] = HandlerEntry(params_type, handler) + + def add_notification_handler( + self, + method: str, + params_type: type[_ParamsT], + handler: NotificationHandler[LifespanResultT, _ParamsT], + ) -> None: + """Register a notification handler for `method`. + + `params_type` should subclass `NotificationParams` so `_meta` + parses uniformly. Replaces any existing handler. + """ + self._notification_handlers[method] = HandlerEntry(params_type, handler) + + # --- ServerRegistry protocol (consumed by ServerRunner) ------------------ + + def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a request method, or `None`.""" + return self._request_handlers.get(method) + + def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None: + """Return the registered entry for a notification method, or `None`.""" + return self._notification_handlers.get(method) # TODO: Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed @@ -347,167 +390,31 @@ async def run( # the initialization lifecycle, but can do so with any available node # rather than requiring initialization for each connection. stateless: bool = False, - ): - async with AsyncExitStack() as stack: - lifespan_context = await stack.enter_async_context(self.lifespan(self)) - session = await stack.enter_async_context( - ServerSession( - read_stream, - write_stream, - initialization_options, - stateless=stateless, - ) - ) - - async with anyio.create_task_group() as tg: - try: - async for message in session.incoming_messages: - logger.debug("Received message: %s", message) - - if isinstance(message, RequestResponder) and message.context is not None: - context = message.context - else: - context = contextvars.copy_context() - - context.run( - tg.start_soon, - self._handle_message, - message, - session, - lifespan_context, - raise_exceptions, - ) - finally: - # Transport closed: cancel in-flight handlers. Without this the - # TG join waits for them, and when they eventually try to - # respond they hit a closed write stream (the session's - # _receive_loop closed it when the read stream ended). - tg.cancel_scope.cancel() - - async def _handle_message( - self, - message: RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception, - session: ServerSession, - lifespan_context: LifespanResultT, - raise_exceptions: bool = False, - ): - with warnings.catch_warnings(record=True) as w: - match message: - case RequestResponder() as responder: - with responder: - await self._handle_request( - message, responder.request, session, lifespan_context, raise_exceptions - ) - case Exception(): - logger.error(f"Received exception from stream: {message}") - if raise_exceptions: - raise message - case _: - await self._handle_notification(message, session, lifespan_context) - - for warning in w: # pragma: lax no cover - logger.info("Warning: %s: %s", warning.category.__name__, warning.message) - - async def _handle_request( - self, - message: RequestResponder[types.ClientRequest, types.ServerResult], - req: types.ClientRequest, - session: ServerSession, - lifespan_context: LifespanResultT, - raise_exceptions: bool, - ): - logger.info("Processing request of type %s", type(req).__name__) - - target = getattr(req.params, "name", None) if req.params else None - span_name = f"MCP handle {req.method} {target}" if target else f"MCP handle {req.method}" - - # Extract W3C trace context from _meta (SEP-414). - meta = cast(dict[str, Any] | None, getattr(req.params, "meta", None)) if req.params else None - parent_context = extract_trace_context(meta) if meta is not None else None - - with otel_span( - span_name, - kind=SpanKind.SERVER, - attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id}, - context=parent_context, - ) as span: - if handler := self._request_handlers.get(req.method): - logger.debug("Dispatching request of type %s", type(req).__name__) - - try: - # Extract request context and close_sse_stream from message metadata - request_data = None - close_sse_stream_cb = None - close_standalone_sse_stream_cb = None - if message.message_metadata is not None and isinstance( - message.message_metadata, ServerMessageMetadata - ): - request_data = message.message_metadata.request_context - close_sse_stream_cb = message.message_metadata.close_sse_stream - close_standalone_sse_stream_cb = message.message_metadata.close_standalone_sse_stream - - ctx = ServerRequestContext( - request_id=message.request_id, - meta=message.request_meta, - session=session, - lifespan_context=lifespan_context, - request=request_data, - close_sse_stream=close_sse_stream_cb, - close_standalone_sse_stream=close_standalone_sse_stream_cb, - ) - response = await handler(ctx, req.params) - except MCPError as err: - response = err.error - except anyio.get_cancelled_exc_class(): - if message.cancelled: - # Client sent CancelledNotification; responder.cancel() already - # sent an error response, so skip the duplicate. - logger.info("Request %s cancelled - duplicate response suppressed", message.request_id) - return - # Transport-close cancellation from the TG in run(); re-raise so the - # TG swallows its own cancellation. - raise - except Exception as err: - if raise_exceptions: # pragma: no cover - raise err - response = types.ErrorData(code=0, message=str(err)) - else: - response = types.ErrorData(code=types.METHOD_NOT_FOUND, message="Method not found") - - if isinstance(response, types.ErrorData) and span is not None: - span.set_status(StatusCode.ERROR, response.message) - - try: - await message.respond(response) - except (anyio.BrokenResourceError, anyio.ClosedResourceError): - # Transport closed between handler unblocking and respond. Happens - # when _receive_loop's finally wakes a handler blocked on - # send_request: the handler runs to respond() before run()'s TG - # cancel fires, but after the write stream closed. Closed if our - # end closed (_receive_loop's async-with exit); Broken if the peer - # end closed first (streamable_http terminate()). - logger.debug("Response for %s dropped - transport closed", message.request_id) - return - - logger.debug("Response sent") - - async def _handle_notification( - self, - notify: types.ClientNotification, - session: ServerSession, - lifespan_context: LifespanResultT, ) -> None: - if handler := self._notification_handlers.get(notify.method): - logger.debug("Dispatching notification of type %s", type(notify).__name__) - - try: - ctx = ServerRequestContext( - session=session, - lifespan_context=lifespan_context, - ) - await handler(ctx, notify.params) - except Exception: # pragma: no cover - logger.exception("Uncaught exception in notification handler") + async with self.lifespan(self) as lifespan_context: + dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + read_stream, + write_stream, + raise_handler_exceptions=raise_exceptions, + # Handle `initialize` inline so a client that pipelines it with + # the next request (spec says SHOULD NOT, not MUST NOT) sees + # the initialized state instead of failing the init-gate. + inline_methods=frozenset({"initialize"}), + ) + runner = ServerRunner( + server=self, + dispatcher=dispatcher, + lifespan_state=lifespan_context, + init_options=initialization_options, + # Stateless HTTP has no standalone GET stream, so server-initiated + # requests on `runner.connection` must fail fast with + # `NoBackChannelError` rather than write to a channel that will + # never deliver a response. + has_standalone_channel=not stateless, + stateless=stateless, + dispatch_middleware=[otel_middleware], + ) + await runner.run() def streamable_http_app( self, diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index ec2365810e..fdb69571d8 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -607,11 +607,7 @@ async def handler( completion=result if result is not None else Completion(values=[], total=None, has_more=None), ) - # TODO(maxisbey): remove private access — completion needs post-construction - # handler registration, find a better pattern for this - self._lowlevel_server._add_request_handler( # pyright: ignore[reportPrivateUsage] - "completion/complete", handler - ) + self._lowlevel_server.add_request_handler("completion/complete", CompleteRequestParams, handler) return func return decorator diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py new file mode 100644 index 0000000000..2ec5173e84 --- /dev/null +++ b/src/mcp/server/runner.py @@ -0,0 +1,338 @@ +"""`ServerRunner` - per-connection orchestrator over a `Dispatcher`. + +`ServerRunner` is the bridge between the dispatcher layer (`on_request` / +`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, +typed params). One instance per client connection. It: + +* handles the `initialize` handshake and populates `Connection` +* gates requests until initialized (`ping` exempt) +* looks up the handler in the server's registry, validates params, builds + `Context`, runs the middleware chain, returns the result dict +* drives `dispatcher.run()` and the per-connection lifespan + +`ServerRunner` holds a `Server` directly - `Server` is the registry. +""" + +from __future__ import annotations + +import logging +from collections.abc import Mapping +from dataclasses import dataclass, field +from functools import partial, reduce +from typing import TYPE_CHECKING, Any, Generic, cast, get_args + +import anyio.abc +from opentelemetry.trace import SpanKind, StatusCode +from pydantic import BaseModel, ValidationError +from typing_extensions import TypeVar + +from mcp.server.connection import Connection +from mcp.server.context import CallNext, ServerMiddleware, ServerRequestContext +from mcp.server.models import InitializationOptions +from mcp.server.session import ServerSession +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ServerMessageMetadata +from mcp.shared.transport_context import TransportContext +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.types import ( + INVALID_PARAMS, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + ClientRequest, + ErrorData, + Implementation, + InitializeRequestParams, + InitializeResult, + RequestParams, + client_request_adapter, +) + +if TYPE_CHECKING: + from mcp.server.lowlevel.server import Server + +__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] + +logger = logging.getLogger(__name__) + +LifespanT = TypeVar("LifespanT", default=Any) + + +_INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) + +_SPEC_CLIENT_METHODS: frozenset[str] = frozenset( + cast(type[BaseModel], arm).model_fields["method"].default for arm in get_args(ClientRequest) +) +"""Method names in the spec `ClientRequest` union, derived from the +discriminator literal on each arm. Used to gate upfront validation so custom +methods registered via `add_request_handler` are not rejected.""" + + +def otel_middleware(next_on_request: OnRequest) -> OnRequest: + """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. + + Mirrors the span shape of the existing `Server._handle_request`: span name + `"MCP handle []"`, `mcp.method.name` attribute, W3C + trace context extracted from `params._meta` (SEP-414), and an ERROR + status if the handler raises. + """ + + async def wrapped( + dctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + target: str | None + match params: + case {"name": str() as target}: + pass + case _: + target = None + parent: Any | None + match params: + case {"_meta": {**meta}}: + parent = extract_trace_context(meta) + case _: + parent = None + span_name = f"MCP handle {method}{f' {target}' if target else ''}" + # `otel_middleware` wraps `on_request` only, so `request_id` is always set. + attributes = {"mcp.method.name": method, "jsonrpc.request.id": str(dctx.request_id)} + with otel_span( + span_name, + kind=SpanKind.SERVER, + attributes=attributes, + context=parent, + record_exception=False, + set_status_on_exception=False, + ) as span: + try: + return await next_on_request(dctx, method, params) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except Exception as e: + span.record_exception(e) + span.set_status(StatusCode.ERROR, str(e)) + raise + + return wrapped + + +def _dump_result(result: Any) -> dict[str, Any]: + if result is None: + return {} + if isinstance(result, ErrorData): + # The existing `BaseSession._send_response` treats a handler-returned + # `ErrorData` as a JSON-RPC error, not a success result. Re-raise as + # `MCPError` so the dispatcher's exception boundary emits `JSONRPCError`. + raise MCPError(code=result.code, message=result.message, data=result.data) + if isinstance(result, BaseModel): + return result.model_dump(by_alias=True, mode="json", exclude_none=True) + if isinstance(result, dict): + return cast(dict[str, Any], result) + raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") + + +@dataclass +class ServerRunner(Generic[LifespanT]): + """Per-connection orchestrator. One instance per client connection.""" + + server: Server[LifespanT] + dispatcher: JSONRPCDispatcher[Any] + lifespan_state: LifespanT + has_standalone_channel: bool + init_options: InitializationOptions | None = None + """`InitializeResult` payload. Defaults to `server.create_initialization_options()`.""" + session_id: str | None = None + stateless: bool = False + dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) + + connection: Connection = field(init=False) + session: ServerSession = field(init=False) + """Connection-scoped: the same instance reaches every request as `ctx.session`.""" + _initialized: bool = field(init=False) + + def __post_init__(self) -> None: + self._initialized = self.stateless + if self.init_options is None: + self.init_options = self.server.create_initialization_options() + self.connection = Connection( + self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id + ) + if self.stateless: + # Keep the public event in lockstep with the gate flag so a handler + # awaiting `connection.initialized` does not hang on a stateless + # connection (where no `initialize` exchange ever arrives). + self.connection.initialized.set() + self.session = ServerSession(self.dispatcher, self.connection, stateless=self.stateless) + + async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: + """Drive the dispatcher until the underlying channel closes. + + Composes `dispatch_middleware` over `_on_request` and hands the result + to `dispatcher.run()`. `task_status.started()` is forwarded so callers + can `await tg.start(runner.run)` and resume once the dispatcher is + ready to accept requests. Once the dispatcher exits, + `connection.exit_stack` is unwound (shielded) so any per-connection + cleanup registered by handlers or middleware runs to completion. + """ + try: + await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) + finally: + with anyio.CancelScope(shield=True): + try: + await self.connection.exit_stack.aclose() + except Exception: + # Top-level boundary: a cleanup callback raising must not + # escape `run()` - it would crash stdio servers on a normal + # disconnect and, via raise-in-finally, mask the original + # exception from `dispatcher.run()` (including the + # CancelledError that SHTTP idle-timeout teardown checks). + logger.exception("connection exit_stack cleanup raised") + + def _compose_on_request(self) -> OnRequest: + """Wrap `_on_request` in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw `(dctx, method, params) -> dict` + and wraps everything - initialize, METHOD_NOT_FOUND, validation + failures included. `run()` calls this once and hands the result to + `dispatcher.run()`. + """ + return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + + async def _on_request( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> dict[str, Any]: + # TODO(maxisbey): pinned compat. `BaseSession._receive_loop` validates + # every inbound request against the spec `ClientRequest` discriminated + # union *before* handler lookup, so a spec method with malformed params + # surfaces as INVALID_PARAMS via the dispatcher's ValidationError + # boundary even when no handler is registered. v2 wanted to decouple + # the runner from the spec union; revisit once the suite's divergence + # entry is resolved. Gated on spec methods so custom methods registered + # via `add_request_handler` still route (the existing server rejects + # those too, but nothing pins that and routing them is strictly better). + if method in _SPEC_CLIENT_METHODS: + payload: dict[str, Any] = {"method": method} + if params is not None: + payload["params"] = dict(params) + client_request_adapter.validate_python(payload, by_name=False) + if method == "initialize": + return self._handle_initialize(params) + if not self._initialized and method not in _INIT_EXEMPT: + # TODO(maxisbey): pinned compat. The existing server has no + # dedicated pre-init check; the request dies in ClientRequest + # validation, so the client sees the generic invalid-params shape. + raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") + entry = self.server.get_request_handler(method) + if entry is None: + raise MCPError(code=METHOD_NOT_FOUND, message="Method not found") + # ValidationError propagates; the dispatcher's exception boundary maps + # it to INVALID_PARAMS. Absent wire params reach the handler as None + # (matches the existing `Server._handle_request`, where `req.params` + # is None for optional-params requests like tools/list); the empty-dict + # validate is a required-field check so a required-params model still + # surfaces as INVALID_PARAMS rather than reaching the handler as None. + if params is None: + entry.params_type.model_validate({}, by_name=False) + typed_params = None + else: + typed_params = entry.params_type.model_validate(params, by_name=False) + ctx = self._make_context(dctx, typed_params) + call: CallNext = partial(entry.handler, ctx, typed_params) + for mw in reversed(self.server.middleware): + call = partial(mw, ctx, method, typed_params, call) + return _dump_result(await call()) + + async def _on_notify( + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + ) -> None: + if method == "notifications/initialized": + self._initialized = True + self.connection.initialized.set() + return + if not self._initialized: + logger.debug("dropped %s: received before initialization", method) + return + entry = self.server.get_notification_handler(method) + if entry is None: + logger.debug("no handler for notification %s", method) + return + # Absent wire params reach the handler as None, not an empty model + # (matches the existing `Server._handle_notification`). The empty-dict + # validate is a required-field check: a required-params model (e.g. + # ProgressNotificationParams) takes the malformed-params drop path + # instead of reaching a non-Optional handler as None. + try: + if params is None: + entry.params_type.model_validate({}, by_name=False) + typed_params = None + else: + typed_params = entry.params_type.model_validate(params, by_name=False) + except ValidationError: + logger.warning("dropped %r: malformed params", method) + return + ctx = self._make_context(dctx, typed_params) + try: + await entry.handler(ctx, typed_params) + except Exception: + # Top-level boundary: a notification handler crashing must not + # tear down the connection (it runs as a bare task in the + # dispatcher's task group; an uncaught exception would cancel + # every sibling, including the read loop and in-flight requests). + logger.exception("notification handler for %r raised", method) + + def _make_context( + self, dctx: DispatchContext[TransportContext], typed_params: BaseModel | None + ) -> ServerRequestContext[LifespanT, Any]: + meta = typed_params.meta if isinstance(typed_params, RequestParams) else None + # TODO(maxisbey): remove for Context rework. Reads the SHTTP per-request + # data off the raw `dctx.message_metadata` carrier; replace with the + # per-transport context once that lands. + md = dctx.message_metadata + if isinstance(md, ServerMessageMetadata): + request = md.request_context + close_sse_stream = md.close_sse_stream + close_standalone_sse_stream = md.close_standalone_sse_stream + else: + request = close_sse_stream = close_standalone_sse_stream = None + return ServerRequestContext( + session=self.session, + lifespan_context=self.lifespan_state, + request_id=dctx.request_id, + meta=meta, + request=request, + close_sse_stream=close_sse_stream, + close_standalone_sse_stream=close_standalone_sse_stream, + ) + + def _handle_initialize(self, params: Mapping[str, Any] | None) -> dict[str, Any]: + init = InitializeRequestParams.model_validate(params or {}, by_name=False) + self.connection.client_params = init + requested = init.protocol_version + negotiated = requested if requested in SUPPORTED_PROTOCOL_VERSIONS else LATEST_PROTOCOL_VERSION + self.connection.protocol_version = negotiated + self._initialized = True + self.connection.initialized.set() + assert self.init_options is not None + opts = self.init_options + result = InitializeResult( + protocol_version=negotiated, + capabilities=opts.capabilities, + server_info=Implementation( + name=opts.server_name, + title=opts.title, + description=opts.description, + version=opts.server_version, + website_url=opts.website_url, + icons=opts.icons, + ), + instructions=opts.instructions, + ) + return _dump_result(result) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 3fc7bbf0d3..9016f05a0c 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -1,188 +1,97 @@ -"""ServerSession Module - -This module provides the ServerSession class, which manages communication between the -server and client in the MCP (Model Context Protocol) framework. It is most commonly -used in MCP servers to interact with the client. - -Common usage pattern: -``` - async def handle_call_tool(ctx: RequestContext, params: CallToolRequestParams) -> CallToolResult: - # Check client capabilities before proceeding - if ctx.session.check_client_capability( - types.ClientCapabilities(experimental={"advanced_tools": dict()}) - ): - result = await perform_advanced_tool_operation(params.arguments) - else: - result = await perform_basic_tool_operation(params.arguments) - return result - - async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: - if ctx.session.client_params: - return ListPromptsResult(prompts=generate_custom_prompts(ctx.session.client_params)) - return ListPromptsResult(prompts=default_prompts) - - server = Server(name, on_call_tool=handle_call_tool, on_list_prompts=handle_list_prompts) -``` - -The ServerSession class is typically used internally by the Server class and should not -be instantiated directly by users of the MCP framework. +"""`ServerSession`: server-to-client requests and notifications. + +A thin proxy over `JSONRPCDispatcher` and `Connection`. One instance per +client connection (built by `ServerRunner`). Handlers reach it as +`ctx.session` and use the typed helpers (`create_message`, `elicit_form`, +`send_log_message`, ...) to call back to the client. + +The receive-loop, initialize handling, and per-request task isolation that +used to live here are now owned by `JSONRPCDispatcher` and `ServerRunner`. """ -from enum import Enum -from typing import Any, TypeVar, overload +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, TypeVar, overload -import anyio -import anyio.lowlevel -from anyio.streams.memory import MemoryObjectReceiveStream -from pydantic import AnyUrl, TypeAdapter +from pydantic import AnyUrl, BaseModel from mcp import types -from mcp.server.models import InitializationOptions from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, ProgressFnT from mcp.shared.exceptions import StatelessModeNotSupported -from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.session import ( - BaseSession, - RequestResponder, -) -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import MessageMetadata, ServerMessageMetadata +if TYPE_CHECKING: + from mcp.server.connection import Connection -class InitializationState(Enum): - NotInitialized = 1 - Initializing = 2 - Initialized = 3 +__all__ = ["ServerSession"] +ResultT = TypeVar("ResultT", bound=BaseModel) -ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession") -ServerRequestResponder = ( - RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception -) +class ServerSession: + """Connection-scoped proxy for server-to-client requests and notifications. - -class ServerSession( - BaseSession[ - types.ServerRequest, - types.ServerNotification, - types.ServerResult, - types.ClientRequest, - types.ClientNotification, - ] -): - _initialized: InitializationState = InitializationState.NotInitialized - _client_params: types.InitializeRequestParams | None = None + `send_request` / `send_notification` model-dump their argument and forward + to the dispatcher; the typed helpers below are unchanged from the previous + implementation and only call those two methods. + """ def __init__( self, - read_stream: ReadStream[SessionMessage | Exception], - write_stream: WriteStream[SessionMessage], - init_options: InitializationOptions, + dispatcher: JSONRPCDispatcher[Any], + connection: Connection, + *, stateless: bool = False, ) -> None: - super().__init__(read_stream, write_stream) + self._dispatcher = dispatcher + self._connection = connection self._stateless = stateless - self._initialization_state = ( - InitializationState.Initialized if stateless else InitializationState.NotInitialized - ) - - self._init_options = init_options - self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ - ServerRequestResponder - ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) @property - def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]: - return types.client_request_adapter + def client_params(self) -> types.InitializeRequestParams | None: + """The client's `initialize` request params; `None` before initialization.""" + return self._connection.client_params - @property - def _receive_notification_adapter(self) -> TypeAdapter[types.ClientNotification]: - return types.client_notification_adapter + async def send_request( + self, + request: types.ServerRequest, + result_type: type[ResultT], + request_read_timeout_seconds: float | None = None, + metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, + ) -> ResultT: + """Send a typed server-to-client request and validate the result. + + `metadata.related_request_id` (when supplied) routes the outgoing + message onto the originating request's response stream over + streamable HTTP. + """ + data = request.model_dump(by_alias=True, mode="json", exclude_none=True) + opts: CallOptions = {} + if request_read_timeout_seconds is not None: + opts["timeout"] = request_read_timeout_seconds + if progress_callback is not None: + opts["on_progress"] = progress_callback + related = metadata.related_request_id if isinstance(metadata, ServerMessageMetadata) else None + result = await self._dispatcher.send_raw_request( + data["method"], data.get("params"), opts or None, _related_request_id=related + ) + return result_type.model_validate(result, by_name=False) - @property - def client_params(self) -> types.InitializeRequestParams | None: - return self._client_params + async def send_notification( + self, + notification: types.ServerNotification, + related_request_id: types.RequestId | None = None, + ) -> None: + """Send a typed server-to-client notification.""" + data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) + await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" - if self._client_params is None: # pragma: lax no cover - return False - - client_caps = self._client_params.capabilities - - if capability.roots is not None: # pragma: lax no cover - if client_caps.roots is None: - return False - if capability.roots.list_changed and not client_caps.roots.list_changed: - return False - - if capability.sampling is not None: # pragma: lax no cover - if client_caps.sampling is None: - return False - if capability.sampling.context is not None and client_caps.sampling.context is None: - return False - if capability.sampling.tools is not None and client_caps.sampling.tools is None: - return False - - if capability.elicitation is not None and client_caps.elicitation is None: # pragma: lax no cover - return False - - if capability.experimental is not None: # pragma: lax no cover - if client_caps.experimental is None: - return False - for exp_key, exp_value in capability.experimental.items(): - if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value: - return False - - return True - - async def _receive_loop(self) -> None: - async with self._incoming_message_stream_writer: - await super()._receive_loop() - - async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]): - match responder.request: - case types.InitializeRequest(params=params): - requested_version = params.protocol_version - self._initialization_state = InitializationState.Initializing - self._client_params = params - with responder: - await responder.respond( - types.InitializeResult( - protocol_version=requested_version - if requested_version in SUPPORTED_PROTOCOL_VERSIONS - else types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - server_info=types.Implementation( - name=self._init_options.server_name, - title=self._init_options.title, - description=self._init_options.description, - version=self._init_options.server_version, - website_url=self._init_options.website_url, - icons=self._init_options.icons, - ), - instructions=self._init_options.instructions, - ) - ) - self._initialization_state = InitializationState.Initialized - case types.PingRequest(): - # Ping requests are allowed at any time - pass - case _: - if self._initialization_state != InitializationState.Initialized: - raise RuntimeError("Received request before initialization was complete") - - async def _received_notification(self, notification: types.ClientNotification) -> None: - # Need this to avoid ASYNC910 - await anyio.lowlevel.checkpoint() - match notification: - case types.InitializedNotification(): - self._initialization_state = InitializationState.Initialized - case _: - if self._initialization_state != InitializationState.Initialized: # pragma: no cover - raise RuntimeError("Received notification before initialization was complete") + return self._connection.check_capability(capability) async def send_log_message( self, @@ -293,7 +202,7 @@ async def create_message( """ if self._stateless: raise StatelessModeNotSupported(method="sampling") - client_caps = self._client_params.capabilities if self._client_params else None + client_caps = self.client_params.capabilities if self.client_params else None validate_sampling_tools(client_caps, tools, tool_choice) validate_tool_use_result_messages(messages) @@ -313,7 +222,6 @@ async def create_message( ) metadata_obj = ServerMessageMetadata(related_request_id=related_request_id) - # Use different result types based on whether tools are provided if tools is not None: return await self.send_request( request=request, @@ -488,10 +396,3 @@ async def send_elicit_complete( ), related_request_id, ) - - async def _handle_incoming(self, req: ServerRequestResponder) -> None: - await self._incoming_message_stream_writer.send(req) - - @property - def incoming_messages(self) -> MemoryObjectReceiveStream[ServerRequestResponder]: - return self._incoming_message_stream_reader diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407cea..217444793a 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -635,7 +635,10 @@ async def sse_writer(): finally: await sse_stream_reader.aclose() - except Exception as err: + except Exception as err: # pragma: lax no cover + # Reached only when something raises during POST handling outside + # the per-SSE-stream guard above; whether tests reach this depends + # on client teardown timing. logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -643,7 +646,7 @@ async def sse_writer(): INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: # pragma: no cover + if writer: await writer.send(Exception(err)) return # pragma: no cover diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 81350a8f24..9bcf3cb883 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -9,6 +9,7 @@ from uuid import uuid4 import anyio +import anyio.lowlevel from anyio.abc import TaskStatus from starlette.requests import Request from starlette.responses import Response @@ -139,6 +140,13 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: # Clear any remaining server instances self._server_instances.clear() self._session_owners.clear() + # The cancel above is delivered via `coro.throw()` into this task at + # the task-group join; on CPython 3.11 (gh-106749) that drops `'call'` + # trace events for the outer await chain and desyncs coverage's CTracer + # past the caller's frame. Yielding once here resumes via `.send()`, + # which re-stamps the missing `'call'` events and resyncs the tracer. + # Shielded so a pending outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Process ASGI request with proper session handling and transport setup. diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 170e873a0f..553b8a0bce 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -20,9 +20,18 @@ def otel_span( kind: SpanKind, attributes: dict[str, Any] | None = None, context: Context | None = None, + record_exception: bool = True, + set_status_on_exception: bool = True, ) -> Iterator[Any]: """Create an OTel span.""" - with _tracer.start_as_current_span(name, kind=kind, attributes=attributes, context=context) as span: + with _tracer.start_as_current_span( + name, + kind=kind, + attributes=attributes, + context=context, + record_exception=record_exception, + set_status_on_exception=set_status_on_exception, + ) as span: yield span diff --git a/src/mcp/shared/context.py b/src/mcp/shared/context.py new file mode 100644 index 0000000000..849054dda0 --- /dev/null +++ b/src/mcp/shared/context.py @@ -0,0 +1,86 @@ +"""`BaseContext` - the user-facing per-request context. + +Composition over a `DispatchContext`: forwards the transport metadata, the +back-channel (`send_raw_request`/`notify`), progress reporting, and the cancel +event. Adds `meta` (the inbound request's `_meta` field). + +Satisfies `Outbound`, so `PeerMixin` works on it (the server-side `Context` +mixes that in directly). Shared between client and server: the server's +`Context` extends this with `lifespan`/`connection`; `ClientContext` is just an +alias. +""" + +from collections.abc import Mapping +from typing import Any, Generic + +import anyio +from typing_extensions import TypeVar + +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import RequestParamsMeta + +__all__ = ["BaseContext"] + +TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True) + + +class BaseContext(Generic[TransportT]): + """Per-request context wrapping a `DispatchContext`. + + `ServerRunner` constructs one per inbound request and passes it to the + user's handler. + """ + + def __init__(self, dctx: DispatchContext[TransportT], meta: RequestParamsMeta | None = None) -> None: + self._dctx = dctx + self._meta = meta + + @property + def transport(self) -> TransportT: + """Transport-specific metadata for this inbound request.""" + return self._dctx.transport + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends `notifications/cancelled` for this request.""" + return self._dctx.cancel_requested + + @property + def can_send_request(self) -> bool: + """Whether the back-channel can currently deliver server-initiated requests. + + `False` when the transport has no back-channel, or when the underlying + dispatch context has been closed because the inbound request finished. + """ + return self._dctx.can_send_request + + @property + def meta(self) -> RequestParamsMeta | None: + """The inbound request's `_meta` field, if present.""" + return self._meta + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request to the peer on the back-channel. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: `can_send_request` is `False`. + """ + return await self._dctx.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a notification to the peer on the back-channel.""" + await self._dctx.notify(method, params) + + async def report_progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for this request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + await self._dctx.progress(progress, total, message) diff --git a/src/mcp/shared/direct_dispatcher.py b/src/mcp/shared/direct_dispatcher.py new file mode 100644 index 0000000000..1b07b87d77 --- /dev/null +++ b/src/mcp/shared/direct_dispatcher.py @@ -0,0 +1,197 @@ +"""In-memory `Dispatcher` that wires two peers together with no transport. + +`DirectDispatcher` is the simplest possible `Dispatcher` implementation: a +request on one side directly invokes the other side's `on_request`. There is no +serialization, no JSON-RPC framing, and no streams. It exists to: + +* prove the `Dispatcher` Protocol is implementable without JSON-RPC +* provide a fast substrate for testing the layers above the dispatcher + (`ServerRunner`, `Context`, `Connection`) without wire-level moving parts +* embed a server in-process when the JSON-RPC overhead is unnecessary + +Unlike `JSONRPCDispatcher`, exceptions raised in a handler propagate directly +to the caller - there is no exception-to-`ErrorData` boundary here. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any + +import anyio +import anyio.abc +from pydantic import ValidationError + +from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import MessageMetadata +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, REQUEST_TIMEOUT, RequestId + +__all__ = ["DirectDispatcher", "create_direct_dispatcher_pair"] + +DIRECT_TRANSPORT_KIND = "direct" + + +_Request = Callable[[str, Mapping[str, Any] | None, CallOptions | None], Awaitable[dict[str, Any]]] +_Notify = Callable[[str, Mapping[str, Any] | None], Awaitable[None]] + + +@dataclass +class _DirectDispatchContext: + """`DispatchContext` for an inbound request on a `DirectDispatcher`. + + The back-channel callables target the *originating* side, so a handler's + `send_raw_request` reaches the peer that made the inbound request. + """ + + transport: TransportContext + _back_request: _Request + _back_notify: _Notify + request_id: RequestId | None = None + """Always `None`: direct dispatch has no wire-level request id.""" + message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework + """Always `None`: in-memory dispatch attaches no transport metadata.""" + _on_progress: ProgressFnT | None = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._back_notify(method, params) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._back_request(method, params, opts) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._on_progress is not None: + await self._on_progress(progress, total, message) + + +class DirectDispatcher: + """A `Dispatcher` that calls a peer's handlers directly, in-process. + + Two instances are wired together with `create_direct_dispatcher_pair`; each + holds a reference to the other. `send_raw_request` on one awaits the peer's + `on_request`. `run` parks until `close` is called. + """ + + def __init__(self, transport_ctx: TransportContext): + self._transport_ctx = transport_ctx + self._peer: DirectDispatcher | None = None + self._on_request: OnRequest | None = None + self._on_notify: OnNotify | None = None + self._ready = anyio.Event() + self._closed = anyio.Event() + + def connect_to(self, peer: DirectDispatcher) -> None: + self._peer = peer + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + return await self._peer._dispatch_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._peer is None: + raise RuntimeError("DirectDispatcher has no peer; use create_direct_dispatcher_pair()") + await self._peer._dispatch_notify(method, params) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + self._on_request = on_request + self._on_notify = on_notify + self._ready.set() + task_status.started() + await self._closed.wait() + + def close(self) -> None: + self._closed.set() + + def _make_context(self, on_progress: ProgressFnT | None = None) -> _DirectDispatchContext: + assert self._peer is not None + peer = self._peer + return _DirectDispatchContext( + transport=self._transport_ctx, + _back_request=lambda m, p, o: peer._dispatch_request(m, p, o), + _back_notify=lambda m, p: peer._dispatch_notify(m, p), + _on_progress=on_progress, + ) + + async def _dispatch_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None, + ) -> dict[str, Any]: + await self._ready.wait() + assert self._on_request is not None + opts = opts or {} + dctx = self._make_context(on_progress=opts.get("on_progress")) + try: + with anyio.fail_after(opts.get("timeout")): + try: + return await self._on_request(dctx, method, params) + except MCPError: + raise + except ValidationError as e: + # Same shape JSONRPCDispatcher writes, so runner-over-direct + # tests see what runner-over-JSONRPC would. + raise MCPError(code=INVALID_PARAMS, message="Invalid request parameters", data="") from e + except Exception as e: + raise MCPError(code=INTERNAL_ERROR, message=str(e)) from e + except TimeoutError: + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out after {opts.get('timeout')}s waiting for {method!r}", + ) from None + + async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._ready.wait() + assert self._on_notify is not None + dctx = self._make_context() + await self._on_notify(dctx, method, params) + + +def create_direct_dispatcher_pair( + *, + can_send_request: bool = True, + headers: Mapping[str, str] | None = None, +) -> tuple[DirectDispatcher, DirectDispatcher]: + """Create two `DirectDispatcher` instances wired to each other. + + Args: + can_send_request: Sets `TransportContext.can_send_request` on both + sides. Pass `False` to simulate a transport with no back-channel. + headers: Sets `TransportContext.headers` on both sides. + + Returns: + A `(left, right)` pair. Conventionally `left` is the client side + and `right` is the server side, but the wiring is symmetric. + """ + ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers) + left = DirectDispatcher(ctx) + right = DirectDispatcher(ctx) + left.connect_to(right) + right.connect_to(left) + return left, right diff --git a/src/mcp/shared/dispatcher.py b/src/mcp/shared/dispatcher.py new file mode 100644 index 0000000000..9dfb24940a --- /dev/null +++ b/src/mcp/shared/dispatcher.py @@ -0,0 +1,192 @@ +"""Dispatcher Protocol - the call/return boundary between transports and handlers. + +A Dispatcher turns a duplex message channel into two things: + +* an outbound API: `send_raw_request(method, params)` and `notify(method, params)` +* an inbound pump: `run(on_request, on_notify)` that drives the receive loop + and invokes the supplied handlers for each incoming request/notification + +It is deliberately *not* MCP-aware. Method names are strings, params and +results are `dict[str, Any]`. The MCP type layer (request/result models, +capability negotiation, `Context`) sits above this; the wire encoding +(JSON-RPC, gRPC, in-process direct calls) sits below it. + +See `JSONRPCDispatcher` for the production implementation and +`DirectDispatcher` for an in-memory implementation used in tests and for +embedding a server in-process. +""" + +from collections.abc import Awaitable, Callable, Mapping +from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable + +import anyio +import anyio.abc + +from mcp.shared.message import MessageMetadata +from mcp.shared.transport_context import TransportContext +from mcp.types import RequestId + +__all__ = [ + "CallOptions", + "DispatchContext", + "DispatchMiddleware", + "Dispatcher", + "OnNotify", + "OnRequest", + "Outbound", + "ProgressFnT", +] + +TransportT_co = TypeVar("TransportT_co", bound=TransportContext, covariant=True) + + +class ProgressFnT(Protocol): + """Callback invoked when a progress notification arrives for a pending request.""" + + async def __call__(self, progress: float, total: float | None, message: str | None) -> None: ... + + +class CallOptions(TypedDict, total=False): + """Per-call options for `Outbound.send_raw_request`. + + All keys are optional. Dispatchers ignore keys they do not understand. + """ + + timeout: float + """Seconds to wait for a result before raising and sending `notifications/cancelled`.""" + + on_progress: ProgressFnT + """Receive `notifications/progress` updates for this request.""" + + resumption_token: str + """Opaque token to resume a previously interrupted request (transport-dependent).""" + + on_resumption_token: Callable[[str], Awaitable[None]] + """Receive a resumption token when the transport issues one.""" + + +@runtime_checkable +class Outbound(Protocol): + """Anything that can send requests and notifications to the peer. + + Both `Dispatcher` (top-level outbound) and `DispatchContext` (back-channel + during an inbound request) extend this. The MCP type layer (`PeerMixin`, + `Connection`, `Context`) builds typed `send_request` / convenience methods + on top of this raw channel. + """ + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + """Send a request and await its raw result dict. + + Raises: + MCPError: If the peer responded with an error, or the handler + raised. Implementations normalize all handler exceptions to + `MCPError` so callers see a single exception type. + """ + ... + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + """Send a fire-and-forget notification.""" + ... + + +class DispatchContext(Outbound, Protocol[TransportT_co]): + """Per-request context handed to `on_request` / `on_notify`. + + Carries the transport metadata for the inbound message and provides the + back-channel for sending requests/notifications to the peer while handling + it. `send_raw_request` raises `NoBackChannelError` if `can_send_request` + is `False`. + """ + + @property + def transport(self) -> TransportT_co: + """Transport-specific metadata for this inbound message.""" + ... + + @property + def can_send_request(self) -> bool: + """Whether the back-channel can currently deliver server-initiated requests. + + `False` when the transport has no back-channel, or when this context has + been closed (the inbound request finished). `send_raw_request` raises + `NoBackChannelError` exactly when this is `False`. + """ + ... + + @property + def request_id(self) -> RequestId | None: + """The id of the inbound request, or `None` for a notification. + + For JSON-RPC this is the wire `id` field. Handlers thread it through + as `related_request_id` on outbound notifications so HTTP transports + can route them onto the originating request's response stream. + """ + ... + + @property + def message_metadata(self) -> MessageMetadata: + """The metadata the transport attached to this inbound message, if any. + + This is `SessionMessage.metadata` passed through verbatim: HTTP + transports attach `ServerMessageMetadata` (the HTTP request, SSE + stream-close callbacks); stdio and in-memory dispatch attach nothing. + Tied to the `SessionMessage` wire format - goes away when transports + stop delivering messages that way. + """ + # TODO(maxisbey): remove for context rework + ... + + @property + def cancel_requested(self) -> anyio.Event: + """Set when the peer sends `notifications/cancelled` for this request.""" + ... + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + """Report progress for the inbound request, if the peer supplied a progress token. + + A no-op when no token was supplied. + """ + ... + + +OnRequest = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[dict[str, Any]]] +"""Handler for inbound requests: `(ctx, method, params) -> result`. Raise `MCPError` to send an error response.""" + +OnNotify = Callable[[DispatchContext[TransportContext], str, Mapping[str, Any] | None], Awaitable[None]] +"""Handler for inbound notifications: `(ctx, method, params)`.""" + +DispatchMiddleware = Callable[[OnRequest], OnRequest] +"""Wraps an `OnRequest` to produce another `OnRequest`. Applied outermost-first.""" + + +class Dispatcher(Outbound, Protocol[TransportT_co]): + """A duplex request/notification channel with call-return semantics. + + Implementations own correlation of outbound requests to inbound results, the + receive loop, per-request concurrency, and cancellation/progress wiring. + """ + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the underlying channel closes. + + Each inbound request is dispatched to `on_request` in its own task; + the returned dict (or raised `MCPError`) is sent back as the response. + Inbound notifications go to `on_notify`. + + `task_status.started()` is called once the dispatcher is ready to + accept `send_request`/`notify` calls, so callers can use + `await tg.start(dispatcher.run, on_request, on_notify)`. + """ + ... diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index f153ea319d..bb4cfc0d00 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -2,7 +2,7 @@ from typing import Any, cast -from mcp.types import URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError +from mcp.types import INVALID_REQUEST, URL_ELICITATION_REQUIRED, ElicitRequestURLParams, ErrorData, JSONRPCError class MCPError(Exception): @@ -41,6 +41,25 @@ def __str__(self) -> str: return self.message +class NoBackChannelError(MCPError): + """Raised when sending a server-initiated request over a transport that cannot deliver it. + + Stateless HTTP and JSON-response-mode HTTP have no channel for the server to + push requests (sampling, elicitation, roots/list) to the client. This is + raised by `DispatchContext.send_raw_request` when `can_send_request` is + `False`, and serializes to an `INVALID_REQUEST` error response. + """ + + def __init__(self, method: str): + super().__init__( + code=INVALID_REQUEST, + message=( + f"Cannot send {method!r}: this transport context has no back-channel for server-initiated requests." + ), + ) + self.method = method + + class StatelessModeNotSupported(RuntimeError): """Raised when attempting to use a method that is not supported in stateless mode. diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py new file mode 100644 index 0000000000..55eba7486a --- /dev/null +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -0,0 +1,676 @@ +"""JSON-RPC `Dispatcher` implementation. + +Consumes the existing `SessionMessage`-based stream contract that all current +transports (stdio, SSE, streamable HTTP) speak. Owns request-id correlation, +the receive loop, per-request task isolation, cancellation/progress wiring, and +the single exception-to-wire boundary. + +The MCP type layer (`ServerRunner`, `Context`, `Client`) sits above this and +sees only `(ctx, method, params) -> dict`. Transports sit below and see only +`SessionMessage` reads/writes. + +The dispatcher is *mostly* MCP-agnostic - methods/params are opaque strings and +dicts - but it intercepts `notifications/cancelled` and +`notifications/progress` because request correlation, cancellation and +progress are exactly the wiring this layer exists to provide. Those few wire +shapes are extracted with structural `match` patterns (no casts, no +`mcp.types` model coupling); a malformed payload simply fails to match and +the correlation is skipped. +""" + +from __future__ import annotations + +import contextvars +import logging +from collections.abc import Awaitable, Callable, Mapping +from dataclasses import dataclass, field +from typing import Any, Generic, Literal, TypeVar, cast, overload + +import anyio +import anyio.abc +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from opentelemetry.trace import SpanKind +from pydantic import ValidationError + +from mcp.shared._otel import inject_trace_context, otel_span +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import CallOptions, Dispatcher, OnNotify, OnRequest, ProgressFnT +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.message import ( + ClientMessageMetadata, + MessageMetadata, + ServerMessageMetadata, + SessionMessage, +) +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INVALID_PARAMS, + REQUEST_CANCELLED, + REQUEST_TIMEOUT, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + ProgressToken, + RequestId, +) + +__all__ = ["JSONRPCDispatcher"] + +logger = logging.getLogger(__name__) + +TransportT = TypeVar("TransportT", bound=TransportContext) + +PeerCancelMode = Literal["interrupt", "signal"] +"""How inbound `notifications/cancelled` is applied to a running handler. + +`"interrupt"` (default) cancels the handler's scope. `"signal"` only sets +`ctx.cancel_requested` and lets the handler observe it cooperatively. +""" + + +def _coerce_id(request_id: RequestId) -> RequestId: + """Coerce a string request ID to int when it's a valid int literal. + + `_allocate_id` only ever produces `int` keys for `_pending`, but a peer + may echo the ID back as a JSON string. The TypeScript SDK and `BaseSession` + both perform this coercion at lookup time so the response still correlates. + """ + if isinstance(request_id, str): + try: + return int(request_id) + except ValueError: + pass + return request_id + + +@dataclass(slots=True) +class _Pending: + """An outbound request awaiting its response.""" + + send: MemoryObjectSendStream[dict[str, Any] | ErrorData] + receive: MemoryObjectReceiveStream[dict[str, Any] | ErrorData] + on_progress: ProgressFnT | None = None + + +@dataclass(slots=True) +class _InFlight(Generic[TransportT]): + """An inbound request currently being handled.""" + + scope: anyio.CancelScope + dctx: _JSONRPCDispatchContext[TransportT] + + +@dataclass +class _JSONRPCDispatchContext(Generic[TransportT]): + """Concrete `DispatchContext` produced for each inbound JSON-RPC message.""" + + transport: TransportT + _dispatcher: JSONRPCDispatcher[TransportT] + _request_id: RequestId | None + message_metadata: MessageMetadata = None # TODO(maxisbey): remove for Context rework + """The transport-attached `SessionMessage.metadata` for this inbound message. + + Carries `ServerMessageMetadata` (HTTP request, SSE stream-close callbacks) + that the server lifts onto its request context. `None` for transports + that attach nothing. + """ + _progress_token: ProgressToken | None = None + _closed: bool = False + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + + @property + def request_id(self) -> RequestId | None: + return self._request_id + + @property + def can_send_request(self) -> bool: + return self.transport.can_send_request and not self._closed + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._closed: + logger.debug("dropped %s: dispatch context closed", method) + return + await self._dispatcher.notify(method, params, _related_request_id=self._request_id) + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + if not self.can_send_request: + raise NoBackChannelError(method) + return await self._dispatcher.send_raw_request(method, params, opts, _related_request_id=self._request_id) + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + if self._progress_token is None: + return + params: dict[str, Any] = {"progressToken": self._progress_token, "progress": progress} + if total is not None: + params["total"] = total + if message is not None: + params["message"] = message + await self.notify("notifications/progress", params) + + def close(self) -> None: + self._closed = True + + +def _default_transport_builder(_meta: MessageMetadata) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + +def _shielded_progress(fn: ProgressFnT) -> ProgressFnT: + """Wrap a user progress callback so it can't crash the dispatcher. + + The callback runs as a bare task in the dispatcher's task group; an + uncaught exception would cancel every sibling (the read loop and all + in-flight requests). Swallow and log instead, matching the previous + receive-loop's behavior. + """ + + async def _wrapped(progress: float, total: float | None, message: str | None) -> None: + try: + await fn(progress, total, message) + except Exception: + logger.exception("progress callback raised") + + return _wrapped + + +def _outbound_metadata(related_request_id: RequestId | None, opts: CallOptions | None) -> MessageMetadata: + """Choose the `SessionMessage.metadata` for an outgoing request/notification. + + `ServerMessageMetadata` tags a server-to-client message with the inbound + request it belongs to (so streamable-HTTP can route it onto that request's + SSE stream). `ClientMessageMetadata` carries resumption hints to the + client transport. `None` is the common case. + """ + if related_request_id is not None: + return ServerMessageMetadata(related_request_id=related_request_id) + if opts: + token = opts.get("resumption_token") + on_token = opts.get("on_resumption_token") + if token is not None or on_token is not None: + return ClientMessageMetadata(resumption_token=token, on_resumption_token_update=on_token) + return None + + +class JSONRPCDispatcher(Dispatcher[TransportT]): + """`Dispatcher` over the existing `SessionMessage` stream contract. + + Inherits the `Dispatcher` Protocol explicitly so pyright checks + conformance at the class definition rather than at first use. + """ + + @overload + def __init__( + self: JSONRPCDispatcher[TransportContext], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + inline_methods: frozenset[str] = frozenset(), + ) -> None: ... + @overload + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[MessageMetadata], TransportT], + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + inline_methods: frozenset[str] = frozenset(), + ) -> None: ... + def __init__( + self, + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + transport_builder: Callable[[MessageMetadata], TransportT] | None = None, + peer_cancel_mode: PeerCancelMode = "interrupt", + raise_handler_exceptions: bool = False, + inline_methods: frozenset[str] = frozenset(), + ) -> None: + self._read_stream = read_stream + self._write_stream = write_stream + # The overloads guarantee that when `transport_builder` is omitted, + # `TransportT` is `TransportContext`, so the default is type-correct; + # pyright can't see across overloads, hence the cast. + self._transport_builder = cast( + "Callable[[MessageMetadata], TransportT]", + transport_builder or _default_transport_builder, + ) + self._peer_cancel_mode: PeerCancelMode = peer_cancel_mode + self._raise_handler_exceptions = raise_handler_exceptions + self._inline_methods = inline_methods + """Request methods handled inline in the read loop (awaited before the + next message is dequeued) instead of spawned concurrently. Use for + methods whose side effects must be observable to the next message, + e.g. `initialize`, so a pipelined follow-up sees the initialized state. + Only suitable for handlers that complete quickly, since inline handling + blocks dequeuing; a handler that awaits the peer (`send_raw_request`) + while inline will deadlock because the parked read loop cannot dequeue + the response.""" + + self._next_id = 0 + self._pending: dict[RequestId, _Pending] = {} + self._in_flight: dict[RequestId, _InFlight[TransportT]] = {} + self._tg: anyio.abc.TaskGroup | None = None + self._running = False + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: RequestId | None = None, + ) -> dict[str, Any]: + """Send a JSON-RPC request and await its response. + + `_related_request_id` is set only by `_JSONRPCDispatchContext` when a + handler makes a server-to-client request mid-flight; it routes the + outgoing message onto the correct per-request SSE stream (SHTTP) via + `ServerMessageMetadata`. Top-level callers leave it `None`. + + Raises: + MCPError: The peer responded with a JSON-RPC error; or + `REQUEST_TIMEOUT` if `opts["timeout"]` elapsed; or + `CONNECTION_CLOSED` if the dispatcher shut down while + awaiting the response. + RuntimeError: Called before `run()` has started or after it has + finished. + """ + if not self._running: + raise RuntimeError("JSONRPCDispatcher.send_raw_request called before run() / after close") + opts = opts or {} + request_id = self._allocate_id() + out_params = dict(params) if params is not None else {} + out_meta = dict(out_params.get("_meta") or {}) + on_progress = opts.get("on_progress") + if on_progress is not None: + # The caller wants progress updates. The spec mechanism is: include + # `_meta.progressToken` on the request; the peer echoes that token on + # any `notifications/progress` it sends. We use the request id as the + # token so the receive loop can find this `_Pending.on_progress` by + # `_pending[token]` without a second lookup table. + out_meta["progressToken"] = request_id + out_params["_meta"] = out_meta + + # buffer=1: at most one outcome is ever delivered. A `WouldBlock` from + # `_resolve_pending`/`_fan_out_closed` means the waiter already has an + # outcome and dropping the late/redundant signal is correct. buffer=0 + # is unsafe - there's a window between registering `_pending[id]` and + # parking in `receive()` where a close signal would be lost. + send, receive = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + pending = _Pending(send=send, receive=receive, on_progress=on_progress) + self._pending[request_id] = pending + + metadata = _outbound_metadata(_related_request_id, opts) + target = out_params.get("name") + span_name = f"MCP send {method}{f' {target}' if isinstance(target, str) else ''}" + # TODO(maxisbey): the otel span + inject below mirror + # BaseSession.send_request for parity. They belong in an outbound + # middleware (symmetric with otel_middleware on the inbound side) once + # that seam exists; the dispatcher should not own otel. + try: + with otel_span( + span_name, + kind=SpanKind.CLIENT, + attributes={"mcp.method.name": method, "jsonrpc.request.id": request_id}, + ): + # Inject W3C trace context into _meta (SEP-414). With a no-op + # tracer this writes nothing, but `_meta` itself is still + # present on the wire (and the interaction suite pins that). + inject_trace_context(out_meta) + msg = JSONRPCRequest(jsonrpc="2.0", id=request_id, method=method, params=out_params) + await self._write(msg, metadata) + with anyio.fail_after(opts.get("timeout")): + outcome = await receive.receive() + except TimeoutError: + # Spec-recommended courtesy: tell the peer we've given up so it can + # stop work and free resources. v1's BaseSession.send_request does + # NOT do this; it's new behaviour. + await self._cancel_outbound(request_id, f"timed out after {opts.get('timeout')}s", _related_request_id) + raise MCPError(code=REQUEST_TIMEOUT, message=f"Request {method!r} timed out") from None + except anyio.get_cancelled_exc_class(): + # Our caller's scope was cancelled. We're already inside a cancelled + # scope, so any bare `await` here re-raises immediately - shield to + # let the courtesy cancel notification go out before we propagate. + with anyio.CancelScope(shield=True): + await self._cancel_outbound(request_id, "caller cancelled", _related_request_id) + raise + finally: + # Always remove the waiter, even on cancel/timeout, so a late + # response from the peer (race) hits a closed stream and is dropped + # in `_dispatch` rather than leaking. + self._pending.pop(request_id, None) + send.close() + receive.close() + + if isinstance(outcome, ErrorData): + raise MCPError(code=outcome.code, message=outcome.message, data=outcome.data) + return outcome + + async def notify( + self, + method: str, + params: Mapping[str, Any] | None, + *, + _related_request_id: RequestId | None = None, + ) -> None: + msg = JSONRPCNotification(jsonrpc="2.0", method=method, params=dict(params) if params is not None else None) + await self._write(msg, _outbound_metadata(_related_request_id, None)) + + async def run( + self, + on_request: OnRequest, + on_notify: OnNotify, + *, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Drive the receive loop until the read stream closes. + + Each inbound request is handled in its own task in an internal task + group; `task_status.started()` fires once that group is open, so + `await tg.start(dispatcher.run, ...)` resumes when `send_raw_request` + is usable. + """ + try: + async with anyio.create_task_group() as tg: + self._tg = tg + self._running = True + task_status.started() + try: + async with self._read_stream, self._write_stream: + try: + async for item in self._read_stream: + # Duck-typed: `_context_streams.ContextReceiveStream` + # exposes `.last_context` (the sender's contextvars + # snapshot per message). Plain memory streams don't. + sender_ctx: contextvars.Context | None = getattr( + self._read_stream, "last_context", None + ) + await self._dispatch(item, on_request, on_notify, sender_ctx) + except anyio.ClosedResourceError: + # The transport closed our receive end and we looped + # back to `__anext__` on the now-closed stream + # (stateless SHTTP teardown). Same as EOF. + logger.debug("read stream closed by transport; treating as EOF") + # Read stream EOF: wake any blocked `send_raw_request` waiters + # (callers outside this task group) with CONNECTION_CLOSED. + self._running = False + self._fan_out_closed() + finally: + # Transport closed: cancel in-flight handlers. Without this + # the task-group join waits for them, and a handler that + # outlives its caller (its request timed out client-side, or + # the client disconnected mid-call) would keep `run()` from + # returning forever. Same behaviour as `Server.run()` before + # the dispatcher rework. + tg.cancel_scope.cancel() + finally: + # Covers the cancel/crash paths where the inline fan-out above is + # never reached. Idempotent. + self._running = False + self._tg = None + self._fan_out_closed() + + async def _dispatch( + self, + item: SessionMessage | Exception, + on_request: OnRequest, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + """Route one inbound item. + + Everything here is `send_nowait` or `_spawn`; the only `await` is for + `inline_methods` requests, which deliberately block dequeuing until + handled. Any other `await` would let one slow message head-of-line + block the entire read loop. + """ + if isinstance(item, Exception): + logger.debug("transport yielded exception: %r", item) + return + metadata = item.metadata + msg = item.message + match msg: + case JSONRPCRequest(): + await self._dispatch_request(msg, metadata, on_request, sender_ctx) + case JSONRPCNotification(): + self._dispatch_notification(msg, metadata, on_notify, sender_ctx) + case JSONRPCResponse(): + self._resolve_pending(msg.id, msg.result) + case JSONRPCError(): # pragma: no branch + # `id` may be None per JSON-RPC (parse error before id known). + # The match is exhaustive over JSONRPCMessage; the no-match arc + # on this final case is unreachable. + self._resolve_pending(msg.id, msg.error) + + async def _dispatch_request( + self, + req: JSONRPCRequest, + metadata: MessageMetadata, + on_request: OnRequest, + sender_ctx: contextvars.Context | None, + ) -> None: + progress_token: ProgressToken | None + match req.params: + case {"_meta": {"progressToken": str() | int() as progress_token}}: + pass + case _: + progress_token = None + transport_ctx = self._transport_builder(metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, + _dispatcher=self, + _request_id=req.id, + message_metadata=metadata, + _progress_token=progress_token, + ) + scope = anyio.CancelScope() + self._in_flight[req.id] = _InFlight(scope=scope, dctx=dctx) + if req.method in self._inline_methods: + # Spawn (so `sender_ctx` applies, matching the concurrent path) but + # park the read loop until the handler returns; that's the inline + # ordering guarantee. Because the read loop is parked, a handler + # that awaits the peer here (e.g. `dctx.send_raw_request`) will + # deadlock: the response can never be dequeued. + done = anyio.Event() + + async def _run_inline() -> None: + try: + await self._handle_request(req, dctx, scope, on_request) + finally: + done.set() + + self._spawn(_run_inline, sender_ctx=sender_ctx) + await done.wait() + else: + self._spawn(self._handle_request, req, dctx, scope, on_request, sender_ctx=sender_ctx) + + def _dispatch_notification( + self, + msg: JSONRPCNotification, + metadata: MessageMetadata, + on_notify: OnNotify, + sender_ctx: contextvars.Context | None, + ) -> None: + if msg.method == "notifications/cancelled": + match msg.params: + case {"requestId": str() | int() as rid} if (in_flight := self._in_flight.get(rid)) is not None: + in_flight.dctx.cancel_requested.set() + if self._peer_cancel_mode == "interrupt": + in_flight.scope.cancel() + case _: + pass + return + if msg.method == "notifications/progress": + match msg.params: + case {"progressToken": str() | int() as token, "progress": int() | float() as progress} if ( + pending := self._pending.get(_coerce_id(token)) + ) is not None and pending.on_progress is not None: + total = msg.params.get("total") + message = msg.params.get("message") + self._spawn( + _shielded_progress(pending.on_progress), + float(progress), + float(total) if isinstance(total, int | float) else None, + message if isinstance(message, str) else None, + sender_ctx=sender_ctx, + ) + case _: + pass + # fall through: progress is also teed to on_notify + transport_ctx = self._transport_builder(metadata) + dctx = _JSONRPCDispatchContext( + transport=transport_ctx, _dispatcher=self, _request_id=None, message_metadata=metadata + ) + self._spawn(on_notify, dctx, msg.method, msg.params, sender_ctx=sender_ctx) + + def _resolve_pending(self, request_id: RequestId | None, outcome: dict[str, Any] | ErrorData) -> None: + pending = self._pending.get(_coerce_id(request_id)) if request_id is not None else None + if pending is None: + logger.debug("dropping response for unknown/late request id %r", request_id) + return + try: + pending.send.send_nowait(outcome) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("waiter for request id %r already gone", request_id) + + def _spawn( + self, + fn: Callable[..., Awaitable[Any]], + *args: object, + sender_ctx: contextvars.Context | None, + ) -> None: + """Schedule `fn(*args)` in the run() task group, propagating the sender's contextvars. + + ASGI middleware (auth, OTel) sets contextvars on the request task that + wrote into the read stream. `Context.run(tg.start_soon, ...)` makes + the spawned handler inherit *that* context instead of the receive + loop's, so `auth_context_var` and OTel spans survive. + """ + assert self._tg is not None + if sender_ctx is not None: + sender_ctx.run(self._tg.start_soon, fn, *args) + else: + self._tg.start_soon(fn, *args) + + def _fan_out_closed(self) -> None: + """Wake every pending `send_raw_request` waiter with `CONNECTION_CLOSED`. + + Synchronous (uses `send_nowait`) because it's called from `finally` + which may be inside a cancelled scope. Idempotent. + """ + closed = ErrorData(code=CONNECTION_CLOSED, message="connection closed") + for pending in self._pending.values(): + try: + pending.send.send_nowait(closed) + except (anyio.WouldBlock, anyio.BrokenResourceError, anyio.ClosedResourceError): + pass + self._pending.clear() + + async def _handle_request( + self, + req: JSONRPCRequest, + dctx: _JSONRPCDispatchContext[TransportT], + scope: anyio.CancelScope, + on_request: OnRequest, + ) -> None: + """Run `on_request` for one inbound request and write its response. + + This is the single exception-to-wire boundary: handler exceptions are + caught here and serialized to `JSONRPCError`. Nothing above this in + the stack constructs wire errors. + """ + try: + with scope: + try: + result = await on_request(dctx, req.method, req.params) + finally: + # Handler done: close the back-channel (detached work that + # later calls `dctx.send_raw_request()` should see + # `NoBackChannelError`) and drop from `_in_flight` so a + # late `notifications/cancelled` is a no-op rather than + # racing the result write below. No checkpoint between + # handler return and the pop, so the cancel can't + # interleave there. + dctx.close() + self._in_flight.pop(req.id, None) + await self._write_result(req.id, result) + if scope.cancel_called: + # Peer-cancel: `_dispatch_notification` cancelled this scope + # while the handler was running. anyio swallows a scope's *own* + # cancel at __exit__, so execution lands here rather than the + # `except cancelled` arm below. + # TODO(maxisbey): spec says SHOULD NOT respond after cancel. + # The existing server always has, so match that for now. + await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) + except anyio.get_cancelled_exc_class(): + # Outer-cancel: run()'s task group is shutting down. Any bare + # `await` here re-raises immediately, so shield the courtesy write. + with anyio.CancelScope(shield=True): + await self._write_error(req.id, ErrorData(code=REQUEST_CANCELLED, message="Request cancelled")) + raise + except MCPError as e: + await self._write_error(req.id, e.error) + except ValidationError: + # TODO(maxisbey): data="" is pinned compat with the existing + # server (which never leaked pydantic error text onto the wire). + # Consider putting the validation detail in `data` once the + # interaction suite's divergence entry is resolved. + await self._write_error( + req.id, ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + ) + except Exception as e: + logger.exception("handler for %r raised", req.method) + # TODO(maxisbey): code=0 is pinned compat with the existing + # server's `_handle_request`. JSON-RPC says INTERNAL_ERROR + # (-32603); revisit once the suite's divergence entry is resolved. + await self._write_error(req.id, ErrorData(code=0, message=str(e))) + if self._raise_handler_exceptions: + raise + finally: + self._in_flight.pop(req.id, None) + + def _allocate_id(self) -> int: + self._next_id += 1 + return self._next_id + + async def _write(self, message: JSONRPCMessage, metadata: MessageMetadata = None) -> None: + await self._write_stream.send(SessionMessage(message=message, metadata=metadata)) + + async def _write_result(self, request_id: RequestId, result: dict[str, Any]) -> None: + try: + await self._write(JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped result for %r: write stream closed", request_id) + + async def _write_error(self, request_id: RequestId, error: ErrorData) -> None: + try: + await self._write(JSONRPCError(jsonrpc="2.0", id=request_id, error=error)) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logger.debug("dropped error for %r: write stream closed", request_id) + + async def _cancel_outbound(self, request_id: RequestId, reason: str, related_request_id: RequestId | None) -> None: + # Thread `related_request_id` so streamable-HTTP routes the cancel onto + # the same per-request SSE stream as the request it cancels; without it + # the notification falls through to the standalone GET stream and is + # dropped when no GET stream is open. + try: + await self.notify( + "notifications/cancelled", + {"requestId": request_id, "reason": reason}, + _related_request_id=related_request_id, + ) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + pass diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index 468590d095..b20bfa793e 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -5,6 +5,8 @@ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager +import anyio.lowlevel + from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared.message import SessionMessage @@ -28,3 +30,11 @@ async def create_client_server_memory_streams() -> AsyncGenerator[tuple[MessageS async with server_to_client_receive, client_to_server_send, client_to_server_receive, server_to_client_send: yield client_streams, server_streams + # Callers routinely cancel a task group wrapped around these streams just + # before this context exits; that cancel is delivered via `coro.throw()`, + # which on CPython 3.11 (gh-106749) drops `'call'` trace events for the + # outer await chain and desyncs coverage's CTracer past the caller's frame. + # Closing memory streams never suspends, so this is the last chance to + # resync: yielding once resumes via `.send()`, which re-stamps the missing + # `'call'` events. Shielded so a pending outer cancel is not re-delivered. + await anyio.lowlevel.cancel_shielded_checkpoint() diff --git a/src/mcp/shared/peer.py b/src/mcp/shared/peer.py new file mode 100644 index 0000000000..25ec112b02 --- /dev/null +++ b/src/mcp/shared/peer.py @@ -0,0 +1,216 @@ +"""Typed MCP request sugar over an `Outbound`. + +`PeerMixin` defines the server-to-client request methods (sampling, elicitation, +roots, ping) once. Any class that satisfies `Outbound` (i.e. has +`send_raw_request` and `notify`) can mix it in and get the typed methods for +free - `Context`, `Connection`, `Client`, or the bare `Peer` wrapper below. + +The mixin does no capability gating: it builds the params, calls +`self.send_raw_request(method, params)`, and parses the result into the typed +model. Gating (and `NoBackChannelError`) is the host's `send_raw_request`'s job. +""" + +from collections.abc import Mapping +from typing import Any, overload + +from pydantic import BaseModel + +from mcp.shared.dispatcher import CallOptions, Outbound +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + CreateMessageResultWithTools, + ElicitRequestedSchema, + ElicitRequestFormParams, + ElicitRequestURLParams, + ElicitResult, + IncludeContext, + ListRootsResult, + ModelPreferences, + SamplingMessage, + Tool, + ToolChoice, +) + +__all__ = ["Meta", "Peer", "PeerMixin", "dump_params"] + +Meta = dict[str, Any] +"""Type alias for the `_meta` field carried on request/notification params.""" + + +def dump_params(model: BaseModel | None, meta: Meta | None = None) -> dict[str, Any] | None: + """Serialize a params model to a wire dict, merging `meta` into `_meta`. + + Shared by `PeerMixin`, `Connection`, and `TypedServerRequestMixin` so every + typed convenience method gets the same `_meta` handling. `meta` keys take + precedence over any `_meta` already present on the model. + """ + out = model.model_dump(by_alias=True, mode="json", exclude_none=True) if model is not None else None + if meta: + out = dict(out or {}) + out["_meta"] = {**out.get("_meta", {}), **meta} + return out + + +class PeerMixin: + """Typed server-to-client request methods. + + Each method constrains `self` to `Outbound` so the mixin can be applied + to anything with `send_raw_request`/`notify` - pyright checks the host + class structurally at the call site. + """ + + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: None = None, + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult: ... + @overload + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool], + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResultWithTools: ... + async def sample( + self: Outbound, + messages: list[SamplingMessage], + *, + max_tokens: int, + system_prompt: str | None = None, + include_context: IncludeContext | None = None, + temperature: float | None = None, + stop_sequences: list[str] | None = None, + metadata: dict[str, Any] | None = None, + model_preferences: ModelPreferences | None = None, + tools: list[Tool] | None = None, + tool_choice: ToolChoice | None = None, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> CreateMessageResult | CreateMessageResultWithTools: + """Send a `sampling/createMessage` request to the peer. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: The host's transport context has no + back-channel for server-initiated requests. + """ + params = CreateMessageRequestParams( + messages=messages, + system_prompt=system_prompt, + include_context=include_context, + temperature=temperature, + max_tokens=max_tokens, + stop_sequences=stop_sequences, + metadata=metadata, + model_preferences=model_preferences, + tools=tools, + tool_choice=tool_choice, + ) + result = await self.send_raw_request("sampling/createMessage", dump_params(params, meta), opts) + if tools is not None: + return CreateMessageResultWithTools.model_validate(result, by_name=False) + return CreateMessageResult.model_validate(result, by_name=False) + + async def elicit_form( + self: Outbound, + message: str, + requested_schema: ElicitRequestedSchema, + *, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a form-mode `elicitation/create` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestFormParams(message=message, requested_schema=requested_schema) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) + return ElicitResult.model_validate(result, by_name=False) + + async def elicit_url( + self: Outbound, + message: str, + url: str, + elicitation_id: str, + *, + meta: Meta | None = None, + opts: CallOptions | None = None, + ) -> ElicitResult: + """Send a URL-mode `elicitation/create` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + params = ElicitRequestURLParams(message=message, url=url, elicitation_id=elicitation_id) + result = await self.send_raw_request("elicitation/create", dump_params(params, meta), opts) + return ElicitResult.model_validate(result, by_name=False) + + async def list_roots( + self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None + ) -> ListRootsResult: + """Send a `roots/list` request. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + result = await self.send_raw_request("roots/list", dump_params(None, meta), opts) + return ListRootsResult.model_validate(result, by_name=False) + + async def ping(self: Outbound, *, meta: Meta | None = None, opts: CallOptions | None = None) -> None: + """Send a `ping` request and ignore the result. + + Raises: + MCPError: The peer responded with an error. + NoBackChannelError: No back-channel for server-initiated requests. + """ + await self.send_raw_request("ping", dump_params(None, meta), opts) + + +class Peer(PeerMixin): + """Standalone wrapper that gives any `Outbound` the `PeerMixin` sugar. + + `Context` and `Connection` mix `PeerMixin` in directly; use `Peer` when + you have a bare dispatcher (or any `Outbound`) and want the typed methods + without writing your own host class. + """ + + def __init__(self, outbound: Outbound) -> None: + self._outbound = outbound + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + return await self._outbound.send_raw_request(method, params, opts) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + await self._outbound.notify(method, params) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ea5d8833bd..afed6d54f1 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -2,12 +2,12 @@ import contextvars import logging -from collections.abc import Callable from contextlib import AsyncExitStack from types import TracebackType from typing import Any, Generic, Protocol, TypeVar import anyio +import anyio.lowlevel from anyio.streams.memory import MemoryObjectSendStream from opentelemetry.trace import SpanKind from pydantic import BaseModel, TypeAdapter @@ -80,7 +80,6 @@ def __init__( request_meta: RequestParamsMeta | None, request: ReceiveRequestT, session: BaseSession[SendRequestT, SendNotificationT, SendResultT, ReceiveRequestT, ReceiveNotificationT], - on_complete: Callable[[RequestResponder[ReceiveRequestT, SendResultT]], Any], message_metadata: MessageMetadata = None, context: contextvars.Context | None = None, ) -> None: @@ -91,15 +90,10 @@ def __init__( self.context = context self._session = session self._completed = False - self._cancel_scope = anyio.CancelScope() - self._on_complete = on_complete self._entered = False # Track if we're in a context manager def __enter__(self) -> RequestResponder[ReceiveRequestT, SendResultT]: - """Enter the context manager, enabling request cancellation tracking.""" self._entered = True - self._cancel_scope = anyio.CancelScope() - self._cancel_scope.__enter__() return self def __exit__( @@ -108,15 +102,7 @@ def __exit__( exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: - """Exit the context manager, performing cleanup and notifying completion.""" - try: - if self._completed: - self._on_complete(self) - finally: - self._entered = False - if not self._cancel_scope: # pragma: no cover - raise RuntimeError("No active cancel scope") - self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) + self._entered = False async def respond(self, response: SendResultT | ErrorData) -> None: """Send a response for this request. @@ -130,37 +116,11 @@ async def respond(self, response: SendResultT | ErrorData) -> None: if not self._entered: # pragma: no cover raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" - - if not self.cancelled: # pragma: no branch - self._completed = True - - await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, response=response - ) - - async def cancel(self) -> None: - """Cancel this request and mark it as completed.""" - if not self._entered: # pragma: no cover - raise RuntimeError("RequestResponder must be used as a context manager") - if not self._cancel_scope: # pragma: no cover - raise RuntimeError("No active cancel scope") - - self._cancel_scope.cancel() - self._completed = True # Mark as completed so it's removed from in_flight - # Send an error response to indicate cancellation + self._completed = True await self._session._send_response( # type: ignore[reportPrivateUsage] - request_id=self.request_id, - response=ErrorData(code=0, message="Request cancelled"), + request_id=self.request_id, response=response ) - @property - def in_flight(self) -> bool: # pragma: no cover - return not self._completed and not self.cancelled - - @property - def cancelled(self) -> bool: - return self._cancel_scope.cancel_called - class BaseSession( Generic[ @@ -180,7 +140,6 @@ class BaseSession( _response_streams: dict[RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]] _request_id: int - _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] def __init__( @@ -195,7 +154,6 @@ def __init__( self._response_streams = {} self._request_id = 0 self._session_read_timeout_seconds = read_timeout_seconds - self._in_flight = {} self._progress_callbacks = {} self._exit_stack = AsyncExitStack() @@ -216,7 +174,15 @@ async def __aexit__( # would be very surprising behavior), so make sure to cancel the tasks # in the task group. self._task_group.cancel_scope.cancel() - return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + # The cancel above is delivered via `coro.throw()` into this task; on + # CPython 3.11 (gh-106749) that drops `'call'` trace events for the + # outer await chain and desyncs coverage's CTracer past the caller's + # frame. Yielding once here resumes via `.send()`, which re-stamps the + # missing `'call'` events and resyncs the tracer. Shielded so a pending + # outer cancel is not re-delivered at this point. + await anyio.lowlevel.cancel_shielded_checkpoint() + return result async def send_request( self, @@ -347,15 +313,10 @@ async def _handle_session_message(message: SessionMessage) -> None: request_meta=validated_request.params.meta if validated_request.params else None, request=validated_request, session=self, - on_complete=lambda r: self._in_flight.pop(r.request_id, None), message_metadata=message.metadata, context=sender_context, ) - self._in_flight[responder.request_id] = responder await self._received_request(responder) - - if not responder._completed: # type: ignore[reportPrivateUsage] - await self._handle_incoming(responder) except Exception: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server @@ -375,33 +336,36 @@ async def _handle_session_message(message: SessionMessage) -> None: message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) - # Handle cancellation notifications if isinstance(notification, CancelledNotification): - cancelled_id = notification.params.request_id - if cancelled_id in self._in_flight: # pragma: no branch - await self._in_flight[cancelled_id].cancel() - else: - # Handle progress notifications callback - if isinstance(notification, ProgressNotification): - progress_token = notification.params.progress_token - # If there is a progress callback for this token, - # call it with the progress information - if progress_token in self._progress_callbacks: - callback = self._progress_callbacks[progress_token] - try: - await callback( - notification.params.progress, - notification.params.total, - notification.params.message, - ) - except Exception: - logging.exception("Progress callback raised an exception") - await self._received_notification(notification) - await self._handle_incoming(notification) + # ClientSession runs server-initiated requests + # inline in this loop, so by the time a peer + # cancellation is read there is nothing left to + # cancel. Consume it here so message_handler + # keeps the contract it had before the + # dispatcher swap removed _in_flight. + return + # Handle progress notifications callback + if isinstance(notification, ProgressNotification): + progress_token = notification.params.progress_token + # If there is a progress callback for this token, + # call it with the progress information + if progress_token in self._progress_callbacks: + callback = self._progress_callbacks[progress_token] + try: + await callback( + notification.params.progress, + notification.params.total, + notification.params.message, + ) + except Exception: + logging.exception("Progress callback raised an exception") + await self._received_notification(notification) + await self._handle_incoming(notification) except Exception: # For other validation errors, log and continue - logging.warning( # pragma: no cover - f"Failed to validate notification:. Message was: {message.message}", + logging.warning( + "Failed to validate notification: %s", + message.message, exc_info=True, ) else: # Response or error diff --git a/src/mcp/shared/transport_context.py b/src/mcp/shared/transport_context.py new file mode 100644 index 0000000000..55e5f6bc5f --- /dev/null +++ b/src/mcp/shared/transport_context.py @@ -0,0 +1,38 @@ +"""Transport-specific metadata attached to each inbound message. + +`TransportContext` is the base; each transport defines its own subclass with +whatever fields make sense (HTTP request id, ASGI scope, stdio process handle, +etc.). The dispatcher passes it through opaquely; only the layers above the +dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields. +""" + +from collections.abc import Mapping +from dataclasses import dataclass + +__all__ = ["TransportContext"] + + +@dataclass(kw_only=True, frozen=True) +class TransportContext: + """Base transport metadata for an inbound message. + + Subclass per transport and add fields as needed. Instances are immutable. + """ + + kind: str + """Short identifier for the transport (e.g. `"stdio"`, `"streamable-http"`).""" + + can_send_request: bool + """Whether the transport can deliver server-initiated requests to the peer. + + `False` for stateless HTTP and HTTP with JSON response mode; `True` for + stdio, SSE, and stateful streamable HTTP. When `False`, + `DispatchContext.send_raw_request` raises `NoBackChannelError`. + """ + + headers: Mapping[str, str] | None = None + """Request headers carried by this message, when the transport has them. + + Populated by HTTP-based transports; `None` on stdio. Handlers should + None-check before use. + """ diff --git a/src/mcp/types/__init__.py b/src/mcp/types/__init__.py index b2d537fb70..cb49ff29db 100644 --- a/src/mcp/types/__init__.py +++ b/src/mcp/types/__init__.py @@ -152,6 +152,7 @@ INVALID_REQUEST, METHOD_NOT_FOUND, PARSE_ERROR, + REQUEST_CANCELLED, REQUEST_TIMEOUT, URL_ELICITATION_REQUIRED, ErrorData, @@ -319,6 +320,7 @@ "INVALID_REQUEST", "METHOD_NOT_FOUND", "PARSE_ERROR", + "REQUEST_CANCELLED", "REQUEST_TIMEOUT", "URL_ELICITATION_REQUIRED", "ErrorData", diff --git a/src/mcp/types/jsonrpc.py b/src/mcp/types/jsonrpc.py index 84304a37c1..14743c33b0 100644 --- a/src/mcp/types/jsonrpc.py +++ b/src/mcp/types/jsonrpc.py @@ -43,6 +43,7 @@ class JSONRPCResponse(BaseModel): # SDK error codes CONNECTION_CLOSED = -32000 REQUEST_TIMEOUT = -32001 +REQUEST_CANCELLED = -32002 # Standard JSON-RPC error codes PARSE_ERROR = -32700 diff --git a/tests/client/test_session.py b/tests/client/test_session.py index f25c964f03..28d212d007 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1,6 +1,11 @@ from __future__ import annotations +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + import anyio +import anyio.streams.memory import pytest from mcp import types @@ -10,12 +15,14 @@ from mcp.shared.session import RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( + INVALID_PARAMS, LATEST_PROTOCOL_VERSION, CallToolResult, Implementation, InitializedNotification, InitializeRequest, InitializeResult, + JSONRPCError, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -26,6 +33,29 @@ client_request_adapter, ) +_SendToClient = anyio.streams.memory.MemoryObjectSendStream[SessionMessage | Exception] +_RecvFromClient = anyio.streams.memory.MemoryObjectReceiveStream[SessionMessage] + + +@asynccontextmanager +async def raw_client_session( + **kwargs: Any, +) -> AsyncIterator[tuple[ClientSession, _SendToClient, _RecvFromClient]]: + """Yield `(session, send_to_client, recv_from_client)` with the receive loop running. + + `send_to_client` accepts `SessionMessage | Exception` so tests can inject + transport-level exceptions. No initialize handshake is performed. + """ + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage](32) + async with ClientSession(s2c_recv, c2s_send, **kwargs) as session: + try: + with anyio.fail_after(5): + yield session, s2c_send, c2s_recv + finally: + s2c_send.close() + c2s_recv.close() + @pytest.mark.anyio async def test_client_session_initialize(): @@ -705,3 +735,129 @@ async def mock_server(): await session.initialize() await session.call_tool(name=mocked_tool.name, arguments={"foo": "bar"}, meta=meta) + + +@pytest.mark.anyio +async def test_receive_loop_answers_malformed_inbound_request_with_invalid_params(): + """A request that fails ServerRequest validation gets an INVALID_PARAMS error response.""" + async with raw_client_session() as (_session, to_client, from_client): + await to_client.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=7, method="sampling/createMessage", params={"broken": 1})) + ) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCError) + assert out.message.id == 7 + assert out.message.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_receive_loop_answers_invalid_params_when_sampling_callback_raises(): + """Same boundary catches exceptions from the request handler itself.""" + + async def boom(ctx: object, params: object) -> types.CreateMessageResult: + raise RuntimeError("sampling boom") + + params = types.CreateMessageRequestParams( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], + max_tokens=10, + ).model_dump(by_alias=True, mode="json", exclude_none=True) + async with raw_client_session(sampling_callback=boom) as (_session, to_client, from_client): + await to_client.send( + SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=8, method="sampling/createMessage", params=params)) + ) + out = await from_client.receive() + assert isinstance(out.message, JSONRPCError) + assert out.message.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_receive_loop_logs_and_drops_malformed_notification(caplog: pytest.LogCaptureFixture): + """A notification that fails ServerNotification validation is logged and dropped.""" + seen: list[object] = [] + delivered = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + delivered.set() + + async with raw_client_session(message_handler=handler) as (_session, to_client, _): + await to_client.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="not/a/spec/notification"))) + # Follow with a valid notification so we know the loop is still alive. + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) + ) + await delivered.wait() + assert isinstance(seen[0], types.ToolListChangedNotification) + assert "Failed to validate notification" in caplog.text + + +@pytest.mark.anyio +async def test_receive_loop_forwards_transport_exception_to_message_handler(): + seen: list[object] = [] + delivered = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + delivered.set() + + async with raw_client_session(message_handler=handler) as (_session, to_client, _): + exc = ValueError("bad bytes") + await to_client.send(exc) + await delivered.wait() + assert seen == [exc] + + +@pytest.mark.anyio +async def test_receive_loop_consumes_server_cancelled_without_reaching_message_handler(): + """A server-sent notifications/cancelled is swallowed, matching the pre-swap contract. + + The server dispatcher now emits this on sampling/elicitation timeout, but + ClientSession has no in-flight tracking to act on it, so surfacing it would + only break user handlers that exhaustively match ServerNotification. + """ + seen: list[object] = [] + delivered = anyio.Event() + + async def handler(msg: object) -> None: + seen.append(msg) + delivered.set() + + async with raw_client_session(message_handler=handler) as (_session, to_client, _): + await to_client.send( + SessionMessage( + JSONRPCNotification( + jsonrpc="2.0", method="notifications/cancelled", params={"requestId": 1, "reason": "timed out"} + ) + ) + ) + # Follow with a notification that does reach the handler so we can + # assert ordering deterministically. + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/tools/list_changed")) + ) + await delivered.wait() + assert len(seen) == 1 + assert isinstance(seen[0], types.ToolListChangedNotification) + + +@pytest.mark.anyio +async def test_receive_loop_swallows_progress_callback_exception(caplog: pytest.LogCaptureFixture): + delivered = anyio.Event() + + async def boom(progress: float, total: float | None, message: str | None) -> None: + raise RuntimeError("progress boom") + + async def handler(msg: object) -> None: + delivered.set() + + async with raw_client_session(message_handler=handler) as (session, to_client, _): + # Register the callback under a known token without sending a request. + session._progress_callbacks[42] = boom # pyright: ignore[reportPrivateUsage] + params = {"progressToken": 42, "progress": 0.5} + await to_client.send( + SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/progress", params=params)) + ) + # The progress notification also reaches the message handler after the + # callback runs, so this fires once the callback's exception is handled. + await delivered.wait() + assert "Progress callback raised an exception" in caplog.text diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index c8fc41fd5d..8baee128b5 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -1,8 +1,15 @@ """Tests for InMemoryTransport.""" +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any + +import anyio +import anyio.lowlevel import pytest from mcp import Client, types +from mcp.client import _memory from mcp.client._memory import InMemoryTransport from mcp.server import Server, ServerRequestContext from mcp.server.mcpserver import MCPServer @@ -95,3 +102,47 @@ async def test_raise_exceptions(mcpserver_server: MCPServer): transport = InMemoryTransport(mcpserver_server, raise_exceptions=True) async with transport as (read_stream, _write_stream): assert read_stream is not None + + +async def test_aexit_with_well_behaved_lifespan_runs_teardown_without_cancel(): + """A lifespan that finishes promptly on EOF should run to completion. + + The transport closes the streams first and waits for the server to exit + naturally, so teardown observes no cancellation. + """ + teardown_ran = anyio.Event() + + @asynccontextmanager + async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: + yield {} + await anyio.lowlevel.checkpoint() + teardown_ran.set() + + server = Server(name="test_server", lifespan=lifespan) + with anyio.fail_after(5): + async with InMemoryTransport(server): + pass + assert teardown_ran.is_set() + + +async def test_aexit_with_blocking_lifespan_is_bounded(monkeypatch: pytest.MonkeyPatch): + """A lifespan that never returns must not hang `__aexit__` forever. + + After EOFing the server the transport waits `SERVER_SHUTDOWN_GRACE` for a + natural exit, then cancels the server task as a backstop so the + task-group join completes. + """ + monkeypatch.setattr(_memory, "SERVER_SHUTDOWN_GRACE", 0.05) + teardown_started = anyio.Event() + + @asynccontextmanager + async def blocking_lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]: + yield {} + teardown_started.set() + await anyio.Event().wait() + + server = Server(name="test_server", lifespan=blocking_lifespan) + with anyio.fail_after(5): + async with InMemoryTransport(server): + pass + assert teardown_started.is_set() diff --git a/tests/conftest.py b/tests/conftest.py index af7e479932..2278c9939e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,44 @@ +import os +from collections.abc import Iterator + import pytest +# OpenTelemetry's `set_tracer_provider` is set-once per process, so the suite +# uses a single span-capture mechanism: logfire's `capfire` fixture (its +# `configure()` swaps span processors on repeat calls rather than re-setting +# the provider). Logfire's default `distributed_tracing=None` emits a +# RuntimeWarning + diagnostic span when incoming W3C trace context is +# extracted; several tests exercise that propagation deliberately, so opt in +# suite-wide. Set before logfire is imported anywhere. +os.environ.setdefault("LOGFIRE_DISTRIBUTED_TRACING", "true") + +import opentelemetry.trace # noqa: E402 (env var must be set before logfire import below) +from logfire.testing import CaptureLogfire # noqa: E402 + +import mcp.shared._otel # noqa: E402 + @pytest.fixture def anyio_backend(): return "asyncio" + + +@pytest.fixture(name="capfire") +def _capfire_isolated(capfire: CaptureLogfire) -> Iterator[CaptureLogfire]: + """Override of logfire's `capfire` that scopes the MCP tracer to the test. + + `capfire` installs a real tracer provider, and logfire's proxy machinery + mutates the cached `mcp.shared._otel._tracer` to delegate to it for the + rest of the process. Without isolation, every subsequent test in the same + worker would emit real spans, and `send_raw_request` would inject a real + `traceparent` into outbound `_meta`, breaking the interaction-suite + snapshots that pin `_meta={}` under a no-op tracer. + + Setup points `_tracer` at the now-live provider so MCP spans record; + teardown replaces it with a `NoOpTracer`. + """ + mcp.shared._otel._tracer = opentelemetry.trace.get_tracer_provider().get_tracer("mcp-python-sdk") + try: + yield capfire + finally: + mcp.shared._otel._tracer = opentelemetry.trace.NoOpTracer() diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py deleted file mode 100644 index da586f3098..0000000000 --- a/tests/issues/test_malformed_input.py +++ /dev/null @@ -1,151 +0,0 @@ -# Claude Debug -"""Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" - -import anyio -import pytest - -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.types import INVALID_PARAMS, JSONRPCError, JSONRPCMessage, JSONRPCRequest, ServerCapabilities - - -@pytest.mark.anyio -async def test_malformed_initialize_request_does_not_crash_server(): - """Test that malformed initialize requests return proper error responses - instead of crashing the server (HackerOne #3156202). - """ - # Create in-memory streams for testing - read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](10) - write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](10) - - try: - # Create a malformed initialize request (missing required params field) - malformed_request = JSONRPCRequest( - jsonrpc="2.0", - id="f20fe86132ed4cd197f89a7134de5685", - method="initialize", - # params=None # Missing required params field - ) - - # Wrap in session message - request_message = SessionMessage(message=malformed_request) - - # Start a server session - async with ServerSession( - read_stream=read_receive_stream, - write_stream=write_send_stream, - init_options=InitializationOptions( - server_name="test_server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ): - # Send the malformed request - await read_send_stream.send(request_message) - - # Give the session time to process the request - await anyio.sleep(0.1) - - # Check that we received an error response instead of a crash - try: - response_message = write_receive_stream.receive_nowait() - response = response_message.message - - # Verify it's a proper JSON-RPC error response - assert isinstance(response, JSONRPCError) - assert response.jsonrpc == "2.0" - assert response.id == "f20fe86132ed4cd197f89a7134de5685" - assert response.error.code == INVALID_PARAMS - assert "Invalid request parameters" in response.error.message - - # Verify the session is still alive and can handle more requests - # Send another malformed request to confirm server stability - another_malformed_request = JSONRPCRequest( - jsonrpc="2.0", - id="test_id_2", - method="tools/call", - # params=None # Missing required params - ) - another_request_message = SessionMessage(message=another_malformed_request) - - await read_send_stream.send(another_request_message) - await anyio.sleep(0.1) - - # Should get another error response, not a crash - second_response_message = write_receive_stream.receive_nowait() - second_response = second_response_message.message - - assert isinstance(second_response, JSONRPCError) - assert second_response.id == "test_id_2" - assert second_response.error.code == INVALID_PARAMS - - except anyio.WouldBlock: # pragma: no cover - pytest.fail("No response received - server likely crashed") - finally: # pragma: lax no cover - # Close all streams to ensure proper cleanup - await read_send_stream.aclose() - await write_send_stream.aclose() - await read_receive_stream.aclose() - await write_receive_stream.aclose() - - -@pytest.mark.anyio -async def test_multiple_concurrent_malformed_requests(): - """Test that multiple concurrent malformed requests don't crash the server.""" - # Create in-memory streams for testing - read_send_stream, read_receive_stream = anyio.create_memory_object_stream[SessionMessage | Exception](100) - write_send_stream, write_receive_stream = anyio.create_memory_object_stream[SessionMessage](100) - - try: - # Start a server session - async with ServerSession( - read_stream=read_receive_stream, - write_stream=write_send_stream, - init_options=InitializationOptions( - server_name="test_server", - server_version="1.0.0", - capabilities=ServerCapabilities(), - ), - ): - # Send multiple malformed requests concurrently - malformed_requests: list[SessionMessage] = [] - for i in range(10): - malformed_request = JSONRPCRequest( - jsonrpc="2.0", - id=f"malformed_{i}", - method="initialize", - # params=None # Missing required params - ) - request_message = SessionMessage(message=malformed_request) - malformed_requests.append(request_message) - - # Send all requests - for request in malformed_requests: - await read_send_stream.send(request) - - # Give time to process - await anyio.sleep(0.2) - - # Verify we get error responses for all requests - error_responses: list[JSONRPCMessage] = [] - try: - while True: - response_message = write_receive_stream.receive_nowait() - error_responses.append(response_message.message) - except anyio.WouldBlock: - pass # No more messages - - # Should have received 10 error responses - assert len(error_responses) == 10 - - for i, response in enumerate(error_responses): - assert isinstance(response, JSONRPCError) - assert response.id == f"malformed_{i}" - assert response.error.code == INVALID_PARAMS - finally: # pragma: lax no cover - # Close all streams to ensure proper cleanup - await read_send_stream.aclose() - await write_send_stream.aclose() - await read_receive_stream.aclose() - await write_receive_stream.aclose() diff --git a/tests/server/conftest.py b/tests/server/conftest.py new file mode 100644 index 0000000000..e0fa8ee9b0 --- /dev/null +++ b/tests/server/conftest.py @@ -0,0 +1,46 @@ +"""Shared fixtures for server-side tests.""" + +from collections.abc import Iterator + +import pytest +from logfire.testing import CaptureLogfire, TestExporter +from opentelemetry.sdk.trace import ReadableSpan + + +class SpanCapture: + """Thin adapter over logfire's `TestExporter` for asserting on MCP spans. + + `finished()` returns the raw `ReadableSpan` objects emitted by the + `mcp-python-sdk` instrumentation scope, filtered to exclude logfire's + synthetic `pending_span` markers, so tests can assert directly on + `.name`, `.kind`, `.status`, `.attributes`, `.parent`, `.events`. + """ + + def __init__(self, exporter: TestExporter) -> None: + self._exporter = exporter + + def clear(self) -> None: + self._exporter.clear() + + def finished(self) -> list[ReadableSpan]: + return [ + s + for s in self._exporter.exported_spans + if s.instrumentation_scope is not None + and s.instrumentation_scope.name == "mcp-python-sdk" + and not (s.attributes and s.attributes.get("logfire.span_type") == "pending_span") + ] + + +@pytest.fixture +def spans(capfire: CaptureLogfire) -> Iterator[SpanCapture]: + """In-memory MCP span capture, cleared before and after each test. + + Backed by the project-level `capfire` override (see `tests/conftest.py`), + which scopes `mcp.shared._otel._tracer` to the test so the real tracer + doesn't leak into later tests in the same worker. + """ + capture = SpanCapture(capfire.exporter) + capture.clear() + yield capture + capture.clear() diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py new file mode 100644 index 0000000000..b8378574ef --- /dev/null +++ b/tests/server/test_connection.py @@ -0,0 +1,256 @@ +"""Tests for `Connection`. + +`Connection` wraps an `Outbound` (the standalone stream). Its `notify` is +best-effort (never raises); `send_raw_request` is gated on +`has_standalone_channel`. Tested with a stub `Outbound` so we can assert wire +shape and inject failures. +""" + +import logging +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + CreateMessageRequest, + CreateMessageRequestParams, + ElicitationCapability, + EmptyResult, + Implementation, + InitializeRequestParams, + ListRootsRequest, + ListRootsResult, + PingRequest, + RootsCapability, + SamplingCapability, + SamplingContextCapability, + SamplingToolsCapability, +) + + +def _client_params(capabilities: ClientCapabilities) -> InitializeRequestParams: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + client_info=Implementation(name="t", version="0"), + ) + + +class StubOutbound: + def __init__( + self, *, result: dict[str, Any] | None = None, raise_on_send: type[BaseException] | None = None + ) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self._result = result if result is not None else {} + self._raise_on_send = raise_on_send + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + self.requests.append((method, params)) + return self._result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + if self._raise_on_send is not None: + raise self._raise_on_send() + self.notifications.append((method, params)) + + +@pytest.mark.anyio +async def test_connection_notify_forwards_to_outbound(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"level": "info", "data": "hi"}) + assert out.notifications == [("notifications/message", {"level": "info", "data": "hi"})] + + +@pytest.mark.anyio +async def test_connection_notify_swallows_broken_stream_and_debug_logs(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound(raise_on_send=anyio.BrokenResourceError) + conn = Connection(out, has_standalone_channel=True) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert "stream closed" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_notify_drops_when_no_standalone_channel(caplog: pytest.LogCaptureFixture): + caplog.set_level(logging.DEBUG, logger="mcp.server.connection") + out = StubOutbound() + conn = Connection(out, has_standalone_channel=False) + await conn.notify("notifications/message", {"data": "x"}) # must not raise + assert out.notifications == [] + assert "no standalone channel" in caplog.text.lower() + + +@pytest.mark.anyio +async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalone_channel(): + conn = Connection(StubOutbound(), has_standalone_channel=False) + with pytest.raises(NoBackChannelError): + await conn.send_raw_request("ping", None) + + +@pytest.mark.anyio +async def test_connection_send_raw_request_forwards_when_standalone_channel_present(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_raw_request("ping", None) + assert out.requests == [("ping", None)] + assert result == {} + + +@pytest.mark.anyio +async def test_connection_send_request_with_spec_type_infers_result_type(): + out = StubOutbound(result={"roots": [{"uri": "file:///ws"}]}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(ListRootsRequest()) + method, _ = out.requests[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert str(result.roots[0].uri) == "file:///ws" + + +@pytest.mark.anyio +async def test_connection_send_request_validates_result_alias_only(): + """Peer results validate alias-only; a snake_case key from the wire is + ignored as extra, not populated by Python field name.""" + snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} + conn = Connection(StubOutbound(result=snake), has_standalone_channel=True) + result = await conn.send_request(CreateMessageRequest(params=CreateMessageRequestParams(messages=[], max_tokens=1))) + assert result.stop_reason is None + + +@pytest.mark.anyio +async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): + out = StubOutbound(result={}) + conn = Connection(out, has_standalone_channel=True) + result = await conn.send_request(PingRequest(), result_type=EmptyResult) + assert isinstance(result, EmptyResult) + + +@pytest.mark.anyio +async def test_connection_ping_sends_ping_on_standalone(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.ping() + assert out.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_connection_log_sends_logging_message_notification(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", {"k": "v"}, logger="my.logger") + method, params = out.notifications[0] + assert method == "notifications/message" + assert params is not None + assert params["level"] == "info" + assert params["data"] == {"k": "v"} + assert params["logger"] == "my.logger" + + +@pytest.mark.anyio +async def test_connection_log_with_meta_includes_meta_in_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.log("info", "x", meta={"traceId": "abc"}) + _, params = out.notifications[0] + assert params is not None + assert params["_meta"] == {"traceId": "abc"} + + +@pytest.mark.anyio +async def test_connection_list_changed_notifications_send_correct_methods(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed() + await conn.send_prompt_list_changed() + await conn.send_resource_list_changed() + await conn.send_resource_updated("file:///workspace/a.txt") + methods = [m for m, _ in out.notifications] + assert methods == [ + "notifications/tools/list_changed", + "notifications/prompts/list_changed", + "notifications/resources/list_changed", + "notifications/resources/updated", + ] + assert out.notifications[-1][1] == {"uri": "file:///workspace/a.txt"} + + +@pytest.mark.anyio +async def test_connection_send_tool_list_changed_with_meta_includes_meta_only_params(): + out = StubOutbound() + conn = Connection(out, has_standalone_channel=True) + await conn.send_tool_list_changed(meta={"k": 1}) + assert out.notifications == [("notifications/tools/list_changed", {"_meta": {"k": 1}})] + + +def test_connection_check_capability_false_before_initialized(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False + + +@pytest.mark.parametrize( + ("have", "want", "expected"), + [ + (ClientCapabilities(roots=None), ClientCapabilities(roots=RootsCapability()), False), + ( + ClientCapabilities(roots=RootsCapability(list_changed=False)), + ClientCapabilities(roots=RootsCapability(list_changed=True)), + False, + ), + (ClientCapabilities(sampling=None), ClientCapabilities(sampling=SamplingCapability()), False), + ( + ClientCapabilities(sampling=SamplingCapability()), + ClientCapabilities(sampling=SamplingCapability(context=SamplingContextCapability())), + False, + ), + ( + ClientCapabilities(sampling=SamplingCapability()), + ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + False, + ), + ( + ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), + True, + ), + (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), + (ClientCapabilities(experimental={"a": {"x": 1}}), ClientCapabilities(experimental={"a": {"x": 2}}), False), + (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"a": {}}), True), + ], +) +def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_params = _client_params(have) + assert conn.check_capability(want) is expected + + +def test_connection_client_info_and_capabilities_derive_from_client_params(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + assert conn.client_info is None + assert conn.client_capabilities is None + caps = ClientCapabilities(sampling=SamplingCapability()) + conn.client_params = _client_params(caps) + assert conn.client_info is not None and conn.client_info.name == "t" + assert conn.client_capabilities == caps + + +def test_connection_check_capability_true_when_client_declares_it(): + conn = Connection(StubOutbound(), has_standalone_channel=True) + conn.client_params = _client_params( + ClientCapabilities(sampling=SamplingCapability(), roots=RootsCapability(list_changed=True)) + ) + conn.initialized.set() + assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert conn.check_capability(ClientCapabilities(roots=RootsCapability(list_changed=True))) is True + assert conn.check_capability(ClientCapabilities(elicitation=ElicitationCapability())) is False diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 46925916d9..015a5cbafa 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -1,64 +1,8 @@ -from unittest.mock import AsyncMock, Mock - import anyio import pytest -from mcp import types from mcp.server.lowlevel.server import Server -from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder - - -@pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_true(): - """Transport exceptions are re-raised when raise_exceptions=True.""" - server = Server("test-server") - session = Mock(spec=ServerSession) - - test_exception = RuntimeError("Test error") - - with pytest.raises(RuntimeError, match="Test error"): - await server._handle_message(test_exception, session, {}, raise_exceptions=True) - - -@pytest.mark.anyio -async def test_exception_handling_with_raise_exceptions_false(): - """Transport exceptions are logged locally but not sent to the client. - - The transport that reported the error is likely broken; writing back - through it races with stream closure (#1967, #2064). The TypeScript, - Go, and C# SDKs all log locally only. - """ - server = Server("test-server") - session = Mock(spec=ServerSession) - session.send_log_message = AsyncMock() - - await server._handle_message(RuntimeError("Test error"), session, {}, raise_exceptions=False) - - session.send_log_message.assert_not_called() - - -@pytest.mark.anyio -async def test_normal_message_handling_not_affected(): - """Test that normal messages still work correctly""" - server = Server("test-server") - session = Mock(spec=ServerSession) - - # Create a mock RequestResponder - responder = Mock(spec=RequestResponder) - responder.request = types.PingRequest(method="ping") - responder.__enter__ = Mock(return_value=responder) - responder.__exit__ = Mock(return_value=None) - - # Mock the _handle_request method to avoid complex setup - server._handle_request = AsyncMock() - - # Should handle normally without any exception handling - await server._handle_message(responder, session, {}, raise_exceptions=False) - - # Verify _handle_request was called - server._handle_request.assert_called_once() @pytest.mark.anyio @@ -71,23 +15,21 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo 1. Transport yields an Exception into the read stream (streamable_http.py does this in its broad POST-handler except). 2. Transport closes the read stream (terminate() in stateless mode). - 3. _receive_loop exits its `async with read_stream, write_stream:` block, - closing the write stream. - 4. Meanwhile _handle_message(exc) was spawned via tg.start_soon and runs - after the write stream is closed. + 3. The read loop exits and closes the write stream. - Before the fix, _handle_message tried to send_log_message through the - closed write stream, raising ClosedResourceError inside the TaskGroup - and crashing server.run(). After the fix, it only logs locally. + Before the fix, the message handler tried to send_log_message through the + closed write stream, raising ClosedResourceError and crashing server.run(). + After the fix (and now in the dispatcher), the exception is only logged + locally. """ server = Server("test-server") read_send, read_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) # Zero-buffer on the write stream forces send() to block until received. - # With no receiver, a send() sits blocked until _receive_loop exits its - # `async with self._read_stream, self._write_stream:` block and closes the - # stream, at which point the blocked send raises ClosedResourceError. - # This deterministically reproduces the race without sleeps. + # With no receiver, a send() sits blocked until the read loop exits its + # `async with read_stream, write_stream:` block and closes the stream, at + # which point the blocked send raises ClosedResourceError. This + # deterministically reproduces the race without sleeps. write_send, write_recv = anyio.create_memory_object_stream[SessionMessage](0) # What the streamable HTTP transport does: push the exception, then close. @@ -96,11 +38,11 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo with anyio.fail_after(5): # stateless=True so server.run doesn't wait for initialize handshake. - # Before this fix, this raised ExceptionGroup(ClosedResourceError). + # Before the fix, this raised ExceptionGroup(ClosedResourceError). await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) - # write_send was closed inside _receive_loop's `async with`; receive_nowait - # raises EndOfStream iff the buffer is empty (i.e., server wrote nothing). + # write_send was closed inside run's `async with`; receive_nowait raises + # EndOfStream iff the buffer is empty (i.e., server wrote nothing). with pytest.raises(anyio.EndOfStream): write_recv.receive_nowait() write_recv.close() diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py new file mode 100644 index 0000000000..403cf3d15a --- /dev/null +++ b/tests/server/test_runner.py @@ -0,0 +1,655 @@ +"""Tests for `ServerRunner`. + +End-to-end over `JSONRPCDispatcher` with a real lowlevel `Server` as the +registry. The `connected_runner` helper starts both sides and (by default) +performs the initialize handshake, so each test exercises only the behaviour +under test. +""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, cast + +import anyio +import pytest +from opentelemetry.trace import SpanKind, StatusCode + +from mcp.server.context import ServerRequestContext +from mcp.server.lowlevel.server import NotificationOptions, Server +from mcp.server.models import InitializationOptions +from mcp.server.runner import ServerRunner, otel_middleware +from mcp.server.session import ServerSession +from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest +from mcp.shared.exceptions import MCPError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.transport_context import TransportContext +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.types import ( + INVALID_PARAMS, + LATEST_PROTOCOL_VERSION, + METHOD_NOT_FOUND, + CallToolRequestParams, + ClientCapabilities, + ErrorData, + Implementation, + InitializeRequestParams, + ListToolsResult, + NotificationParams, + PaginatedRequestParams, + ProgressNotificationParams, + RequestParams, + SetLevelRequestParams, + Tool, +) + +from ..shared.conftest import jsonrpc_pair +from ..shared.test_dispatcher import Recorder, echo_handlers +from .conftest import SpanCapture + +Ctx = ServerRequestContext[dict[str, Any], Any] + + +def _initialize_params() -> dict[str, Any]: + return InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=ClientCapabilities(), + client_info=Implementation(name="test-client", version="1.0"), + ).model_dump(by_alias=True, exclude_none=True) + + +_seen_ctx: list[Ctx] = [] +SrvT = Server[dict[str, Any]] + + +@pytest.fixture +def server() -> SrvT: + """A lowlevel Server with one tools/list handler registered.""" + _seen_ctx.clear() + + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: + _seen_ctx.append(ctx) + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +@asynccontextmanager +async def connected_runner( + server: SrvT, + *, + initialized: bool = True, + stateless: bool = False, + has_standalone_channel: bool = True, + init_options: InitializationOptions | None = None, + session_id: str | None = None, + dispatch_middleware: list[DispatchMiddleware] | None = None, +) -> AsyncIterator[tuple[JSONRPCDispatcher[TransportContext], ServerRunner[dict[str, Any]]]]: + """Yield `(client, runner)` running over an in-memory JSON-RPC dispatcher pair. + + Starts the client (echo handlers) and `runner.run()` in a task group, wraps + the body in `anyio.fail_after(5)`, and cancels on exit. When + `initialized` is true the helper performs the real `initialize` request + before yielding, so tests start past the init-gate via the public path. + """ + client, server_d, close = jsonrpc_pair() + assert isinstance(client, JSONRPCDispatcher) and isinstance(server_d, JSONRPCDispatcher) + runner = ServerRunner( + server=server, + dispatcher=server_d, + lifespan_state={}, + has_standalone_channel=has_standalone_channel, + init_options=init_options, + session_id=session_id, + stateless=stateless, + dispatch_middleware=dispatch_middleware or [], + ) + c_req, c_notify = echo_handlers(Recorder()) + body_exc: BaseException | None = None + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(runner.run) + try: + with anyio.fail_after(5): + if initialized: + await client.send_raw_request("initialize", _initialize_params()) + yield client, runner + except BaseException as e: + # Capture and re-raise outside the task group so test failures + # surface as the original exception, not an ExceptionGroup wrapper. + body_exc = e + close() + if body_exc is not None: + raise body_exc + + +@pytest.mark.anyio +async def test_connected_runner_propagates_body_exception_unwrapped(server: SrvT): + """The harness re-raises body exceptions as-is, not as `ExceptionGroup`.""" + with pytest.raises(RuntimeError, match="boom"): + async with connected_runner(server): + raise RuntimeError("boom") + + +@pytest.mark.anyio +async def test_runner_handles_initialize_and_populates_connection(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["serverInfo"]["name"] == "test-server" + assert "tools" in result["capabilities"] + assert runner.connection.client_info is not None + assert runner.connection.client_info.name == "test-client" + assert runner.connection.protocol_version == LATEST_PROTOCOL_VERSION + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_gates_requests_before_initialize(server: SrvT): + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + # ping is exempt from the gate + assert await client.send_raw_request("ping", None) == {} + + +@pytest.mark.anyio +async def test_runner_routes_to_handler_and_builds_context(server: SrvT): + async with connected_runner(server) as (client, runner): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + ctx = _seen_ctx[0] + assert isinstance(ctx, ServerRequestContext) + assert ctx.lifespan_context == {} + assert isinstance(ctx.session, ServerSession) + assert ctx.session is runner.session + assert ctx.request_id is not None + + +@pytest.mark.anyio +async def test_runner_spec_method_with_no_handler_raises_method_not_found(server: SrvT): + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("resources/list", None) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_runner_non_spec_method_with_no_handler_raises_method_not_found(server: SrvT): + """Upfront validation is gated to spec methods, so a non-spec method + skips it and reaches handler lookup.""" + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("nonexistent/method", None) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_runner_malformed_params_for_unregistered_spec_method_raises_invalid_params(server: SrvT): + """A spec method with malformed params is INVALID_PARAMS even with no handler.""" + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/call", {"name": 123}) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + +@pytest.mark.anyio +async def test_runner_rejects_snake_case_initialize_params(server: SrvT): + """Inbound wire payloads validate alias-only; Python field names are not + accepted (`protocol_version` must arrive as `protocolVersion`).""" + snake = { + "protocol_version": LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "client_info": {"name": "c", "version": "0"}, + } + async with connected_runner(server, initialized=False) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("initialize", snake) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_runner_rejects_snake_case_params_for_custom_handler(server: SrvT): + """Custom-method handlers (which skip the spec-method gate) still validate + alias-only at the per-handler boundary.""" + + async def handler(ctx: Ctx, params: ProgressNotificationParams) -> dict[str, Any]: + return {"ok": True} + + server.add_request_handler("custom/progress", ProgressNotificationParams, handler) + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("custom/progress", {"progress_token": 1, "progress": 0.5}) + assert exc.value.error.code == INVALID_PARAMS + result = await client.send_raw_request("custom/progress", {"progressToken": 1, "progress": 0.5}) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_snake_case_params(server: SrvT, caplog: pytest.LogCaptureFixture): + """Notification params validate alias-only; snake_case is dropped as malformed.""" + + async def handler(ctx: Ctx, params: ProgressNotificationParams) -> None: + raise NotImplementedError + + server.add_notification_handler("notifications/roots/list_changed", ProgressNotificationParams, handler) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", {"progress_token": 1, "progress": 0.5}) + await client.send_raw_request("tools/list", None) + assert "dropped 'notifications/roots/list_changed': malformed params" in caplog.text + + +@pytest.mark.anyio +async def test_runner_on_notify_initialized_sets_flag_and_connection_event(server: SrvT): + async with connected_runner(server, initialized=False) as (client, runner): + await client.notify("notifications/initialized", None) + await runner.connection.initialized.wait() + assert runner._initialized is True + + +@pytest.mark.anyio +async def test_runner_on_notify_routes_to_registered_handler(server: SrvT): + seen: list[tuple[Any, Any]] = [] + delivered = anyio.Event() + + async def on_roots_changed(ctx: Ctx, params: NotificationParams | None) -> None: + seen.append((ctx, params)) + if len(seen) == 2: + delivered.set() + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots_changed) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + await client.notify("notifications/roots/list_changed", {}) + await delivered.wait() + assert isinstance(seen[0][0], ServerRequestContext) + # Absent wire params reach the handler as None; present-but-empty validates. + assert seen[0][1] is None + assert isinstance(seen[1][1], NotificationParams) + + +@pytest.mark.anyio +async def test_runner_on_notify_handler_exception_is_swallowed_and_logged( + server: SrvT, caplog: pytest.LogCaptureFixture +): + """A notification handler crashing must not tear down the connection.""" + + async def boom(ctx: Ctx, params: NotificationParams | None) -> None: + raise RuntimeError("notification handler boom") + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, boom) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", None) + # Connection still alive: a request after the crashing handler succeeds. + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert "notification handler for 'notifications/roots/list_changed' raised" in caplog.text + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_malformed_params(server: SrvT, caplog: pytest.LogCaptureFixture): + """Malformed notification params are logged and dropped, not raised.""" + + async def on_level(ctx: Ctx, params: SetLevelRequestParams) -> None: + raise NotImplementedError + + server.add_notification_handler("notifications/roots/list_changed", SetLevelRequestParams, on_level) + async with connected_runner(server) as (client, _): + await client.notify("notifications/roots/list_changed", {"level": "not-a-level"}) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert "dropped 'notifications/roots/list_changed': malformed params" in caplog.text + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_absent_params_when_model_requires_them( + server: SrvT, caplog: pytest.LogCaptureFixture +): + """A params-less progress notification is dropped, not delivered as None. + + `on_progress` is typed to receive a non-Optional `ProgressNotificationParams`; + the previous server validated the full notification union and dropped this + as malformed before dispatch. + """ + + async def on_progress(ctx: Ctx, params: ProgressNotificationParams) -> None: + raise NotImplementedError + + server.add_notification_handler("notifications/progress", ProgressNotificationParams, on_progress) + async with connected_runner(server) as (client, _): + await client.notify("notifications/progress", None) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert "dropped 'notifications/progress': malformed params" in caplog.text + assert "notification handler for" not in caplog.text + + +@pytest.mark.anyio +async def test_runner_absent_wire_params_reaches_request_handler_as_none(): + """A request with no `params` member on the wire reaches the handler as + `None`, matching the previous server and the `| None` handler annotation. + + The in-SDK client always attaches `_meta`, so a dispatch middleware + forwards `params=None` to model what an external client sends. + """ + seen: list[PaginatedRequestParams | None] = [] + + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: + seen.append(params) + return ListToolsResult(tools=[]) + + def drop_params(next_on_request: OnRequest) -> OnRequest: + async def wrapped(dctx: DispatchContext[Any], method: str, params: Any) -> dict[str, Any]: + return await next_on_request(dctx, method, None if method == "tools/list" else params) + + return wrapped + + server: SrvT = Server(name="s", on_list_tools=list_tools) + async with connected_runner(server, dispatch_middleware=[drop_params]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen == [None] + + +@pytest.mark.anyio +async def test_runner_absent_wire_params_for_required_params_custom_method_is_invalid_params(): + """A custom method whose `params_type` has required fields rejects absent + wire params as INVALID_PARAMS rather than invoking the handler with None.""" + + class GreetParams(RequestParams): + name: str + + async def greet(ctx: Ctx, params: GreetParams) -> dict[str, Any]: + raise NotImplementedError + + def drop_params(next_on_request: OnRequest) -> OnRequest: + async def wrapped(dctx: DispatchContext[Any], method: str, params: Any) -> dict[str, Any]: + return await next_on_request(dctx, method, None if method == "custom/greet" else params) + + return wrapped + + server: SrvT = Server(name="s") + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server, dispatch_middleware=[drop_params]) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("custom/greet", {"name": "x"}) + assert exc.value.error.code == INVALID_PARAMS + + +@pytest.mark.anyio +async def test_runner_on_notify_drops_before_init_and_unknown_methods(server: SrvT): + seen: list[Any] = [] + + async def on_roots(ctx: Ctx, params: NotificationParams | None) -> None: + seen.append(params) + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots) + async with connected_runner(server, initialized=False) as (client, _): + await client.notify("notifications/roots/list_changed", None) # before init: dropped + await client.notify("notifications/initialized", None) + await client.notify("notifications/unknown", None) # no handler: dropped + await client.notify("notifications/roots/list_changed", None) # post-init: delivered + await anyio.wait_all_tasks_blocked() + assert seen == [None] # only the post-init one reached the handler + + +@pytest.mark.anyio +async def test_runner_dispatch_middleware_wraps_everything_including_initialize(server: SrvT): + seen_methods: list[str] = [] + + def trace_mw(next_on_request: Any) -> Any: + async def wrapped(dctx: Any, method: str, params: Any) -> Any: + seen_methods.append(method) + return await next_on_request(dctx, method, params) + + return wrapped + + async with connected_runner(server, dispatch_middleware=[trace_mw]) as (client, _): + await client.send_raw_request("tools/list", None) + assert seen_methods == ["initialize", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_wraps_handlers_but_not_initialize(server: SrvT): + seen_methods: list[str] = [] + + async def ctx_mw(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: + seen_methods.append(method) + return await call_next() + + server.middleware.append(ctx_mw) + async with connected_runner(server) as (client, _): + await client.send_raw_request("ping", None) + await client.send_raw_request("tools/list", None) + # initialize (sent by the helper) NOT wrapped; ping and tools/list ARE. + assert seen_methods == ["ping", "tools/list"] + + +@pytest.mark.anyio +async def test_runner_server_middleware_runs_outermost_first(server: SrvT): + order: list[str] = [] + + def make_mw(tag: str) -> Any: + async def mw(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: + order.append(f"{tag}-in") + result = await call_next() + order.append(f"{tag}-out") + return result + + return mw + + server.middleware.extend([make_mw("a"), make_mw("b")]) + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + assert order == ["a-in", "b-in", "b-out", "a-out"] + + +@pytest.mark.anyio +async def test_runner_handler_returning_none_yields_empty_result(server: SrvT): + async def set_level(ctx: Ctx, params: SetLevelRequestParams) -> None: + return None + + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert result == {} + + +@pytest.mark.anyio +async def test_runner_handler_returning_error_data_produces_jsonrpc_error(server: SrvT): + """A handler returning `ErrorData` reaches the client as a JSON-RPC error, + not a success result, matching `BaseSession._send_response`.""" + + async def set_level(ctx: Ctx, params: SetLevelRequestParams) -> ErrorData: + return ErrorData(code=INVALID_PARAMS, message="bad level", data={"got": params.level}) + + server.add_request_handler("logging/setLevel", SetLevelRequestParams, set_level) + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("logging/setLevel", {"level": "info"}) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="bad level", data={"got": "info"}) + + +@pytest.mark.anyio +async def test_runner_handler_returning_unsupported_type_surfaces_as_error(server: SrvT): + async def bad_return(ctx: Ctx, params: PaginatedRequestParams | None) -> int: + return 42 + + # cast: deliberately registering a handler with a bad return type to + # exercise the runtime check; pyright would (correctly) reject it otherwise. + server.add_request_handler("tools/list", PaginatedRequestParams, cast(Any, bad_return)) + async with connected_runner(server) as (client, _): + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == 0 + assert "int" in exc.value.error.message + + +@pytest.mark.anyio +async def test_runner_stateless_skips_init_gate(server: SrvT): + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + + +@pytest.mark.anyio +async def test_runner_stateless_connection_initialized_event_set_on_construction(server: SrvT): + """`connection.initialized` mirrors the gate flag in stateless mode so + `await connection.initialized.wait()` does not hang when no handshake + arrives.""" + async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (_, runner): + assert runner._initialized is True + assert runner.connection.initialized.is_set() + await runner.connection.initialized.wait() + + +@pytest.mark.anyio +async def test_server_add_request_handler_routes_custom_method_with_validated_params(server: SrvT): + """Custom methods outside the spec `ClientRequest` union skip upfront + validation and route to the registered handler.""" + + class GreetParams(RequestParams): + name: str + + received: list[GreetParams] = [] + + async def greet(ctx: Ctx, params: GreetParams) -> dict[str, Any]: + received.append(params) + return {"greeting": f"hello {params.name}"} + + server.add_request_handler("custom/greet", GreetParams, greet) + async with connected_runner(server) as (client, _): + result = await client.send_raw_request("custom/greet", {"name": "world"}) + assert result == {"greeting": "hello world"} + assert isinstance(received[0], GreetParams) + assert received[0].name == "world" + + +@pytest.mark.anyio +async def test_runner_initialize_result_reflects_init_options(): + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: + raise NotImplementedError + + server: SrvT = Server(name="caps-test", on_list_tools=list_tools, instructions="be nice") + init_options = server.create_initialization_options(NotificationOptions(tools_changed=True), {"ext": {"k": "v"}}) + async with connected_runner(server, initialized=False, init_options=init_options) as (client, _): + result = await client.send_raw_request("initialize", _initialize_params()) + assert result["capabilities"]["tools"]["listChanged"] is True + assert result["capabilities"]["experimental"] == {"ext": {"k": "v"}} + assert result["serverInfo"]["name"] == "caps-test" + assert result["instructions"] == "be nice" + + +@pytest.mark.anyio +async def test_runner_initialize_echoes_supported_version_and_falls_back_to_latest(server: SrvT): + oldest = SUPPORTED_PROTOCOL_VERSIONS[0] + async with connected_runner(server, initialized=False) as (client, _): + params = {**_initialize_params(), "protocolVersion": oldest} + result = await client.send_raw_request("initialize", params) + assert result["protocolVersion"] == oldest + async with connected_runner(server, initialized=False) as (client, _): + params = {**_initialize_params(), "protocolVersion": "1999-01-01"} + result = await client.send_raw_request("initialize", params) + assert result["protocolVersion"] == LATEST_PROTOCOL_VERSION + + +@pytest.mark.anyio +async def test_otel_middleware_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): + async def call_tool(ctx: Ctx, params: CallToolRequestParams) -> dict[str, Any]: + return {"content": [], "isError": False} + + server.add_request_handler("tools/call", CallToolRequestParams, call_tool) + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + finished = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + [span] = finished + assert span.name == "MCP handle tools/call mytool" + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert isinstance(span.attributes["jsonrpc.request.id"], str) + assert span.status.status_code == StatusCode.UNSET + + +@pytest.mark.anyio +async def test_otel_trace_context_propagates_client_to_server(server: SrvT, spans: SpanCapture): + """The client dispatcher injects traceparent into `_meta`; the server's + `otel_middleware` extracts it, so client and server spans share a trace.""" + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", None) + [client_span] = [s for s in spans.finished() if s.kind == SpanKind.CLIENT] + [server_span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert server_span.parent is not None + assert client_span.context is not None and server_span.context is not None + assert server_span.parent.span_id == client_span.context.span_id + assert server_span.context.trace_id == client_span.context.trace_id + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("resources/list", None) + assert exc.value.error.code == METHOD_NOT_FOUND + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found" + # MCPError is a protocol-level response, not a crash - no traceback event. + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_otel_middleware_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): + async def failing(ctx: Ctx, params: PaginatedRequestParams | None) -> Any: + raise ValueError("handler blew up") + + server.add_request_handler("tools/list", PaginatedRequestParams, failing) + async with connected_runner(server, dispatch_middleware=[otel_middleware]) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == 0 + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" + + +@pytest.mark.anyio +async def test_runner_connection_exit_stack_unwinds_after_run_returns(server: SrvT) -> None: + """`runner.connection.exit_stack` is closed when the dispatcher loop ends.""" + cleaned: list[int] = [] + + async def _append(i: int) -> None: + cleaned.append(i) + + async with connected_runner(server) as (client, runner): + for i in (1, 2, 3): + runner.connection.exit_stack.push_async_callback(_append, i) + await client.send_raw_request("tools/list", None) + assert cleaned == [] + assert cleaned == [3, 2, 1] + + +@pytest.mark.anyio +async def test_runner_exit_stack_cleanup_exception_is_logged_not_propagated( + server: SrvT, caplog: pytest.LogCaptureFixture +) -> None: + """A raising cleanup callback is caught and logged; `run()` exits cleanly.""" + cleaned: list[str] = [] + + async def _ok() -> None: + cleaned.append("ok") + + async def _boom() -> None: + raise RuntimeError("cleanup failed") + + async with connected_runner(server) as (client, runner): + runner.connection.exit_stack.push_async_callback(_ok) + runner.connection.exit_stack.push_async_callback(_boom) + await client.send_raw_request("tools/list", None) + assert cleaned == ["ok"] + assert "connection exit_stack cleanup raised" in caplog.text diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py new file mode 100644 index 0000000000..8971d3d52f --- /dev/null +++ b/tests/server/test_server_context.py @@ -0,0 +1,158 @@ +"""Tests for the server-side `Context`. + +`Context` composes `BaseContext` (forwarding to a `DispatchContext`) with +`PeerMixin` (typed sample/elicit/roots/ping) plus `lifespan` and `connection`. +End-to-end tested over `DirectDispatcher`. +""" + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any + +import anyio +import pytest + +from mcp.server.connection import Connection +from mcp.server.context import Context +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.transport_context import TransportContext +from mcp.types import CreateMessageResult, ListRootsRequest, ListRootsResult, SamplingMessage, TextContent + +from ..shared.conftest import direct_pair +from ..shared.test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@dataclass +class _Lifespan: + name: str + + +@pytest.mark.anyio +async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): + captured: list[Context[_Lifespan]] = [] + conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + captured.append(ctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): + # Now we have the server dispatcher; build the real Connection bound to it. + conn.__init__(server, has_standalone_channel=True, session_id="sess-1") + with anyio.fail_after(5): + await client.send_raw_request("t", None) + ctx = captured[0] + assert ctx.lifespan.name == "app" + assert ctx.connection is conn + assert ctx.transport.kind == "direct" + assert ctx.can_send_request is True + assert ctx.session_id == "sess-1" + assert ctx.headers is None + + +@pytest.mark.anyio +async def test_context_sample_round_trips_via_peer_mixin_on_base_context_outbound(): + crec = Recorder() + + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + crec.requests.append((method, params)) + return {"role": "assistant", "content": {"type": "text", "text": "ok"}, "model": "m"} + + results: list[CreateMessageResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append( + await ctx.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hi"))], + max_tokens=5, + ) + ) + return {} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=client_on_request, + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert crec.requests[0][0] == "sampling/createMessage" + assert isinstance(results[0], CreateMessageResult) + + +@pytest.mark.anyio +async def test_context_send_request_with_spec_type_infers_result_via_typed_mixin(): + async def client_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return {"roots": []} + + results: list[ListRootsResult] = [] + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + results.append(await ctx.send_request(ListRootsRequest())) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_request=client_on_request) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert isinstance(results[0], ListRootsResult) + + +@pytest.mark.anyio +async def test_context_log_sends_request_scoped_message_notification(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("debug", "hello") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + method, params = crec.notifications[0] + assert method == "notifications/message" + assert params is not None and params["level"] == "debug" and params["data"] == "hello" + + +@pytest.mark.anyio +async def test_context_log_includes_logger_and_meta_when_supplied(): + crec = Recorder() + _, c_notify = echo_handlers(crec) + + async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + ctx: Context[_Lifespan] = Context( + dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) + ) + await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request, client_on_notify=c_notify) as ( + client, + *_, + ): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + await crec.notified.wait() + _, params = crec.notifications[0] + assert params is not None + assert params["logger"] == "my.log" + assert params["_meta"] == {"traceId": "t"} diff --git a/tests/server/test_session.py b/tests/server/test_session.py index 6116a7c7f5..c77ac8a42c 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,535 +1,128 @@ -from typing import Any +"""Tests for `ServerSession`. + +`ServerSession` is a thin proxy over a dispatcher and a `Connection`. Tested +with a stub dispatcher so we can assert what reaches the wire (method, params, +`CallOptions`, related-request-id) without standing up a full transport. +""" + +from collections.abc import Mapping +from typing import Any, cast -import anyio import pytest from mcp import types -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions +from mcp.server.connection import Connection from mcp.server.session import ServerSession -from mcp.shared.exceptions import MCPError -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder +from mcp.shared.dispatcher import CallOptions +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import ServerMessageMetadata from mcp.types import ( - ClientNotification, - CompletionsCapability, - InitializedNotification, - PromptsCapability, - ResourcesCapability, - ServerCapabilities, + LATEST_PROTOCOL_VERSION, + ClientCapabilities, + Implementation, + InitializeRequestParams, + SamplingCapability, + SamplingToolsCapability, ) -@pytest.mark.anyio -async def test_server_session_initialize(): - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - - # Create a message handler to catch exceptions - async def message_handler( # pragma: no cover - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): - raise message - - received_initialized = False - - async def run_server(): - nonlocal received_initialized - - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - if isinstance(message, ClientNotification) and isinstance( - message, InitializedNotification - ): # pragma: no branch - received_initialized = True - return - - try: - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=message_handler, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - - await client_session.initialize() - except anyio.ClosedResourceError: # pragma: no cover - pass +class StubDispatcher: + """Records `send_raw_request` / `notify` calls and returns a canned result.""" - assert received_initialized + def __init__(self, result: dict[str, Any] | None = None) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None, Any]] = [] + self.result = result if result is not None else {} + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + *, + _related_request_id: Any = None, + ) -> dict[str, Any]: + self.requests.append((method, params, opts, _related_request_id)) + return self.result -@pytest.mark.anyio -async def test_check_client_capability(): - """check_client_capability reflects the capabilities sent by the client at initialize.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) - - initialized = anyio.Event() - - async def list_roots_callback(context: Any) -> types.ListRootsResult: # pragma: no cover - return types.ListRootsResult(roots=[]) - - async def run_server(server_session: ServerSession): - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, ClientNotification) and isinstance( - message, InitializedNotification - ): # pragma: no branch - initialized.set() - return - - async with ( - ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions(server_name="mcp", server_version="0.1.0", capabilities=ServerCapabilities()), - ) as server_session, - ClientSession( - server_to_client_receive, - client_to_server_send, - list_roots_callback=list_roots_callback, - ) as client_session, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server, server_session) - await client_session.initialize() - with anyio.fail_after(5): - await initialized.wait() - - # ClientSession advertises roots when a list_roots_callback is provided. - assert server_session.check_client_capability(types.ClientCapabilities(roots=types.RootsCapability())) - # ClientSession does not advertise sampling without a sampling_callback. - assert not server_session.check_client_capability(types.ClientCapabilities(sampling=types.SamplingCapability())) - - -@pytest.mark.anyio -async def test_server_capabilities(): - notification_options = NotificationOptions() - experimental_capabilities: dict[str, Any] = {} - - async def noop_list_prompts( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListPromptsResult: - raise NotImplementedError - - async def noop_list_resources( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListResourcesResult: - raise NotImplementedError - - async def noop_completion(ctx: ServerRequestContext, params: types.CompleteRequestParams) -> types.CompleteResult: + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: raise NotImplementedError - # No capabilities - server = Server("test") - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts is None - assert caps.resources is None - assert caps.completions is None - - # With prompts handler - server = Server("test", on_list_prompts=noop_list_prompts) - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts == PromptsCapability(list_changed=False) - assert caps.resources is None - assert caps.completions is None - - # With prompts + resources handlers - server = Server("test", on_list_prompts=noop_list_prompts, on_list_resources=noop_list_resources) - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts == PromptsCapability(list_changed=False) - assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) - assert caps.completions is None - - # With prompts + resources + completion handlers - server = Server( - "test", - on_list_prompts=noop_list_prompts, - on_list_resources=noop_list_resources, - on_completion=noop_completion, - ) - caps = server.get_capabilities(notification_options, experimental_capabilities) - assert caps.prompts == PromptsCapability(list_changed=False) - assert caps.resources == ResourcesCapability(subscribe=False, list_changed=False) - assert caps.completions == CompletionsCapability() - - -@pytest.mark.anyio -async def test_server_session_initialize_with_older_protocol_version(): - """Test that server accepts and responds with older protocol (2024-11-05).""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - received_initialized = False - received_protocol_version = None - - async def run_server(): - nonlocal received_initialized - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - if isinstance(message, types.ClientNotification) and isinstance( - message, InitializedNotification - ): # pragma: no branch - received_initialized = True - return - - async def mock_client(): - nonlocal received_protocol_version - - # Send initialization request with older protocol version (2024-11-05) - await client_to_server_send.send( - SessionMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version="2024-11-05", - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) +def _make_session(dispatcher: StubDispatcher, *, capabilities: ClientCapabilities | None = None) -> ServerSession: + conn = Connection(dispatcher, has_standalone_channel=True) + if capabilities is not None: + conn.client_params = InitializeRequestParams( + protocol_version=LATEST_PROTOCOL_VERSION, + capabilities=capabilities, + client_info=Implementation(name="c", version="0"), ) - - # Wait for the initialize response - init_response_message = await server_to_client_receive.receive() - assert isinstance(init_response_message.message, types.JSONRPCResponse) - result_data = init_response_message.message.result - init_result = types.InitializeResult.model_validate(result_data) - - # Check that the server responded with the requested protocol version - received_protocol_version = init_result.protocol_version - assert received_protocol_version == "2024-11-05" - - # Send initialized notification - await client_to_server_send.send( - SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) - ) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - - assert received_initialized - assert received_protocol_version == "2024-11-05" + # cast: `ServerSession` is typed to take `JSONRPCDispatcher` but only ever + # calls `send_raw_request` / `notify`, so the stub is structurally sufficient. + return ServerSession(cast("JSONRPCDispatcher[Any]", dispatcher), conn) @pytest.mark.anyio -async def test_ping_request_before_initialization(): - """Test that ping requests are allowed before initialization is complete.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - ping_response_received = False - ping_response_id = None - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - # We should receive a ping request before initialization - if isinstance(message, RequestResponder) and isinstance( - message.request, types.PingRequest - ): # pragma: no branch - # Respond to the ping - with message: - await message.respond(types.EmptyResult()) - return - - async def mock_client(): - nonlocal ping_response_received, ping_response_id +async def test_send_request_forwards_timeout_and_progress_callback_as_call_options(): + dispatcher = StubDispatcher(result={"roots": []}) + session = _make_session(dispatcher) - # Send ping request before any initialization - await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=42, method="ping"))) - - # Wait for the ping response - ping_response_message = await server_to_client_receive.receive() - assert isinstance(ping_response_message.message, types.JSONRPCResponse) - - ping_response_received = True - ping_response_id = ping_response_message.message.id - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + raise NotImplementedError - assert ping_response_received - assert ping_response_id == 42 + result = await session.send_request( + types.ListRootsRequest(), + types.ListRootsResult, + request_read_timeout_seconds=2.5, + metadata=ServerMessageMetadata(related_request_id=7), + progress_callback=on_progress, + ) + assert isinstance(result, types.ListRootsResult) + method, _params, opts, related = dispatcher.requests[0] + assert method == "roots/list" + assert opts == {"timeout": 2.5, "on_progress": on_progress} + assert related == 7 @pytest.mark.anyio -async def test_create_message_tool_result_validation(): - """Test tool_use/tool_result validation in create_message.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as session: - # Set up client params with sampling.tools capability for the test - session._client_params = types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities( - sampling=types.SamplingCapability(tools=types.SamplingToolsCapability()) - ), - client_info=types.Implementation(name="test", version="1.0"), - ) - - tool = types.Tool(name="test_tool", input_schema={"type": "object"}) - text = types.TextContent(type="text", text="hello") - tool_use = types.ToolUseContent(type="tool_use", id="call_1", name="test_tool", input={}) - tool_result = types.ToolResultContent(type="tool_result", tool_use_id="call_1", content=[]) - - # Case 1: tool_result mixed with other content - with pytest.raises(ValueError, match="only tool_result content"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="assistant", content=tool_use), - types.SamplingMessage(role="user", content=[tool_result, text]), # mixed! - ], - max_tokens=100, - tools=[tool], - ) - - # Case 2: tool_result without previous message - with pytest.raises(ValueError, match="requires a previous message"): - await session.create_message( - messages=[types.SamplingMessage(role="user", content=tool_result)], - max_tokens=100, - tools=[tool], - ) - - # Case 3: tool_result without previous tool_use - with pytest.raises(ValueError, match="do not match any tool_use"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="user", content=tool_result), - ], - max_tokens=100, - tools=[tool], - ) - - # Case 4: mismatched tool IDs - with pytest.raises(ValueError, match="ids of tool_result blocks and tool_use blocks"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="assistant", content=tool_use), - types.SamplingMessage( - role="user", - content=types.ToolResultContent(type="tool_result", tool_use_id="wrong_id", content=[]), - ), - ], - max_tokens=100, - tools=[tool], - ) - - # Case 5: text-only message with tools (no tool_results) - passes validation - # Covers has_tool_results=False branch. - # We use move_on_after because validation happens synchronously before - # send_request, which would block indefinitely waiting for a response. - # The timeout lets validation pass, then cancels the blocked send. - with anyio.move_on_after(0.01): - await session.create_message( - messages=[types.SamplingMessage(role="user", content=text)], - max_tokens=100, - tools=[tool], - ) - - # Case 6: valid matching tool_result/tool_use IDs - passes validation - # Covers tool_use_ids == tool_result_ids branch. - # (see Case 5 comment for move_on_after explanation) - with anyio.move_on_after(0.01): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="assistant", content=tool_use), - types.SamplingMessage(role="user", content=tool_result), - ], - max_tokens=100, - tools=[tool], - ) - - # Case 7: validation runs even without `tools` parameter - # (tool loop continuation may omit tools while containing tool_result) - with pytest.raises(ValueError, match="do not match any tool_use"): - await session.create_message( - messages=[ - types.SamplingMessage(role="user", content=text), - types.SamplingMessage(role="user", content=tool_result), - ], - max_tokens=100, - # Note: no tools parameter - ) - - # Case 8: empty messages list - skips validation entirely - # Covers the `if messages:` branch (line 280->302) - with anyio.move_on_after(0.01): # pragma: no branch - await session.create_message(messages=[], max_tokens=100) +async def test_send_request_omits_call_options_when_none_given(): + dispatcher = StubDispatcher(result={"roots": []}) + session = _make_session(dispatcher) + await session.send_request(types.ListRootsRequest(), types.ListRootsResult) + _method, _params, opts, related = dispatcher.requests[0] + assert opts is None + assert related is None @pytest.mark.anyio -async def test_create_message_without_tools_capability(): - """Test that create_message raises MCPError when tools are provided without capability.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ) as session: - # Set up client params WITHOUT sampling.tools capability - session._client_params = types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities(sampling=types.SamplingCapability()), - client_info=types.Implementation(name="test", version="1.0"), - ) - - tool = types.Tool(name="test_tool", input_schema={"type": "object"}) - text = types.TextContent(type="text", text="hello") - - # Should raise MCPError when tools are provided but client lacks capability - with pytest.raises(MCPError) as exc_info: - await session.create_message( - messages=[types.SamplingMessage(role="user", content=text)], - max_tokens=100, - tools=[tool], - ) - assert "does not support sampling tools capability" in exc_info.value.error.message - - # Should also raise MCPError when tool_choice is provided - with pytest.raises(MCPError) as exc_info: - await session.create_message( - messages=[types.SamplingMessage(role="user", content=text)], - max_tokens=100, - tool_choice=types.ToolChoice(mode="auto"), - ) - assert "does not support sampling tools capability" in exc_info.value.error.message +async def test_send_request_validates_result_alias_only(): + """Peer results validate alias-only; a snake_case key from the wire is + ignored as extra, not populated by Python field name.""" + snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} + session = _make_session(StubDispatcher(result=snake)) + request = types.CreateMessageRequest(params=types.CreateMessageRequestParams(messages=[], max_tokens=1)) + result = await session.send_request(request, types.CreateMessageResult) + assert result.stop_reason is None @pytest.mark.anyio -async def test_other_requests_blocked_before_initialization(): - """Test that non-ping requests are still blocked before initialization.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - error_response_received = False - error_code = None - - async def run_server(): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="mcp", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ), - ): - # Server should handle the request and send an error response - # No need to process incoming_messages since the error is handled automatically - await anyio.sleep(0.1) # Give time for the request to be processed - - async def mock_client(): - nonlocal error_response_received, error_code - - # Try to send a non-ping request before initialization - await client_to_server_send.send( - SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=1, method="prompts/list")) - ) - - # Wait for the error response - error_message = await server_to_client_receive.receive() - if isinstance(error_message.message, types.JSONRPCError): # pragma: no branch - error_response_received = True - error_code = error_message.message.error.code +async def test_create_message_with_tools_returns_with_tools_result(): + dispatcher = StubDispatcher(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) + session = _make_session( + dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + ) + result = await session.create_message( + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], + max_tokens=10, + tools=[types.Tool(name="t", input_schema={"type": "object"})], + ) + assert isinstance(result, types.CreateMessageResultWithTools) + method, params, _opts, _related = dispatcher.requests[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - assert error_response_received - assert error_code == types.INVALID_PARAMS +def test_check_client_capability_delegates_to_connection(): + dispatcher = StubDispatcher() + session = _make_session(dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability())) + assert session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())) is True + assert session.check_client_capability(ClientCapabilities(experimental={"x": {}})) is False diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py deleted file mode 100644 index 81041152bc..0000000000 --- a/tests/server/test_session_race_condition.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Test for race condition fix in initialization flow. - -This test verifies that requests can be processed immediately after -responding to InitializeRequest, without waiting for InitializedNotification. - -This is critical for HTTP transport where requests can arrive in any order. -""" - -import anyio -import pytest - -from mcp import types -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder -from mcp.types import ServerCapabilities, Tool - - -@pytest.mark.anyio -async def test_request_immediately_after_initialize_response(): - """Test that requests are accepted immediately after initialize response. - - This reproduces the race condition in stateful HTTP mode where: - 1. Client sends InitializeRequest - 2. Server responds with InitializeResult - 3. Client immediately sends tools/list (before server receives InitializedNotification) - 4. Without fix: Server rejects with "Received request before initialization was complete" - 5. With fix: Server accepts and processes the request - - This test simulates the HTTP transport behavior where InitializedNotification - may arrive in a separate POST request after other requests. - """ - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](10) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](10) - - tools_list_success = False - error_received = None - - async def run_server(): - nonlocal tools_list_success - - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="test-server", - server_version="1.0.0", - capabilities=ServerCapabilities( - tools=types.ToolsCapability(list_changed=False), - ), - ), - ) as server_session: - async for message in server_session.incoming_messages: # pragma: no branch - if isinstance(message, Exception): # pragma: no cover - raise message - - # Handle tools/list request - if isinstance(message, RequestResponder): - if isinstance(message.request, types.ListToolsRequest): # pragma: no branch - tools_list_success = True - # Respond with a tool list - with message: - await message.respond( - types.ListToolsResult( - tools=[ - Tool( - name="example_tool", - description="An example tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - ) - - # Handle InitializedNotification - if isinstance(message, types.ClientNotification): - if isinstance(message, types.InitializedNotification): # pragma: no branch - # Done - exit gracefully - return - - async def mock_client(): - nonlocal error_received - - # Step 1: Send InitializeRequest - await client_to_server_send.send( - SessionMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) - ) - ) - - # Step 2: Wait for InitializeResult - init_msg = await server_to_client_receive.receive() - assert isinstance(init_msg.message, types.JSONRPCResponse) - - # Step 3: Immediately send tools/list BEFORE InitializedNotification - # This is the race condition scenario - await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=2, method="tools/list"))) - - # Step 4: Check the response - tools_msg = await server_to_client_receive.receive() - if isinstance(tools_msg.message, types.JSONRPCError): # pragma: no cover - error_received = tools_msg.message.error.message - - # Step 5: Send InitializedNotification - await client_to_server_send.send( - SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) - ) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - anyio.create_task_group() as tg, - ): - tg.start_soon(run_server) - tg.start_soon(mock_client) - - # With the PR fix: tools_list_success should be True, error_received should be None - # Without the fix: error_received would contain "Received request before initialization was complete" - assert tools_list_success, f"tools/list should have succeeded. Error received: {error_received}" - assert error_received is None, f"Expected no error, but got: {error_received}" diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 3bfc6e674c..1b628e2388 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -7,45 +7,35 @@ See: https://github.com/modelcontextprotocol/python-sdk/issues/1097 """ -from collections.abc import AsyncGenerator from typing import Any +from unittest.mock import Mock import anyio import pytest from mcp import types -from mcp.server.models import InitializationOptions +from mcp.server.connection import Connection +from mcp.server.context import ServerRequestContext +from mcp.server.lowlevel.server import Server from mcp.server.session import ServerSession -from mcp.shared.exceptions import StatelessModeNotSupported +from mcp.shared.exceptions import NoBackChannelError, StatelessModeNotSupported +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.message import SessionMessage -from mcp.types import ServerCapabilities +from mcp.types import JSONRPCRequest, JSONRPCResponse, ListToolsResult, PaginatedRequestParams -@pytest.fixture -async def stateless_session() -> AsyncGenerator[ServerSession, None]: - """Create a stateless ServerSession for testing.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - init_options = InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), +def _make_session(*, stateless: bool) -> ServerSession: + """A `ServerSession` with a mock dispatcher; the stateless guard fires before any send.""" + return ServerSession( + Mock(spec=JSONRPCDispatcher), + Connection(Mock(), has_standalone_channel=False), + stateless=stateless, ) - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - init_options, - stateless=True, - ) as session: - yield session + +@pytest.fixture +def stateless_session() -> ServerSession: + return _make_session(stateless=True) @pytest.mark.anyio @@ -126,30 +116,8 @@ async def test_exception_has_method_attribute(stateless_session: ServerSession): @pytest.fixture -async def stateful_session() -> AsyncGenerator[ServerSession, None]: - """Create a stateful ServerSession for testing.""" - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage | Exception](1) - - init_options = InitializationOptions( - server_name="test", - server_version="0.1.0", - capabilities=ServerCapabilities(), - ) - - async with ( - client_to_server_send, - client_to_server_receive, - server_to_client_send, - server_to_client_receive, - ): - async with ServerSession( - client_to_server_receive, - server_to_client_send, - init_options, - stateless=False, - ) as session: - yield session +def stateful_session() -> ServerSession: + return _make_session(stateless=False) @pytest.mark.anyio @@ -175,3 +143,45 @@ async def mock_send_request(*_: Any, **__: Any) -> types.ListRootsResult: assert send_request_called assert isinstance(result, types.ListRootsResult) + + +@pytest.mark.anyio +async def test_server_run_stateless_wires_no_standalone_channel(): + """`Server.run(stateless=True)` must wire `Connection.has_standalone_channel=False`. + + Stateless HTTP has no standalone GET stream, so server-initiated requests on + the connection must fail fast with `NoBackChannelError` rather than write to + a channel that will never deliver a response. The `ServerSession` typed + helpers carry their own stateless guard (tested above); this pins the + `Connection` wiring that `Server.run` produces. + """ + captured: list[Connection] = [] + + async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: + # `ServerRequestContext` doesn't expose `connection` directly yet (it + # will after the Context rework); reach it via the session for now. + captured.append(ctx.session._connection) # pyright: ignore[reportPrivateUsage] + return ListToolsResult(tools=[]) + + server: Server[Any] = Server("test", on_list_tools=list_tools) + + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) + + async def run_server() -> None: + await server.run(server_read, server_write, server.create_initialization_options(), stateless=True) + + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: + tg.start_soon(run_server) + # stateless=True skips the init gate, so tools/list routes immediately. + await to_server.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/list"))) + with anyio.fail_after(5): + response = (await from_server.receive()).message + assert isinstance(response, JSONRPCResponse) + tg.cancel_scope.cancel() + + assert len(captured) == 1 + conn = captured[0] + assert conn.has_standalone_channel is False + with pytest.raises(NoBackChannelError): + await conn.ping() diff --git a/tests/server/test_validation.py b/tests/server/test_validation.py index ad97dd3fd6..19f4eb1088 100644 --- a/tests/server/test_validation.py +++ b/tests/server/test_validation.py @@ -120,6 +120,16 @@ def test_validate_tool_use_result_messages_raises_when_tool_result_without_previ validate_tool_use_result_messages(messages) +def test_validate_tool_use_result_messages_raises_when_previous_message_has_no_tool_use() -> None: + """Raises when tool_result follows a message that has content but no tool_use.""" + messages = [ + SamplingMessage(role="assistant", content=TextContent(type="text", text="just text")), + SamplingMessage(role="user", content=ToolResultContent(type="tool_result", tool_use_id="tool-1")), + ] + with pytest.raises(ValueError, match="do not match any tool_use in the previous message"): + validate_tool_use_result_messages(messages) + + def test_validate_tool_use_result_messages_raises_when_tool_result_ids_dont_match_tool_use() -> None: """Raises when tool_result IDs don't match tool_use IDs.""" messages = [ diff --git a/tests/shared/conftest.py b/tests/shared/conftest.py new file mode 100644 index 0000000000..7b53b42654 --- /dev/null +++ b/tests/shared/conftest.py @@ -0,0 +1,61 @@ +"""Shared fixtures for `Dispatcher` contract tests. + +The `pair_factory` fixture parametrizes contract tests over every `Dispatcher` +implementation, so the same behavioral assertions run against `DirectDispatcher` +(in-memory) and `JSONRPCDispatcher` (over crossed anyio memory streams). +""" + +from collections.abc import Callable + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import create_direct_dispatcher_pair +from mcp.shared.dispatcher import Dispatcher +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.message import SessionMessage +from mcp.shared.transport_context import TransportContext + +DispatcherTriple = tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Callable[[], None]] +PairFactory = Callable[..., DispatcherTriple] + + +def direct_pair(*, can_send_request: bool = True) -> DispatcherTriple: + client, server = create_direct_dispatcher_pair(can_send_request=can_send_request) + + def close() -> None: + client.close() + server.close() + + return client, server, close + + +def jsonrpc_pair(*, can_send_request: bool = True) -> DispatcherTriple: + """Two `JSONRPCDispatcher`s wired over crossed in-memory streams.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=can_send_request) + + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send, transport_builder=builder) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send, transport_builder=builder) + + def close() -> None: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + return client, server, close + + +@pytest.fixture( + params=[ + pytest.param(direct_pair, id="direct"), + pytest.param(jsonrpc_pair, id="jsonrpc"), + ] +) +def pair_factory(request: pytest.FixtureRequest) -> PairFactory: + return request.param + + +__all__ = ["PairFactory", "direct_pair", "jsonrpc_pair"] diff --git a/tests/shared/test_context.py b/tests/shared/test_context.py new file mode 100644 index 0000000000..68057a9e10 --- /dev/null +++ b/tests/shared/test_context.py @@ -0,0 +1,133 @@ +"""Tests for `BaseContext`. + +`BaseContext` is composition over a `DispatchContext` - it forwards +`transport`/`cancel_requested`/`send_raw_request`/`notify`/`progress` +and adds `meta`. It must satisfy `Outbound` so `PeerMixin` works on it. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.context import BaseContext +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer +from mcp.shared.transport_context import TransportContext + +from .conftest import direct_pair, jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_base_context_forwards_transport_and_cancel_requested(): + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + captured.append(bctx) + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + bctx = captured[0] + assert bctx.transport.kind == "direct" + assert isinstance(bctx.cancel_requested, anyio.Event) + assert bctx.can_send_request is True + assert bctx.meta is None + + +@pytest.mark.anyio +async def test_base_context_can_send_request_reflects_dispatch_context_closed_state(): + """`can_send_request` must track the dctx, not the static transport flag, + so it agrees with whether `send_raw_request` would raise.""" + captured: list[BaseContext[TransportContext]] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + captured.append(BaseContext(ctx)) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + bctx = captured[0] + assert bctx.transport.can_send_request is True + assert bctx.can_send_request is False + + +@pytest.mark.anyio +async def test_base_context_send_raw_request_and_notify_forward_to_dispatch_context(): + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + sample = await bctx.send_raw_request("sampling/createMessage", {"x": 1}) + await bctx.notify("notifications/message", {"level": "info"}) + return {"sample": sample} + + async with running_pair( + direct_pair, + server_on_request=server_on_request, + client_on_request=c_req, + client_on_notify=c_notify, + ) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + await crec.notified.wait() + assert crec.requests == [("sampling/createMessage", {"x": 1})] + assert crec.notifications == [("notifications/message", {"level": "info"})] + assert result["sample"] == {"echoed": "sampling/createMessage", "params": {"x": 1}} + + +@pytest.mark.anyio +async def test_base_context_report_progress_invokes_caller_on_progress(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await bctx.report_progress(0.5, total=1.0, message="halfway") + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_base_context_satisfies_outbound_so_peer_mixin_works(): + """Wrapping a BaseContext in Peer proves it satisfies Outbound structurally.""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx) + await Peer(bctx).ping() + return {} + + crec = Recorder() + c_req, c_notify = echo_handlers(crec) + async with running_pair( + direct_pair, server_on_request=server_on_request, client_on_request=c_req, client_on_notify=c_notify + ) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) + assert crec.requests == [("ping", None)] + + +@pytest.mark.anyio +async def test_base_context_meta_holds_supplied_request_params_meta(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + bctx = BaseContext(ctx, meta={"progressToken": "abc"}) + assert bctx.meta is not None and bctx.meta.get("progressToken") == "abc" + return {} + + async with running_pair(direct_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None) diff --git a/tests/shared/test_dispatcher.py b/tests/shared/test_dispatcher.py new file mode 100644 index 0000000000..745f4b3875 --- /dev/null +++ b/tests/shared/test_dispatcher.py @@ -0,0 +1,299 @@ +"""Behavioral tests for the Dispatcher Protocol. + +The contract tests are parametrized over every `Dispatcher` implementation via +the `pair_factory` fixture (see `conftest.py`); they must pass for both +`DirectDispatcher` and `JSONRPCDispatcher`. Implementation-specific tests pass +a concrete factory directly. +""" + +from collections.abc import AsyncIterator, Mapping +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any + +import anyio +import pytest + +from mcp.shared.direct_dispatcher import DirectDispatcher, create_direct_dispatcher_pair +from mcp.shared.dispatcher import DispatchContext, Dispatcher, OnNotify, OnRequest, Outbound +from mcp.shared.exceptions import MCPError +from mcp.shared.transport_context import TransportContext +from mcp.types import INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST, REQUEST_TIMEOUT, ErrorData, Tool + +from .conftest import PairFactory, direct_pair + + +class Recorder: + def __init__(self) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.contexts: list[DispatchContext[TransportContext]] = [] + self.notified = anyio.Event() + + +def echo_handlers(recorder: Recorder) -> tuple[OnRequest, OnNotify]: + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + # Strip `_meta` so JSON-RPC and direct dispatch record identically: + # the JSON-RPC outbound path always attaches `_meta` (otel injection). + recorded = {k: v for k, v in (params or {}).items() if k != "_meta"} if params is not None else None + recorder.requests.append((method, recorded)) + recorder.contexts.append(ctx) + return {"echoed": method, "params": recorded or {}} + + async def on_notify(ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None) -> None: + recorder.notifications.append((method, params)) + recorder.notified.set() + + return on_request, on_notify + + +@asynccontextmanager +async def running_pair( + factory: PairFactory, + *, + server_on_request: OnRequest | None = None, + server_on_notify: OnNotify | None = None, + client_on_request: OnRequest | None = None, + client_on_notify: OnNotify | None = None, + can_send_request: bool = True, +) -> AsyncIterator[tuple[Dispatcher[TransportContext], Dispatcher[TransportContext], Recorder, Recorder]]: + """Yield `(client, server, client_recorder, server_recorder)` with both `run()` loops live.""" + client, server, close = factory(can_send_request=can_send_request) + client_rec, server_rec = Recorder(), Recorder() + c_req, c_notify = echo_handlers(client_rec) + s_req, s_notify = echo_handlers(server_rec) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, client_on_request or c_req, client_on_notify or c_notify) + await tg.start(server.run, server_on_request or s_req, server_on_notify or s_notify) + try: + yield client, server, client_rec, server_rec + finally: + tg.cancel_scope.cancel() + finally: + close() + + +@pytest.mark.anyio +async def test_send_raw_request_returns_result_from_peer_on_request(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/list", {"cursor": "abc"}) + assert result == {"echoed": "tools/list", "params": {"cursor": "abc"}} + assert srec.requests == [("tools/list", {"cursor": "abc"})] + + +@pytest.mark.anyio +async def test_send_raw_request_reraises_mcperror_from_handler_unchanged(pair_factory: PairFactory): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise MCPError(code=INVALID_PARAMS, message="bad cursor") + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INVALID_PARAMS + assert exc.value.error.message == "bad cursor" + + +@pytest.mark.anyio +async def test_send_raw_request_maps_validation_error_to_invalid_params(pair_factory: PairFactory): + """A pydantic `ValidationError` from the handler surfaces as the + normalized INVALID_PARAMS shape on every dispatcher.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + +@pytest.mark.anyio +async def test_send_raw_request_with_timeout_raises_mcperror_request_timeout(pair_factory: PairFactory): + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await anyio.sleep_forever() + raise NotImplementedError + + async with running_pair(pair_factory, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None, {"timeout": 0}) + assert exc.value.error.code == REQUEST_TIMEOUT + + +@pytest.mark.anyio +async def test_notify_invokes_peer_on_notify(pair_factory: PairFactory): + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/initialized", {"v": 1}) + await srec.notified.wait() + assert srec.notifications == [("notifications/initialized", {"v": 1})] + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_round_trips_to_calling_side(pair_factory: PairFactory): + """A handler's ctx.send_raw_request reaches the side that made the inbound request.""" + + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + sample = await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + return {"sampled": sample} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert crec.requests == [("sampling/createMessage", {"prompt": "hi"})] + assert result == {"sampled": {"echoed": "sampling/createMessage", "params": {"prompt": "hi"}}} + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_raises_nobackchannelerror_when_transport_disallows(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", None) + + async with running_pair(pair_factory, server_on_request=server_on_request, can_send_request=False) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/call", None) + assert exc.value.error.code == INVALID_REQUEST + + +@pytest.mark.anyio +async def test_ctx_notify_invokes_calling_side_on_notify(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.notify("notifications/message", {"level": "info"}) + return {} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + await crec.notified.wait() + assert crec.notifications == [("notifications/message", {"level": "info"})] + + +@pytest.mark.anyio +async def test_ctx_progress_invokes_caller_on_progress_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5, total=1.0, message="halfway") + return {} + + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) + assert received == [(0.5, 1.0, "halfway")] + + +@pytest.mark.anyio +async def test_ctx_progress_is_noop_when_caller_supplied_no_callback(pair_factory: PairFactory): + async def server_on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + async with running_pair(pair_factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("tools/call", None) + assert result == {"ok": True} + + +@pytest.mark.anyio +async def test_ctx_message_metadata_is_none_when_transport_attaches_nothing(pair_factory: PairFactory): + """Plain requests carry no transport metadata, so handlers see `None`.""" + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + assert len(srec.contexts) == 1 + assert srec.contexts[0].message_metadata is None + + +@pytest.mark.anyio +async def test_ctx_request_id_exposes_inbound_id(pair_factory: PairFactory): + """JSON-RPC carries the wire id through; direct dispatch has none.""" + async with running_pair(pair_factory) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None) + await client.send_raw_request("tools/call", None) + a, b = (ctx.request_id for ctx in srec.contexts) + assert (a is None and b is None) or (isinstance(a, int) and isinstance(b, int) and a != b) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_wraps_non_mcperror_exception_as_internal_error_with_cause(): + """DirectDispatcher-specific: the original exception is chained via __cause__.""" + + async def on_request( + ctx: DispatchContext[TransportContext], method: str, params: Mapping[str, Any] | None + ) -> dict[str, Any]: + raise ValueError("oops") + + async with running_pair(direct_pair, server_on_request=on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", {}) + assert exc.value.error.code == INTERNAL_ERROR + assert isinstance(exc.value.__cause__, ValueError) + + +@pytest.mark.anyio +async def test_direct_send_raw_request_issued_before_peer_run_blocks_until_peer_ready(): + client, server = create_direct_dispatcher_pair() + s_req, s_notify = echo_handlers(Recorder()) + c_req, c_notify = echo_handlers(Recorder()) + + async def late_start(): + await anyio.sleep(0) + await server.run(s_req, s_notify) + + async with anyio.create_task_group() as tg: + tg.start_soon(client.run, c_req, c_notify) + tg.start_soon(late_start) + with anyio.fail_after(5): + result = await client.send_raw_request("ping", None) + assert result == {"echoed": "ping", "params": {}} + client.close() + server.close() + + +@pytest.mark.anyio +async def test_direct_send_raw_request_and_notify_raise_runtimeerror_when_no_peer_connected(): + d = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + with pytest.raises(RuntimeError, match="no peer"): + await d.send_raw_request("ping", None) + with pytest.raises(RuntimeError, match="no peer"): + await d.notify("ping", None) + + +@pytest.mark.anyio +async def test_direct_close_makes_run_return(): + client, server = create_direct_dispatcher_pair() + on_request, on_notify = echo_handlers(Recorder()) + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: + tg.start_soon(server.run, on_request, on_notify) + tg.start_soon(client.run, on_request, on_notify) + client.close() + server.close() + + +if TYPE_CHECKING: + _d: Dispatcher[TransportContext] = DirectDispatcher(TransportContext(kind="direct", can_send_request=True)) + _o: Outbound = _d diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py new file mode 100644 index 0000000000..da8f8272f8 --- /dev/null +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -0,0 +1,933 @@ +"""JSON-RPC-specific Dispatcher tests. + +Behaviors with no `DirectDispatcher` analog: request-id correlation, the +exception-to-wire boundary, peer-cancel handling, and shutdown fan-out. +The contract tests shared with `DirectDispatcher` live in +`test_dispatcher.py`. +""" + +import contextvars +from collections.abc import Mapping +from typing import Any + +import anyio +import anyio.lowlevel +import pytest + +from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream +from mcp.shared.dispatcher import CallOptions, DispatchContext +from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.jsonrpc_dispatcher import ( # pyright: ignore[reportPrivateUsage] + JSONRPCDispatcher, + _coerce_id, + _outbound_metadata, + _Pending, +) +from mcp.shared.message import ClientMessageMetadata, MessageMetadata, ServerMessageMetadata, SessionMessage +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CONNECTION_CLOSED, + INTERNAL_ERROR, + INVALID_PARAMS, + ErrorData, + JSONRPCError, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + Tool, +) + +from .conftest import jsonrpc_pair +from .test_dispatcher import Recorder, echo_handlers, running_pair + +DCtx = DispatchContext[TransportContext] + + +@pytest.mark.anyio +async def test_concurrent_send_raw_requests_correlate_by_id_when_responses_arrive_out_of_order(): + release_first = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await release_first.wait() + return {"m": method} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + results: dict[str, dict[str, Any]] = {} + + async def call(method: str) -> None: + results[method] = await client.send_raw_request(method, None) + + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + tg.start_soon(call, "first") + await anyio.sleep(0) + tg.start_soon(call, "second") + await anyio.sleep(0) + # second resolves while first is still parked + assert "first" not in results + release_first.set() + assert results == {"first": {"m": "first"}, "second": {"m": "second"}} + + +@pytest.mark.anyio +async def test_handler_raising_exception_sends_code_zero_with_str_message(): + """Matches the existing server's `_handle_request`: code=0, message=str(e).""" + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("kaboom") + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("tools/list", None) + assert exc.value.error.code == 0 + assert exc.value.error.message == "kaboom" + assert exc.value.__cause__ is None # cause does not survive the wire + + +@pytest.mark.anyio +async def test_peer_cancel_interrupt_mode_writes_cancelled_error_response(): + """Matches the existing server: a peer-cancelled request is answered with code=0.""" + handler_started = anyio.Event() + handler_exited = anyio.Event() + seen_ctx: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen_ctx.append(ctx) + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_exited.set() + raise NotImplementedError + + seen_error: list[ErrorData] = [] + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call_then_record() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("slow", None) + seen_error.append(exc.value.error) + + tg.start_soon(call_then_record) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await handler_exited.wait() + assert seen_ctx[0].cancel_requested.is_set() + assert seen_error == [ErrorData(code=0, message="Request cancelled")] + + +@pytest.mark.anyio +async def test_peer_cancel_signal_mode_sets_event_but_handler_runs_to_completion(): + handler_started = anyio.Event() + cancel_seen = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await ctx.cancel_requested.wait() + cancel_seen.set() + return {"finished": True} + + def factory(*, can_send_request: bool = True): + client, server, close = jsonrpc_pair(can_send_request=can_send_request) + # Reach in to set signal mode on the server side. + assert isinstance(server, JSONRPCDispatcher) + server._peer_cancel_mode = "signal" # pyright: ignore[reportPrivateUsage] + return client, server, close + + result_box: list[dict[str, Any]] = [] + async with running_pair(factory, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + async with anyio.create_task_group() as tg: # pragma: no branch + + async def call() -> None: + result_box.append(await client.send_raw_request("slow", None)) + + tg.start_soon(call) + await handler_started.wait() + await client.notify("notifications/cancelled", {"requestId": 1}) + await cancel_seen.wait() + assert result_box == [{"finished": True}] + + +@pytest.mark.anyio +async def test_send_raw_request_raises_connection_closed_when_read_stream_eofs_mid_await(): + """A blocked send_raw_request is woken with CONNECTION_CLOSED when run() exits.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + + async def caller() -> None: + with pytest.raises(MCPError) as exc: + await client.send_raw_request("ping", None) + assert exc.value.error.code == CONNECTION_CLOSED + + tg.start_soon(caller) + await anyio.sleep(0) + # No server: simulate the peer dropping by closing the read side. + s2c_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_run_returns_cleanly_when_read_stream_receive_end_is_closed(): + """Iterating a closed receive end raises ClosedResourceError; run() treats it as EOF. + + Stateless SHTTP teardown closes the dispatcher's receive end after the + request is handled; the next loop iteration must not surface as a crash. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + # Close the dispatcher's own receive end (not the send end) before run() + # iterates it: __anext__ on a closed stream raises ClosedResourceError. + c2s_recv.close() + with anyio.fail_after(5): + await server.run(on_request, on_notify) + for s in (c2s_send, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_run_cancels_in_flight_handlers_when_read_stream_eofs(): + """A handler that outlives its caller must not keep run() from returning. + + Without the cancel-at-EOF, the task-group join would wait on this handler + forever (over SSE that leaks the handler task and the GET request hosting + the session). + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + handler_started = anyio.Event() + handler_cancelled = anyio.Event() + + async def park(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + try: + await anyio.sleep_forever() + finally: + handler_cancelled.set() + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + run_returned = anyio.Event() + + async def drive() -> None: + await server.run(park, on_notify) + run_returned.set() + + async with anyio.create_task_group() as tg: + tg.start_soon(drive) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None))) + with anyio.fail_after(5): + await handler_started.wait() + c2s_send.close() # EOF the read side; run() must cancel the parked handler + await run_returned.wait() + assert handler_cancelled.is_set() + s2c_recv.close() + + +@pytest.mark.anyio +async def test_run_closes_write_stream_on_exit(): + """run() enters both streams; the write end is released on EOF.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + c2s_send.close() # EOF the read side; run() exits + with anyio.fail_after(5), pytest.raises(anyio.EndOfStream): # pragma: no branch + # Write end was entered and released by run(); peer's receive sees EOF. + await s2c_recv.receive() + s2c_recv.close() + + +@pytest.mark.anyio +async def test_late_response_after_timeout_is_dropped_without_crashing(): + handler_started = anyio.Event() + proceed = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + handler_started.set() + await proceed.wait() + return {"late": True} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await client.send_raw_request("slow", None, {"timeout": 0}) + # The server handler is still running; let it finish and write a + # response for an id the client has already discarded. + await handler_started.wait() + proceed.set() + # One more round-trip proves the dispatcher is still healthy. + assert await client.send_raw_request("ping", None) == {"late": True} + + +@pytest.mark.anyio +async def test_raise_handler_exceptions_true_propagates_out_of_run(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + + def builder(_meta: object) -> TransportContext: + return TransportContext(kind="jsonrpc", can_send_request=True) + + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, transport_builder=builder, raise_handler_exceptions=True + ) + + async def boom(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise RuntimeError("propagate me") + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + with pytest.raises(BaseException) as exc: + async with anyio.create_task_group() as tg: + await tg.start(server.run, boom, on_notify) + # Inject a request directly onto the server's read stream. + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="x", params=None)) + ) + assert exc.group_contains(RuntimeError, match="propagate me") + # The error response was still written before re-raising. + sent = s2c_recv.receive_nowait() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCError) + assert sent.message.error.code == 0 + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_send_raw_request_tags_outbound_with_server_message_metadata(): + """Server-to-client requests carry related_request_id for SHTTP routing.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + return await ctx.send_raw_request("sampling/createMessage", {"prompt": "hi"}) + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + # Kick the server with an inbound request id=7. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert isinstance(outbound.metadata, ServerMessageMetadata) + assert outbound.metadata.related_request_id == 7 + # Reply so the handler completes cleanly. + await c2s_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=outbound.message.id, result={"ok": True})) + ) + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.id == 7 + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_courtesy_cancel_on_timeout_tags_outbound_with_server_message_metadata(): + """The timeout-path `notifications/cancelled` carries the originating request id. + + Streamable-HTTP's `message_router` keys on `ServerMessageMetadata.related_request_id`; + a cancel without it would fall through to the standalone GET stream and be dropped + when no GET stream is open, so the client never learns to stop work. + """ + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + with pytest.raises(MCPError): # REQUEST_TIMEOUT + await ctx.send_raw_request("sampling/createMessage", None, {"timeout": 0}) + return {"gave_up": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=7, method="t", params=None))) + with anyio.fail_after(5): + outbound = await s2c_recv.receive() + assert isinstance(outbound, SessionMessage) + assert isinstance(outbound.message, JSONRPCRequest) + assert outbound.message.method == "sampling/createMessage" + sampling_id = outbound.message.id + # Don't respond; let the timeout fire. Next on the wire is the courtesy cancel. + with anyio.fail_after(5): + cancel = await s2c_recv.receive() + assert isinstance(cancel, SessionMessage) + assert isinstance(cancel.message, JSONRPCNotification) + assert cancel.message.method == "notifications/cancelled" + assert cancel.message.params == {"requestId": sampling_id, "reason": "timed out after 0s"} + assert isinstance(cancel.metadata, ServerMessageMetadata) + assert cancel.metadata.related_request_id == 7 + with anyio.fail_after(5): + final = await s2c_recv.receive() + assert isinstance(final, SessionMessage) + assert isinstance(final.message, JSONRPCResponse) + assert final.message.result == {"gave_up": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_ctx_message_metadata_carries_inbound_request_metadata(): + """Transport-attached metadata (HTTP request, SSE close hooks) is readable off the dispatch context.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + metadata = ServerMessageMetadata(request_context="request-scoped-data") + seen: list[MessageMetadata] = [] + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(ctx.message_metadata) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send( + SessionMessage( + message=JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/call", params=None), + metadata=metadata, + ) + ) + with anyio.fail_after(5): + await s2c_recv.receive() # response sent => the handler has run + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert len(seen) == 1 + assert seen[0] is metadata # the exact object, passed through verbatim + + +@pytest.mark.anyio +async def test_ctx_message_metadata_carries_inbound_notification_metadata(): + """Notifications get the same metadata pass-through as requests.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + metadata = ServerMessageMetadata(request_context="request-scoped-data") + seen: list[MessageMetadata] = [] + notified = anyio.Event() + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + raise NotImplementedError + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + seen.append(ctx.message_metadata) + notified.set() + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send( + SessionMessage( + message=JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized", params=None), + metadata=metadata, + ) + ) + with anyio.fail_after(5): + await notified.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert len(seen) == 1 + assert seen[0] is metadata + + +@pytest.mark.anyio +async def test_ctx_progress_with_only_progress_value_omits_total_and_message(): + received: list[tuple[float, float | None, str | None]] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + received.append((progress, total, message)) + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.25) + return {} + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("t", None, {"on_progress": on_progress}) + assert received == [(0.25, None, None)] + + +@pytest.mark.anyio +async def test_ctx_after_handler_return_reports_closed_and_drops_backchannel_traffic(): + """Once `_handle_request` closes the dctx, the back-channel guard and ops agree. + + Detached work that outlives the handler must see `can_send_request == False`, + get `NoBackChannelError` from `send_raw_request`, and have `notify`/`progress` + silently dropped rather than emitted with a stale `related_request_id`. + """ + captured: list[DCtx] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + captured.append(ctx) + assert ctx.can_send_request is True + return {} + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, _server, crec, _srec): + with anyio.fail_after(5): + await client.send_raw_request("tools/call", None, {"on_progress": on_progress}) + dctx = captured[0] + assert dctx.can_send_request is False + with pytest.raises(NoBackChannelError): + await dctx.send_raw_request("sampling/createMessage", None) + await dctx.notify("notifications/message", {"level": "info"}) + await dctx.progress(0.9) + # A second round-trip flushes any notification the server might have + # written, so an empty client recorder afterwards proves the drop. + await client.send_raw_request("ping", None) + assert crec.notifications == [] + + +@pytest.mark.anyio +async def test_progress_callback_exception_is_swallowed_and_logged(caplog: pytest.LogCaptureFixture): + """A user progress callback raising must not crash the dispatcher.""" + + async def boom(progress: float, total: float | None, message: str | None) -> None: + raise RuntimeError("progress callback boom") + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await ctx.progress(0.5) + return {"ok": True} + + opts: CallOptions = {"on_progress": boom} + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + result = await client.send_raw_request("t", None, opts) + # Request still completes; the callback's crash was swallowed. + assert result == {"ok": True} + assert "progress callback raised" in caplog.text + + +@pytest.mark.anyio +async def test_inline_methods_are_handled_before_next_message_is_dequeued(): + """A method in `inline_methods` runs to completion before subsequent + messages are dispatched, so its side effects are visible to them.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + c2s_recv, s2c_send, inline_methods=frozenset({"first"}) + ) + state = {"initialized": False} + seen: list[bool] = [] + + async def on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + if method == "first": + await anyio.lowlevel.checkpoint() + state["initialized"] = True + else: + seen.append(state["initialized"]) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + # Buffer both requests before run() reads anything. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="first", params=None))) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="second", params=None))) + c2s_send.close() + with anyio.fail_after(5): + await server.run(on_request, on_notify) + assert seen == [True] + s2c_recv.close() + + +@pytest.mark.anyio +async def test_send_raw_request_always_carries_meta_on_the_wire(): + """Outbound requests always include `params._meta` (otel injection per SEP-414). + + Caller-supplied `_meta` keys are preserved; the progress token is merged in. + """ + seen: list[Mapping[str, Any] | None] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(params) + return {} + + async def noop_progress(progress: float, total: float | None, message: str | None) -> None: + raise NotImplementedError + + opts: CallOptions = {"on_progress": noop_progress} + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5): + await client.send_raw_request("a", None) + await client.send_raw_request("b", {"x": 1, "_meta": {"k": "v"}}, opts) + # `_meta` is always present. Its contents depend on the active otel + # tracer (traceparent/tracestate may be injected), so assert presence + # and that anything beyond W3C keys is exactly what we expect. + w3c = {"traceparent", "tracestate"} + assert seen[0] is not None and seen[0].keys() == {"_meta"} + assert set(seen[0]["_meta"].keys()) <= w3c + assert seen[1] is not None and seen[1]["x"] == 1 + assert set(seen[1]["_meta"].keys()) - w3c == {"k", "progressToken"} + assert seen[1]["_meta"]["k"] == "v" + + +@pytest.mark.anyio +async def test_handler_raising_validation_error_sends_invalid_params(): + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + async with running_pair(jsonrpc_pair, server_on_request=server_on_request) as (client, *_): + with anyio.fail_after(5), pytest.raises(MCPError) as exc: + await client.send_raw_request("t", None) + assert exc.value.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + +@pytest.mark.anyio +async def test_send_raw_request_before_run_raises_runtimeerror(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + try: + with pytest.raises(RuntimeError, match="before run"): + await d.send_raw_request("ping", None) + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_transport_exception_in_read_stream_is_logged_and_dropped(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, on_request, on_notify) + await c2s_send.send(ValueError("transport hiccup")) + # Dispatcher must remain healthy after the dropped exception. + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None))) + with anyio.fail_after(5): + resp = await s2c_recv.receive() + assert isinstance(resp, SessionMessage) + assert isinstance(resp.message, JSONRPCResponse) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_notification_for_unknown_token_falls_through_to_on_notify(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/progress", {"progressToken": 999, "progress": 0.5}) + await srec.notified.wait() + assert srec.notifications == [("notifications/progress", {"progressToken": 999, "progress": 0.5})] + + +@pytest.mark.anyio +async def test_cancelled_notification_for_unknown_request_id_is_noop(): + async with running_pair(jsonrpc_pair) as (client, _server, _crec, srec): + with anyio.fail_after(5): + await client.notify("notifications/cancelled", {"requestId": 999}) + # No effect; dispatcher remains healthy. + assert await client.send_raw_request("t", None) == {"echoed": "t", "params": {}} + assert srec.notifications == [] # cancelled is fully consumed, never teed + + +_probe: contextvars.ContextVar[str] = contextvars.ContextVar("probe", default="unset") + + +@pytest.mark.anyio +@pytest.mark.parametrize("inline", [frozenset[str](), frozenset({"t"})], ids=["spawned", "inline"]) +async def test_handler_inherits_sender_contextvars(inline: frozenset[str]): + """The handler task sees contextvars set by the task that wrote into the + read stream, on both the spawned and the inline-method dispatch paths.""" + raw_send, raw_recv = anyio.create_memory_object_stream[tuple[contextvars.Context, SessionMessage | Exception]](4) + read_stream = ContextReceiveStream[SessionMessage | Exception](raw_recv) + write_send = ContextSendStream[SessionMessage | Exception](raw_send) + out_send, out_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(read_stream, out_send, inline_methods=inline) + + seen: list[str] = [] + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + seen.append(_probe.get()) + return {} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + + async def sender() -> None: + _probe.set("from-sender") + await write_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="t", params=None)) + ) + + tg.start_soon(sender) + with anyio.fail_after(5): + resp = await out_recv.receive() + assert isinstance(resp, SessionMessage) + tg.cancel_scope.cancel() + finally: + for s in (raw_send, raw_recv, out_send, out_recv): + s.close() + assert seen == ["from-sender"] + + +@pytest.mark.anyio +async def test_response_write_after_peer_drop_is_swallowed(): + """Handler completes after the write stream is closed; the dropped write doesn't crash run().""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + server: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(c2s_recv, s2c_send) + proceed = anyio.Event() + handlers_done = anyio.Event() + + async def server_on_request(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + await proceed.wait() + if method == "raise": + handlers_done.set() + raise MCPError(code=INTERNAL_ERROR, message="x") + return {"ok": True} + + async def on_notify(ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + try: + async with anyio.create_task_group() as tg: + await tg.start(server.run, server_on_request, on_notify) + await c2s_send.send(SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=1, method="ok", params=None))) + await c2s_send.send( + SessionMessage(message=JSONRPCRequest(jsonrpc="2.0", id=2, method="raise", params=None)) + ) + await anyio.sleep(0) + # Peer drops: close the receive end so the server's writes hit BrokenResourceError. + s2c_recv.close() + proceed.set() + with anyio.fail_after(5): + await handlers_done.wait() + # run() must still be healthy - close the read side to let it exit cleanly. + c2s_send.close() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_cancel_outbound_after_write_stream_closed_is_swallowed(): + """Courtesy-cancel write hits a closed stream; the error is swallowed and cancellation propagates.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](4) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + caller_done = anyio.Event() + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + caller_scope = anyio.CancelScope() + + async def caller() -> None: + with caller_scope: + await client.send_raw_request("slow", None) + caller_done.set() + + tg.start_soon(caller) + # Deterministic proof the request write completed: pull it off the wire. + with anyio.fail_after(5): + sent = await c2s_recv.receive() + assert isinstance(sent, SessionMessage) + assert isinstance(sent.message, JSONRPCRequest) + # Now safe: close the client's write end, then cancel the caller. The + # shielded `_cancel_outbound` write hits ClosedResourceError and is + # swallowed; cancellation propagates cleanly. + c2s_send.close() + caller_scope.cancel() + with anyio.fail_after(5): + await caller_done.wait() + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +def test_resolve_pending_drops_outcome_when_waiter_stream_already_closed(): + """White-box: a response for an id still in _pending but whose waiter has gone.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + recv.close() # waiter gone - send_nowait will raise BrokenResourceError + d._resolve_pending(1, {"late": True}) # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send): + s.close() + + +def test_fan_out_closed_drops_signal_when_waiter_already_has_outcome(): + """White-box: the buffer=1 invariant - WouldBlock means waiter already has an outcome.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](1) + d: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + send, recv = anyio.create_memory_object_stream[dict[str, Any] | ErrorData](1) + # Register a fake pending and pre-fill its single buffer slot. + d._pending[1] = _Pending(send=send, receive=recv) # pyright: ignore[reportPrivateUsage] + send.send_nowait({"real": "result"}) + d._fan_out_closed() # pyright: ignore[reportPrivateUsage] + # The real result is still there; the close signal was dropped. + assert recv.receive_nowait() == {"real": "result"} + assert d._pending == {} # pyright: ignore[reportPrivateUsage] + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv, send, recv): + s.close() + + +def test_outbound_metadata_with_resumption_token_returns_client_metadata(): + md = _outbound_metadata(None, {"resumption_token": "abc"}) + assert isinstance(md, ClientMessageMetadata) + assert md.resumption_token == "abc" + assert _outbound_metadata(None, None) is None + assert _outbound_metadata(None, {}) is None + + +@pytest.mark.anyio +async def test_response_with_string_id_correlates_to_int_keyed_pending_request(): + """A peer that echoes the request ID as a JSON string still resolves the waiter.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_stringly() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=str(rid), result={"ok": True})) + ) + + tg.start_soon(respond_stringly) + result = await client.send_raw_request("ping", None) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + + +@pytest.mark.anyio +async def test_progress_with_string_token_reaches_callback_for_int_keyed_request(): + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + seen: list[float] = [] + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + with anyio.fail_after(5): + + async def respond_with_string_token_progress() -> None: + out = await c2s_recv.receive() + assert isinstance(out, SessionMessage) + assert isinstance(out.message, JSONRPCRequest) + rid = out.message.id + assert isinstance(rid, int) + await s2c_send.send( + SessionMessage( + message=JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={"progressToken": str(rid), "progress": 0.5}, + ) + ) + ) + await s2c_send.send( + SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=rid, result={"ok": True})) + ) + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + seen.append(progress) + + tg.start_soon(respond_with_string_token_progress) + result = await client.send_raw_request("ping", None, {"on_progress": on_progress}) + assert result == {"ok": True} + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() + assert seen == [0.5] + + +def test_coerce_id_passes_through_non_numeric_string_and_int(): + assert _coerce_id("7") == 7 + assert _coerce_id("not-an-int") == "not-an-int" + assert _coerce_id(42) == 42 + + +@pytest.mark.anyio +async def test_jsonrpc_error_response_with_null_id_is_dropped(): + """Parse-error responses (id=null) have no waiter; they're logged and dropped.""" + c2s_send, c2s_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + s2c_send, s2c_recv = anyio.create_memory_object_stream[SessionMessage | Exception](32) + client: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher(s2c_recv, c2s_send) + on_request, on_notify = echo_handlers(Recorder()) + try: + async with anyio.create_task_group() as tg: + await tg.start(client.run, on_request, on_notify) + await s2c_send.send( + SessionMessage(message=JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=-32700, message="x"))) + ) + await anyio.sleep(0) + tg.cancel_scope.cancel() + finally: + for s in (c2s_send, c2s_recv, s2c_send, s2c_recv): + s.close() diff --git a/tests/shared/test_otel.py b/tests/shared/test_otel.py index ec7ff78cc1..a7df4c4294 100644 --- a/tests/shared/test_otel.py +++ b/tests/shared/test_otel.py @@ -10,9 +10,6 @@ pytestmark = pytest.mark.anyio -# Logfire warns about propagated trace context by default (distributed_tracing=None). -# This is expected here since we're testing cross-boundary context propagation. -@pytest.mark.filterwarnings("ignore::RuntimeWarning") async def test_client_and_server_spans(capfire: CaptureLogfire): """Verify that calling a tool produces client and server spans with correct attributes.""" server = MCPServer("test") diff --git a/tests/shared/test_peer.py b/tests/shared/test_peer.py new file mode 100644 index 0000000000..47277ec88e --- /dev/null +++ b/tests/shared/test_peer.py @@ -0,0 +1,180 @@ +"""Tests for `PeerMixin` and `Peer`. + +Each PeerMixin method is tested by wrapping a `DirectDispatcher` in `Peer`, +calling the typed method, and asserting (a) the right method+params went out +and (b) the return value is the typed result model. +""" + +from collections.abc import Mapping +from typing import Any + +import anyio +import pytest + +from mcp.shared.dispatcher import DispatchContext +from mcp.shared.peer import Peer, dump_params +from mcp.shared.transport_context import TransportContext +from mcp.types import ( + CreateMessageResult, + CreateMessageResultWithTools, + ElicitResult, + ListRootsResult, + SamplingMessage, + TextContent, + Tool, +) + +from .conftest import direct_pair +from .test_dispatcher import running_pair + +DCtx = DispatchContext[TransportContext] + + +class _Recorder: + def __init__(self, result: dict[str, Any]) -> None: + self.result = result + self.seen: list[tuple[str, Mapping[str, Any] | None]] = [] + + async def on_request(self, ctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: + self.seen.append((method, params)) + return self.result + + +@pytest.mark.anyio +async def test_peer_sample_sends_create_message_and_returns_typed_result(): + rec = _Recorder({"role": "assistant", "content": {"type": "text", "text": "hi"}, "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="hello"))], + max_tokens=10, + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["maxTokens"] == 10 + assert isinstance(result, CreateMessageResult) + assert result.model == "m" + + +@pytest.mark.anyio +async def test_peer_sample_validates_result_alias_only(): + """Peer results validate alias-only; a snake_case key from the wire is + ignored as extra, not populated by Python field name.""" + snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} + rec = _Recorder(snake) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], max_tokens=1 + ) + assert isinstance(result, CreateMessageResult) + assert result.stop_reason is None + + +@pytest.mark.anyio +async def test_peer_sample_with_tools_returns_with_tools_result(): + rec = _Recorder({"role": "assistant", "content": [{"type": "text", "text": "x"}], "model": "m"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.sample( + [SamplingMessage(role="user", content=TextContent(type="text", text="q"))], + max_tokens=5, + tools=[Tool(name="t", input_schema={"type": "object"})], + ) + method, params = rec.seen[0] + assert method == "sampling/createMessage" + assert params is not None and params["tools"][0]["name"] == "t" + assert isinstance(result, CreateMessageResultWithTools) + + +@pytest.mark.anyio +async def test_peer_elicit_form_sends_elicitation_create_with_form_params(): + rec = _Recorder({"action": "accept", "content": {"name": "Max"}}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_form("Your name?", requested_schema={"type": "object", "properties": {}}) + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "form" + assert params["message"] == "Your name?" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_elicit_url_sends_elicitation_create_with_url_params(): + rec = _Recorder({"action": "accept"}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.elicit_url("Auth needed", url="https://example.com/auth", elicitation_id="e1") + method, params = rec.seen[0] + assert method == "elicitation/create" + assert params is not None and params["mode"] == "url" + assert params["url"] == "https://example.com/auth" + assert isinstance(result, ElicitResult) + + +@pytest.mark.anyio +async def test_peer_list_roots_sends_roots_list_and_returns_typed_result(): + rec = _Recorder({"roots": [{"uri": "file:///workspace"}]}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.list_roots() + method, _ = rec.seen[0] + assert method == "roots/list" + assert isinstance(result, ListRootsResult) + assert len(result.roots) == 1 + assert str(result.roots[0].uri) == "file:///workspace" + + +@pytest.mark.anyio +async def test_peer_list_roots_with_meta_sends_meta_in_params(): + rec = _Recorder({"roots": []}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + await peer.list_roots(meta={"traceId": "t1"}) + method, params = rec.seen[0] + assert method == "roots/list" + assert params == {"_meta": {"traceId": "t1"}} + + +def test_dump_params_merges_meta_over_model_meta(): + out = dump_params(None, None) + assert out is None + out = dump_params(None, {"k": 1}) + assert out == {"_meta": {"k": 1}} + + +@pytest.mark.anyio +async def test_peer_notify_forwards_to_wrapped_outbound(): + sent: list[tuple[str, Mapping[str, Any] | None]] = [] + + class _Out: + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: Any = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + sent.append((method, params)) + + await Peer(_Out()).notify("n", {"x": 1}) + assert sent == [("n", {"x": 1})] + + +@pytest.mark.anyio +async def test_peer_ping_sends_ping_and_returns_none(): + rec = _Recorder({}) + async with running_pair(direct_pair, server_on_request=rec.on_request) as (client, *_): + peer = Peer(client) + with anyio.fail_after(5): + result = await peer.ping() + method, _ = rec.seen[0] + assert method == "ping" + assert result is None diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py deleted file mode 100644 index aad9e5d439..0000000000 --- a/tests/shared/test_progress_notifications.py +++ /dev/null @@ -1,264 +0,0 @@ -from typing import Any -from unittest.mock import patch - -import anyio -import pytest - -from mcp import Client, types -from mcp.client.session import ClientSession -from mcp.server import Server, ServerRequestContext -from mcp.server.lowlevel import NotificationOptions -from mcp.server.models import InitializationOptions -from mcp.server.session import ServerSession -from mcp.shared.message import SessionMessage -from mcp.shared.session import RequestResponder - - -@pytest.mark.anyio -async def test_bidirectional_progress_notifications(): - """Test that both client and server can send progress notifications.""" - # Create memory streams for client/server - server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](5) - client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](5) - - # Run a server session so we can send progress updates in tool - async def run_server(): - # Create a server session - async with ServerSession( - client_to_server_receive, - server_to_client_send, - InitializationOptions( - server_name="ProgressTestServer", - server_version="0.1.0", - capabilities=server.get_capabilities(NotificationOptions(), {}), - ), - ) as server_session: - async for message in server_session.incoming_messages: - try: - await server._handle_message(message, server_session, {}) - except Exception as e: # pragma: no cover - raise e - - # Track progress updates - server_progress_updates: list[dict[str, Any]] = [] - client_progress_updates: list[dict[str, Any]] = [] - - # Progress tokens - server_progress_token = "server_token_123" - client_progress_token = "client_token_456" - - # Register progress handler - async def handle_progress(ctx: ServerRequestContext, params: types.ProgressNotificationParams) -> None: - server_progress_updates.append( - { - "token": params.progress_token, - "progress": params.progress, - "total": params.total, - "message": params.message, - } - ) - - # Register list tool handler - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="test_tool", - description="A tool that sends progress notifications types.CallToolResult: - # Make sure we received a progress token - if params.name == "test_tool": - assert params.meta is not None - progress_token = params.meta.get("progress_token") - assert progress_token is not None - assert progress_token == client_progress_token - - # Send progress notifications using ctx.session - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=0.25, - total=1.0, - message="Server progress 25%", - ) - - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=0.5, - total=1.0, - message="Server progress 50%", - ) - - await ctx.session.send_progress_notification( - progress_token=progress_token, - progress=1.0, - total=1.0, - message="Server progress 100%", - ) - - return types.CallToolResult(content=[types.TextContent(type="text", text="Tool executed successfully")]) - - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - - # Create a server with progress capability - server = Server( - name="ProgressTestServer", - on_progress=handle_progress, - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - - # Client message handler to store progress notifications - async def handle_client_message( - message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, - ) -> None: - if isinstance(message, Exception): # pragma: no cover - raise message - - if isinstance(message, types.ServerNotification): # pragma: no branch - if isinstance(message, types.ProgressNotification): # pragma: no branch - params = message.params - client_progress_updates.append( - { - "token": params.progress_token, - "progress": params.progress, - "total": params.total, - "message": params.message, - } - ) - - # Test using client - async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - message_handler=handle_client_message, - ) as client_session, - anyio.create_task_group() as tg, - ): - # Start the server in a background task - tg.start_soon(run_server) - - # Initialize the client connection - await client_session.initialize() - - # Call list_tools with progress token - await client_session.list_tools() - - # Call test_tool with progress token - await client_session.call_tool("test_tool", meta={"progress_token": client_progress_token}) - - # Send progress notifications from client to server - await client_session.send_progress_notification( - progress_token=server_progress_token, - progress=0.33, - total=1.0, - message="Client progress 33%", - ) - - await client_session.send_progress_notification( - progress_token=server_progress_token, - progress=0.66, - total=1.0, - message="Client progress 66%", - ) - - await client_session.send_progress_notification( - progress_token=server_progress_token, - progress=1.0, - total=1.0, - message="Client progress 100%", - ) - - # Wait and exit - await anyio.sleep(0.5) - tg.cancel_scope.cancel() - - # Verify client received progress updates from server - assert len(client_progress_updates) == 3 - assert client_progress_updates[0]["token"] == client_progress_token - assert client_progress_updates[0]["progress"] == 0.25 - assert client_progress_updates[0]["message"] == "Server progress 25%" - assert client_progress_updates[2]["progress"] == 1.0 - - # Verify server received progress updates from client - assert len(server_progress_updates) == 3 - assert server_progress_updates[0]["token"] == server_progress_token - assert server_progress_updates[0]["progress"] == 0.33 - assert server_progress_updates[0]["message"] == "Client progress 33%" - assert server_progress_updates[2]["progress"] == 1.0 - - -@pytest.mark.anyio -async def test_progress_callback_exception_logging(): - """Test that exceptions in progress callbacks are logged and \ - don't crash the session.""" - # Track logged warnings - logged_errors: list[str] = [] - - def mock_log_exception(msg: str, *args: Any, **kwargs: Any) -> None: - logged_errors.append(msg % args if args else msg) - - # Create a progress callback that raises an exception - async def failing_progress_callback(progress: float, total: float | None, message: str | None) -> None: - raise ValueError("Progress callback failed!") - - # Create a server with a tool that sends progress notifications - async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: - if params.name == "progress_tool": - assert ctx.request_id is not None - # Send a progress notification - await ctx.session.send_progress_notification( - progress_token=ctx.request_id, - progress=50.0, - total=100.0, - message="Halfway done", - ) - return types.CallToolResult(content=[types.TextContent(type="text", text="progress_result")]) - raise ValueError(f"Unknown tool: {params.name}") # pragma: no cover - - async def handle_list_tools( - ctx: ServerRequestContext, params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - return types.ListToolsResult( - tools=[ - types.Tool( - name="progress_tool", - description="A tool that sends progress notifications", - input_schema={}, - ) - ] - ) - - server = Server( - name="TestProgressServer", - on_call_tool=handle_call_tool, - on_list_tools=handle_list_tools, - ) - - # Test with mocked logging - with patch("mcp.shared.session.logging.exception", side_effect=mock_log_exception): - async with Client(server) as client: - # Call tool with a failing progress callback - result = await client.call_tool( - "progress_tool", - arguments={}, - progress_callback=failing_progress_callback, - ) - - # Verify the request completed successfully despite the callback failure - assert len(result.content) == 1 - content = result.content[0] - assert isinstance(content, types.TextContent) - assert content.text == "progress_result" - - # Check that a warning was logged for the progress callback exception - assert len(logged_errors) > 0 - assert any("Progress callback raised an exception" in warning for warning in logged_errors) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index d7c6cc3b5f..8a53b0819d 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -23,19 +23,6 @@ ) -@pytest.mark.anyio -async def test_in_flight_requests_cleared_after_completion(): - """Verify that _in_flight is empty after all requests complete.""" - server = Server(name="test server") - async with Client(server) as client: - # Send a request and wait for response - response = await client.send_ping() - assert isinstance(response, EmptyResult) - - # Verify _in_flight is empty - assert len(client.session._in_flight) == 0 - - @pytest.mark.anyio async def test_request_cancellation(): """Test that requests can be cancelled while in-flight.""" diff --git a/uv.lock b/uv.lock index df63607f40..1a0ea56b45 100644 --- a/uv.lock +++ b/uv.lock @@ -881,6 +881,7 @@ dev = [ { name = "inline-snapshot" }, { name = "logfire" }, { name = "mcp", extra = ["cli", "ws"] }, + { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "pyright" }, { name = "pytest" }, @@ -903,7 +904,8 @@ docs = [ [package.metadata] requires-dist = [ - { name = "anyio", specifier = ">=4.9" }, + { name = "anyio", marker = "python_full_version < '3.14'", specifier = ">=4.9" }, + { name = "anyio", marker = "python_full_version >= '3.14'", specifier = ">=4.10" }, { name = "httpx", specifier = ">=0.27.1,<1.0.0" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "jsonschema", specifier = ">=4.20.0" }, @@ -933,6 +935,7 @@ dev = [ { name = "inline-snapshot", specifier = ">=0.23.0" }, { name = "logfire", specifier = ">=3.0.0" }, { name = "mcp", extras = ["cli", "ws"], editable = "." }, + { name = "opentelemetry-sdk", specifier = ">=1.39.1" }, { name = "pillow", specifier = ">=12.0" }, { name = "pyright", specifier = ">=1.1.400" }, { name = "pytest", specifier = ">=8.4.0" },