Skip to content

Commit ee4f98d

Browse files
fix: preserve stdio streams in server transport
1 parent ac96f88 commit ee4f98d

2 files changed

Lines changed: 60 additions & 26 deletions

File tree

src/mcp/server/stdio.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ async def run_server():
1717
```
1818
"""
1919

20+
import os
2021
import sys
2122
from contextlib import asynccontextmanager
2223
from io import TextIOWrapper
@@ -38,10 +39,18 @@ async def stdio_server(stdin: anyio.AsyncFile[str] | None = None, stdout: anyio.
3839
# standard process handles. Encoding of stdin/stdout as text streams on
3940
# python is platform-dependent (Windows is particularly problematic), so we
4041
# re-wrap the underlying binary stream to ensure UTF-8.
42+
close_stdin = False
43+
close_stdout = False
4144
if not stdin:
42-
stdin = anyio.wrap_file(TextIOWrapper(sys.stdin.buffer, encoding="utf-8", errors="replace"))
45+
stdin_fd = os.dup(sys.stdin.fileno())
46+
stdin_buffer = os.fdopen(stdin_fd, "rb", closefd=True)
47+
stdin = anyio.wrap_file(TextIOWrapper(stdin_buffer, encoding="utf-8", errors="replace"))
48+
close_stdin = True
4349
if not stdout:
44-
stdout = anyio.wrap_file(TextIOWrapper(sys.stdout.buffer, encoding="utf-8"))
50+
stdout_fd = os.dup(sys.stdout.fileno())
51+
stdout_buffer = os.fdopen(stdout_fd, "wb", closefd=True)
52+
stdout = anyio.wrap_file(TextIOWrapper(stdout_buffer, encoding="utf-8"))
53+
close_stdout = True
4554

4655
read_stream_writer, read_stream = create_context_streams[SessionMessage | Exception](0)
4756
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
@@ -71,7 +80,13 @@ async def stdout_writer():
7180
except anyio.ClosedResourceError: # pragma: no cover
7281
await anyio.lowlevel.checkpoint()
7382

74-
async with anyio.create_task_group() as tg:
75-
tg.start_soon(stdin_reader)
76-
tg.start_soon(stdout_writer)
77-
yield read_stream, write_stream
83+
try:
84+
async with anyio.create_task_group() as tg:
85+
tg.start_soon(stdin_reader)
86+
tg.start_soon(stdout_writer)
87+
yield read_stream, write_stream
88+
finally:
89+
if close_stdin:
90+
await stdin.aclose()
91+
if close_stdout:
92+
await stdout.aclose()

tests/server/test_stdio.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import io
22
import sys
3+
import tempfile
34
import threading
45
from collections.abc import AsyncIterator
56
from contextlib import asynccontextmanager
@@ -67,33 +68,51 @@ async def test_stdio_server_round_trips_messages_over_injected_streams() -> None
6768

6869

6970
@pytest.mark.anyio
70-
async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None:
71+
async def test_stdio_server_invalid_utf8() -> None:
7172
"""Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream.
7273
7374
Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream
7475
exception; subsequent valid messages are still processed.
7576
"""
7677
# \xff\xfe are invalid UTF-8 start bytes.
7778
valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")
78-
raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
79-
80-
# Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that
81-
# stdio_server()'s default path wraps it with errors='replace'.
82-
monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8"))
83-
monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8"))
84-
85-
with anyio.fail_after(5):
86-
async with stdio_server() as (read_stream, write_stream):
87-
await write_stream.aclose()
88-
async with read_stream: # pragma: no branch
89-
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
90-
first = await read_stream.receive()
91-
assert isinstance(first, Exception)
92-
93-
# Second line: valid message still comes through
94-
second = await read_stream.receive()
95-
assert isinstance(second, SessionMessage)
96-
assert second.message == valid
79+
raw_stdin = tempfile.TemporaryFile()
80+
raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n")
81+
raw_stdin.seek(0)
82+
raw_stdout = tempfile.TemporaryFile()
83+
84+
# Replace sys.stdin/stdout with wrappers backed by real file descriptors so
85+
# stdio_server()'s default path can duplicate them without closing the
86+
# original process-level streams.
87+
original_stdin = sys.stdin
88+
original_stdout = sys.stdout
89+
test_stdin = TextIOWrapper(raw_stdin, encoding="utf-8")
90+
test_stdout = TextIOWrapper(raw_stdout, encoding="utf-8")
91+
sys.stdin = test_stdin
92+
sys.stdout = test_stdout
93+
94+
try:
95+
with anyio.fail_after(5):
96+
async with stdio_server() as (read_stream, write_stream):
97+
await write_stream.aclose()
98+
async with read_stream: # pragma: no branch
99+
# First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream
100+
first = await read_stream.receive()
101+
assert isinstance(first, Exception)
102+
103+
# Second line: valid message still comes through
104+
second = await read_stream.receive()
105+
assert isinstance(second, SessionMessage)
106+
assert second.message == valid
107+
108+
assert not sys.stdin.closed
109+
assert not sys.stdout.closed
110+
sys.stdout.write("stdio still open")
111+
finally:
112+
sys.stdin = original_stdin
113+
sys.stdout = original_stdout
114+
test_stdin.close()
115+
test_stdout.close()
97116

98117

99118
class _KeepOpenBytesIO(io.BytesIO):

0 commit comments

Comments
 (0)