Skip to content

Commit 616476f

Browse files
authored
Bind transport sessions to the authenticated principal (#2718)
1 parent 2472563 commit 616476f

7 files changed

Lines changed: 647 additions & 42 deletions

File tree

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
import time
3-
from typing import Any
3+
from typing import Any, TypedDict
44

55
from pydantic import AnyHttpUrl
66
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
@@ -19,6 +19,30 @@ def __init__(self, auth_info: AccessToken):
1919
self.scopes = auth_info.scopes
2020

2121

22+
class AuthorizationContext(TypedDict):
23+
client_id: str
24+
issuer: str | None
25+
subject: str | None
26+
27+
28+
def authorization_context(user: AuthenticatedUser) -> AuthorizationContext:
29+
"""Identify the principal `user` represents, for transports to compare
30+
against the principal that created a session. Components the token
31+
verifier does not supply are `None`, so the comparison degrades to the
32+
remaining components.
33+
34+
See `examples/servers/simple-auth/mcp_simple_auth/token_verifier.py` for
35+
a verifier that populates `subject` and `claims` from an introspection
36+
response."""
37+
token = user.access_token
38+
issuer = (token.claims or {}).get("iss")
39+
return AuthorizationContext(
40+
client_id=token.client_id,
41+
issuer=str(issuer) if issuer is not None else None,
42+
subject=token.subject,
43+
)
44+
45+
2246
class BearerAuthBackend(AuthenticationBackend):
2347
"""Authentication backend that validates Bearer tokens using a TokenVerifier."""
2448

src/mcp/server/sse.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ async def handle_sse(request):
5050
from starlette.types import Receive, Scope, Send
5151

5252
from mcp import types
53+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
5354
from mcp.server.transport_security import (
5455
TransportSecurityMiddleware,
5556
TransportSecuritySettings,
@@ -73,6 +74,9 @@ class SseServerTransport:
7374

7475
_endpoint: str
7576
_read_stream_writers: dict[UUID, ContextSendStream[SessionMessage | Exception]]
77+
# Identity of the credential that created each session; requests for a
78+
# session must present the same credential.
79+
_session_owners: dict[UUID, AuthorizationContext]
7680
_security: TransportSecurityMiddleware
7781

7882
def __init__(self, endpoint: str, security_settings: TransportSecuritySettings | None = None) -> None:
@@ -112,19 +116,20 @@ def __init__(self, endpoint: str, security_settings: TransportSecuritySettings |
112116

113117
self._endpoint = endpoint
114118
self._read_stream_writers = {}
119+
self._session_owners = {}
115120
self._security = TransportSecurityMiddleware(security_settings)
116121
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
117122

118123
@asynccontextmanager
119124
async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
120-
if scope["type"] != "http": # pragma: no cover
125+
if scope["type"] != "http":
121126
logger.error("connect_sse received non-HTTP request")
122127
raise ValueError("connect_sse can only handle HTTP requests")
123128

124129
# Validate request headers for DNS rebinding protection
125130
request = Request(scope, receive)
126131
error_response = await self._security.validate_request(request, is_post=False)
127-
if error_response: # pragma: no cover
132+
if error_response:
128133
await error_response(scope, receive, send)
129134
raise ValueError("Request validation failed")
130135

@@ -134,6 +139,9 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
134139
write_stream, write_stream_reader = create_context_streams[SessionMessage](0)
135140

136141
session_id = uuid4()
142+
user = scope.get("user")
143+
if isinstance(user, AuthenticatedUser):
144+
self._session_owners[session_id] = authorization_context(user)
137145
self._read_stream_writers[session_id] = read_stream_writer
138146
logger.debug(f"Created new session with ID: {session_id}")
139147

@@ -169,35 +177,38 @@ async def sse_writer():
169177
}
170178
)
171179

172-
async with anyio.create_task_group() as tg:
173-
174-
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
175-
"""The EventSourceResponse returning signals a client close / disconnect.
176-
In this case we close our side of the streams to signal the client that
177-
the connection has been closed.
178-
"""
179-
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
180-
scope, receive, send
181-
)
182-
await sse_stream_reader.aclose()
183-
await read_stream_writer.aclose()
184-
await write_stream_reader.aclose()
185-
self._read_stream_writers.pop(session_id, None)
186-
logging.debug(f"Client session disconnected {session_id}")
180+
try:
181+
async with anyio.create_task_group() as tg:
182+
183+
async def response_wrapper(scope: Scope, receive: Receive, send: Send):
184+
"""The EventSourceResponse returning signals a client close / disconnect.
185+
In this case we close our side of the streams to signal the client that
186+
the connection has been closed.
187+
"""
188+
await EventSourceResponse(content=sse_stream_reader, data_sender_callable=sse_writer)(
189+
scope, receive, send
190+
)
191+
await read_stream_writer.aclose()
192+
await write_stream_reader.aclose()
193+
await sse_stream_reader.aclose()
194+
logging.debug(f"Client session disconnected {session_id}")
187195

