Skip to content

Commit b1fbc90

Browse files
committed
Bind transport sessions to the authenticated principal
Both HTTP transports now record the principal that created each session — the OAuth client together with the issuer and subject when the token verifier supplies them — and serve subsequent requests for that session only when they present the same principal. Requests presenting a different principal receive the same 404 response as for an unknown session ID, and SSE session entries are removed when the connection ends. Servers without authentication, and authentication backends other than the built-in BearerAuthBackend, are unaffected: no principal is recorded and the comparison always passes. The new in-process SSE tests bring connect_sse, handle_post_message, and TransportSecurityMiddleware under tracked coverage, so the corresponding no-cover pragmas are removed.
1 parent 3eb5799 commit b1fbc90

7 files changed

Lines changed: 640 additions & 44 deletions

File tree

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import time
3-
from typing import Any
3+
from typing import Any, TypedDict
44

55
from pydantic import AnyHttpUrl
66
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
@@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
1919
self.scopes = auth_info.scopes
2020

2121

22+
class AuthorizationContext(TypedDict):
23+
client_id: str
24+
issuer: str | None
25+
subject: str | None
26+
27+
28+
def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
29+
"""Identify the principal `user` represents, for transports to compare
30+
against the principal that created a session. Components the token
31+
verifier does not supply are `None`, so the comparison degrades to the
32+
remaining components.
33+
34+
See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
35+
a verifier that populates `subject` and `claims` from an introspection
36+
response."""
37+
token = user.access_token
38+
issuer = (token.claims or {}).get("iss")
39+
return AuthorizationContext(
40+
client_id=token.client_id,
41+
issuer=str(issuer) if issuer is not None else None,
42+
subject=token.subject,
43+
)
44+
45+
2246
class BearerAuthBackend(AuthenticationBackend):
2347
"""Authentication backend that validates Bearer tokens using a TokenVerifier."""
2448

src/mcp/server/sse.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ async def handle_sse(request):
5050
from starlette.types import Receive, Scope, Send
5151

5252
from mcp import types
53+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
5354
from mcp.server.transport_security import (
5455
TransportSecurityMiddleware,
5556
TransportSecuritySettings,
@@ -73,6 +74,9 @@ class SseServerTransport:
7374

7475
_endpoint: str
7576
_read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]]
77+
# Identity of the credential that created each session; requests for a
78+
# session must present the same credential.
79+
_session_owners: dict[UUID, AuthorizationContext]
7680
_security: TransportSecurityMiddleware
7781

7882
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
@@ -112,11 +116,12 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
112116

113117
self._endpoint = endpoint
114118
self._read_stream_writers = {}
119+
self._session_owners = {}
115120
self._security = TransportSecurityMiddleware(security_settings)
116121
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
117122

118123
@asynccontextmanager
119-
async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # pragma: no cover
124+
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
120125
if scope["type"] != "http":
121126
logger.error("connect_sse received non-HTTP request")
122127
raise ValueError("connect_sse can only handle HTTP requests")
@@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): # prag
134139
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
135140

136141
session_id = uuid4()
142+
user = scope.get("user")
143+
if isinstance(user, AuthenticatedUser):
144+
self._session_owners[session_id] = authorization_context(user)
137145
self._read_stream_writers[session_id] = read_stream_writer
138146
logger.debug(f"Created new session with ID: {session_id}")
139147

@@ -169,28 +177,32 @@ async def sse_writer():
169177
}
170178
)
171179

