Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,16 @@ required-features = [
]
path = "tests/test_streamable_http_stale_session.rs"

[[test]]
name = "test_streamable_http_connection_reuse"
required-features = [
"server",
"client",
"macros",
"schemars",
"transport-streamable-http-server",
"transport-streamable-http-client",
"transport-streamable-http-client-reqwest",
]
path = "tests/test_streamable_http_connection_reuse.rs"

Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl StreamableHttpClientTransport<reqwest::Client> {
/// This method requires the `transport-streamable-http-client-reqwest` feature.
pub fn from_uri(uri: impl Into<Arc<str>>) -> Self {
StreamableHttpClientTransport::with_client(
reqwest::Client::default(),
Self::default_http_client(),
StreamableHttpClientTransportConfig {
uri: uri.into(),
auth_header: None,
Expand All @@ -277,7 +277,19 @@ impl StreamableHttpClientTransport<reqwest::Client> {
///
/// * `config` - The config to use with this transport
pub fn from_config(config: StreamableHttpClientTransportConfig) -> Self {
StreamableHttpClientTransport::with_client(reqwest::Client::default(), config)
StreamableHttpClientTransport::with_client(Self::default_http_client(), config)
}

/// Build the default reqwest client for this transport.
///
/// Disables idle connection pooling to avoid ~40 ms stalls caused by
/// TCP Delayed ACK on Linux when the previous response body was not
/// fully consumed before the pool attempts to reuse the connection.
fn default_http_client() -> reqwest::Client {
reqwest::Client::builder()
.pool_max_idle_per_host(0)
.build()
.expect("failed to build default reqwest client")
}
}

Expand Down
118 changes: 54 additions & 64 deletions crates/rmcp/src/transport/streamable_http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,37 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
}

impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
/// Convert a raw SSE stream into a JSON-RPC message stream without
/// reconnection logic.
fn raw_sse_to_jsonrpc(
stream: BoxedSseStream,
) -> impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>> + Send + 'static
{
stream.filter_map(|event| async {
match event {
Err(e) => Some(Err(StreamableHttpError::Sse(e))),
Ok(sse) => {
let is_message =
matches!(sse.event.as_deref(), None | Some("") | Some("message"));
if !is_message {
return None;
}
let data = sse.data?;
if data.trim().is_empty() {
return None;
}
match serde_json::from_str::<ServerJsonRpcMessage>(&data) {
Ok(msg) => Some(Ok(msg)),
Err(e) => {
tracing::debug!("failed to deserialize server message: {e}");
None
}
}
}
}
})
}

async fn execute_sse_stream(
sse_stream: impl Stream<Item = Result<ServerJsonRpcMessage, StreamableHttpError<C::Error>>>
+ Send
Expand All @@ -303,14 +334,23 @@ impl<C: StreamableHttpClient> StreamableHttpClientWorker<C> {
let Some(message) = message.transpose()? else {
break;
};
let is_response = matches!(message, ServerJsonRpcMessage::Response(_));
let is_response = matches!(
message,
ServerJsonRpcMessage::Response(_) | ServerJsonRpcMessage::Error(_)
);
let yield_result = sse_worker_tx.send(message).await;
if yield_result.is_err() {
tracing::trace!("streamable http transport worker dropped, exiting");
break;
}
if close_on_response && is_response {
tracing::debug!("got response, closing sse stream");
tracing::debug!("got response, draining sse stream for connection reuse");
// Consume the remaining stream so the HTTP/1.1 connection
// returns to the pool cleanly.
let _ = tokio::time::timeout(std::time::Duration::from_millis(50), async {
while sse_stream.next().await.is_some() {}
})
.await;
break;
}
}
Expand Down Expand Up @@ -718,38 +758,12 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
if let Some(sid) = &session_id {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: sid.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
custom_headers: protocol_headers
.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
} else {
let sse_stream =
SseAutoReconnectStream::never_reconnect(
stream,
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
}
streams.spawn(Self::execute_sse_stream(
Self::raw_sse_to_jsonrpc(stream),
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
tracing::trace!("got new sse stream after re-init");
Ok(())
}
Expand All @@ -769,36 +783,12 @@ impl<C: StreamableHttpClient> Worker for StreamableHttpClientWorker<C> {
Ok(())
}
Ok(StreamableHttpPostResponse::Sse(stream, ..)) => {
if let Some(session_id) = &session_id {
let sse_stream = SseAutoReconnectStream::new(
stream,
StreamableHttpClientReconnect {
client: self.client.clone(),
session_id: session_id.clone(),
uri: config.uri.clone(),
auth_header: config.auth_header.clone(),
custom_headers: protocol_headers.clone(),
},
self.config.retry_config.clone(),
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
} else {
let sse_stream = SseAutoReconnectStream::never_reconnect(
stream,
StreamableHttpError::<C::Error>::UnexpectedEndOfStream,
);
streams.spawn(Self::execute_sse_stream(
sse_stream,
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
}
streams.spawn(Self::execute_sse_stream(
Self::raw_sse_to_jsonrpc(stream),
sse_worker_tx.clone(),
true,
transport_task_ct.child_token(),
));
tracing::trace!("got new sse stream");
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ impl LocalSessionWorker {
{
OutboundChannel::RequestWise {
id: *id,
close: false,
close: true,
}
} else {
OutboundChannel::Common
Expand All @@ -483,7 +483,7 @@ impl LocalSessionWorker {
{
OutboundChannel::RequestWise {
id: *id,
close: false,
close: true,
}
} else {
OutboundChannel::Common
Expand All @@ -501,7 +501,11 @@ impl LocalSessionWorker {
if let Some(request_wise) = self.tx_router.get_mut(&id) {
request_wise.tx.send(message).await;
if close {
self.tx_router.remove(&id);
if let Some(channel) = self.tx_router.remove(&id) {
for resource in channel.resources {
self.resource_router.remove(&resource);
}
}
}
} else {
return Err(SessionError::ChannelClosed(Some(id)));
Expand Down
122 changes: 122 additions & 0 deletions crates/rmcp/tests/test_streamable_http_connection_reuse.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#![cfg(not(feature = "local"))]

use std::time::Instant;

use rmcp::{
ServerHandler, ServiceExt,
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::{CallToolRequestParams, ClientInfo, ServerCapabilities, ServerInfo},
schemars, tool, tool_handler, tool_router,
transport::{
StreamableHttpClientTransport,
streamable_http_client::StreamableHttpClientTransportConfig,
streamable_http_server::{
StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager,
},
},
};
use tokio_util::sync::CancellationToken;

#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct SumRequest {
a: i32,
b: i32,
}

#[derive(Debug, Clone)]
struct SumServer {
tool_router: ToolRouter<Self>,
}

impl SumServer {
fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
}

#[tool_router]
impl SumServer {
#[tool(description = "Sum two numbers")]
fn sum(&self, Parameters(SumRequest { a, b }): Parameters<SumRequest>) -> String {
(a + b).to_string()
}
}

#[tool_handler(router = self.tool_router)]
impl ServerHandler for SumServer {
fn get_info(&self) -> ServerInfo {
ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
}
}

/// Verify that subsequent tool calls do not regress in latency due to
/// HTTP/1.1 connection pool exhaustion. Before the fix, each POST SSE
/// response was dropped without fully consuming the body, preventing
/// connection reuse and forcing a new TCP connection (~40 ms) per call.
#[tokio::test]
async fn test_subsequent_tool_calls_reuse_connections() -> anyhow::Result<()> {
let ct = CancellationToken::new();

let service: StreamableHttpService<SumServer, LocalSessionManager> = StreamableHttpService::new(
|| Ok(SumServer::new()),
Default::default(),
StreamableHttpServerConfig::default()
.with_sse_keep_alive(None)
.with_cancellation_token(ct.child_token()),
);

let router = axum::Router::new().nest_service("/mcp", service);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let addr = listener.local_addr()?;

let server_handle = tokio::spawn({
let ct = ct.clone();
async move {
let _ = axum::serve(listener, router)
.with_graceful_shutdown(async move { ct.cancelled_owned().await })
.await;
}
});

let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(format!("http://{addr}/mcp")),
);
let client = ClientInfo::default().serve(transport).await?;

// Warm up: first call may include one-time setup costs.
let args: serde_json::Map<String, serde_json::Value> =
serde_json::from_value(serde_json::json!({"a": 1, "b": 2}))?;
let _ = client
.call_tool(CallToolRequestParams::new("sum").with_arguments(args))
.await?;

// Measure subsequent calls.
let mut durations = Vec::new();
for i in 0..5i32 {
let args: serde_json::Map<String, serde_json::Value> =
serde_json::from_value(serde_json::json!({"a": i, "b": i + 1}))?;
let start = Instant::now();
let result = client
.call_tool(CallToolRequestParams::new("sum").with_arguments(args))
.await?;
let elapsed = start.elapsed();
durations.push(elapsed);

assert!(result.is_error != Some(true));
}

let _ = client.cancel().await;
ct.cancel();
server_handle.await?;

// With connection reuse, localhost calls should complete well under 20 ms.
// Before the fix, they consistently took ~42 ms due to new TCP connections.
let max_allowed = std::time::Duration::from_millis(20);
for d in &durations {
assert!(*d < max_allowed);
}

Ok(())
}
Loading