From d494169fdd1222ad87d15ce3b7f5668cb257ab13 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 13:34:58 -0700 Subject: [PATCH 1/4] v3 websocket protocol --- crates/client-api-messages/DEVELOP.md | 9 ++ .../examples/get_ws_schema_v3.rs | 13 ++ crates/client-api-messages/src/websocket.rs | 1 + .../client-api-messages/src/websocket/v3.rs | 28 ++++ crates/client-api/src/routes/subscribe.rs | 132 +++++++++++++----- crates/core/src/client.rs | 1 + crates/core/src/client/client_connection.rs | 3 +- crates/core/src/client/message_handlers.rs | 1 + crates/core/src/client/message_handlers_v2.rs | 8 ++ crates/core/src/client/message_handlers_v3.rs | 32 +++++ crates/core/src/client/messages.rs | 62 +++++--- .../subscription/module_subscription_actor.rs | 2 +- 12 files changed, 237 insertions(+), 55 deletions(-) create mode 100644 crates/client-api-messages/examples/get_ws_schema_v3.rs create mode 100644 crates/client-api-messages/src/websocket/v3.rs create mode 100644 crates/core/src/client/message_handlers_v3.rs diff --git a/crates/client-api-messages/DEVELOP.md b/crates/client-api-messages/DEVELOP.md index 47868d4d3ce..48341bd80aa 100644 --- a/crates/client-api-messages/DEVELOP.md +++ b/crates/client-api-messages/DEVELOP.md @@ -19,3 +19,12 @@ spacetime generate -p spacetimedb-cli --lang \ --out-dir \ --module-def ws_schema_v2.json ``` + +For the v3 WebSocket transport schema: + +```sh +cargo run --example get_ws_schema_v3 > ws_schema_v3.json +spacetime generate -p spacetimedb-cli --lang \ + --out-dir \ + --module-def ws_schema_v3.json +``` diff --git a/crates/client-api-messages/examples/get_ws_schema_v3.rs b/crates/client-api-messages/examples/get_ws_schema_v3.rs new file mode 100644 index 00000000000..b4a752a5664 --- /dev/null +++ b/crates/client-api-messages/examples/get_ws_schema_v3.rs @@ -0,0 +1,13 @@ +use spacetimedb_client_api_messages::websocket::v3::{ClientFrame, ServerFrame}; +use spacetimedb_lib::ser::serde::SerializeWrapper; +use spacetimedb_lib::{RawModuleDef, RawModuleDefV8}; + +fn main() -> Result<(), serde_json::Error> { + let module = RawModuleDefV8::with_builder(|module| { + module.add_type::(); + module.add_type::(); + }); + let module = RawModuleDef::V8BackCompat(module); + + serde_json::to_writer(std::io::stdout().lock(), SerializeWrapper::from_ref(&module)) +} diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 0935d2e3c55..14ec394670f 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -17,3 +17,4 @@ pub mod common; pub mod v1; pub mod v2; +pub mod v3; diff --git a/crates/client-api-messages/src/websocket/v3.rs b/crates/client-api-messages/src/websocket/v3.rs new file mode 100644 index 00000000000..5be37768299 --- /dev/null +++ b/crates/client-api-messages/src/websocket/v3.rs @@ -0,0 +1,28 @@ +use bytes::Bytes; +pub use spacetimedb_sats::SpacetimeType; + +pub const BIN_PROTOCOL: &str = "v3.bsatn.spacetimedb"; + +/// Transport envelopes sent by the client over the v3 websocket protocol. +/// +/// The inner bytes are BSATN-encoded v2 [`crate::websocket::v2::ClientMessage`] values. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ClientFrame { + /// A single logical client message. + Single(Bytes), + /// Multiple logical client messages that should be processed in-order. + Batch(Box<[Bytes]>), +} + +/// Transport envelopes sent by the server over the v3 websocket protocol. +/// +/// The inner bytes are BSATN-encoded v2 [`crate::websocket::v2::ServerMessage`] values. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ServerFrame { + /// A single logical server message. + Single(Bytes), + /// Multiple logical server messages that should be processed in-order. + Batch(Box<[Bytes]>), +} diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index d1bb1d2b11f..868e7425eda 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -23,8 +23,8 @@ use prometheus::{Histogram, IntGauge}; use scopeguard::{defer, ScopeGuard}; use serde::Deserialize; use spacetimedb::client::messages::{ - serialize, serialize_v2, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, - ToProtocol, + serialize, serialize_v2, serialize_v3, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, + SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, @@ -38,6 +38,7 @@ use spacetimedb::worker_metrics::WORKER_METRICS; use spacetimedb::Identity; use spacetimedb_client_api_messages::websocket::v1 as ws_v1; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; +use spacetimedb_client_api_messages::websocket::v3 as ws_v3; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use tokio::sync::{mpsc, watch}; @@ -62,6 +63,8 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v1::TEXT_PROT pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v1::BIN_PROTOCOL); #[allow(clippy::declare_interior_mutable_const)] pub const V2_BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v2::BIN_PROTOCOL); +#[allow(clippy::declare_interior_mutable_const)] +pub const V3_BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v3::BIN_PROTOCOL); pub trait HasWebSocketOptions { fn websocket_options(&self) -> WebSocketOptions; @@ -101,7 +104,7 @@ fn resolve_confirmed_reads_default(version: WsVersion, confirmed: Option) } match version { WsVersion::V1 => false, - WsVersion::V2 => crate::DEFAULT_CONFIRMED_READS, + WsVersion::V2 | WsVersion::V3 => crate::DEFAULT_CONFIRMED_READS, } } @@ -151,6 +154,13 @@ where } let (res, ws_upgrade, protocol) = ws.select_protocol([ + ( + V3_BIN_PROTOCOL, + NegotiatedProtocol { + protocol: Protocol::Binary, + version: WsVersion::V3, + }, + ), ( V2_BIN_PROTOCOL, NegotiatedProtocol { @@ -284,7 +294,7 @@ where }; client.send_message(None, OutboundMessage::V1(message.into())) } - WsVersion::V2 => { + WsVersion::V2 | WsVersion::V3 => { let message = ws_v2::ServerMessage::InitialConnection(ws_v2::InitialConnection { identity: client_identity, connection_id, @@ -1293,10 +1303,15 @@ async fn ws_encode_task( // copied to the wire. Since we don't know when that will happen, we prepare // for a few messages to be in-flight, i.e. encoded but not yet sent. const BUF_POOL_CAPACITY: usize = 16; + let binary_message_serializer = match config.version { + WsVersion::V1 => None, + WsVersion::V2 => Some(serialize_v2 as BinarySerializeFn), + WsVersion::V3 => Some(serialize_v3 as BinarySerializeFn), + }; let buf_pool = ArrayQueue::new(BUF_POOL_CAPACITY); let mut in_use_bufs: Vec> = Vec::with_capacity(BUF_POOL_CAPACITY); - while let Some(message) = messages.recv().await { + 'send: while let Some(message) = messages.recv().await { // Drop serialize buffers with no external referent, // returning them to the pool. in_use_bufs.retain(|in_use| !in_use.is_unique()); @@ -1306,16 +1321,22 @@ async fn ws_encode_task( let in_use_buf = match message { OutboundWsMessage::Error(message) => { - if config.version == WsVersion::V2 { - log::error!("dropping v1 error message sent to a v2 client: {:?}", message); + if config.version != WsVersion::V1 { + log::error!( + "dropping v1 error message sent to a binary websocket client: {:?}", + message + ); continue; } - let (stats, in_use, mut frames) = ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await; - metrics.report(None, None, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + None, + None, + ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await, + ) else { + break 'send; + }; in_use } OutboundWsMessage::Message(message) => { @@ -1323,38 +1344,47 @@ async fn ws_encode_task( let num_rows = message.num_rows(); match message { OutboundMessage::V2(server_message) => { - if config.version != WsVersion::V2 { + if config.version == WsVersion::V1 { log::error!("dropping v2 message on v1 connection"); continue; } - let (stats, in_use, mut frames) = - ws_encode_message_v2(config, buf, server_message, false, &bsatn_rlb_pool).await; - metrics.report(workload, num_rows, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + workload, + num_rows, + ws_encode_binary_message( + config, + buf, + server_message, + binary_message_serializer.expect("v2 message should not be sent on a v1 connection"), + false, + &bsatn_rlb_pool, + ) + .await, + ) else { + break 'send; + }; in_use } OutboundMessage::V1(message) => { - if config.version == WsVersion::V2 { - log::error!( - "dropping v1 message for v2 connection until v2 serialization is implemented: {:?}", - message - ); + if config.version != WsVersion::V1 { + log::error!("dropping v1 message for a binary websocket connection: {:?}", message); continue; } let is_large = num_rows.is_some_and(|n| n > 1024); - let (stats, in_use, mut frames) = - ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await; - metrics.report(workload, num_rows, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + workload, + num_rows, + ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await, + ) else { + break 'send; + }; in_use } } @@ -1370,6 +1400,24 @@ async fn ws_encode_task( } } +/// Reports encode metrics for an already-encoded message and forwards all of +/// its frames to the websocket send task. +fn ws_forward_frames( + metrics: &SendMetrics, + outgoing_frames: &mpsc::UnboundedSender, + workload: Option, + num_rows: Option, + encoded: (EncodeMetrics, InUseSerializeBuffer, I), +) -> Result> +where + I: Iterator, +{ + let (stats, in_use, frames) = encoded; + metrics.report(workload, num_rows, stats); + frames.into_iter().try_for_each(|frame| outgoing_frames.send(frame))?; + Ok(in_use) +} + /// Some stats about serialization and compression. /// /// Returned by [`ws_encode_message`]. @@ -1443,21 +1491,29 @@ async fn ws_encode_message( (metrics, msg_alloc, frames) } -#[allow(dead_code, unused_variables)] -async fn ws_encode_message_v2( +type BinarySerializeFn = fn( + &BsatnRowListBuilderPool, + SerializeBuffer, + ws_v2::ServerMessage, + ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes); + +async fn ws_encode_binary_message( config: ClientConfig, buf: SerializeBuffer, message: ws_v2::ServerMessage, + serialize_message: BinarySerializeFn, is_large_message: bool, bsatn_rlb_pool: &BsatnRowListBuilderPool, ) -> (EncodeMetrics, InUseSerializeBuffer, impl Iterator + use<>) { let start = Instant::now(); + let compression = config.compression; let (in_use, data) = if is_large_message { let bsatn_rlb_pool = bsatn_rlb_pool.clone(); - spawn_rayon(move || serialize_v2(&bsatn_rlb_pool, buf, message, config.compression)).await + spawn_rayon(move || serialize_message(&bsatn_rlb_pool, buf, message, compression)).await } else { - serialize_v2(bsatn_rlb_pool, buf, message, config.compression) + serialize_message(bsatn_rlb_pool, buf, message, compression) }; let metrics = EncodeMetrics { @@ -2298,9 +2354,11 @@ mod tests { #[test] fn confirmed_reads_default_depends_on_ws_version() { + assert!(resolve_confirmed_reads_default(WsVersion::V3, None)); assert!(resolve_confirmed_reads_default(WsVersion::V2, None)); assert!(!resolve_confirmed_reads_default(WsVersion::V1, None)); assert!(resolve_confirmed_reads_default(WsVersion::V1, Some(true))); + assert!(!resolve_confirmed_reads_default(WsVersion::V3, Some(false))); assert!(!resolve_confirmed_reads_default(WsVersion::V2, Some(false))); } diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index cad4f79adcf..4411192c625 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -7,6 +7,7 @@ pub mod consume_each_list; mod message_handlers; mod message_handlers_v1; mod message_handlers_v2; +mod message_handlers_v3; pub mod messages; pub use client_connection::{ diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 6fb8d8e1623..0a7a7f1a11b 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -47,6 +47,7 @@ pub enum Protocol { pub enum WsVersion { V1, V2, + V3, } impl Protocol { @@ -384,7 +385,7 @@ impl ClientConnectionSender { debug_assert!( matches!( (&self.config.version, &message), - (WsVersion::V1, OutboundMessage::V1(_)) | (WsVersion::V2, OutboundMessage::V2(_)) + (WsVersion::V1, OutboundMessage::V1(_)) | (WsVersion::V2 | WsVersion::V3, OutboundMessage::V2(_)) ), "attempted to send message variant that does not match client websocket version" ); diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index 76f5fa53afa..fb85730c11c 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -23,5 +23,6 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst match client.config.version { WsVersion::V1 => super::message_handlers_v1::handle(client, message, timer).await, WsVersion::V2 => super::message_handlers_v2::handle(client, message, timer).await, + WsVersion::V3 => super::message_handlers_v3::handle(client, message, timer).await, } } diff --git a/crates/core/src/client/message_handlers_v2.rs b/crates/core/src/client/message_handlers_v2.rs index 5dd2f80d01b..2db523e472d 100644 --- a/crates/core/src/client/message_handlers_v2.rs +++ b/crates/core/src/client/message_handlers_v2.rs @@ -20,6 +20,14 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst ))) } }; + handle_decoded_message(client, message, timer).await +} + +pub(super) async fn handle_decoded_message( + client: &ClientConnection, + message: ws_v2::ClientMessage, + timer: Instant, +) -> Result<(), MessageHandleError> { let module = client.module(); let mod_info = module.info(); let mod_metrics = &mod_info.metrics; diff --git a/crates/core/src/client/message_handlers_v3.rs b/crates/core/src/client/message_handlers_v3.rs new file mode 100644 index 00000000000..696e7337ed0 --- /dev/null +++ b/crates/core/src/client/message_handlers_v3.rs @@ -0,0 +1,32 @@ +use super::{ClientConnection, DataMessage, MessageHandleError}; +use serde::de::Error as _; +use spacetimedb_client_api_messages::websocket::{v2 as ws_v2, v3 as ws_v3}; +use spacetimedb_lib::bsatn; +use std::time::Instant; + +pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Instant) -> Result<(), MessageHandleError> { + client.observe_websocket_request_message(&message); + let frame = match message { + DataMessage::Binary(message_buf) => bsatn::from_slice::(&message_buf)?, + DataMessage::Text(_) => { + return Err(MessageHandleError::TextDecode(serde_json::Error::custom( + "v3 websocket does not support text messages", + ))) + } + }; + + match frame { + ws_v3::ClientFrame::Single(message) => { + let message = bsatn::from_slice::(&message)?; + super::message_handlers_v2::handle_decoded_message(client, message, timer).await?; + } + ws_v3::ClientFrame::Batch(messages) => { + for message in messages { + let message = bsatn::from_slice::(&message)?; + super::message_handlers_v2::handle_decoded_message(client, message, timer).await?; + } + } + } + + Ok(()) +} diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index ed65e092d0e..38c5fadb260 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -10,6 +10,7 @@ use derive_more::From; use spacetimedb_client_api_messages::websocket::common::{self as ws_common, RowListLen as _}; use spacetimedb_client_api_messages::websocket::v1::{self as ws_v1}; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; +use spacetimedb_client_api_messages::websocket::v3 as ws_v3; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; @@ -97,6 +98,20 @@ impl SerializeBuffer { } } +fn finalize_binary_serialize_buffer( + buffer: SerializeBuffer, + uncompressed_len: usize, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + match decide_compression(uncompressed_len, compression) { + ws_v1::Compression::None => buffer.uncompressed(), + ws_v1::Compression::Brotli => { + buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) + } + ws_v1::Compression::Gzip => buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress), + } +} + type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>; pub enum InUseSerializeBuffer { @@ -159,21 +174,14 @@ pub fn serialize( let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { bsatn::to_writer(w.into_inner(), &msg).unwrap() }); + let srv_msg_len = srv_msg.len(); // At this point, we no longer have a use for `msg`, // so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); // Conditionally compress the message. - let (in_use, msg_bytes) = match decide_compression(srv_msg.len(), config.compression) { - ws_v1::Compression::None => buffer.uncompressed(), - ws_v1::Compression::Brotli => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) - } - ws_v1::Compression::Gzip => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress) - } - }; + let (in_use, msg_bytes) = finalize_binary_serialize_buffer(buffer, srv_msg_len, config.compression); (in_use, msg_bytes.into()) } } @@ -192,18 +200,40 @@ pub fn serialize_v2( let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { bsatn::to_writer(w.into_inner(), &msg).expect("should be able to bsatn encode v2 message"); }); + let srv_msg_len = srv_msg.len(); // At this point, we no longer have a use for `msg`, // so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); - match decide_compression(srv_msg.len(), compression) { - ws_v1::Compression::None => buffer.uncompressed(), - ws_v1::Compression::Brotli => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) - } - ws_v1::Compression::Gzip => buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress), - } + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) +} + +/// Serialize `msg` into a [`DataMessage`] containing a [`ws_v3::ServerFrame::Single`] +/// whose payload is a BSATN-encoded [`ws_v2::ServerMessage`]. +/// +/// This mirrors the v2 framing by prepending the compression tag and applying +/// conditional compression when configured. +pub fn serialize_v3( + bsatn_rlb_pool: &BsatnRowListBuilderPool, + mut buffer: SerializeBuffer, + msg: ws_v2::ServerMessage, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + let mut inner = BytesMut::with_capacity(SERIALIZE_BUFFER_INIT_CAP); + bsatn::to_writer((&mut inner).writer().into_inner(), &msg).expect("should be able to bsatn encode v2 message"); + + // At this point, we no longer have a use for `msg`, + // so try to reclaim its buffers. + msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); + + let frame = ws_v3::ServerFrame::Single(inner.freeze()); + let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { + bsatn::to_writer(w.into_inner(), &frame).expect("should be able to bsatn encode v3 server frame"); + }); + let srv_msg_len = srv_msg.len(); + + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) } #[derive(Debug, From)] diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 92f296f3b8c..4ab8b0c28a7 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1639,7 +1639,7 @@ impl ModuleSubscriptions { message, ); } - WsVersion::V2 => { + WsVersion::V2 | WsVersion::V3 => { if let Some(request_id) = event.request_id { self.send_reducer_failure_result_v2(client, &event, request_id); } From d7b8577bbe88e9ee5284691c6e5255d39d9cf3e4 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:17:44 -0700 Subject: [PATCH 2/4] Update C# sdk to use v3 websocket api --- sdks/csharp/src/CompressionHelpers.cs | 20 +- sdks/csharp/src/Plugins/WebSocket.jslib | 14 +- .../src/SpacetimeDB/ClientApi/ClientFrame.cs | 10 + .../src/SpacetimeDB/ClientApi/ServerFrame.cs | 10 + sdks/csharp/src/SpacetimeDBClient.cs | 46 +++-- sdks/csharp/src/WebSocket.cs | 191 +++++++++++++++--- sdks/csharp/src/WebSocketProtocols.cs | 26 +++ sdks/csharp/src/WebSocketV3Frames.cs | 101 +++++++++ sdks/csharp/tests~/SnapshotTests.cs | 66 ++++++ sdks/csharp/tests~/Tests.cs | 44 +++- 10 files changed, 478 insertions(+), 50 deletions(-) create mode 100644 sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs create mode 100644 sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs create mode 100644 sdks/csharp/src/WebSocketProtocols.cs create mode 100644 sdks/csharp/src/WebSocketV3Frames.cs diff --git a/sdks/csharp/src/CompressionHelpers.cs b/sdks/csharp/src/CompressionHelpers.cs index 832208938ed..cb37c5e25c6 100644 --- a/sdks/csharp/src/CompressionHelpers.cs +++ b/sdks/csharp/src/CompressionHelpers.cs @@ -49,13 +49,12 @@ internal static GZipStream GzipReader(Stream stream) /// /// The compressed and encoded server message as a byte array. /// The deserialized object. - internal static ServerMessage DecompressDecodeMessage(byte[] bytes) + internal static byte[] DecompressMessagePayload(byte[] bytes) { using var stream = new MemoryStream(bytes); // The stream will never be empty. It will at least contain the compression algo. var compression = (CompressionAlgos)stream.ReadByte(); - // Conditionally decompress and decode. Stream decompressedStream = compression switch { CompressionAlgos.None => stream, @@ -67,10 +66,21 @@ internal static ServerMessage DecompressDecodeMessage(byte[] bytes) // TODO: consider pooling these. // DO NOT TRY TO TAKE THIS OUT. The BrotliStream ReadByte() implementation allocates an array // PER BYTE READ. You have to do it all at once to avoid that problem. - MemoryStream memoryStream = new MemoryStream(); + using var memoryStream = new MemoryStream(); decompressedStream.CopyTo(memoryStream); - memoryStream.Seek(0, SeekOrigin.Begin); - return new ServerMessage.BSATN().Read(new BinaryReader(memoryStream)); + return memoryStream.ToArray(); + } + + internal static ServerMessage DecodeServerMessage(byte[] bytes) + { + using var stream = new MemoryStream(bytes); + using var reader = new BinaryReader(stream); + return new ServerMessage.BSATN().Read(reader); + } + + internal static ServerMessage DecompressDecodeMessage(byte[] bytes) + { + return DecodeServerMessage(DecompressMessagePayload(bytes)); } /// diff --git a/sdks/csharp/src/Plugins/WebSocket.jslib b/sdks/csharp/src/Plugins/WebSocket.jslib index d2427954bb8..34d62dd6613 100644 --- a/sdks/csharp/src/Plugins/WebSocket.jslib +++ b/sdks/csharp/src/Plugins/WebSocket.jslib @@ -24,6 +24,9 @@ mergeInto(LibraryManager.library, { var host = UTF8ToString(baseUriPtr); var uri = UTF8ToString(uriPtr); var protocol = UTF8ToString(protocolPtr); + // The C# WebGL bridge can only pass one string argument here, so + // multiple offered subprotocols are marshalled as a comma-separated string. + var offeredProtocols = protocol.indexOf(',') === -1 ? protocol : protocol.split(','); var authToken = UTF8ToString(authTokenPtr); if (authToken) { @@ -46,7 +49,7 @@ mergeInto(LibraryManager.library, { } } - var socket = new window.WebSocket(uri, protocol); + var socket = new window.WebSocket(uri, offeredProtocols); socket.binaryType = "arraybuffer"; var socketId = manager.nextId++; @@ -54,7 +57,12 @@ mergeInto(LibraryManager.library, { socket.onopen = function() { if (manager.callbacks.open) { - dynCall('vi', manager.callbacks.open, [socketId]); + var protocolStr = socket.protocol || ""; + var protocolArray = intArrayFromString(protocolStr); + var protocolPtr = _malloc(protocolArray.length); + HEAP8.set(protocolArray, protocolPtr); + dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]); + _free(protocolPtr); } }; @@ -115,4 +123,4 @@ mergeInto(LibraryManager.library, { socket.close(code, reason); delete manager.instances[socketId]; } -}); \ No newline at end of file +}); diff --git a/sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs b/sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs new file mode 100644 index 00000000000..7cb7980a3f2 --- /dev/null +++ b/sdks/csharp/src/SpacetimeDB/ClientApi/ClientFrame.cs @@ -0,0 +1,10 @@ +#nullable enable + +namespace SpacetimeDB.ClientApi +{ + [SpacetimeDB.Type] + internal partial record ClientFrame : SpacetimeDB.TaggedEnum<( + byte[] Single, + byte[][] Batch + )>; +} diff --git a/sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs b/sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs new file mode 100644 index 00000000000..9660ac589ef --- /dev/null +++ b/sdks/csharp/src/SpacetimeDB/ClientApi/ServerFrame.cs @@ -0,0 +1,10 @@ +#nullable enable + +namespace SpacetimeDB.ClientApi +{ + [SpacetimeDB.Type] + internal partial record ServerFrame : SpacetimeDB.TaggedEnum<( + byte[] Single, + byte[][] Batch + )>; +} diff --git a/sdks/csharp/src/SpacetimeDBClient.cs b/sdks/csharp/src/SpacetimeDBClient.cs index ef202ed168b..639ad9f1d54 100644 --- a/sdks/csharp/src/SpacetimeDBClient.cs +++ b/sdks/csharp/src/SpacetimeDBClient.cs @@ -168,6 +168,8 @@ public abstract class DbConnectionBase : IDbConne protected abstract IErrorContext ToErrorContext(Exception errorContext); protected abstract IProcedureEventContext ToProcedureEventContext(ProcedureEvent procedureEvent); + private Func decodeTransportMessages = DecodeV2TransportMessages; + private readonly ConcurrentDictionary> waitingOneOffQueries = new(); private readonly ConcurrentDictionary pendingReducerCalls = new(); @@ -219,10 +221,16 @@ protected DbConnectionBase() { var options = new WebSocket.ConnectOptions { - Protocol = "v2.bsatn.spacetimedb" + Protocols = WebSocketProtocols.Preferred }; webSocket = new WebSocket(options); webSocket.OnMessage += OnMessageReceived; + webSocket.OnProtocolNegotiated += protocolVersion => + { + decodeTransportMessages = protocolVersion == WebSocketProtocolVersion.V3 + ? WebSocketV3Frames.DecodeServerMessages + : DecodeV2TransportMessages; + }; webSocket.OnSendError += a => onSendError?.Invoke(a); #if UNITY_5_3_OR_NEWER webSocket.OnClose += (e) => @@ -289,6 +297,8 @@ internal struct ParsedMessage private static readonly Status Committed = new Status.Committed(default); + private static byte[][] DecodeV2TransportMessages(byte[] payload) => new[] { payload }; + /// /// Get a description of a message suitable for storing in the tracker metadata. /// @@ -427,9 +437,18 @@ void ParseOneOffQuery(OneOffQueryResult resp) #endif try { - var message = _parseQueue.Take(_parseCancellationToken); - var parsedMessage = ParseMessage(message); - _applyQueue.Add(parsedMessage, _parseCancellationToken); + var unparsed = _parseQueue.Take(_parseCancellationToken); + var payload = CompressionHelpers.DecompressMessagePayload(unparsed.bytes); + var decodedMessages = decodeTransportMessages(payload); + stats.ParseMessageQueueTracker.FinishTrackingRequest( + unparsed.parseQueueTrackerId, + $"type=ws_frame,count={decodedMessages.Length}" + ); + foreach (var messageBytes in decodedMessages) + { + var parsedMessage = ParseMessage(messageBytes, unparsed.timestamp); + _applyQueue.Add(parsedMessage, _parseCancellationToken); + } } catch (OperationCanceledException) { @@ -452,13 +471,11 @@ void ParseOneOffQuery(OneOffQueryResult resp) } } - ParsedMessage ParseMessage(UnparsedMessage unparsed) + ParsedMessage ParseMessage(byte[] messageBytes, DateTime timestamp) { var dbOps = ParsedDatabaseUpdate.New(); - var message = CompressionHelpers.DecompressDecodeMessage(unparsed.bytes); + var message = CompressionHelpers.DecodeServerMessage(messageBytes); var trackerMetadata = TrackerMetadataForMessage(message); - - stats.ParseMessageQueueTracker.FinishTrackingRequest(unparsed.parseQueueTrackerId, trackerMetadata); var parseStart = DateTime.UtcNow; ReducerEvent? reducerEvent = default; @@ -469,11 +486,11 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) case ServerMessage.InitialConnection: break; case ServerMessage.SubscribeApplied(var subscribeApplied): - stats.SubscriptionRequestTracker.FinishTrackingRequest(subscribeApplied.RequestId, unparsed.timestamp); + stats.SubscriptionRequestTracker.FinishTrackingRequest(subscribeApplied.RequestId, timestamp); dbOps = ParseSubscribeRows(subscribeApplied.Rows); break; case ServerMessage.UnsubscribeApplied(var unsubscribeApplied): - stats.SubscriptionRequestTracker.FinishTrackingRequest(unsubscribeApplied.RequestId, unparsed.timestamp); + stats.SubscriptionRequestTracker.FinishTrackingRequest(unsubscribeApplied.RequestId, timestamp); if (unsubscribeApplied.Rows != null) { dbOps = ParseUnsubscribeRows(unsubscribeApplied.Rows); @@ -482,7 +499,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) case ServerMessage.SubscriptionError(var subscriptionError): if (subscriptionError.RequestId.HasValue) { - stats.SubscriptionRequestTracker.FinishTrackingRequest(subscriptionError.RequestId.Value, unparsed.timestamp); + stats.SubscriptionRequestTracker.FinishTrackingRequest(subscriptionError.RequestId.Value, timestamp); } break; case ServerMessage.TransactionUpdate(var transactionUpdate): @@ -492,7 +509,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) ParseOneOffQuery(resp); break; case ServerMessage.ReducerResult(var reducerResult): - if (!stats.ReducerRequestTracker.FinishTrackingRequest(reducerResult.RequestId, unparsed.timestamp)) + if (!stats.ReducerRequestTracker.FinishTrackingRequest(reducerResult.RequestId, timestamp)) { Log.Warn($"Failed to finish tracking reducer request: {reducerResult.RequestId}"); } @@ -545,7 +562,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) procedureResult.RequestId ); - if (!stats.ProcedureRequestTracker.FinishTrackingRequest(procedureResult.RequestId, unparsed.timestamp)) + if (!stats.ProcedureRequestTracker.FinishTrackingRequest(procedureResult.RequestId, timestamp)) { Log.Warn($"Failed to finish tracking procedure request: {procedureResult.RequestId}"); } @@ -558,7 +575,7 @@ ParsedMessage ParseMessage(UnparsedMessage unparsed) stats.ParseMessageTracker.InsertRequest(parseStart, trackerMetadata); var applyTracker = stats.ApplyMessageQueueTracker.StartTrackingRequest(trackerMetadata); - return new ParsedMessage { message = message, dbOps = dbOps, receiveTimestamp = unparsed.timestamp, applyQueueTrackerId = applyTracker, reducerEvent = reducerEvent, procedureEvent = procedureEvent }; + return new ParsedMessage { message = message, dbOps = dbOps, receiveTimestamp = timestamp, applyQueueTrackerId = applyTracker, reducerEvent = reducerEvent, procedureEvent = procedureEvent }; } } @@ -609,6 +626,7 @@ void IDbConnection.Connect(string? token, string uri, string addressOrName, Comp { isClosing = false; connectionClosed = false; + decodeTransportMessages = DecodeV2TransportMessages; Identity = null; initialConnectionId = null; onConnectInvoked = false; diff --git a/sdks/csharp/src/WebSocket.cs b/sdks/csharp/src/WebSocket.cs index 26ce87127ba..95749f08694 100644 --- a/sdks/csharp/src/WebSocket.cs +++ b/sdks/csharp/src/WebSocket.cs @@ -3,7 +3,7 @@ using System; using System.Collections.Concurrent; -using System.Linq; +using System.Collections.Generic; using System.Net.Sockets; using System.Net.WebSockets; using System.Runtime.InteropServices; @@ -15,6 +15,8 @@ namespace SpacetimeDB { internal class WebSocket { + private delegate (byte[] EncodedMessage, bool ShouldYield) DequeueSendWork(); + public delegate void OpenEventHandler(); public delegate void MessageEventHandler(byte[] message, DateTime timestamp); @@ -26,7 +28,7 @@ internal class WebSocket public struct ConnectOptions { - public string Protocol; + public string[] Protocols; } // WebSocket buffer for incoming messages @@ -36,13 +38,16 @@ public struct ConnectOptions private readonly ConnectOptions _options; private readonly byte[] _receiveBuffer = new byte[MAXMessageSize]; private readonly ConcurrentQueue dispatchQueue = new(); + private static readonly ClientMessage.BSATN clientMessageBsatn = new(); protected ClientWebSocket Ws = new(); private CancellationTokenSource? _connectCts; + private DequeueSendWork dequeueSendWork; public WebSocket(ConnectOptions options) { _options = options; + dequeueSendWork = DequeueV2SendWork; #if UNITY_WEBGL && !UNITY_EDITOR InitializeWebGL(); #endif @@ -57,6 +62,14 @@ public WebSocket(ConnectOptions options) /// public event MessageEventHandler? OnMessage; public event CloseEventHandler? OnClose; + public event Action? OnProtocolNegotiated; + + private WebSocketProtocolVersion protocolVersion = WebSocketProtocolVersion.V2; + public WebSocketProtocolVersion ProtocolVersion + { + get => protocolVersion; + internal set => SetProtocolVersion(value); + } #if UNITY_WEBGL && !UNITY_EDITOR private bool _isConnected = false; @@ -88,10 +101,11 @@ IntPtr errorCallback [DllImport("__Internal")] private static extern void WebSocket_Close(int socketId, int code, string reason); - [AOT.MonoPInvokeCallback(typeof(Action))] - private static void WebGLOnOpen(int socketId) + [AOT.MonoPInvokeCallback(typeof(Action))] + private static void WebGLOnOpen(int socketId, IntPtr protocolPtr) { - Instance?.HandleWebGLOpen(socketId); + var protocol = Marshal.PtrToStringUTF8(protocolPtr) ?? string.Empty; + Instance?.HandleWebGLOpen(socketId, protocol); } [AOT.MonoPInvokeCallback(typeof(Action))] @@ -137,7 +151,7 @@ private void InitializeWebGL() { Instance = this; // Convert callbacks to function pointers - var openPtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnOpen); + var openPtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnOpen); var messagePtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnMessage); var closePtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnClose); var errorPtr = Marshal.GetFunctionPointerForDelegate((Action)WebGLOnError); @@ -148,6 +162,7 @@ private void InitializeWebGL() public async Task Connect(string? auth, string host, string nameOrAddress, ConnectionId connectionId, Compression compression, bool light, bool? confirmedReads) { + ResetProtocolVersion(); #if UNITY_WEBGL && !UNITY_EDITOR if (_isConnecting || _isConnected) return; @@ -166,7 +181,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne _socketId = new TaskCompletionSource(); var callbackPtr = Marshal.GetFunctionPointerForDelegate((Action)OnSocketIdReceived); - WebSocket_Connect(host, uri, _options.Protocol, auth, callbackPtr); + WebSocket_Connect(host, uri, WebSocketProtocols.SerializeOfferedProtocols(_options.Protocols), auth, callbackPtr); _webglSocketId = await _socketId.Task; if (_webglSocketId == -1) { @@ -189,6 +204,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne } // Events will be handled via UnitySendMessage callbacks #else + Ws = new ClientWebSocket(); var uri = $"{host}/v1/database/{nameOrAddress}/subscribe?connection_id={connectionId}&compression={compression}"; if (light) { @@ -201,7 +217,10 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne uri += $"&confirmed={enabled}"; } var url = new Uri(uri); - Ws.Options.AddSubProtocol(_options.Protocol); + foreach (var protocol in _options.Protocols) + { + Ws.Options.AddSubProtocol(protocol); + } _connectCts = new CancellationTokenSource(10000); if (!string.IsNullOrEmpty(auth)) @@ -218,6 +237,7 @@ public async Task Connect(string? auth, string host, string nameOrAddress, Conne await Ws.ConnectAsync(url, _connectCts.Token); if (Ws.State == WebSocketState.Open) { + SetProtocolVersion(WebSocketProtocols.Normalize(Ws.SubProtocol)); if (OnConnect != null) { dispatchQueue.Enqueue(() => OnConnect()); @@ -373,7 +393,8 @@ await Ws.CloseAsync(WebSocketCloseStatus.MessageTooBig, closeMessage, if (OnMessage != null) { - var message = _receiveBuffer.Take(count).ToArray(); + var message = new byte[count]; + Buffer.BlockCopy(_receiveBuffer, 0, message, 0, count); // directly invoke message handling OnMessage(message, startReceive); } @@ -454,8 +475,8 @@ public void Abort() #endif } - private Task? senderTask; - private readonly ConcurrentQueue messageSendQueue = new(); + private bool senderActive; + private readonly Queue messageSendQueue = new(); /// /// This sender guarantees that that messages are sent out in the order they are received. Our websocket @@ -465,25 +486,66 @@ public void Abort() /// The message to send public void Send(ClientMessage message) { -#if UNITY_WEBGL && !UNITY_EDITOR try { - var messageBSATN = new ClientMessage.BSATN(); - var encodedMessage = IStructuralReadWrite.ToBytes(messageBSATN, message); - WebSocket_Send(_webglSocketId, encodedMessage, encodedMessage.Length); + var encodedMessage = IStructuralReadWrite.ToBytes(clientMessageBsatn, message); + var startProcessor = false; + lock (messageSendQueue) + { + messageSendQueue.Enqueue(encodedMessage); + if (!senderActive) + { + senderActive = true; + startProcessor = true; + } + } + + if (startProcessor) + { + _ = StartProcessSendQueue(); + } } catch (Exception e) { - UnityEngine.Debug.LogError($"WebSocket send error: {e}"); dispatchQueue.Enqueue(() => OnSendError?.Invoke(e)); } + } + + private Task StartProcessSendQueue() + { +#if UNITY_WEBGL && !UNITY_EDITOR + return ProcessSendQueue(); #else + return Task.Run(ProcessSendQueue); +#endif + } + + private void ScheduleSendQueueContinuation() + { +#if UNITY_WEBGL && !UNITY_EDITOR + dispatchQueue.Enqueue(TryStartSendQueueProcessor); +#else + _ = Task.Run(() => + { + TryStartSendQueueProcessor(); + return Task.CompletedTask; + }); +#endif + } + + private void TryStartSendQueueProcessor() + { lock (messageSendQueue) { - messageSendQueue.Enqueue(message); - senderTask ??= Task.Run(ProcessSendQueue); + if (senderActive || messageSendQueue.Count == 0) + { + return; + } + + senderActive = true; } -#endif + + _ = StartProcessSendQueue(); } private async Task ProcessSendQueue() @@ -492,37 +554,111 @@ private async Task ProcessSendQueue() { while (true) { - ClientMessage message; + byte[] encodedMessage; + bool shouldYield; lock (messageSendQueue) { - if (!messageSendQueue.TryDequeue(out message)) + if (messageSendQueue.Count == 0) { // We are out of messages to send - senderTask = null; + senderActive = false; return; } + + (encodedMessage, shouldYield) = dequeueSendWork(); } - var messageBSATN = new ClientMessage.BSATN(); - var encodedMessage = IStructuralReadWrite.ToBytes(messageBSATN, message); - await Ws!.SendAsync(encodedMessage, WebSocketMessageType.Binary, true, CancellationToken.None); + await SendEncodedMessage(encodedMessage); + + if (shouldYield) + { + // After sending one capped v3 frame, stop this queue pump and + // schedule a follow-up pass using the same runtime primitives + // this SDK already uses for send processing on each platform. + lock (messageSendQueue) + { + senderActive = false; + } + ScheduleSendQueueContinuation(); + return; + } } } catch (Exception e) { - senderTask = null; + lock (messageSendQueue) + { + senderActive = false; + } if (OnSendError != null) dispatchQueue.Enqueue(() => OnSendError(e)); } } + private byte[][] DequeueMessagesForV3Frame() + { + var messageCount = WebSocketV3Frames.CountClientMessagesThatFitInFrame(messageSendQueue); + if (messageCount <= 0) + { + throw new InvalidOperationException("Expected at least one queued v2 message when building a v3 frame."); + } + + var messages = new byte[messageCount][]; + for (var i = 0; i < messageCount; i++) + { + messages[i] = messageSendQueue.Dequeue(); + } + return messages; + } + + private (byte[] EncodedMessage, bool ShouldYield) DequeueV2SendWork() => + (messageSendQueue.Dequeue(), false); + + private (byte[] EncodedMessage, bool ShouldYield) DequeueV3SendWork() + { + var queuedMessages = DequeueMessagesForV3Frame(); + return (WebSocketV3Frames.EncodeClientMessages(queuedMessages), messageSendQueue.Count > 0); + } + + private void ResetProtocolVersion() + { + protocolVersion = WebSocketProtocolVersion.V2; + dequeueSendWork = DequeueV2SendWork; + } + + private void SetProtocolVersion(WebSocketProtocolVersion protocolVersion) + { + // Protocol selection is a transport concern: changing it swaps the + // active send strategy and notifies higher layers to swap their + // receive decoder as well. + this.protocolVersion = protocolVersion; + dequeueSendWork = protocolVersion == WebSocketProtocolVersion.V3 + ? DequeueV3SendWork + : DequeueV2SendWork; + OnProtocolNegotiated?.Invoke(protocolVersion); + } + + private Task SendEncodedMessage(byte[] encodedMessage) + { +#if UNITY_WEBGL && !UNITY_EDITOR + var result = WebSocket_Send(_webglSocketId, encodedMessage, encodedMessage.Length); + if (result != 0) + { + throw new InvalidOperationException("WebSocket send failed."); + } + return Task.CompletedTask; +#else + return Ws!.SendAsync(new ArraySegment(encodedMessage), WebSocketMessageType.Binary, true, CancellationToken.None); +#endif + } + public WebSocketState GetState() { return Ws!.State; } #if UNITY_WEBGL && !UNITY_EDITOR - public void HandleWebGLOpen(int socketId) + public void HandleWebGLOpen(int socketId, string protocol) { if (socketId == _webglSocketId) { @@ -535,6 +671,7 @@ public void HandleWebGLOpen(int socketId) _cancelConnectRequested = false; return; } + SetProtocolVersion(WebSocketProtocols.Normalize(protocol)); _isConnected = true; if (OnConnect != null) dispatchQueue.Enqueue(() => OnConnect()); diff --git a/sdks/csharp/src/WebSocketProtocols.cs b/sdks/csharp/src/WebSocketProtocols.cs new file mode 100644 index 00000000000..0e98ec0c48c --- /dev/null +++ b/sdks/csharp/src/WebSocketProtocols.cs @@ -0,0 +1,26 @@ +namespace SpacetimeDB +{ + internal enum WebSocketProtocolVersion + { + V2, + V3, + } + + internal static class WebSocketProtocols + { + internal const string V2 = "v2.bsatn.spacetimedb"; + internal const string V3 = "v3.bsatn.spacetimedb"; + + internal static readonly string[] Preferred = new[] { V3, V2 }; + + internal static WebSocketProtocolVersion Normalize(string? protocol) + { + // Treat an empty negotiated subprotocol as legacy v2 defensively. + return protocol == V3 ? WebSocketProtocolVersion.V3 : WebSocketProtocolVersion.V2; + } + +#if UNITY_WEBGL && !UNITY_EDITOR + internal static string SerializeOfferedProtocols(string[] protocols) => string.Join(",", protocols); +#endif + } +} diff --git a/sdks/csharp/src/WebSocketV3Frames.cs b/sdks/csharp/src/WebSocketV3Frames.cs new file mode 100644 index 00000000000..7f715df7b09 --- /dev/null +++ b/sdks/csharp/src/WebSocketV3Frames.cs @@ -0,0 +1,101 @@ +using SpacetimeDB.BSATN; +using SpacetimeDB.ClientApi; + +using System; +using System.Collections.Generic; +using System.IO; + +namespace SpacetimeDB +{ + internal static class WebSocketV3Frames + { + internal const int MaxFrameBytes = 256 * 1024; + + private const int EnumTagBytes = 1; + private const int CollectionLengthBytes = 4; + private const int ByteArrayLengthBytes = 4; + + private static readonly ClientFrame.BSATN clientFrameBsatn = new(); + private static readonly ServerFrame.BSATN serverFrameBsatn = new(); + + // v3 is only a transport envelope around already-encoded v2 messages, + // so batching works in terms of raw byte payloads rather than logical messages. + internal static byte[] EncodeClientMessages(IReadOnlyList messages) + { + if (messages.Count == 0) + { + throw new InvalidOperationException("Cannot encode an empty v3 client frame."); + } + + ClientFrame frame = messages.Count == 1 + ? new ClientFrame.Single(messages[0]) + : new ClientFrame.Batch(ToArray(messages)); + + return IStructuralReadWrite.ToBytes(clientFrameBsatn, frame); + } + + internal static byte[][] DecodeServerMessages(byte[] encodedFrame) + { + using var stream = new MemoryStream(encodedFrame); + using var reader = new BinaryReader(stream); + var frame = serverFrameBsatn.Read(reader); + return frame switch + { + ServerFrame.Single(var message) => new[] { message }, + ServerFrame.Batch(var messages) => messages, + _ => throw new InvalidOperationException("Unknown v3 server frame variant."), + }; + } + + // Count the maximal prefix of already-encoded client messages that fits in + // one v3 frame using BSATN framing sizes directly instead of trial serialization. + internal static int CountClientMessagesThatFitInFrame( + IEnumerable messages, + int maxFrameBytes = MaxFrameBytes + ) + { + var messageCount = 0; + var payloadBytes = 0; + + foreach (var message in messages) + { + if (messageCount == 0) + { + if (EncodedSingleFrameSize(message.Length) > maxFrameBytes) + { + return 1; + } + } + else + { + var batchSize = EncodedBatchFrameSize(messageCount + 1, payloadBytes + message.Length); + if (batchSize > maxFrameBytes) + { + break; + } + } + + messageCount++; + payloadBytes += message.Length; + } + + return messageCount; + } + + private static int EncodedSingleFrameSize(int messageBytes) => + EnumTagBytes + ByteArrayLengthBytes + messageBytes; + + private static int EncodedBatchFrameSize(int messageCount, int payloadBytes) => + EnumTagBytes + CollectionLengthBytes + (messageCount * ByteArrayLengthBytes) + payloadBytes; + + private static byte[][] ToArray(IReadOnlyList messages) + { + var array = new byte[messages.Count][]; + for (var i = 0; i < messages.Count; i++) + { + array[i] = messages[i]; + } + return array; + } + } +} diff --git a/sdks/csharp/tests~/SnapshotTests.cs b/sdks/csharp/tests~/SnapshotTests.cs index e083928111e..fcaed905f5d 100644 --- a/sdks/csharp/tests~/SnapshotTests.cs +++ b/sdks/csharp/tests~/SnapshotTests.cs @@ -381,6 +381,72 @@ public static IEnumerable SampleDump() } + [Fact] + public void V3BatchedServerFrameIsProcessedInOrder() + { + DbConnection.IsTesting = true; + + var client = + DbConnection.Builder() + .WithUri("wss://spacetimedb.com") + .WithDatabaseName("example") + .Build(); + + client.webSocket.ProtocolVersion = WebSocketProtocolVersion.V3; + + ServerMessage initialConnection = SampleId( + "j5DMlKmWjfbSl7qmZQOok7HDSwsAJopRSJjdlUsNogs=", + "token", + "Vd4dFzcEzhLHJ6uNL8VXFg==" + ); + ServerMessage transactionUpdate = SampleTransactionUpdate( + 1, + [SampleUserInsert("l0qzG1GPRtC1mwr+54q98tv0325gozLc6cNzq4vrzqY=", "A", true)] + ); + + ServerFrame frame = new ServerFrame.Batch(new[] + { + IStructuralReadWrite.ToBytes(new ServerMessage.BSATN(), initialConnection), + IStructuralReadWrite.ToBytes(new ServerMessage.BSATN(), transactionUpdate), + }); + var payload = IStructuralReadWrite.ToBytes(new ServerFrame.BSATN(), frame); + + var transportMessage = new byte[payload.Length + 1]; + transportMessage[0] = 0; + Buffer.BlockCopy(payload, 0, transportMessage, 1, payload.Length); + + client.OnMessageReceived(transportMessage, DateTime.UtcNow); + + var deadline = DateTime.UtcNow.AddSeconds(2); + List? users = null; + while (true) + { + client.FrameTick(); + users = client.Db.User.Iter().ToList(); + if (users.Count == 1) + { + break; + } + + if (DateTime.UtcNow >= deadline) + { + throw new TimeoutException("Timed out waiting for a v3 batched frame to be applied."); + } + Thread.Sleep(1); + } + + Assert.Equal( + Identity.From(Convert.FromBase64String("j5DMlKmWjfbSl7qmZQOok7HDSwsAJopRSJjdlUsNogs=")), + client.Identity + ); + + Assert.Single(users); + Assert.Equal("A", users[0].Name); + Assert.True(users[0].Online); + + client.Disconnect(); + } + [Theory] [MemberData(nameof(SampleDump))] public async Task VerifySampleDump(string dumpName, ServerMessage[] sampleDumpParsed) diff --git a/sdks/csharp/tests~/Tests.cs b/sdks/csharp/tests~/Tests.cs index 3adb4970cba..1559317a382 100644 --- a/sdks/csharp/tests~/Tests.cs +++ b/sdks/csharp/tests~/Tests.cs @@ -2,6 +2,7 @@ using CsCheck; using SpacetimeDB; using SpacetimeDB.BSATN; +using SpacetimeDB.ClientApi; using SpacetimeDB.Types; public class Tests @@ -128,4 +129,45 @@ public static void ListstreamWorks() } }); } -} \ No newline at end of file + + [Fact] + public static void V3BatchSizingCapsAt256KiB() + { + var messages = new[] + { + new byte[100_000], + new byte[100_000], + new byte[100_000], + }; + + Assert.Equal(2, WebSocketV3Frames.CountClientMessagesThatFitInFrame(messages)); + Assert.Equal(1, WebSocketV3Frames.CountClientMessagesThatFitInFrame(new[] { new byte[300_000] })); + Assert.Equal(0, WebSocketV3Frames.CountClientMessagesThatFitInFrame(Array.Empty())); + } + + [Fact] + public static void V3ServerFrameDecodeHandlesSingleAndBatch() + { + static byte[] EncodeFrame(ServerFrame frame) => + IStructuralReadWrite.ToBytes(new ServerFrame.BSATN(), frame); + + var singlePayload = new byte[] { 1, 2, 3 }; + var single = WebSocketV3Frames.DecodeServerMessages( + EncodeFrame(new ServerFrame.Single(singlePayload)) + ); + Assert.Single(single); + Assert.Equal(singlePayload, single[0]); + + var batchPayloads = new[] + { + new byte[] { 4, 5 }, + new byte[] { 6, 7, 8 }, + }; + var batch = WebSocketV3Frames.DecodeServerMessages( + EncodeFrame(new ServerFrame.Batch(batchPayloads)) + ); + Assert.Equal(2, batch.Length); + Assert.Equal(batchPayloads[0], batch[0]); + Assert.Equal(batchPayloads[1], batch[1]); + } +} From 20f5a4f034817da81e31cfa991cf7f059dbb5f0e Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:36:16 -0700 Subject: [PATCH 3/4] Clarify WebGL protocol marshalling --- sdks/csharp/src/Plugins/WebSocket.jslib | 21 ++++++++++++++++----- sdks/csharp/src/WebSocket.cs | 2 ++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/sdks/csharp/src/Plugins/WebSocket.jslib b/sdks/csharp/src/Plugins/WebSocket.jslib index 34d62dd6613..4b4d39a8e49 100644 --- a/sdks/csharp/src/Plugins/WebSocket.jslib +++ b/sdks/csharp/src/Plugins/WebSocket.jslib @@ -58,11 +58,22 @@ mergeInto(LibraryManager.library, { socket.onopen = function() { if (manager.callbacks.open) { var protocolStr = socket.protocol || ""; - var protocolArray = intArrayFromString(protocolStr); - var protocolPtr = _malloc(protocolArray.length); - HEAP8.set(protocolArray, protocolPtr); - dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]); - _free(protocolPtr); + // Marshal the negotiated subprotocol to C# just for the duration of + // this callback. We use stack allocation because the pointer only + // needs to remain valid while dynCall is executing synchronously. + var protocolLength = lengthBytesUTF8(protocolStr) + 1; + var stack = stackSave(); + try { + var protocolPtr = stackAlloc(protocolLength); + // Write a temporary null-terminated UTF-8 string into the + // Emscripten stack frame so the C# callback can copy it. + stringToUTF8(protocolStr, protocolPtr, protocolLength); + dynCall('vii', manager.callbacks.open, [socketId, protocolPtr]); + } finally { + // Release the temporary stack allocation immediately after + // the callback returns; C# must not retain the pointer. + stackRestore(stack); + } } }; diff --git a/sdks/csharp/src/WebSocket.cs b/sdks/csharp/src/WebSocket.cs index 95749f08694..96e61763f91 100644 --- a/sdks/csharp/src/WebSocket.cs +++ b/sdks/csharp/src/WebSocket.cs @@ -104,6 +104,8 @@ IntPtr errorCallback [AOT.MonoPInvokeCallback(typeof(Action))] private static void WebGLOnOpen(int socketId, IntPtr protocolPtr) { + // The JS bridge passes a temporary UTF-8 pointer that is only valid for + // this callback, so copy it into a managed string immediately. var protocol = Marshal.PtrToStringUTF8(protocolPtr) ?? string.Empty; Instance?.HandleWebGLOpen(socketId, protocol); } From ee0c8fb6ab5d2fbf82a4bdd46e39ce6c6e8bfcb5 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:49:06 -0700 Subject: [PATCH 4/4] Add C# v2 fallback websocket test --- sdks/csharp/tests~/Tests.cs | 84 +++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/sdks/csharp/tests~/Tests.cs b/sdks/csharp/tests~/Tests.cs index 1559317a382..cd92f84bb1d 100644 --- a/sdks/csharp/tests~/Tests.cs +++ b/sdks/csharp/tests~/Tests.cs @@ -1,4 +1,7 @@ using System.Diagnostics; +using System.Net; +using System.Net.Sockets; +using System.Net.WebSockets; using CsCheck; using SpacetimeDB; using SpacetimeDB.BSATN; @@ -170,4 +173,85 @@ static byte[] EncodeFrame(ServerFrame frame) => Assert.Equal(batchPayloads[0], batch[0]); Assert.Equal(batchPayloads[1], batch[1]); } + + [Fact] + public static async Task WebSocketFallsBackToV2WhenServerOnlyNegotiatesV2() + { + static int GetFreePort() + { + using var listener = new TcpListener(IPAddress.Loopback, 0); + listener.Start(); + return ((IPEndPoint)listener.LocalEndpoint).Port; + } + + static async Task WaitForAsync(Task task, SpacetimeDB.WebSocket ws, string error) + { + var deadline = DateTime.UtcNow.AddSeconds(5); + while (!task.IsCompleted) + { + ws.Update(); + if (DateTime.UtcNow >= deadline) + { + throw new TimeoutException(error); + } + await Task.Delay(10); + } + + await task; + } + + var port = GetFreePort(); + using var listener = new HttpListener(); + listener.Prefixes.Add($"http://127.0.0.1:{port}/"); + listener.Start(); + + var serverObservedProtocols = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var serverTask = Task.Run(async () => + { + var context = await listener.GetContextAsync(); + serverObservedProtocols.TrySetResult(context.Request.Headers["Sec-WebSocket-Protocol"] ?? string.Empty); + + var webSocketContext = await context.AcceptWebSocketAsync(WebSocketProtocols.V2); + await Task.Delay(100); + await webSocketContext.WebSocket.CloseAsync( + WebSocketCloseStatus.NormalClosure, + "done", + CancellationToken.None + ); + }); + + var ws = new SpacetimeDB.WebSocket(new SpacetimeDB.WebSocket.ConnectOptions + { + Protocols = WebSocketProtocols.Preferred, + }); + + var connected = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closed = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + ws.OnConnect += () => connected.TrySetResult(); + ws.OnClose += _ => closed.TrySetResult(); + + var clientTask = Task.Run(() => ws.Connect( + "test-token", + $"ws://127.0.0.1:{port}", + "example", + ConnectionId.Random(), + Compression.None, + false, + null + )); + + await WaitForAsync(connected.Task, ws, "Timed out waiting for websocket connection."); + + Assert.Equal(WebSocketProtocolVersion.V2, ws.ProtocolVersion); + + var offeredProtocols = await serverObservedProtocols.Task.WaitAsync(TimeSpan.FromSeconds(5)); + Assert.Contains(WebSocketProtocols.V3, offeredProtocols); + Assert.Contains(WebSocketProtocols.V2, offeredProtocols); + + await WaitForAsync(closed.Task, ws, "Timed out waiting for websocket close."); + await serverTask.WaitAsync(TimeSpan.FromSeconds(5)); + await clientTask.WaitAsync(TimeSpan.FromSeconds(5)); + } }