172-
async with anyio.create_task_group() as tg:
173-
174-
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
175-
"""The EventSourceResponse returning signals a client close / disconnect.
176-
In this case we close our side of the streams to signal the client that
177-
the connection has been closed.
178-
"""
179-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
180-
scope, receive, send
181-
)
182-
await read_stream_writer.aclose()
183-
await write_stream_reader.aclose()
184-
self._read_stream_writers.pop(session_id, None)
185-
logging.debug(f"Client session disconnected {session_id}")
180+
try:
181+
async with anyio.create_task_group() as tg:
182+
183+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
184+
"""The EventSourceResponse returning signals a client close / disconnect.
185+
In this case we close our side of the streams to signal the client that
186+
the connection has been closed.
187+
"""
188+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
189+
scope, receive, send
190+
)
191+
await read_stream_writer.aclose()
192+
await write_stream_reader.aclose()
193+
await sse_stream_reader.aclose()
194+
logging.debug(f"Client session disconnected {session_id}")
186195

187-
logger.debug("Starting SSE response task")
188-
tg.start_soon(response_wrapper, scope, receive, send)
196+
logger.debug("Starting SSE response task")
197+
tg.start_soon(response_wrapper, scope, receive, send)
189198

190-
logger.debug("Yielding read and write streams")
191-
yield (read_stream, write_stream)
199+
logger.debug("Yielding read and write streams")
200+
yield (read_stream, write_stream)
201+
finally:
202+
self._read_stream_writers.pop(session_id, None)
203+
self._session_owners.pop(session_id, None)
192204

193-
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None: # pragma: no cover
205+
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
194206
logger.debug("Handling POST message")
195207
request = Request(scope, receive)
196208

@@ -219,6 +231,15 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
219231
response = Response("Could not find session", status_code=404)
220232
return await response(scope, receive, send)
221233

234+
user = scope.get("user")
235+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
236+
if requestor != self._session_owners.get(session_id):
237+
# A session can only be used with the credential that created it.
238+
# Respond exactly as if the session did not exist.
239+
logger.warning("Rejecting message for session %s: credential does not match", session_id)
240+
response = Response("Could not find session", status_code=404)
241+
return await response(scope, receive, send)
242+
222243
body = await request.body()
223244
logger.debug(f"Received JSON: {body}")
224245

src/mcp/server/streamable_http_manager.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import logging
77
from collections.abc import AsyncIterator
8-
from http import HTTPStatus
98
from typing import TYPE_CHECKING, Any
109
from uuid import uuid4
1110

@@ -15,6 +14,7 @@
1514
from starlette.responses import Response
1615
from starlette.types import Receive, Scope, Send
1716

