Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def main():
import contextvars
import logging
import warnings
from collections.abc import AsyncIterator, Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from importlib.metadata import version as importlib_version
from typing import Any, Generic, cast
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals


@asynccontextmanager
async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]:
async def lifespan(_: Server[LifespanResultT]) -> AsyncGenerator[dict[str, Any]]:
"""Default lifespan context manager that does nothing.

Returns:
Expand Down Expand Up @@ -371,6 +371,10 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
# When True, stdin/file-style EOF is treated as "no more inbound messages";
# accepted request handlers are allowed to finish and flush their responses.
drain_in_flight_on_read_eof: bool = False,
drain_in_flight_on_read_eof_timeout_seconds: float = 5.0,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
Expand All @@ -380,6 +384,7 @@ async def run(
write_stream,
initialization_options,
stateless=stateless,
close_write_stream_on_read_end=not drain_in_flight_on_read_eof,
)
)

Expand Down Expand Up @@ -408,11 +413,14 @@ async def run(
raise_exceptions,
)
finally:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()
if not drain_in_flight_on_read_eof:
# Transport closed: cancel in-flight handlers. Without this the
# TG join waits for them, and when they eventually try to
# respond they hit a closed write stream (the session's
# _receive_loop closed it when the read stream ended).
tg.cancel_scope.cancel()
else:
tg.cancel_scope.deadline = anyio.current_time() + drain_in_flight_on_read_eof_timeout_seconds

async def _handle_message(
self,
Expand Down
8 changes: 6 additions & 2 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import inspect
import json
import re
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import Any, Generic, Literal, TypeVar, overload

Expand Down Expand Up @@ -74,6 +74,8 @@

logger = get_logger(__name__)

STDIO_EOF_DRAIN_TIMEOUT_SECONDS = 5.0

_CallableT = TypeVar("_CallableT", bound=Callable[..., Any])


Expand Down Expand Up @@ -119,7 +121,7 @@ def lifespan_wrapper(
lifespan: Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]],
) -> Callable[[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]]:
@asynccontextmanager
async def wrap(_: Server[LifespanResultT]) -> AsyncIterator[LifespanResultT]:
async def wrap(_: Server[LifespanResultT]) -> AsyncGenerator[LifespanResultT]:
async with lifespan(app) as context:
yield context

Expand Down Expand Up @@ -852,6 +854,8 @@ async def run_stdio_async(self) -> None:
read_stream,
write_stream,
self._lowlevel_server.create_initialization_options(),
drain_in_flight_on_read_eof=True,
drain_in_flight_on_read_eof_timeout_seconds=STDIO_EOF_DRAIN_TIMEOUT_SECONDS,
)

async def run_sse_async( # pragma: no cover
Expand Down
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@ def __init__(
write_stream: WriteStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
close_write_stream_on_read_end: bool = True,
) -> None:
super().__init__(read_stream, write_stream)
super().__init__(
read_stream,
write_stream,
close_write_stream_on_read_end=close_write_stream_on_read_end,
)
self._stateless = stateless
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
Expand Down
13 changes: 11 additions & 2 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,11 @@ def __init__(
write_stream: WriteStream[SessionMessage],
# If none, reading will never time out
read_timeout_seconds: float | None = None,
close_write_stream_on_read_end: bool = True,
) -> None:
self._read_stream = read_stream
self._write_stream = write_stream
self._close_write_stream_on_read_end = close_write_stream_on_read_end
self._response_streams = {}
self._request_id = 0
self._session_read_timeout_seconds = read_timeout_seconds
Expand Down Expand Up @@ -234,7 +236,11 @@ async def __aexit__(
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
try:
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
finally:
if not self._close_write_stream_on_read_end:
await self._write_stream.aclose()

async def send_request(
self,
Expand Down Expand Up @@ -349,7 +355,10 @@ def _receive_notification_adapter(self) -> TypeAdapter[ReceiveNotificationT]:
raise NotImplementedError

async def _receive_loop(self) -> None:
async with self._read_stream, self._write_stream:
async with AsyncExitStack() as stack:
await stack.enter_async_context(self._read_stream)
if self._close_write_stream_on_read_end:
await stack.enter_async_context(self._write_stream)
try:

async def _handle_session_message(message: SessionMessage) -> None:
Expand Down
83 changes: 83 additions & 0 deletions tests/issues/test_2678_stdio_eof_drain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import subprocess
import sys
import textwrap
from pathlib import Path


def test_stdio_redirected_stdin_eof_drains_accepted_tool_responses(tmp_path: Path) -> None:
server_py = tmp_path / "server.py"
payload_jsonl = tmp_path / "payload.jsonl"
response_jsonl = tmp_path / "response.jsonl"

server_py.write_text(
textwrap.dedent(
"""
import asyncio

from mcp.server.mcpserver import MCPServer

mcp = MCPServer("repro")

@mcp.tool()
async def slow_echo(text: str) -> str:
await asyncio.sleep(0.05)
return text

if __name__ == "__main__":
mcp.run(transport="stdio")
"""
),
encoding="utf-8",
)
payload_jsonl.write_text(
"\n".join(
[
json.dumps(
{
"jsonrpc": "2.0",
"id": 0,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "repro", "version": "0.1"},
},
}
),
json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}}),
json.dumps(
{
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call",
"params": {"name": "slow_echo", "arguments": {"text": "first"}},
}
),
json.dumps(
{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {"name": "slow_echo", "arguments": {"text": "second"}},
}
),
]
)
+ "\n",
encoding="utf-8",
)

with payload_jsonl.open("rb") as stdin, response_jsonl.open("wb") as stdout:
completed = subprocess.run(
[sys.executable, str(server_py)],
stdin=stdin,
stdout=stdout,
stderr=subprocess.PIPE,
timeout=10,
check=False,
)

assert completed.returncode == 0, completed.stderr.decode("utf-8", errors="replace")
response_ids = {json.loads(line)["id"] for line in response_jsonl.read_text(encoding="utf-8").splitlines()}
assert {0, 1, 2}.issubset(response_ids)
Loading