Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions crates/rmcp/src/transport/streamable_http_server/session/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
time::Duration,
};

use futures::Stream;
use futures::{Stream, StreamExt};
use thiserror::Error;
use tokio::sync::{
mpsc::{Receiver, Sender},
Expand Down Expand Up @@ -86,10 +86,17 @@ impl SessionManager for LocalSessionManager {
.get(id)
.ok_or(LocalSessionManagerError::SessionNotFound(id.clone()))?;
let receiver = handle.establish_request_wise_channel().await?;
handle
.push_message(message, receiver.http_request_id)
.await?;
Ok(ReceiverStream::new(receiver.inner))
let http_request_id = receiver.http_request_id;
handle.push_message(message, http_request_id).await?;

let priming = self.session_config.sse_retry.map(|retry| {
let event_id = match http_request_id {
Some(id) => format!("0/{id}"),
None => "0".into(),
};
ServerSseMessage::priming(event_id, retry)
});
Ok(futures::stream::iter(priming).chain(ReceiverStream::new(receiver.inner)))
}

async fn create_standalone_stream(
Expand Down Expand Up @@ -188,23 +195,29 @@ struct CachedTx {
cache: VecDeque<ServerSseMessage>,
http_request_id: Option<HttpRequestId>,
capacity: usize,
starting_index: usize,
}

impl CachedTx {
fn new(tx: Sender<ServerSseMessage>, http_request_id: Option<HttpRequestId>) -> Self {
fn new(
tx: Sender<ServerSseMessage>,
http_request_id: Option<HttpRequestId>,
starting_index: usize,
) -> Self {
Self {
cache: VecDeque::with_capacity(tx.capacity()),
capacity: tx.capacity(),
tx,
http_request_id,
starting_index,
}
}
fn new_common(tx: Sender<ServerSseMessage>) -> Self {
Self::new(tx, None)
Self::new(tx, None, 0)
}

fn next_event_id(&self) -> EventId {
let index = self.cache.back().map_or(0, |m| {
let index = self.cache.back().map_or(self.starting_index, |m| {
m.event_id
.as_deref()
.unwrap_or_default()
Expand Down Expand Up @@ -350,10 +363,15 @@ impl LocalSessionWorker {
if channel.resources.is_empty() || matches!(resource, ResourceKey::McpRequestId(_))
{
tracing::debug!(http_request_id, "close http request wise channel");
if let Some(channel) = self.tx_router.remove(&http_request_id) {
for resource in channel.resources {
if let Some(channel) = self.tx_router.get_mut(&http_request_id) {
for resource in channel.resources.drain() {
self.resource_router.remove(&resource);
}
// Replace the sender with a closed dummy so no new
// messages are routed here, but the cache stays alive
// for late resume requests.
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
channel.tx.tx = closed_tx;
}
}
} else {
Expand Down Expand Up @@ -403,13 +421,15 @@ impl LocalSessionWorker {
async fn establish_request_wise_channel(
&mut self,
) -> Result<StreamableHttpMessageReceiver, SessionError> {
self.tx_router.retain(|_, rw| !rw.tx.tx.is_closed());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, the code assumes that once a new stream is created we can discard streams that were closed (including those that experienced a disconnect). This can lead to a scenario that the stream is discarded before it was fully consumed. For example if while the client was waiting before doing the resume GET request, it performed another request. Example flow:

Time 0: Client issues req A (long running task that takes for example 10 seconds)
Time 9: Client receives disconnect and now will wait 3 second before performing GET resume request
Time 11: Client issues req B. Req A is discarded as `tx.is_closed()` 
Time 12: Cilent sends GET request to resume stream from req A. But stream is `not found`.

Additionally, there is a memory risk here. If a client doesn't create a new stream the router is not cleaned out. May lead to unnecessary extra memory consumption when there are many clients which maintain a session but are not active.

Maybe there is need for a different approach, such as cleaning out HttpRequestWise after a timeout period after it completed.

let http_request_id = self.next_http_request_id();
let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let starting_index = usize::from(self.session_config.sse_retry.is_some());
self.tx_router.insert(
http_request_id,
HttpRequestWise {
resources: Default::default(),
tx: CachedTx::new(tx, Some(http_request_id)),
tx: CachedTx::new(tx, Some(http_request_id), starting_index),
},
);
tracing::debug!(http_request_id, "establish new request wise channel");
Expand Down Expand Up @@ -521,24 +541,25 @@ impl LocalSessionWorker {
match last_event_id.http_request_id {
Some(http_request_id) => {
if let Some(request_wise) = self.tx_router.get_mut(&http_request_id) {
// Resume existing request-wise channel
let channel = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
let (tx, rx) = channel;
let was_completed = request_wise.tx.tx.is_closed();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly,tx.is_closed() doesn't necessarily mean the request completed. A disconnect from the client will cause tx.is_closed(). If there was a disconnect and the request was not completed, the stream should be left active to send messages as the request continues processing. I think there is need to add an additional completed field to HttpRequestWise which will be set at unregister_resource.

let (tx, rx) = tokio::sync::mpsc::channel(self.session_config.channel_capacity);
request_wise.tx.tx = tx;
let index = last_event_id.index;
// sync messages after index
request_wise.tx.sync(index).await?;
if was_completed {
// Close the sender after replaying so the stream ends
// instead of hanging indefinitely.
let (closed_tx, _) = tokio::sync::mpsc::channel(1);
request_wise.tx.tx = closed_tx;
}
Ok(StreamableHttpMessageReceiver {
http_request_id: Some(http_request_id),
inner: rx,
})
} else {
// Request-wise channel completed (POST response already delivered).
// The client's EventSource is reconnecting after the POST SSE stream
// ended. Fall through to common channel handling below.
tracing::debug!(
http_request_id,
"Request-wise channel completed, falling back to common channel"
"Request-wise channel not found, falling back to common channel"
);
self.resume_or_shadow_common(last_event_id.index).await
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct to fallback to resume_or_shadow_common? If there was a http_request_id provided and it is not found, shouldn't we provide an error? Seems to me that this will lead to providing messages from a different stream than the one the client expects.

From the spec:

The server MUST NOT replay messages that would have been delivered on a different stream.

}
Expand Down Expand Up @@ -1072,18 +1093,25 @@ pub struct SessionConfig {
/// Defaults to 5 minutes. Set to `None` to disable (not recommended
/// for long-running servers behind proxies).
pub keep_alive: Option<Duration>,
/// SSE retry interval for priming events on request-wise streams.
/// When set, the session layer prepends a priming event with the correct
/// stream-identifying event ID to each request-wise SSE stream.
/// Default is 3 seconds, matching `StreamableHttpServerConfig::default()`.
pub sse_retry: Option<Duration>,
}

impl SessionConfig {
pub const DEFAULT_CHANNEL_CAPACITY: usize = 16;
pub const DEFAULT_KEEP_ALIVE: Duration = Duration::from_secs(300);
pub const DEFAULT_SSE_RETRY: Duration = Duration::from_secs(3);
}

impl Default for SessionConfig {
fn default() -> Self {
Self {
channel_capacity: Self::DEFAULT_CHANNEL_CAPACITY,
keep_alive: Some(Self::DEFAULT_KEEP_ALIVE),
sse_retry: Some(Self::DEFAULT_SSE_RETRY),
}
}
}
Expand Down
12 changes: 3 additions & 9 deletions crates/rmcp/src/transport/streamable_http_server/tower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,20 +598,14 @@ where

match message {
ClientJsonRpcMessage::Request(_) => {
// Priming for request-wise streams is handled by the
// session layer (SessionManager::create_stream) which
// has access to the http_request_id for correct event IDs.
let stream = self
.session_manager
.create_stream(&session_id, message)
.await
.map_err(internal_error_response("get session"))?;
// Prepend priming event if sse_retry configured
let stream = if let Some(retry) = self.config.sse_retry {
let priming = ServerSseMessage::priming("0", retry);
futures::stream::once(async move { priming })
.chain(stream)
.left_stream()
} else {
stream.right_stream()
};
Ok(sse_stream_response(
stream,
self.config.sse_keep_alive,
Expand Down
Loading
Loading