17+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
1818
from mcp.server.streamable_http import (
1919
MCP_SESSION_ID_HEADER,
2020
EventStore,
@@ -89,6 +89,9 @@ def __init__(
8989
# Session tracking (only used if not stateless)
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
92+
# Identity of the credential that created each session; requests for a
93+
# session must present the same credential.
94+
self._session_owners: dict[str, AuthorizationContext] = {}
9295

9396
# The task group will be set during lifespan
9497
self._task_group = None
@@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
135138
self._task_group = None
136139
# Clear any remaining server instances
137140
self._server_instances.clear()
141+
self._session_owners.clear()
138142

139143
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140144
"""Process ASGI request with proper session handling and transport setup.
@@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
192196
request = Request(scope, receive)
193197
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
194198

199+
user = scope.get("user")
200+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
201+
195202
# Existing session case
196203
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
197204
transport = self._server_instances[request_mcp_session_id]
205+
if requestor != self._session_owners.get(request_mcp_session_id):
206+
# A session can only be used with the credential that created
207+
# it. Respond exactly as if the session did not exist.
208+
logger.warning(
209+
"Rejecting request for session %s: credential does not match the one that created the session",
210+
request_mcp_session_id[:64],
211+
)
212+
body = JSONRPCError(
213+
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
214+
)
215+
response = Response(
216+
body.model_dump_json(by_alias=True, exclude_unset=True),
217+
status_code=404,
218+
media_type="application/json",
219+
)
220+
await response(scope, receive, send)
221+
return
198222
logger.debug("Session already exists, handling request directly")
199223
# Push back idle deadline on activity
200224
if transport.idle_scope is not None and self.session_idle_timeout is not None:
@@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
216240
)
217241

218242
assert http_transport.mcp_session_id is not None
243+
if requestor is not None:
244+
self._session_owners[http_transport.mcp_session_id] = requestor
219245
self._server_instances[http_transport.mcp_session_id] = http_transport
220246
logger.info(f"Created new transport with session ID: {new_session_id}")
221247

@@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
246272
assert http_transport.mcp_session_id is not None
247273
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
248274
self._server_instances.pop(http_transport.mcp_session_id, None)
275+
self._session_owners.pop(http_transport.mcp_session_id, None)
249276
await http_transport.terminate()
250277
except Exception:
251278
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
@@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
260287
f"{http_transport.mcp_session_id} from active instances."
261288
)
262289
del self._server_instances[http_transport.mcp_session_id]
290+
self._session_owners.pop(http_transport.mcp_session_id, None)
263291

264292
# Assert task group is not None for type checking
265293
assert self._task_group is not None
@@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
273301
# TODO: Align error code once spec clarifies
274302
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
275303
logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}")
276-
error_response = JSONRPCError(
277-
jsonrpc="2.0",
278-
id=None,
279-
error=ErrorData(code=INVALID_REQUEST, message="Session not found"),
304+
body = JSONRPCError(
305+
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
280306
)
281307
response = Response(
282-
content=error_response.model_dump_json(by_alias=True, exclude_unset=True),
283-
status_code=HTTPStatus.NOT_FOUND,
284-
media_type="application/json",
308+
body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json"
285309
)
286310
await response(scope, receive, send)
287311

src/mcp/server/transport_security.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4040
# If not specified, disable DNS rebinding protection by default for backwards compatibility
4141
self.settings = settings or TransportSecuritySettings(enable_dns_rebinding_protection=False)
4242

43-
def _validate_host(self, host: str | None) -> bool: # pragma: no cover
43+
def _validate_host(self, host: str | None) -> bool:
4444
"""Validate the Host header against allowed values."""
4545
if not host:
4646
logger.warning("Missing Host header in request")
@@ -62,7 +62,7 @@ def _validate_host(self, host: str | None) -> bool: # pragma: no cover
6262
logger.warning(f"Invalid Host header: {host}")
6363
return False
6464

65-
def _validate_origin(self, origin: str | None) -> bool: # pragma: no cover
65+
def _validate_origin(self, origin: str | None) -> bool:
6666
"""Validate the Origin header against allowed values."""
6767
# Origin can be absent for same-origin requests
6868
if not origin:
@@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
9494
Returns None if validation passes, or an error Response if validation fails.
9595
"""
9696
# Always validate Content-Type for POST requests
97-
if is_post: # pragma: no branch
97+
if is_post:
9898
content_type = request.headers.get("content-type")
9999
if not self._validate_content_type(content_type):
100100
return Response("Invalid Content-Type header", status_code=400)
@@ -103,14 +103,14 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
103103
if not self.settings.enable_dns_rebinding_protection:
104104
return None
105105

106-
# Validate Host header # pragma: no cover
107-
host = request.headers.get("host") # pragma: no cover
108-
if not self._validate_host(host): # pragma: no cover
109-
return Response("Invalid Host header", status_code=421) # pragma: no cover
106+
# Validate Host header
107+
host = request.headers.get("host")
108+
if not self._validate_host(host):
109+
return Response("Invalid Host header", status_code=421)
110110

111-
# Validate Origin header # pragma: no cover
112-
origin = request.headers.get("origin") # pragma: no cover
113-
if not self._validate_origin(origin): # pragma: no cover
114-
return Response("Invalid Origin header", status_code=403) # pragma: no cover
111+
# Validate Origin header
112+
origin = request.headers.get("origin")
113+
if not self._validate_origin(origin):
114+
return Response("Invalid Origin header", status_code=403)
115115

116-
return None # pragma: no cover
116+
return None

0 commit comments

Comments
 (0)