|
1 | 1 | import io |
2 | 2 | import sys |
| 3 | +import tempfile |
3 | 4 | import threading |
4 | 5 | from collections.abc import AsyncIterator |
5 | 6 | from contextlib import asynccontextmanager |
@@ -67,33 +68,51 @@ async def test_stdio_server_round_trips_messages_over_injected_streams() -> None |
67 | 68 |
|
68 | 69 |
|
69 | 70 | @pytest.mark.anyio |
70 | | -async def test_stdio_server_invalid_utf8(monkeypatch: pytest.MonkeyPatch) -> None: |
| 71 | +async def test_stdio_server_invalid_utf8() -> None: |
71 | 72 | """Non-UTF-8 stdin bytes surface as an in-stream exception without killing the stream. |
72 | 73 |
|
73 | 74 | Invalid bytes are replaced with U+FFFD, fail JSON parsing, and arrive as an in-stream |
74 | 75 | exception; subsequent valid messages are still processed. |
75 | 76 | """ |
76 | 77 | # \xff\xfe are invalid UTF-8 start bytes. |
77 | 78 | valid = JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") |
78 | | - raw_stdin = io.BytesIO(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") |
79 | | - |
80 | | - # Replace sys.stdin with a wrapper whose .buffer is our raw bytes, so that |
81 | | - # stdio_server()'s default path wraps it with errors='replace'. |
82 | | - monkeypatch.setattr(sys, "stdin", TextIOWrapper(raw_stdin, encoding="utf-8")) |
83 | | - monkeypatch.setattr(sys, "stdout", TextIOWrapper(io.BytesIO(), encoding="utf-8")) |
84 | | - |
85 | | - with anyio.fail_after(5): |
86 | | - async with stdio_server() as (read_stream, write_stream): |
87 | | - await write_stream.aclose() |
88 | | - async with read_stream: # pragma: no branch |
89 | | - # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream |
90 | | - first = await read_stream.receive() |
91 | | - assert isinstance(first, Exception) |
92 | | - |
93 | | - # Second line: valid message still comes through |
94 | | - second = await read_stream.receive() |
95 | | - assert isinstance(second, SessionMessage) |
96 | | - assert second.message == valid |
| 79 | + raw_stdin = tempfile.TemporaryFile() |
| 80 | + raw_stdin.write(b"\xff\xfe\n" + valid.model_dump_json(by_alias=True, exclude_none=True).encode() + b"\n") |
| 81 | + raw_stdin.seek(0) |
| 82 | + raw_stdout = tempfile.TemporaryFile() |
| 83 | + |
| 84 | + # Replace sys.stdin/stdout with wrappers backed by real file descriptors so |
| 85 | + # stdio_server()'s default path can duplicate them without closing the |
| 86 | + # original process-level streams. |
| 87 | + original_stdin = sys.stdin |
| 88 | + original_stdout = sys.stdout |
| 89 | + test_stdin = TextIOWrapper(raw_stdin, encoding="utf-8") |
| 90 | + test_stdout = TextIOWrapper(raw_stdout, encoding="utf-8") |
| 91 | + sys.stdin = test_stdin |
| 92 | + sys.stdout = test_stdout |
| 93 | + |
| 94 | + try: |
| 95 | + with anyio.fail_after(5): |
| 96 | + async with stdio_server() as (read_stream, write_stream): |
| 97 | + await write_stream.aclose() |
| 98 | + async with read_stream: # pragma: no branch |
| 99 | + # First line: \xff\xfe -> U+FFFD U+FFFD -> JSON parse fails -> exception in stream |
| 100 | + first = await read_stream.receive() |
| 101 | + assert isinstance(first, Exception) |
| 102 | + |
| 103 | + # Second line: valid message still comes through |
| 104 | + second = await read_stream.receive() |
| 105 | + assert isinstance(second, SessionMessage) |
| 106 | + assert second.message == valid |
| 107 | + |
| 108 | + assert not sys.stdin.closed |
| 109 | + assert not sys.stdout.closed |
| 110 | + sys.stdout.write("stdio still open") |
| 111 | + finally: |
| 112 | + sys.stdin = original_stdin |
| 113 | + sys.stdout = original_stdout |
| 114 | + test_stdin.close() |
| 115 | + test_stdout.close() |
97 | 116 |
|
98 | 117 |
|
99 | 118 | class _KeepOpenBytesIO(io.BytesIO): |
|
0 commit comments