188-
logger.debug("Starting SSE response task")
189-
tg.start_soon(response_wrapper, scope, receive, send)
196+
logger.debug("Starting SSE response task")
197+
tg.start_soon(response_wrapper, scope, receive, send)
190198

191-
logger.debug("Yielding read and write streams")
192-
yield (read_stream, write_stream)
199+
logger.debug("Yielding read and write streams")
200+
yield (read_stream, write_stream)
201+
finally:
202+
self._read_stream_writers.pop(session_id, None)
203+
self._session_owners.pop(session_id, None)
193204

194205
async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) -> None:
195206
logger.debug("Handling POST message")
196207
request = Request(scope, receive)
197208

198209
# Validate request headers for DNS rebinding protection
199210
error_response = await self._security.validate_request(request, is_post=True)
200-
if error_response: # pragma: no cover
211+
if error_response:
201212
return await error_response(scope, receive, send)
202213

203214
session_id_param = request.query_params.get("session_id")
@@ -220,13 +231,22 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
220231
response = Response("Could not find session", status_code=404)
221232
return await response(scope, receive, send)
222233

234+
user = scope.get("user")
235+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
236+
if requestor != self._session_owners.get(session_id):
237+
# A session can only be used with the credential that created it.
238+
# Respond exactly as if the session did not exist.
239+
logger.warning("Rejecting message for session %s: credential does not match", session_id)
240+
response = Response("Could not find session", status_code=404)
241+
return await response(scope, receive, send)
242+
223243
body = await request.body()
224244
logger.debug(f"Received JSON: {body}")
225245

226246
try:
227247
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
228248
logger.debug(f"Validated client message: {message}")
229-
except ValidationError as err: # pragma: no cover
249+
except ValidationError as err:
230250
logger.exception("Failed to parse message")
231251
response = Response("Could not parse message", status_code=400)
232252
await response(scope, receive, send)

src/mcp/server/streamable_http_manager.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import contextlib
66
import logging
77
from collections.abc import AsyncIterator
8-
from http import HTTPStatus
98
from typing import TYPE_CHECKING, Any
109
from uuid import uuid4
1110

@@ -15,6 +14,7 @@
1514
from starlette.responses import Response
1615
from starlette.types import Receive, Scope, Send
1716

