From 08bbc2886fca34de9802aa69bc1285c2102b5cb3 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+he-yufeng@users.noreply.github.com> Date: Mon, 1 Jun 2026 04:03:54 +0800 Subject: [PATCH] fix: surface stateful HTTP session crash cause --- src/mcp/server/streamable_http.py | 25 ++++++++-- src/mcp/server/streamable_http_manager.py | 3 +- tests/server/test_streamable_http_manager.py | 51 ++++++++++++++++++++ 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index f2f4407ce..439fc9d5c 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -10,7 +10,7 @@ import re from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Awaitable, Callable -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from http import HTTPStatus from typing import Any @@ -171,6 +171,7 @@ def __init__( ] = {} self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} self._terminated = False + self._session_run_error: BaseException | None = None # Idle timeout cancel scope; managed by the session manager. self.idle_scope: anyio.CancelScope | None = None @@ -179,6 +180,16 @@ def is_terminated(self) -> bool: """Check if this transport has been explicitly terminated.""" return self._terminated + def note_session_run_error(self, exc: BaseException) -> None: + self._session_run_error = exc + + def _post_error_message(self, err: BaseException) -> str: + display_error = self._session_run_error or err + display_error_text = str(display_error) or type(display_error).__name__ + if display_error is not err and str(display_error): + display_error_text = f"{type(display_error).__name__}: {display_error_text}" + return f"Error handling POST request: {display_error_text}" + def close_sse_stream(self, request_id: RequestId) -> None: """Close SSE connection for a specific request without terminating the stream. @@ -363,6 +374,10 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None: # Remove the request stream from the mapping self._request_streams.pop(request_id, None) + async def _clean_up_post_request_stream(self, request_id: RequestId | None) -> None: + if request_id is not None: + await self._clean_up_memory_streams(request_id) + async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None: """Application entry point that handles all HTTP requests.""" request = Request(scope, receive) @@ -443,6 +458,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re writer = self._read_stream_writer if writer is None: # pragma: no cover raise ValueError("No read stream writer available. Ensure connect() is called first.") + request_id: RequestId | None = None try: # Validate Accept header if not await self._validate_accept_header(request, scope, send): @@ -637,15 +653,16 @@ async def sse_writer(): except Exception as err: logger.exception("Error handling POST request") + await self._clean_up_post_request_stream(request_id) response = self._create_error_response( - f"Error handling POST request: {err}", + self._post_error_message(err), HTTPStatus.INTERNAL_SERVER_ERROR, INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: # pragma: no cover + with suppress(anyio.BrokenResourceError, anyio.ClosedResourceError): await writer.send(Exception(err)) - return # pragma: no cover + return async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 81350a8f2..45fb2cd97 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -274,7 +274,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE self._server_instances.pop(http_transport.mcp_session_id, None) self._session_owners.pop(http_transport.mcp_session_id, None) await http_transport.terminate() - except Exception: + except Exception as exc: + http_transport.note_session_run_error(exc) logger.exception(f"Session {http_transport.mcp_session_id} crashed") finally: if ( # pragma: no branch diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index ba7554796..bf94f7701 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -209,6 +209,57 @@ async def mock_receive(): # pragma: no cover assert not manager._server_instances, "No sessions should be tracked after the only session crashes" +def test_post_error_message_prefers_session_run_error(): + transport = StreamableHTTPServerTransport(mcp_session_id="session-id", is_json_response_enabled=True) + + assert transport._post_error_message(anyio.ClosedResourceError()) == ( + "Error handling POST request: ClosedResourceError" + ) + + transport.note_session_run_error(RuntimeError("BOOM-distinctive-root-cause")) + assert transport._post_error_message(anyio.ClosedResourceError()) == ( + "Error handling POST request: RuntimeError: BOOM-distinctive-root-cause" + ) + + +@pytest.mark.anyio +async def test_stateful_json_response_includes_session_crash_cause(): + app = Server("test-crash-cause") + app.run = AsyncMock(side_effect=RuntimeError("BOOM-distinctive-root-cause")) + manager = StreamableHTTPSessionManager(app=app, json_response=True) + + sent_messages: list[Message] = [] + response_body = b"" + + async def mock_send(message: Message) -> None: + nonlocal response_body + sent_messages.append(message) + if message["type"] == "http.response.body": + response_body += message.get("body", b"") + + async def mock_receive() -> Message: + body = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + } + return {"type": "http.request", "body": json.dumps(body).encode(), "more_body": False} + + async with manager.run(): + await manager.handle_request(_request_scope(), mock_receive, mock_send) + response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start") + assert response_start["status"] == 500 + + error_data = json.loads(response_body) + assert error_data["error"]["code"] == -32603 + assert "RuntimeError: BOOM-distinctive-root-cause" in error_data["error"]["message"] + + @pytest.mark.anyio async def test_stateless_requests_memory_cleanup(): """Test that stateless requests actually clean up resources using real transports."""