Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any
Expand Down Expand Up @@ -171,6 +171,7 @@ def __init__(
] = {}
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
self._terminated = False
self._session_run_error: BaseException | None = None
# Idle timeout cancel scope; managed by the session manager.
self.idle_scope: anyio.CancelScope | None = None

Expand All @@ -179,6 +180,16 @@ def is_terminated(self) -> bool:
"""Check if this transport has been explicitly terminated."""
return self._terminated

def note_session_run_error(self, exc: BaseException) -> None:
self._session_run_error = exc

def _post_error_message(self, err: BaseException) -> str:
display_error = self._session_run_error or err
display_error_text = str(display_error) or type(display_error).__name__
if display_error is not err and str(display_error):
display_error_text = f"{type(display_error).__name__}: {display_error_text}"
return f"Error handling POST request: {display_error_text}"

def close_sse_stream(self, request_id: RequestId) -> None:
"""Close SSE connection for a specific request without terminating the stream.

Expand Down Expand Up @@ -363,6 +374,10 @@ async def _clean_up_memory_streams(self, request_id: RequestId) -> None:
# Remove the request stream from the mapping
self._request_streams.pop(request_id, None)

async def _clean_up_post_request_stream(self, request_id: RequestId | None) -> None:
if request_id is not None:
await self._clean_up_memory_streams(request_id)

async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Application entry point that handles all HTTP requests."""
request = Request(scope, receive)
Expand Down Expand Up @@ -443,6 +458,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
writer = self._read_stream_writer
if writer is None: # pragma: no cover
raise ValueError("No read stream writer available. Ensure connect() is called first.")
request_id: RequestId | None = None
try:
# Validate Accept header
if not await self._validate_accept_header(request, scope, send):
Expand Down Expand Up @@ -637,15 +653,16 @@ async def sse_writer():

except Exception as err:
logger.exception("Error handling POST request")
await self._clean_up_post_request_stream(request_id)
response = self._create_error_response(
f"Error handling POST request: {err}",
self._post_error_message(err),
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(scope, receive, send)
if writer: # pragma: no cover
with suppress(anyio.BrokenResourceError, anyio.ClosedResourceError):
await writer.send(Exception(err))
return # pragma: no cover
return

async def _handle_get_request(self, request: Request, send: Send) -> None:
"""Handle GET request to establish SSE.
Expand Down
3 changes: 2 additions & 1 deletion src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
self._server_instances.pop(http_transport.mcp_session_id, None)
self._session_owners.pop(http_transport.mcp_session_id, None)
await http_transport.terminate()
except Exception:
except Exception as exc:
http_transport.note_session_run_error(exc)
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
finally:
if ( # pragma: no branch
Expand Down
51 changes: 51 additions & 0 deletions tests/server/test_streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,57 @@ async def mock_receive(): # pragma: no cover
assert not manager._server_instances, "No sessions should be tracked after the only session crashes"


def test_post_error_message_prefers_session_run_error():
transport = StreamableHTTPServerTransport(mcp_session_id="session-id", is_json_response_enabled=True)

assert transport._post_error_message(anyio.ClosedResourceError()) == (
"Error handling POST request: ClosedResourceError"
)

transport.note_session_run_error(RuntimeError("BOOM-distinctive-root-cause"))
assert transport._post_error_message(anyio.ClosedResourceError()) == (
"Error handling POST request: RuntimeError: BOOM-distinctive-root-cause"
)


@pytest.mark.anyio
async def test_stateful_json_response_includes_session_crash_cause():
app = Server("test-crash-cause")
app.run = AsyncMock(side_effect=RuntimeError("BOOM-distinctive-root-cause"))
manager = StreamableHTTPSessionManager(app=app, json_response=True)

sent_messages: list[Message] = []
response_body = b""

async def mock_send(message: Message) -> None:
nonlocal response_body
sent_messages.append(message)
if message["type"] == "http.response.body":
response_body += message.get("body", b"")

async def mock_receive() -> Message:
body = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0"},
},
}
return {"type": "http.request", "body": json.dumps(body).encode(), "more_body": False}

async with manager.run():
await manager.handle_request(_request_scope(), mock_receive, mock_send)
response_start = next(msg for msg in sent_messages if msg["type"] == "http.response.start")
assert response_start["status"] == 500

error_data = json.loads(response_body)
assert error_data["error"]["code"] == -32603
assert "RuntimeError: BOOM-distinctive-root-cause" in error_data["error"]["message"]


@pytest.mark.anyio
async def test_stateless_requests_memory_cleanup():
"""Test that stateless requests actually clean up resources using real transports."""
Expand Down
Loading