17+
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context
1818
from mcp.server.streamable_http import (
1919
MCP_SESSION_ID_HEADER,
2020
EventStore,
@@ -89,6 +89,9 @@ def __init__(
8989
# Session tracking (only used if not stateless)
9090
self._session_creation_lock = anyio.Lock()
9191
self._server_instances: dict[str, StreamableHTTPServerTransport] = {}
92+
# Identity of the credential that created each session; requests for a
93+
# session must present the same credential.
94+
self._session_owners: dict[str, AuthorizationContext] = {}
9295

9396
# The task group will be set during lifespan
9497
self._task_group = None
@@ -135,6 +138,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
135138
self._task_group = None
136139
# Clear any remaining server instances
137140
self._server_instances.clear()
141+
self._session_owners.clear()
138142

139143
async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> None:
140144
"""Process ASGI request with proper session handling and transport setup.
@@ -192,9 +196,29 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
192196
request = Request(scope, receive)
193197
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
194198

199+
user = scope.get("user")
200+
requestor = authorization_context(user) if isinstance(user, AuthenticatedUser) else None
201+
195202
# Existing session case
196203
if request_mcp_session_id is not None and request_mcp_session_id in self._server_instances:
197204
transport = self._server_instances[request_mcp_session_id]
205+
if requestor != self._session_owners.get(request_mcp_session_id):
206+
# A session can only be used with the credential that created
207+
# it. Respond exactly as if the session did not exist.
208+
logger.warning(
209+
"Rejecting request for session %s: credential does not match the one that created the session",
210+
request_mcp_session_id[:64],
211+
)
212+
body = JSONRPCError(
213+
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
214+
)
215+
response = Response(
216+
body.model_dump_json(by_alias=True, exclude_unset=True),
217+
status_code=404,
218+
media_type="application/json",
219+
)
220+
await response(scope, receive, send)
221+
return
198222
logger.debug("Session already exists, handling request directly")
199223
# Push back idle deadline on activity
200224
if transport.idle_scope is not None and self.session_idle_timeout is not None:
@@ -216,6 +240,8 @@ async def _handle_stateful_request(self, scope: Scope, receive: Receive, send: S
216240
)
217241

218242
assert http_transport.mcp_session_id is not None
243+
if requestor is not None:
244+
self._session_owners[http_transport.mcp_session_id] = requestor
219245
self._server_instances[http_transport.mcp_session_id] = http_transport
220246
logger.info(f"Created new transport with session ID: {new_session_id}")
221247

@@ -246,6 +272,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
246272
assert http_transport.mcp_session_id is not None
247273
logger.info(f"Session {http_transport.mcp_session_id} idle timeout")
248274
self._server_instances.pop(http_transport.mcp_session_id, None)
275+
self._session_owners.pop(http_transport.mcp_session_id, None)
249276
await http_transport.terminate()
250277
except Exception:
251278
logger.exception(f"Session {http_transport.mcp_session_id} crashed")
@@ -260,6 +287,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
260287
f"{http_transport.mcp_session_id} from active instances."
261288
)
262289
del self._server_instances[http_transport.mcp_session_id]
290+
self._session_owners.pop(http_transport.mcp_session_id, None)
263291

264292
# Assert task group is not None for type checking
265293
assert self._task_group is not None
@@ -273,15 +301,11 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
273301
# TODO: Align error code once spec clarifies
274302
# See: https://github.com/modelcontextprotocol/python-sdk/issues/1821
275303
logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}")
276-
error_response = JSONRPCError(
277-
jsonrpc="2.0",
278-
id=None,
279-
error=ErrorData(code=INVALID_REQUEST, message="Session not found"),
304+
body = JSONRPCError(
305+
jsonrpc="2.0", id=None, error=ErrorData(code=INVALID_REQUEST, message="Session not found")
280306
)
281307
response = Response(
282-
content=error_response.model_dump_json(by_alias=True, exclude_unset=True),
283-
status_code=HTTPStatus.NOT_FOUND,
284-
media_type="application/json",
308+
body.model_dump_json(by_alias=True, exclude_unset=True), status_code=404, media_type="application/json"
285309
)
286310
await response(scope, receive, send)
287311

src/mcp/server/transport_security.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@ def __init__(self, settings: TransportSecuritySettings | None = None):
4242

4343
def _validate_host(self, host: str | None) -> bool:
4444
"""Validate the Host header against allowed values."""
45-
if not host: # pragma: no cover
45+
if not host:
4646
logger.warning("Missing Host header in request")
4747
return False
4848

4949
# Check exact match first
50-
if host in self.settings.allowed_hosts: # pragma: no cover
50+
if host in self.settings.allowed_hosts:
5151
return True
5252

5353
# Check wildcard port patterns
5454
for allowed in self.settings.allowed_hosts:
55-
if allowed.endswith(":*"): # pragma: no branch
55+
if allowed.endswith(":*"):
5656
# Extract base host from pattern
5757
base_host = allowed[:-2]
5858
# Check if the actual host starts with base host and has a port
@@ -65,16 +65,16 @@ def _validate_host(self, host: str | None) -> bool:
6565
def _validate_origin(self, origin: str | None) -> bool:
6666
"""Validate the Origin header against allowed values."""
6767
# Origin can be absent for same-origin requests
68-
if not origin: # pragma: no cover
68+
if not origin:
6969
return True
7070

7171
# Check exact match first
72-
if origin in self.settings.allowed_origins: # pragma: no cover
72+
if origin in self.settings.allowed_origins:
7373
return True
7474

7575
# Check wildcard port patterns
7676
for allowed in self.settings.allowed_origins:
77-
if allowed.endswith(":*"): # pragma: no branch
77+
if allowed.endswith(":*"):
7878
# Extract base origin from pattern
7979
base_origin = allowed[:-2]
8080
# Check if the actual origin starts with base origin and has a port
@@ -94,7 +94,7 @@ async def validate_request(self, request: Request, is_post: bool = False) -> Res
9494
Returns None if validation passes, or an error Response if validation fails.
9595
"""
9696
# Always validate Content-Type for POST requests
97-
if is_post: # pragma: no branch
97+
if is_post:
9898
content_type = request.headers.get("content-type")
9999
if not self._validate_content_type(content_type):
100100
return Response("Invalid Content-Type header", status_code=400)

0 commit comments

Comments
 (0)