From d494169fdd1222ad87d15ce3b7f5668cb257ab13 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 13:34:58 -0700 Subject: [PATCH 1/3] v3 websocket protocol --- crates/client-api-messages/DEVELOP.md | 9 ++ .../examples/get_ws_schema_v3.rs | 13 ++ crates/client-api-messages/src/websocket.rs | 1 + .../client-api-messages/src/websocket/v3.rs | 28 ++++ crates/client-api/src/routes/subscribe.rs | 132 +++++++++++++----- crates/core/src/client.rs | 1 + crates/core/src/client/client_connection.rs | 3 +- crates/core/src/client/message_handlers.rs | 1 + crates/core/src/client/message_handlers_v2.rs | 8 ++ crates/core/src/client/message_handlers_v3.rs | 32 +++++ crates/core/src/client/messages.rs | 62 +++++--- .../subscription/module_subscription_actor.rs | 2 +- 12 files changed, 237 insertions(+), 55 deletions(-) create mode 100644 crates/client-api-messages/examples/get_ws_schema_v3.rs create mode 100644 crates/client-api-messages/src/websocket/v3.rs create mode 100644 crates/core/src/client/message_handlers_v3.rs diff --git a/crates/client-api-messages/DEVELOP.md b/crates/client-api-messages/DEVELOP.md index 47868d4d3ce..48341bd80aa 100644 --- a/crates/client-api-messages/DEVELOP.md +++ b/crates/client-api-messages/DEVELOP.md @@ -19,3 +19,12 @@ spacetime generate -p spacetimedb-cli --lang \ --out-dir \ --module-def ws_schema_v2.json ``` + +For the v3 WebSocket transport schema: + +```sh +cargo run --example get_ws_schema_v3 > ws_schema_v3.json +spacetime generate -p spacetimedb-cli --lang \ + --out-dir \ + --module-def ws_schema_v3.json +``` diff --git a/crates/client-api-messages/examples/get_ws_schema_v3.rs b/crates/client-api-messages/examples/get_ws_schema_v3.rs new file mode 100644 index 00000000000..b4a752a5664 --- /dev/null +++ b/crates/client-api-messages/examples/get_ws_schema_v3.rs @@ -0,0 +1,13 @@ +use spacetimedb_client_api_messages::websocket::v3::{ClientFrame, ServerFrame}; +use spacetimedb_lib::ser::serde::SerializeWrapper; +use spacetimedb_lib::{RawModuleDef, RawModuleDefV8}; + +fn main() -> Result<(), serde_json::Error> { + let module = RawModuleDefV8::with_builder(|module| { + module.add_type::(); + module.add_type::(); + }); + let module = RawModuleDef::V8BackCompat(module); + + serde_json::to_writer(std::io::stdout().lock(), SerializeWrapper::from_ref(&module)) +} diff --git a/crates/client-api-messages/src/websocket.rs b/crates/client-api-messages/src/websocket.rs index 0935d2e3c55..14ec394670f 100644 --- a/crates/client-api-messages/src/websocket.rs +++ b/crates/client-api-messages/src/websocket.rs @@ -17,3 +17,4 @@ pub mod common; pub mod v1; pub mod v2; +pub mod v3; diff --git a/crates/client-api-messages/src/websocket/v3.rs b/crates/client-api-messages/src/websocket/v3.rs new file mode 100644 index 00000000000..5be37768299 --- /dev/null +++ b/crates/client-api-messages/src/websocket/v3.rs @@ -0,0 +1,28 @@ +use bytes::Bytes; +pub use spacetimedb_sats::SpacetimeType; + +pub const BIN_PROTOCOL: &str = "v3.bsatn.spacetimedb"; + +/// Transport envelopes sent by the client over the v3 websocket protocol. +/// +/// The inner bytes are BSATN-encoded v2 [`crate::websocket::v2::ClientMessage`] values. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ClientFrame { + /// A single logical client message. + Single(Bytes), + /// Multiple logical client messages that should be processed in-order. + Batch(Box<[Bytes]>), +} + +/// Transport envelopes sent by the server over the v3 websocket protocol. +/// +/// The inner bytes are BSATN-encoded v2 [`crate::websocket::v2::ServerMessage`] values. +#[derive(SpacetimeType, Debug)] +#[sats(crate = spacetimedb_lib)] +pub enum ServerFrame { + /// A single logical server message. + Single(Bytes), + /// Multiple logical server messages that should be processed in-order. + Batch(Box<[Bytes]>), +} diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index d1bb1d2b11f..868e7425eda 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -23,8 +23,8 @@ use prometheus::{Histogram, IntGauge}; use scopeguard::{defer, ScopeGuard}; use serde::Deserialize; use spacetimedb::client::messages::{ - serialize, serialize_v2, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, - ToProtocol, + serialize, serialize_v2, serialize_v3, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, + SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, @@ -38,6 +38,7 @@ use spacetimedb::worker_metrics::WORKER_METRICS; use spacetimedb::Identity; use spacetimedb_client_api_messages::websocket::v1 as ws_v1; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; +use spacetimedb_client_api_messages::websocket::v3 as ws_v3; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use tokio::sync::{mpsc, watch}; @@ -62,6 +63,8 @@ pub const TEXT_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v1::TEXT_PROT pub const BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v1::BIN_PROTOCOL); #[allow(clippy::declare_interior_mutable_const)] pub const V2_BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v2::BIN_PROTOCOL); +#[allow(clippy::declare_interior_mutable_const)] +pub const V3_BIN_PROTOCOL: HeaderValue = HeaderValue::from_static(ws_v3::BIN_PROTOCOL); pub trait HasWebSocketOptions { fn websocket_options(&self) -> WebSocketOptions; @@ -101,7 +104,7 @@ fn resolve_confirmed_reads_default(version: WsVersion, confirmed: Option) } match version { WsVersion::V1 => false, - WsVersion::V2 => crate::DEFAULT_CONFIRMED_READS, + WsVersion::V2 | WsVersion::V3 => crate::DEFAULT_CONFIRMED_READS, } } @@ -151,6 +154,13 @@ where } let (res, ws_upgrade, protocol) = ws.select_protocol([ + ( + V3_BIN_PROTOCOL, + NegotiatedProtocol { + protocol: Protocol::Binary, + version: WsVersion::V3, + }, + ), ( V2_BIN_PROTOCOL, NegotiatedProtocol { @@ -284,7 +294,7 @@ where }; client.send_message(None, OutboundMessage::V1(message.into())) } - WsVersion::V2 => { + WsVersion::V2 | WsVersion::V3 => { let message = ws_v2::ServerMessage::InitialConnection(ws_v2::InitialConnection { identity: client_identity, connection_id, @@ -1293,10 +1303,15 @@ async fn ws_encode_task( // copied to the wire. Since we don't know when that will happen, we prepare // for a few messages to be in-flight, i.e. encoded but not yet sent. const BUF_POOL_CAPACITY: usize = 16; + let binary_message_serializer = match config.version { + WsVersion::V1 => None, + WsVersion::V2 => Some(serialize_v2 as BinarySerializeFn), + WsVersion::V3 => Some(serialize_v3 as BinarySerializeFn), + }; let buf_pool = ArrayQueue::new(BUF_POOL_CAPACITY); let mut in_use_bufs: Vec> = Vec::with_capacity(BUF_POOL_CAPACITY); - while let Some(message) = messages.recv().await { + 'send: while let Some(message) = messages.recv().await { // Drop serialize buffers with no external referent, // returning them to the pool. in_use_bufs.retain(|in_use| !in_use.is_unique()); @@ -1306,16 +1321,22 @@ async fn ws_encode_task( let in_use_buf = match message { OutboundWsMessage::Error(message) => { - if config.version == WsVersion::V2 { - log::error!("dropping v1 error message sent to a v2 client: {:?}", message); + if config.version != WsVersion::V1 { + log::error!( + "dropping v1 error message sent to a binary websocket client: {:?}", + message + ); continue; } - let (stats, in_use, mut frames) = ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await; - metrics.report(None, None, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + None, + None, + ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await, + ) else { + break 'send; + }; in_use } OutboundWsMessage::Message(message) => { @@ -1323,38 +1344,47 @@ async fn ws_encode_task( let num_rows = message.num_rows(); match message { OutboundMessage::V2(server_message) => { - if config.version != WsVersion::V2 { + if config.version == WsVersion::V1 { log::error!("dropping v2 message on v1 connection"); continue; } - let (stats, in_use, mut frames) = - ws_encode_message_v2(config, buf, server_message, false, &bsatn_rlb_pool).await; - metrics.report(workload, num_rows, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + workload, + num_rows, + ws_encode_binary_message( + config, + buf, + server_message, + binary_message_serializer.expect("v2 message should not be sent on a v1 connection"), + false, + &bsatn_rlb_pool, + ) + .await, + ) else { + break 'send; + }; in_use } OutboundMessage::V1(message) => { - if config.version == WsVersion::V2 { - log::error!( - "dropping v1 message for v2 connection until v2 serialization is implemented: {:?}", - message - ); + if config.version != WsVersion::V1 { + log::error!("dropping v1 message for a binary websocket connection: {:?}", message); continue; } let is_large = num_rows.is_some_and(|n| n > 1024); - let (stats, in_use, mut frames) = - ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await; - metrics.report(workload, num_rows, stats); - if frames.try_for_each(|frame| outgoing_frames.send(frame)).is_err() { - break; - } - + let Ok(in_use) = ws_forward_frames( + &metrics, + &outgoing_frames, + workload, + num_rows, + ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await, + ) else { + break 'send; + }; in_use } } @@ -1370,6 +1400,24 @@ async fn ws_encode_task( } } +/// Reports encode metrics for an already-encoded message and forwards all of +/// its frames to the websocket send task. +fn ws_forward_frames( + metrics: &SendMetrics, + outgoing_frames: &mpsc::UnboundedSender, + workload: Option, + num_rows: Option, + encoded: (EncodeMetrics, InUseSerializeBuffer, I), +) -> Result> +where + I: Iterator, +{ + let (stats, in_use, frames) = encoded; + metrics.report(workload, num_rows, stats); + frames.into_iter().try_for_each(|frame| outgoing_frames.send(frame))?; + Ok(in_use) +} + /// Some stats about serialization and compression. /// /// Returned by [`ws_encode_message`]. @@ -1443,21 +1491,29 @@ async fn ws_encode_message( (metrics, msg_alloc, frames) } -#[allow(dead_code, unused_variables)] -async fn ws_encode_message_v2( +type BinarySerializeFn = fn( + &BsatnRowListBuilderPool, + SerializeBuffer, + ws_v2::ServerMessage, + ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes); + +async fn ws_encode_binary_message( config: ClientConfig, buf: SerializeBuffer, message: ws_v2::ServerMessage, + serialize_message: BinarySerializeFn, is_large_message: bool, bsatn_rlb_pool: &BsatnRowListBuilderPool, ) -> (EncodeMetrics, InUseSerializeBuffer, impl Iterator + use<>) { let start = Instant::now(); + let compression = config.compression; let (in_use, data) = if is_large_message { let bsatn_rlb_pool = bsatn_rlb_pool.clone(); - spawn_rayon(move || serialize_v2(&bsatn_rlb_pool, buf, message, config.compression)).await + spawn_rayon(move || serialize_message(&bsatn_rlb_pool, buf, message, compression)).await } else { - serialize_v2(bsatn_rlb_pool, buf, message, config.compression) + serialize_message(bsatn_rlb_pool, buf, message, compression) }; let metrics = EncodeMetrics { @@ -2298,9 +2354,11 @@ mod tests { #[test] fn confirmed_reads_default_depends_on_ws_version() { + assert!(resolve_confirmed_reads_default(WsVersion::V3, None)); assert!(resolve_confirmed_reads_default(WsVersion::V2, None)); assert!(!resolve_confirmed_reads_default(WsVersion::V1, None)); assert!(resolve_confirmed_reads_default(WsVersion::V1, Some(true))); + assert!(!resolve_confirmed_reads_default(WsVersion::V3, Some(false))); assert!(!resolve_confirmed_reads_default(WsVersion::V2, Some(false))); } diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index cad4f79adcf..4411192c625 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -7,6 +7,7 @@ pub mod consume_each_list; mod message_handlers; mod message_handlers_v1; mod message_handlers_v2; +mod message_handlers_v3; pub mod messages; pub use client_connection::{ diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index 6fb8d8e1623..0a7a7f1a11b 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -47,6 +47,7 @@ pub enum Protocol { pub enum WsVersion { V1, V2, + V3, } impl Protocol { @@ -384,7 +385,7 @@ impl ClientConnectionSender { debug_assert!( matches!( (&self.config.version, &message), - (WsVersion::V1, OutboundMessage::V1(_)) | (WsVersion::V2, OutboundMessage::V2(_)) + (WsVersion::V1, OutboundMessage::V1(_)) | (WsVersion::V2 | WsVersion::V3, OutboundMessage::V2(_)) ), "attempted to send message variant that does not match client websocket version" ); diff --git a/crates/core/src/client/message_handlers.rs b/crates/core/src/client/message_handlers.rs index 76f5fa53afa..fb85730c11c 100644 --- a/crates/core/src/client/message_handlers.rs +++ b/crates/core/src/client/message_handlers.rs @@ -23,5 +23,6 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst match client.config.version { WsVersion::V1 => super::message_handlers_v1::handle(client, message, timer).await, WsVersion::V2 => super::message_handlers_v2::handle(client, message, timer).await, + WsVersion::V3 => super::message_handlers_v3::handle(client, message, timer).await, } } diff --git a/crates/core/src/client/message_handlers_v2.rs b/crates/core/src/client/message_handlers_v2.rs index 5dd2f80d01b..2db523e472d 100644 --- a/crates/core/src/client/message_handlers_v2.rs +++ b/crates/core/src/client/message_handlers_v2.rs @@ -20,6 +20,14 @@ pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Inst ))) } }; + handle_decoded_message(client, message, timer).await +} + +pub(super) async fn handle_decoded_message( + client: &ClientConnection, + message: ws_v2::ClientMessage, + timer: Instant, +) -> Result<(), MessageHandleError> { let module = client.module(); let mod_info = module.info(); let mod_metrics = &mod_info.metrics; diff --git a/crates/core/src/client/message_handlers_v3.rs b/crates/core/src/client/message_handlers_v3.rs new file mode 100644 index 00000000000..696e7337ed0 --- /dev/null +++ b/crates/core/src/client/message_handlers_v3.rs @@ -0,0 +1,32 @@ +use super::{ClientConnection, DataMessage, MessageHandleError}; +use serde::de::Error as _; +use spacetimedb_client_api_messages::websocket::{v2 as ws_v2, v3 as ws_v3}; +use spacetimedb_lib::bsatn; +use std::time::Instant; + +pub async fn handle(client: &ClientConnection, message: DataMessage, timer: Instant) -> Result<(), MessageHandleError> { + client.observe_websocket_request_message(&message); + let frame = match message { + DataMessage::Binary(message_buf) => bsatn::from_slice::(&message_buf)?, + DataMessage::Text(_) => { + return Err(MessageHandleError::TextDecode(serde_json::Error::custom( + "v3 websocket does not support text messages", + ))) + } + }; + + match frame { + ws_v3::ClientFrame::Single(message) => { + let message = bsatn::from_slice::(&message)?; + super::message_handlers_v2::handle_decoded_message(client, message, timer).await?; + } + ws_v3::ClientFrame::Batch(messages) => { + for message in messages { + let message = bsatn::from_slice::(&message)?; + super::message_handlers_v2::handle_decoded_message(client, message, timer).await?; + } + } + } + + Ok(()) +} diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index ed65e092d0e..38c5fadb260 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -10,6 +10,7 @@ use derive_more::From; use spacetimedb_client_api_messages::websocket::common::{self as ws_common, RowListLen as _}; use spacetimedb_client_api_messages::websocket::v1::{self as ws_v1}; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; +use spacetimedb_client_api_messages::websocket::v3 as ws_v3; use spacetimedb_datastore::execution_context::WorkloadType; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::ser::serde::SerializeWrapper; @@ -97,6 +98,20 @@ impl SerializeBuffer { } } +fn finalize_binary_serialize_buffer( + buffer: SerializeBuffer, + uncompressed_len: usize, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + match decide_compression(uncompressed_len, compression) { + ws_v1::Compression::None => buffer.uncompressed(), + ws_v1::Compression::Brotli => { + buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) + } + ws_v1::Compression::Gzip => buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress), + } +} + type BytesMutWriter<'a> = bytes::buf::Writer<&'a mut BytesMut>; pub enum InUseSerializeBuffer { @@ -159,21 +174,14 @@ pub fn serialize( let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { bsatn::to_writer(w.into_inner(), &msg).unwrap() }); + let srv_msg_len = srv_msg.len(); // At this point, we no longer have a use for `msg`, // so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); // Conditionally compress the message. - let (in_use, msg_bytes) = match decide_compression(srv_msg.len(), config.compression) { - ws_v1::Compression::None => buffer.uncompressed(), - ws_v1::Compression::Brotli => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) - } - ws_v1::Compression::Gzip => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress) - } - }; + let (in_use, msg_bytes) = finalize_binary_serialize_buffer(buffer, srv_msg_len, config.compression); (in_use, msg_bytes.into()) } } @@ -192,18 +200,40 @@ pub fn serialize_v2( let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { bsatn::to_writer(w.into_inner(), &msg).expect("should be able to bsatn encode v2 message"); }); + let srv_msg_len = srv_msg.len(); // At this point, we no longer have a use for `msg`, // so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); - match decide_compression(srv_msg.len(), compression) { - ws_v1::Compression::None => buffer.uncompressed(), - ws_v1::Compression::Brotli => { - buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_BROTLI, brotli_compress) - } - ws_v1::Compression::Gzip => buffer.compress_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_GZIP, gzip_compress), - } + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) +} + +/// Serialize `msg` into a [`DataMessage`] containing a [`ws_v3::ServerFrame::Single`] +/// whose payload is a BSATN-encoded [`ws_v2::ServerMessage`]. +/// +/// This mirrors the v2 framing by prepending the compression tag and applying +/// conditional compression when configured. +pub fn serialize_v3( + bsatn_rlb_pool: &BsatnRowListBuilderPool, + mut buffer: SerializeBuffer, + msg: ws_v2::ServerMessage, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + let mut inner = BytesMut::with_capacity(SERIALIZE_BUFFER_INIT_CAP); + bsatn::to_writer((&mut inner).writer().into_inner(), &msg).expect("should be able to bsatn encode v2 message"); + + // At this point, we no longer have a use for `msg`, + // so try to reclaim its buffers. + msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); + + let frame = ws_v3::ServerFrame::Single(inner.freeze()); + let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { + bsatn::to_writer(w.into_inner(), &frame).expect("should be able to bsatn encode v3 server frame"); + }); + let srv_msg_len = srv_msg.len(); + + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) } #[derive(Debug, From)] diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 92f296f3b8c..4ab8b0c28a7 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1639,7 +1639,7 @@ impl ModuleSubscriptions { message, ); } - WsVersion::V2 => { + WsVersion::V2 | WsVersion::V3 => { if let Some(request_id) = event.request_id { self.send_reducer_failure_result_v2(client, &event, request_id); } From eb50885e7ee55a32deefd9a4268bffc3e97f77a9 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 17:17:09 -0700 Subject: [PATCH 2/3] Update unreal sdk to use v3 websocket api --- sdks/unreal/DEVELOP.md | 4 +- .../Private/Connection/DbConnectionBase.cpp | 324 +++++++++++++----- .../Private/Connection/Websocket.cpp | 292 +++++++++++----- .../Private/Tests/SpacetimeDBBSATNTestOrg.cpp | 19 + .../Public/Connection/DbConnectionBase.h | 26 +- .../Public/Connection/Websocket.h | 56 +-- .../Public/Connection/WebsocketV3Frames.h | 290 ++++++++++++++++ 7 files changed, 809 insertions(+), 202 deletions(-) create mode 100644 sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/WebsocketV3Frames.h diff --git a/sdks/unreal/DEVELOP.md b/sdks/unreal/DEVELOP.md index 8fd9925fafe..81f264b9fbc 100644 --- a/sdks/unreal/DEVELOP.md +++ b/sdks/unreal/DEVELOP.md @@ -9,6 +9,9 @@ come from SpacetimeDB codegen (`--lang unrealcpp`) and websocket schema definiti This is not automated; regenerate manually whenever websocket message schemas or Unreal codegen behavior changes. +The Unreal SDK still uses generated WS v2 logical message bindings (`ClientMessageType.g.h`, `ServerMessageType.g.h`). +The WS v3 transport envelope is implemented manually in `Connection/WebsocketV3Frames.h`, because it is only a thin wrapper around already-encoded v2 messages rather than a new logical message schema. + ## WS v2 websocket schema regeneration workflow Run from repo root: @@ -79,4 +82,3 @@ Here's an example of how to include `AdditionalPluginDirectories` in your `.upro ``` - diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp index 34d878e42cb..8d93366cdb1 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp @@ -9,6 +9,7 @@ #include "Async/Async.h" #include "BSATN/UEBSATNHelpers.h" #include "Connection/ProcedureFlags.h" +#include "Connection/WebsocketV3Frames.h" namespace { @@ -63,6 +64,15 @@ static FString DecodeReducerErrorMessage(const TArray& ErrorBytes) } return UE::SpacetimeDB::Deserialize(ErrorBytes); } + +static void AppendMovedServerMessages(TArray& Target, TArray& Source) +{ + for (FServerMessageType& Message : Source) + { + Target.Add(MoveTemp(Message)); + } + Source.Reset(); +} } UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) @@ -73,13 +83,14 @@ UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); } -void UDbConnectionBase::Disconnect() -{ - if (WebSocket) - { - WebSocket->Disconnect(); - } -} +void UDbConnectionBase::Disconnect() +{ + ClearOutboundQueue(); + if (WebSocket) + { + WebSocket->Disconnect(); + } +} bool UDbConnectionBase::IsActive() const { @@ -109,10 +120,30 @@ bool UDbConnectionBase::SendRawMessage(const FString& Message) return WebSocket && WebSocket->SendMessage(Message); } -bool UDbConnectionBase::SendRawMessage(const TArray& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} +bool UDbConnectionBase::SendRawMessage(const TArray& Message) +{ + if (!WebSocket) + { + return false; + } + + // Binary messages reaching this layer are already BSATN-encoded v2 logical + // websocket messages. v3 batching only wraps those bytes in a transport + // envelope; it does not re-materialize higher-level client message objects. + if (WebSocket->GetActiveProtocol() != ESpacetimeDBWsProtocol::V3) + { + return WebSocket->SendMessage(Message); + } + + if (!WebSocket->IsConnected()) + { + return WebSocket->SendMessage(Message); + } + + TArray QueuedMessage = Message; + QueueOutboundMessageV3(MoveTemp(QueuedMessage)); + return true; +} USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() { @@ -162,25 +193,29 @@ void UDbConnectionBase::HandleProtocolViolation(const FString& ErrorMessage) } } -void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) -{ - //tag for arrival order - const int32 Id = NextPreprocessId.GetValue(); - NextPreprocessId.Increment(); - - //do expensive work off-thread - TWeakObjectPtr WeakThis(this); - Async(EAsyncExecution::Thread, [WeakThis, Message, Id]() - { - if (!WeakThis.IsValid()) - { - return; +void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) +{ + //tag for arrival order + const int32 Id = NextPreprocessId.GetValue(); + NextPreprocessId.Increment(); + // Capture the transport protocol before handing work to the background + // preprocess thread so reconnect/disconnect state changes cannot alter how + // this raw websocket frame is decoded. + const ESpacetimeDBWsProtocol Protocol = WebSocket ? WebSocket->GetActiveProtocol() : ESpacetimeDBWsProtocol::V2; + + //do expensive work off-thread + TWeakObjectPtr WeakThis(this); + Async(EAsyncExecution::Thread, [WeakThis, Message, Id, Protocol]() + { + if (!WeakThis.IsValid()) + { + return; } UDbConnectionBase* This = WeakThis.Get(); //parse the message, decompress if needed - FServerMessageType Parsed; - if (!This->PreProcessMessage(Message, Parsed)) + TArray ParsedMessages; + if (!This->PreProcessMessage(Protocol, Message, ParsedMessages)) { AsyncTask(ENamedThreads::GameThread, [WeakThis]() { @@ -195,26 +230,28 @@ void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) } //queue: re-order buffer - TArray Ready; - { - FScopeLock Lock(&This->PreprocessMutex); - // Move the parsed message into the map to avoid copying - This->PreprocessedMessages.Add(Id, MoveTemp(Parsed)); - //check if we can release any messages in order - while (This->PreprocessedMessages.Contains(This->NextReleaseId)) - { - Ready.Add(This->PreprocessedMessages.FindAndRemoveChecked(This->NextReleaseId)); - ++This->NextReleaseId; - } - } + TArray Ready; + { + FScopeLock Lock(&This->PreprocessMutex); + // Move the parsed frame into the map to avoid copying and release + // websocket frames in arrival order. + This->PreprocessedMessages.Add(Id, MoveTemp(ParsedMessages)); + //check if we can release any messages in order + while (This->PreprocessedMessages.Contains(This->NextReleaseId)) + { + TArray Released = This->PreprocessedMessages.FindAndRemoveChecked(This->NextReleaseId); + AppendMovedServerMessages(Ready, Released); + ++This->NextReleaseId; + } + } //if we have any ready messages, append them to the pending messages list that is processed in Tick - if (Ready.Num() > 0) - { - FScopeLock Lock(&This->PendingMessagesMutex); - This->PendingMessages.Append(MoveTemp(Ready)); - } - }); -} + if (Ready.Num() > 0) + { + FScopeLock Lock(&This->PendingMessagesMutex); + AppendMovedServerMessages(This->PendingMessages, Ready); + } + }); +} void UDbConnectionBase::FrameTick() { @@ -257,12 +294,95 @@ bool UDbConnectionBase::IsTickable() const return bIsAutoTicking; } -bool UDbConnectionBase::IsTickableInEditor() const -{ - return bIsAutoTicking; -} - - +bool UDbConnectionBase::IsTickableInEditor() const +{ + return bIsAutoTicking; +} + +void UDbConnectionBase::QueueOutboundMessageV3(TArray Message) +{ + { + FScopeLock Lock(&PendingOutboundMessagesMutex); + PendingOutboundMessages.Add(MoveTemp(Message)); + } + ScheduleOutboundFlush(); +} + +void UDbConnectionBase::FlushOutboundQueueV3() +{ + if (!WebSocket || !WebSocket->IsConnected() || WebSocket->GetActiveProtocol() != ESpacetimeDBWsProtocol::V3) + { + FScopeLock Lock(&PendingOutboundMessagesMutex); + bIsOutboundFlushScheduled = false; + return; + } + + TArray> PendingFrameMessages; + bool bHasRemainingMessages = false; + { + FScopeLock Lock(&PendingOutboundMessagesMutex); + bIsOutboundFlushScheduled = false; + if (PendingOutboundMessages.Num() == 0) + { + return; + } + + // Emit at most one bounded v3 transport frame per flush. If more encoded + // v2 messages remain, they are sent by a later scheduled task so inbound + // websocket work and other game-thread tasks can run between writes. + const int32 BatchSize = UE::SpacetimeDB::V3::CountClientMessagesForFrame( + PendingOutboundMessages, + UE::SpacetimeDB::V3::MaxOutboundFrameBytes + ); + PendingFrameMessages.Reserve(BatchSize); + for (int32 Index = 0; Index < BatchSize; ++Index) + { + PendingFrameMessages.Add(MoveTemp(PendingOutboundMessages[Index])); + } + PendingOutboundMessages.RemoveAt(0, BatchSize, EAllowShrinking::No); + bHasRemainingMessages = PendingOutboundMessages.Num() > 0; + } + + WebSocket->SendMessage(UE::SpacetimeDB::V3::EncodeClientMessages(PendingFrameMessages)); + if (bHasRemainingMessages) + { + ScheduleOutboundFlush(); + } +} + +void UDbConnectionBase::ScheduleOutboundFlush() +{ + { + FScopeLock Lock(&PendingOutboundMessagesMutex); + if (bIsOutboundFlushScheduled) + { + return; + } + bIsOutboundFlushScheduled = true; + } + + const TWeakObjectPtr WeakThis(this); + // Run the follow-up flush on a later game-thread task instead of draining + // multiple oversized batches back-to-back in one turn. That matches the + // yield-and-flush-later behavior used in the other SDKs. + AsyncTask(ENamedThreads::GameThread, [WeakThis]() + { + if (!WeakThis.IsValid()) + { + return; + } + WeakThis->FlushOutboundQueueV3(); + }); +} + +void UDbConnectionBase::ClearOutboundQueue() +{ + FScopeLock Lock(&PendingOutboundMessagesMutex); + PendingOutboundMessages.Reset(); + bIsOutboundFlushScheduled = false; +} + + void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) { switch (Message.Tag) @@ -525,6 +645,7 @@ void UDbConnectionBase::ClearPendingOperations(const FString& Reason) { UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("Cleared pending operations due to connection issue: %s"), *Reason); } + ClearOutboundQueue(); } void UDbConnectionBase::PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update) @@ -561,10 +682,52 @@ void UDbConnectionBase::PreProcessDatabaseUpdate(const FDatabaseUpdateType& Upda { UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableUpdate.TableName); } - } -} - -bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FServerMessageType& OutMessage) + } +} + +void UDbConnectionBase::PreProcessDecodedServerMessage(const FServerMessageType& Message) +{ + switch (Message.Tag) + { + case EServerMessageTag::SubscribeApplied: + { + const FSubscribeAppliedType Payload = Message.GetAsSubscribeApplied(); + PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows, false)); + break; + } + case EServerMessageTag::UnsubscribeApplied: + { + const FUnsubscribeAppliedType Payload = Message.GetAsUnsubscribeApplied(); + if (Payload.Rows.IsSet()) + { + PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows.Value, true)); + } + break; + } + case EServerMessageTag::TransactionUpdate: + { + const FTransactionUpdateType Payload = Message.GetAsTransactionUpdate(); + PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload)); + break; + } + case EServerMessageTag::ReducerResult: + { + const FReducerResultType Payload = Message.GetAsReducerResult(); + if (Payload.Result.IsOk()) + { + PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload.Result.GetAsOk().TransactionUpdate)); + } + break; + } + default: + break; + } +} + +bool UDbConnectionBase::PreProcessMessage( + ESpacetimeDBWsProtocol Protocol, + const TArray& Message, + TArray& OutMessages) { if (Message.Num() == 0) { @@ -584,45 +747,30 @@ bool UDbConnectionBase::PreProcessMessage(const TArray& Message, FServerM return false; } - // Deserialize the decompressed data into a UServerMessageType object - OutMessage = UE::SpacetimeDB::Deserialize(Decompressed); - - // Preprocess row-bearing payloads for table deserializers. - switch (OutMessage.Tag) + OutMessages.Reset(); + if (Protocol == ESpacetimeDBWsProtocol::V3) { - case EServerMessageTag::SubscribeApplied: + TArray> EncodedMessages; + UE::SpacetimeDB::V3::DecodeServerMessages(Decompressed, EncodedMessages); + if (EncodedMessages.Num() == 0) { - const FSubscribeAppliedType Payload = OutMessage.GetAsSubscribeApplied(); - PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows, false)); - break; + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Received empty v3 websocket frame")); + return false; } - case EServerMessageTag::UnsubscribeApplied: - { - const FUnsubscribeAppliedType Payload = OutMessage.GetAsUnsubscribeApplied(); - if (Payload.Rows.IsSet()) - { - PreProcessDatabaseUpdate(QueryRowsToDatabaseUpdate(Payload.Rows.Value, true)); - } - break; - } - case EServerMessageTag::TransactionUpdate: + + OutMessages.Reserve(EncodedMessages.Num()); + for (const TArray& EncodedMessage : EncodedMessages) { - const FTransactionUpdateType Payload = OutMessage.GetAsTransactionUpdate(); - PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload)); - break; + FServerMessageType ParsedMessage = UE::SpacetimeDB::Deserialize(EncodedMessage); + PreProcessDecodedServerMessage(ParsedMessage); + OutMessages.Add(MoveTemp(ParsedMessage)); } - case EServerMessageTag::ReducerResult: - { - const FReducerResultType Payload = OutMessage.GetAsReducerResult(); - if (Payload.Result.IsOk()) - { - PreProcessDatabaseUpdate(TransactionUpdateToDatabaseUpdate(Payload.Result.GetAsOk().TransactionUpdate)); - } - break; - } - default: - break; + return true; } + + FServerMessageType ParsedMessage = UE::SpacetimeDB::Deserialize(Decompressed); + PreProcessDecodedServerMessage(ParsedMessage); + OutMessages.Add(MoveTemp(ParsedMessage)); return true; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp index 7b4bbe53f40..fc5a7befc9d 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp @@ -5,8 +5,19 @@ #include "ModuleBindings/Types/ServerMessageType.g.h" #include "Dom/JsonObject.h" -#include "Serialization/JsonWriter.h" -#include "Serialization/JsonSerializer.h" +#include "Serialization/JsonWriter.h" +#include "Serialization/JsonSerializer.h" + +namespace +{ +const FString V2Protocol = TEXT("v2.bsatn.spacetimedb"); +const FString V3Protocol = TEXT("v3.bsatn.spacetimedb"); + +const FString& GetProtocolName(ESpacetimeDBWsProtocol Protocol) +{ + return Protocol == ESpacetimeDBWsProtocol::V3 ? V3Protocol : V2Protocol; +} +} UWebsocketManager::UWebsocketManager() { @@ -24,73 +35,53 @@ void UWebsocketManager::BeginDestroy() Super::BeginDestroy(); } -void UWebsocketManager::Connect(const FString& ServerUrl) -{ - if (IsConnected()) - { +void UWebsocketManager::Connect(const FString& ServerUrl) +{ + if (IsConnected()) + { UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("UWebsocketManager::Connect: Already connected. Disconnect first.")); return; } - if (ServerUrl.IsEmpty()) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::Connect called with empty URL")); - OnConnectionError.Broadcast(TEXT("Invalid server URL")); - return; - } - - // append InitToken to the connection headers if provided - TMap UpgradeHeaders; - if (!InitToken.IsEmpty()) - { - FString HeaderToken = FString::Printf(TEXT("Bearer %s"), - *InitToken); - UpgradeHeaders.Add("Authorization", HeaderToken); - } - - // Use websocket protocol v2 - const FString Protocol = "v2.bsatn.spacetimedb"; - - // Create the WebSocket connection - WebSocket = FWebSocketsModule::Get().CreateWebSocket(ServerUrl, Protocol, UpgradeHeaders); - - if (!WebSocket.IsValid()) - { - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::Connect: Failed to create WebSocket connection to %s."), *ServerUrl); - OnConnectionError.Broadcast(TEXT("Failed to create WebSocket.")); - return; - } - - // Bind event handlers - WebSocket->OnConnected().AddUObject(this, &UWebsocketManager::HandleConnected); - WebSocket->OnConnectionError().AddUObject(this, &UWebsocketManager::HandleConnectionError); - WebSocket->OnMessage().AddUObject(this, &UWebsocketManager::HandleMessageReceived); - WebSocket->OnBinaryMessage().AddUObject(this, &UWebsocketManager::HandleBinaryMessageReceived); - WebSocket->OnClosed().AddUObject(this, &UWebsocketManager::HandleClosed); - - UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::Connect: Connecting to %s..."), *ServerUrl); - // Start the connection process - WebSocket->Connect(); -} - -void UWebsocketManager::Disconnect() -{ - if (!WebSocket.IsValid()) - { - return; - } - - if (IsConnected()) - { - UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::Disconnect: Closing WebSocket connection.")); - WebSocket->Close(); - } - - // Reset the WebSocket to allow for reconnection attempts - WebSocket.Reset(); -} + if (ServerUrl.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::Connect called with empty URL")); + OnConnectionError.Broadcast(TEXT("Invalid server URL")); + return; + } + + PendingServerUrl = ServerUrl; + bHasEstablishedConnection = false; + bHasAttemptedV2Fallback = false; + // Unreal's websocket API accepts one subprotocol string per connection, so + // we prefer v3 first and retry with v2 only if the initial handshake fails. + ConnectWithProtocol(ServerUrl, ESpacetimeDBWsProtocol::V3); +} -bool UWebsocketManager::SendMessage(const FString& Message) +void UWebsocketManager::Disconnect() +{ + if (!WebSocket.IsValid()) + { + PendingServerUrl.Empty(); + bHasEstablishedConnection = false; + bHasAttemptedV2Fallback = false; + return; + } + + if (IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::Disconnect: Closing WebSocket connection.")); + WebSocket->Close(); + } + + PendingServerUrl.Empty(); + bHasEstablishedConnection = false; + bHasAttemptedV2Fallback = false; + // Reset the WebSocket to allow for reconnection attempts + ResetSocket(); +} + +bool UWebsocketManager::SendMessage(const FString& Message) { if (!IsConnected()) { @@ -138,19 +129,29 @@ void UWebsocketManager::SetInitToken(FString Token) InitToken = Token; } -void UWebsocketManager::HandleConnected() -{ - UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Connected.")); - OnConnected.Broadcast(); -} - -void UWebsocketManager::HandleConnectionError(const FString& Error) -{ - UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager: WebSocket Connection Error: %s"), *Error); - OnConnectionError.Broadcast(Error); - // Reset on error to allow reconnection attempts - WebSocket.Reset(); -} +void UWebsocketManager::HandleConnected() +{ + bHasEstablishedConnection = true; + UE_LOG( + LogSpacetimeDb_Connection, + Log, + TEXT("UWebsocketManager: WebSocket Connected using %s."), + *GetProtocolName(ActiveProtocol) + ); + OnConnected.Broadcast(); +} + +void UWebsocketManager::HandleConnectionError(const FString& Error) +{ + if (TryFallbackToV2(Error)) + { + return; + } + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager: WebSocket Connection Error: %s"), *Error); + OnConnectionError.Broadcast(Error); + // Reset on error to allow reconnection attempts + ResetSocket(); +} void UWebsocketManager::HandleMessageReceived(const FString& Message) { @@ -188,12 +189,133 @@ void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Siz } } -void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) -{ - UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Closed. Status: %d, Reason: %s, Clean: %s"), - StatusCode, *Reason, bWasClean ? TEXT("true") : TEXT("false")); - // Notify listeners about the closure - OnClosed.Broadcast(StatusCode, Reason, bWasClean); - // Reset on close to allow reconnection attempts - WebSocket.Reset(); +void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) +{ + if (TryFallbackToV2(Reason)) + { + return; + } + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Closed. Status: %d, Reason: %s, Clean: %s"), + StatusCode, *Reason, bWasClean ? TEXT("true") : TEXT("false")); + // Notify listeners about the closure + OnClosed.Broadcast(StatusCode, Reason, bWasClean); + // Reset on close to allow reconnection attempts + ResetSocket(); +} + +void UWebsocketManager::ConnectWithProtocol(const FString& ServerUrl, ESpacetimeDBWsProtocol Protocol) +{ + ActiveProtocol = Protocol; + ++ConnectAttemptId; + const uint32 AttemptId = ConnectAttemptId; + + TMap UpgradeHeaders; + if (!InitToken.IsEmpty()) + { + const FString HeaderToken = FString::Printf(TEXT("Bearer %s"), *InitToken); + UpgradeHeaders.Add(TEXT("Authorization"), HeaderToken); + } + + WebSocket = FWebSocketsModule::Get().CreateWebSocket(ServerUrl, GetProtocolName(Protocol), UpgradeHeaders); + if (!WebSocket.IsValid()) + { + UE_LOG( + LogSpacetimeDb_Connection, + Error, + TEXT("UWebsocketManager::Connect: Failed to create WebSocket connection to %s."), + *ServerUrl + ); + if (TryFallbackToV2(TEXT("failed to create websocket"))) + { + return; + } + OnConnectionError.Broadcast(TEXT("Failed to create WebSocket.")); + return; + } + + const TWeakObjectPtr WeakThis(this); + WebSocket->OnConnected().AddLambda([WeakThis, AttemptId]() + { + UWebsocketManager* This = WeakThis.Get(); + if (!This || This->ConnectAttemptId != AttemptId) + { + return; + } + This->HandleConnected(); + }); + WebSocket->OnConnectionError().AddLambda([WeakThis, AttemptId](const FString& Error) + { + UWebsocketManager* This = WeakThis.Get(); + if (!This || This->ConnectAttemptId != AttemptId) + { + return; + } + This->HandleConnectionError(Error); + }); + WebSocket->OnMessage().AddLambda([WeakThis, AttemptId](const FString& Message) + { + UWebsocketManager* This = WeakThis.Get(); + if (!This || This->ConnectAttemptId != AttemptId) + { + return; + } + This->HandleMessageReceived(Message); + }); + WebSocket->OnBinaryMessage().AddLambda([WeakThis, AttemptId](const void* Data, SIZE_T Size, bool bIsLastFragment) + { + UWebsocketManager* This = WeakThis.Get(); + if (!This || This->ConnectAttemptId != AttemptId) + { + return; + } + This->HandleBinaryMessageReceived(Data, Size, bIsLastFragment); + }); + WebSocket->OnClosed().AddLambda([WeakThis, AttemptId](int32 StatusCode, const FString& Reason, bool bWasClean) + { + UWebsocketManager* This = WeakThis.Get(); + if (!This || This->ConnectAttemptId != AttemptId) + { + return; + } + This->HandleClosed(StatusCode, Reason, bWasClean); + }); + + UE_LOG( + LogSpacetimeDb_Connection, + Log, + TEXT("UWebsocketManager::Connect: Connecting to %s with %s..."), + *ServerUrl, + *GetProtocolName(Protocol) + ); + WebSocket->Connect(); +} + +bool UWebsocketManager::TryFallbackToV2(const FString& FailureReason) +{ + // Only downgrade during the initial connect path. Once a websocket session + // has been established we preserve the chosen transport version across later + // disconnect/error handling instead of silently switching protocols. + if (bHasEstablishedConnection || bHasAttemptedV2Fallback || ActiveProtocol != ESpacetimeDBWsProtocol::V3 || PendingServerUrl.IsEmpty()) + { + return false; + } + + bHasAttemptedV2Fallback = true; + UE_LOG( + LogSpacetimeDb_Connection, + Warning, + TEXT("v3 websocket connection failed (%s). Retrying with %s."), + *FailureReason, + *GetProtocolName(ESpacetimeDBWsProtocol::V2) + ); + ResetSocket(); + ConnectWithProtocol(PendingServerUrl, ESpacetimeDBWsProtocol::V2); + return true; +} + +void UWebsocketManager::ResetSocket() +{ + IncompleteMessage.Reset(); + bAwaitingBinaryFragments = false; + WebSocket.Reset(); } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Tests/SpacetimeDBBSATNTestOrg.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Tests/SpacetimeDBBSATNTestOrg.cpp index 20583c7b92a..4bd884360a4 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Tests/SpacetimeDBBSATNTestOrg.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Tests/SpacetimeDBBSATNTestOrg.cpp @@ -37,6 +37,7 @@ #include "ModuleBindings/Types/UnsubscribeFlagsType.g.h" #include "ModuleBindings/Types/UnsubscribeType.g.h" #include "ModuleBindings/Optionals/SpacetimeDbSdkOptionalQueryRows.g.h" +#include "Connection/WebsocketV3Frames.h" // ────────────────────────────────────────────────────────────────────────────── @@ -247,6 +248,16 @@ IMPLEMENT_SIMPLE_AUTOMATION_TEST( FClientMessageType ClientMessageUnsubscribe = FClientMessageType::Unsubscribe(Unsubscribe); TEST_ROUNDTRIP(FClientMessageType, ClientMessageUnsubscribe, "FClientMessageType::Unsubscribe Variant"); + LOG_Category("Client API WS v3"); + UE::SpacetimeDB::V3::FClientFrame ClientFrameSingle = + UE::SpacetimeDB::V3::FClientFrame::Single(UE::SpacetimeDB::Serialize(ClientMessageCallReducer)); + TEST_ROUNDTRIP(UE::SpacetimeDB::V3::FClientFrame, ClientFrameSingle, "FClientFrame::Single"); + TArray> BatchedClientMessages; + BatchedClientMessages.Add(UE::SpacetimeDB::Serialize(ClientMessageCallReducer)); + BatchedClientMessages.Add(UE::SpacetimeDB::Serialize(ClientMessageCallProcedure)); + UE::SpacetimeDB::V3::FClientFrame ClientFrameBatch = UE::SpacetimeDB::V3::FClientFrame::Batch(BatchedClientMessages); + TEST_ROUNDTRIP(UE::SpacetimeDB::V3::FClientFrame, ClientFrameBatch, "FClientFrame::Batch"); + FPersistentTableRowsType PersistentRows; PersistentRows.Inserts = BsatnRowsFixed; PersistentRows.Deletes = BsatnRowsOffsets; @@ -357,6 +368,14 @@ IMPLEMENT_SIMPLE_AUTOMATION_TEST( TEST_ROUNDTRIP(FServerMessageType, MessageReducerResult, "FServerMessageType::ReducerResult Variant"); FServerMessageType MessageProcedureResult = FServerMessageType::ProcedureResult(ProcedureResult); TEST_ROUNDTRIP(FServerMessageType, MessageProcedureResult, "FServerMessageType::ProcedureResult Variant"); + UE::SpacetimeDB::V3::FServerFrame ServerFrameSingle = + UE::SpacetimeDB::V3::FServerFrame::Single(UE::SpacetimeDB::Serialize(MessageInitialConnection)); + TEST_ROUNDTRIP(UE::SpacetimeDB::V3::FServerFrame, ServerFrameSingle, "FServerFrame::Single"); + TArray> BatchedServerMessages; + BatchedServerMessages.Add(UE::SpacetimeDB::Serialize(MessageInitialConnection)); + BatchedServerMessages.Add(UE::SpacetimeDB::Serialize(MessageTransactionUpdate)); + UE::SpacetimeDB::V3::FServerFrame ServerFrameBatch = UE::SpacetimeDB::V3::FServerFrame::Batch(BatchedServerMessages); + TEST_ROUNDTRIP(UE::SpacetimeDB::V3::FServerFrame, ServerFrameBatch, "FServerFrame::Batch"); return true; } diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h index e8c26b3a988..fdb3067541d 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h @@ -286,23 +286,28 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam /** Internal handler that processes a single server message. */ void ProcessServerMessage(const FServerMessageType& Message); + void PreProcessDecodedServerMessage(const FServerMessageType& Message); void PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update); /** Decompress and parse a raw message. */ - bool PreProcessMessage(const TArray& Message, FServerMessageType& OutMessage); + bool PreProcessMessage(ESpacetimeDBWsProtocol Protocol, const TArray& Message, TArray& OutMessages); bool DecompressPayload(uint8 Variant, const TArray& In, TArray& Out); bool DecompressGzip(const TArray& InData, TArray& OutData); bool DecompressBrotli(const TArray& InData, TArray& OutData); void ClearPendingOperations(const FString& Reason); void HandleProtocolViolation(const FString& ErrorMessage); - - /** Pending messages awaiting processing on the game thread. */ - TArray PendingMessages; + void QueueOutboundMessageV3(TArray Message); + void FlushOutboundQueueV3(); + void ScheduleOutboundFlush(); + void ClearOutboundQueue(); + + /** Pending messages awaiting processing on the game thread. */ + TArray PendingMessages; /** Mutex protecting access to PendingMessages. */ FCriticalSection PendingMessagesMutex; - /** Map of preprocessed messages keyed by their sequential id. */ - TMap PreprocessedMessages; + /** Map of preprocessed websocket frames keyed by their sequential id. */ + TMap> PreprocessedMessages; /** Protects PreprocessedMessages and PendingMessages ordering state. */ FCriticalSection PreprocessMutex; @@ -310,8 +315,13 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam /** Counter for assigning ids to incoming messages. */ FThreadSafeCounter NextPreprocessId; - /** Id of the next message expected to be released. */ - int32 NextReleaseId = 0; + /** Id of the next message expected to be released. */ + int32 NextReleaseId = 0; + + /** Already-serialized v2 client messages waiting to be wrapped in v3 frames. */ + TArray> PendingOutboundMessages; + FCriticalSection PendingOutboundMessagesMutex; + FThreadSafeBool bIsOutboundFlushScheduled = false; // Map of table name to row deserializer TMap> TableDeserializers; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h index b9e6f91378d..e0c55cff4e9 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h @@ -10,10 +10,16 @@ #include "LogCategory.h" -#include "Websocket.generated.h" - -/** Delegate broadcast when a connection is successfully established */ -DECLARE_DYNAMIC_MULTICAST_DELEGATE(FOnWebSocketConnected); +#include "Websocket.generated.h" + +enum class ESpacetimeDBWsProtocol : uint8 +{ + V2, + V3, +}; + +/** Delegate broadcast when a connection is successfully established */ +DECLARE_DYNAMIC_MULTICAST_DELEGATE(FOnWebSocketConnected); /** Delegate broadcast on connection error */ DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketConnectionError, const FString&, ErrorMessage); /** Delegate broadcast when a text message is received */ @@ -67,7 +73,10 @@ class SPACETIMEDBSDK_API UWebsocketManager : public UObject * Checks if the WebSocket connection is currently active. * @return True if connected, false otherwise. */ - bool IsConnected() const; + bool IsConnected() const; + + /** Returns the websocket protocol currently in use for this connection. */ + ESpacetimeDBWsProtocol GetActiveProtocol() const { return ActiveProtocol; } /** * Sets the initial auth token used when connecting. @@ -99,27 +108,34 @@ class SPACETIMEDBSDK_API UWebsocketManager : public UObject /**Underlying WebSocket implementation */ TSharedPtr WebSocket; - /** Handler for successful connection */ - void HandleConnected(); + /** Handler for successful connection */ + void HandleConnected(); /** Handler for connection errors */ void HandleConnectionError(const FString& Error); /** Handler for incoming text messages */ void HandleMessageReceived(const FString& Message); /** Handler for incoming binary messages */ void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); - /** Handler for socket close */ - void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - - FString InitToken; - - /** Buffer used to accumulate binary fragments until a complete message - * is received. */ - TArray IncompleteMessage; - - /** Tracks if we are waiting for additional binary fragments. */ - bool bAwaitingBinaryFragments = false; - -}; + /** Handler for socket close */ + void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + void ConnectWithProtocol(const FString& ServerUrl, ESpacetimeDBWsProtocol Protocol); + bool TryFallbackToV2(const FString& FailureReason); + void ResetSocket(); + + FString InitToken; + FString PendingServerUrl; + + /** Buffer used to accumulate binary fragments until a complete message + * is received. */ + TArray IncompleteMessage; + + /** Tracks if we are waiting for additional binary fragments. */ + bool bAwaitingBinaryFragments = false; + bool bHasEstablishedConnection = false; + bool bHasAttemptedV2Fallback = false; + uint32 ConnectAttemptId = 0; + ESpacetimeDBWsProtocol ActiveProtocol = ESpacetimeDBWsProtocol::V3; +}; // Helper function to log a struct as JSON, expanding any transient objects template diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/WebsocketV3Frames.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/WebsocketV3Frames.h new file mode 100644 index 00000000000..abb1cff2ed3 --- /dev/null +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/WebsocketV3Frames.h @@ -0,0 +1,290 @@ +#pragma once + +#include "CoreMinimal.h" +#include "BSATN/UESpacetimeDB.h" + +namespace UE::SpacetimeDB::V3 +{ + +// v3 is only a transport envelope. The inner payloads are already-encoded v2 +// websocket messages, so these helpers intentionally operate on raw bytes. +constexpr int32 MaxOutboundFrameBytes = 256 * 1024; + +enum class EClientFrameTag : uint8 +{ + Single = 0, + Batch = 1, +}; + +struct FClientFrame +{ + EClientFrameTag Tag = EClientFrameTag::Single; + TVariant, TArray>> FrameData; + + static FClientFrame Single(const TArray& Value) + { + FClientFrame Frame; + Frame.Tag = EClientFrameTag::Single; + Frame.FrameData.Set>(Value); + return Frame; + } + + static FClientFrame Batch(const TArray>& Value) + { + FClientFrame Frame; + Frame.Tag = EClientFrameTag::Batch; + Frame.FrameData.Set>>(Value); + return Frame; + } + + bool IsSingle() const + { + return Tag == EClientFrameTag::Single; + } + + bool IsBatch() const + { + return Tag == EClientFrameTag::Batch; + } + + TArray GetAsSingle() const + { + check(IsSingle()); + return FrameData.Get>(); + } + + TArray> GetAsBatch() const + { + check(IsBatch()); + return FrameData.Get>>(); + } + + bool operator==(const FClientFrame& Other) const + { + if (Tag != Other.Tag) + { + return false; + } + return IsSingle() ? GetAsSingle() == Other.GetAsSingle() : GetAsBatch() == Other.GetAsBatch(); + } + + bool operator!=(const FClientFrame& Other) const + { + return !(*this == Other); + } +}; + +enum class EServerFrameTag : uint8 +{ + Single = 0, + Batch = 1, +}; + +struct FServerFrame +{ + EServerFrameTag Tag = EServerFrameTag::Single; + TVariant, TArray>> FrameData; + + static FServerFrame Single(const TArray& Value) + { + FServerFrame Frame; + Frame.Tag = EServerFrameTag::Single; + Frame.FrameData.Set>(Value); + return Frame; + } + + static FServerFrame Batch(const TArray>& Value) + { + FServerFrame Frame; + Frame.Tag = EServerFrameTag::Batch; + Frame.FrameData.Set>>(Value); + return Frame; + } + + bool IsSingle() const + { + return Tag == EServerFrameTag::Single; + } + + bool IsBatch() const + { + return Tag == EServerFrameTag::Batch; + } + + TArray GetAsSingle() const + { + check(IsSingle()); + return FrameData.Get>(); + } + + TArray> GetAsBatch() const + { + check(IsBatch()); + return FrameData.Get>>(); + } + + bool operator==(const FServerFrame& Other) const + { + if (Tag != Other.Tag) + { + return false; + } + return IsSingle() ? GetAsSingle() == Other.GetAsSingle() : GetAsBatch() == Other.GetAsBatch(); + } + + bool operator!=(const FServerFrame& Other) const + { + return !(*this == Other); + } +}; + +constexpr int32 BsatnEnumTagBytes = 1; +constexpr int32 BsatnLengthPrefixBytes = 4; + +inline int32 EncodedSingleFrameSize(const TArray& Message) +{ + return BsatnEnumTagBytes + BsatnLengthPrefixBytes + Message.Num(); +} + +inline int32 EncodedBatchFrameSizeForFirstMessage(const TArray& Message) +{ + return BsatnEnumTagBytes + BsatnLengthPrefixBytes + BsatnLengthPrefixBytes + Message.Num(); +} + +inline int32 EncodedBatchElementSize(const TArray& Message) +{ + return BsatnLengthPrefixBytes + Message.Num(); +} + +// Compute the largest prefix of already-encoded v2 client messages that fits in +// one v3 transport frame without trial-serializing candidate batches. The +// queue already stores encoded payload bytes, so a length-based fit check is +// enough here. +inline int32 CountClientMessagesForFrame(const TArray>& Messages, int32 MaxFrameBytes) +{ + check(Messages.Num() > 0); + + const TArray& FirstMessage = Messages[0]; + if (EncodedSingleFrameSize(FirstMessage) > MaxFrameBytes) + { + return 1; + } + + int32 Count = 1; + int32 BatchSize = EncodedBatchFrameSizeForFirstMessage(FirstMessage); + while (Count < Messages.Num()) + { + const TArray& NextMessage = Messages[Count]; + const int32 NextBatchSize = BatchSize + EncodedBatchElementSize(NextMessage); + if (NextBatchSize > MaxFrameBytes) + { + break; + } + BatchSize = NextBatchSize; + ++Count; + } + + return Count; +} + +inline TArray EncodeClientMessages(const TArray>& Messages) +{ + check(Messages.Num() > 0); + return UE::SpacetimeDB::Serialize( + Messages.Num() == 1 ? FClientFrame::Single(Messages[0]) : FClientFrame::Batch(Messages) + ); +} + +inline TArray EncodeServerMessages(const TArray>& Messages) +{ + check(Messages.Num() > 0); + return UE::SpacetimeDB::Serialize( + Messages.Num() == 1 ? FServerFrame::Single(Messages[0]) : FServerFrame::Batch(Messages) + ); +} + +inline void DecodeServerMessages(const TArray& Data, TArray>& OutMessages) +{ + const FServerFrame Frame = UE::SpacetimeDB::Deserialize(Data); + if (Frame.IsSingle()) + { + OutMessages.Reset(1); + OutMessages.Add(Frame.GetAsSingle()); + return; + } + + OutMessages = Frame.GetAsBatch(); +} + +} // namespace UE::SpacetimeDB::V3 + +namespace UE::SpacetimeDB +{ + +inline void serialize(UEWriter& Writer, const V3::FClientFrame& Value) +{ + Writer.write_u8(static_cast(Value.Tag)); + switch (Value.Tag) + { + case V3::EClientFrameTag::Single: + serialize(Writer, Value.FrameData.Get>()); + break; + case V3::EClientFrameTag::Batch: + serialize(Writer, Value.FrameData.Get>>()); + break; + default: + ensureMsgf(false, TEXT("Unknown v3 client-frame tag")); + break; + } +} + +template<> +inline V3::FClientFrame deserialize(UEReader& Reader) +{ + const V3::EClientFrameTag Tag = static_cast(Reader.read_u8()); + switch (Tag) + { + case V3::EClientFrameTag::Single: + return V3::FClientFrame::Single(Reader.read_array_u8()); + case V3::EClientFrameTag::Batch: + return V3::FClientFrame::Batch(Reader.read_array>()); + default: + ensureMsgf(false, TEXT("Unknown v3 client-frame tag")); + return V3::FClientFrame(); + } +} + +inline void serialize(UEWriter& Writer, const V3::FServerFrame& Value) +{ + Writer.write_u8(static_cast(Value.Tag)); + switch (Value.Tag) + { + case V3::EServerFrameTag::Single: + serialize(Writer, Value.FrameData.Get>()); + break; + case V3::EServerFrameTag::Batch: + serialize(Writer, Value.FrameData.Get>>()); + break; + default: + ensureMsgf(false, TEXT("Unknown v3 server-frame tag")); + break; + } +} + +template<> +inline V3::FServerFrame deserialize(UEReader& Reader) +{ + const V3::EServerFrameTag Tag = static_cast(Reader.read_u8()); + switch (Tag) + { + case V3::EServerFrameTag::Single: + return V3::FServerFrame::Single(Reader.read_array_u8()); + case V3::EServerFrameTag::Batch: + return V3::FServerFrame::Batch(Reader.read_array>()); + default: + ensureMsgf(false, TEXT("Unknown v3 server-frame tag")); + return V3::FServerFrame(); + } +} + +} // namespace UE::SpacetimeDB From 311518ebda5bf048e3d835e3478c291f16ad3189 Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Wed, 8 Apr 2026 18:48:32 -0700 Subject: [PATCH 3/3] Add Unreal v3 websocket integration tests --- .../Public/Connection/DbConnectionBase.h | 12 ++- .../Tests/SpacetimeFullClientTests.cpp | 77 +++++++++++++++++++ .../TestClient/Private/Tests/TestHandler.cpp | 27 +++++++ .../Public/Tests/CommonTestFunctions.h | 2 +- .../Public/Tests/SpacetimeFullClientTests.h | 5 ++ .../TestClient/Public/Tests/TestHandler.h | 12 +++ .../Public/Tests/CommonTestFunctions.h | 2 +- sdks/unreal/tests/test.rs | 13 +++- 8 files changed, 144 insertions(+), 6 deletions(-) diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h index fdb3067541d..6bd0dc22f43 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h @@ -128,9 +128,15 @@ class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGam UFUNCTION(BlueprintCallable, Category="SpacetimeDB") void Disconnect(); - /** Check if the underlying WebSocket is connected. */ - UFUNCTION(BlueprintPure, Category="SpacetimeDB") - bool IsActive() const; + /** Check if the underlying WebSocket is connected. */ + UFUNCTION(BlueprintPure, Category="SpacetimeDB") + bool IsActive() const; + + /** Returns the websocket transport currently in use for this connection. */ + ESpacetimeDBWsProtocol GetActiveWebSocketProtocol() const + { + return WebSocket ? WebSocket->GetActiveProtocol() : ESpacetimeDBWsProtocol::V2; + } UFUNCTION(BlueprintCallable, Category="SpacetimeDB") void FrameTick(); diff --git a/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/SpacetimeFullClientTests.cpp b/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/SpacetimeFullClientTests.cpp index 4d201baeee3..89e560fe2cc 100644 --- a/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/SpacetimeFullClientTests.cpp +++ b/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/SpacetimeFullClientTests.cpp @@ -16,6 +16,7 @@ #include "Tests/PrimitiveHandlerList.def" #include "Connection/Credentials.h" +#include "Connection/Websocket.h" #include "ModuleBindings/Tables/ResultEveryPrimitiveStructStringTable.g.h" #include "ModuleBindings/Tables/ResultI32StringTable.g.h" #include "ModuleBindings/Tables/ResultIdentityStringTable.g.h" @@ -2428,3 +2429,79 @@ bool FInsertCallUuidV7Test::RunTest(const FString &Parameters) ADD_LATENT_AUTOMATION_COMMAND(FWaitForTestCounter(*this, TestName, Handler->Counter, FPlatformTime::Seconds())); return true; } + +bool FWebsocketV3ProtocolHappyPathTest::RunTest(const FString& Parameters) +{ + TestName = "WebsocketV3ProtocolHappyPath"; + + if (!ValidateParameterConfig(this)) + { + return false; + } + + TSharedPtr Counter = MakeShared(); + Counter->Register(TEXT("protocol_is_v3")); + + UDbConnection* Connection = ConnectThen(Counter, TestName, [Counter](UDbConnection* Conn) + { + if (Conn->GetActiveWebSocketProtocol() == ESpacetimeDBWsProtocol::V3) + { + Counter->MarkSuccess(TEXT("protocol_is_v3")); + } + else + { + Counter->MarkFailure(TEXT("protocol_is_v3"), TEXT("Expected connection to negotiate websocket v3")); + } + }); + + ADD_LATENT_AUTOMATION_COMMAND(FWaitForTestCounter(*this, TestName, Counter, FPlatformTime::Seconds())); + return true; +} + +bool FWebsocketV3InboundOrderingTest::RunTest(const FString& Parameters) +{ + TestName = "WebsocketV3InboundOrdering"; + + if (!ValidateParameterConfig(this)) + { + return false; + } + + UOrderedInsertHandler* Handler = CreateTestHandler(); + Handler->ExpectedValues = { 1, 2, 3, 4 }; + Handler->Counter->Register(TEXT("protocol_is_v3")); + Handler->Counter->Register(TEXT("OrderedU8Inserts")); + + UDbConnection* Connection = ConnectThen(Handler->Counter, TestName, [this, Handler](UDbConnection* Conn) + { + if (Conn->GetActiveWebSocketProtocol() != ESpacetimeDBWsProtocol::V3) + { + Handler->Counter->MarkFailure(TEXT("protocol_is_v3"), TEXT("Expected connection to negotiate websocket v3")); + Handler->Counter->Abort(); + return; + } + Handler->Counter->MarkSuccess(TEXT("protocol_is_v3")); + + Conn->Db->OneU8->OnInsert.AddDynamic(Handler, &UOrderedInsertHandler::OnInsertOneU8); + + SubscribeAllThen(Conn, [this, Handler](FSubscriptionEventContext Ctx) + { + if (Ctx.Db->OneU8->Count() != 0) + { + Handler->Counter->MarkFailure(TEXT("OrderedU8Inserts"), TEXT("Expected OneU8 to be empty before ordered insert test")); + Handler->Counter->Abort(); + return; + } + + // Issue a same-turn burst so the client receives ordered inbound work + // while the v3 transport is active. + Ctx.Reducers->InsertOneU8(1); + Ctx.Reducers->InsertOneU8(2); + Ctx.Reducers->InsertOneU8(3); + Ctx.Reducers->InsertOneU8(4); + }); + }); + + ADD_LATENT_AUTOMATION_COMMAND(FWaitForTestCounter(*this, TestName, Handler->Counter, FPlatformTime::Seconds())); + return true; +} diff --git a/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/TestHandler.cpp b/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/TestHandler.cpp index a44863b5003..2f20e1fa808 100644 --- a/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/TestHandler.cpp +++ b/sdks/unreal/tests/TestClient/Source/TestClient/Private/Tests/TestHandler.cpp @@ -16,6 +16,33 @@ void UInsertPrimitiveHandler::OnInsertOne##Suffix(const FEventContext& Context, FOREACH_PRIMITIVE(DEFINE_UFUNC) #undef DEFINE_UFUNC +void UOrderedInsertHandler::OnInsertOneU8(const FEventContext& Context, const FOneU8Type& Value) +{ + static const FString Name(TEXT("OrderedU8Inserts")); + + const int32 NextIndex = ReceivedValues.Num(); + if (NextIndex >= ExpectedValues.Num()) + { + Counter->MarkFailure(Name, TEXT("Received more inserts than expected")); + return; + } + + ReceivedValues.Add(Value.N); + if (Value.N != ExpectedValues[NextIndex]) + { + Counter->MarkFailure( + Name, + FString::Printf(TEXT("Out-of-order insert at index %d: expected %d, got %d"), NextIndex, ExpectedValues[NextIndex], Value.N) + ); + return; + } + + if (ReceivedValues.Num() == ExpectedValues.Num()) + { + Counter->MarkSuccess(Name); + } +} + /* DeletePrimitive handler functions ------------------------------------ */ #define DEFINE_DELETE_UNIQUE(Suffix, Field, Literal, Expected, RowStructType) \ void UDeletePrimitiveHandler::OnInsertUnique##Suffix(const FEventContext& Context, const RowStructType& Value) \ diff --git a/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/CommonTestFunctions.h b/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/CommonTestFunctions.h index 29bc21270b2..bde08e6ca4a 100644 --- a/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/CommonTestFunctions.h +++ b/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/CommonTestFunctions.h @@ -154,4 +154,4 @@ T* CreateTestHandler() Handler->AddToRoot(); Handler->Counter = MakeShared(); return Handler; -} \ No newline at end of file +} diff --git a/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/SpacetimeFullClientTests.h b/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/SpacetimeFullClientTests.h index 02b26831c9f..8c32f01f0a9 100644 --- a/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/SpacetimeFullClientTests.h +++ b/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/SpacetimeFullClientTests.h @@ -132,3 +132,8 @@ IMPLEMENT_SIMPLE_AUTOMATION_TEST(FOverlappingSubscriptionsTest, "SpacetimeDB.Tes IMPLEMENT_SIMPLE_AUTOMATION_TEST(FInsertCallUuidV4Test, "SpacetimeDB.TestClient.InsertCallUuidV4Test", EAutomationTestFlags::EditorContext | EAutomationTestFlags::EngineFilter) IMPLEMENT_SIMPLE_AUTOMATION_TEST(FInsertCallUuidV7Test, "SpacetimeDB.TestClient.InsertCallUuidV7Test", EAutomationTestFlags::EditorContext | EAutomationTestFlags::EngineFilter) + +/** Tests that the Unreal client prefers the v3 websocket transport when available. */ +IMPLEMENT_SIMPLE_AUTOMATION_TEST(FWebsocketV3ProtocolHappyPathTest, "SpacetimeDB.TestClient.WebsocketV3ProtocolHappyPathTest", EAutomationTestFlags::EditorContext | EAutomationTestFlags::EngineFilter) +/** Tests that inbound messages remain ordered when the client is running over the v3 websocket transport. */ +IMPLEMENT_SIMPLE_AUTOMATION_TEST(FWebsocketV3InboundOrderingTest, "SpacetimeDB.TestClient.WebsocketV3InboundOrderingTest", EAutomationTestFlags::EditorContext | EAutomationTestFlags::EngineFilter) diff --git a/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/TestHandler.h b/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/TestHandler.h index 45b80ed7dac..7975d5992da 100644 --- a/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/TestHandler.h +++ b/sdks/unreal/tests/TestClient/Source/TestClient/Public/Tests/TestHandler.h @@ -69,6 +69,18 @@ class UInsertPrimitiveHandler : public UTestHandler TArray ExpectedStrings; }; +UCLASS() +class UOrderedInsertHandler : public UTestHandler +{ + GENERATED_BODY() +public: + TArray ExpectedValues; + TArray ReceivedValues; + + UFUNCTION() + void OnInsertOneU8(const FEventContext& Context, const FOneU8Type& Value); +}; + /** Handler used for delete-primitive tests. */ UCLASS() class UDeletePrimitiveHandler : public UTestHandler diff --git a/sdks/unreal/tests/TestProcClient/Source/TestProcClient/Public/Tests/CommonTestFunctions.h b/sdks/unreal/tests/TestProcClient/Source/TestProcClient/Public/Tests/CommonTestFunctions.h index c4ebc251608..ad67f1ddbe6 100644 --- a/sdks/unreal/tests/TestProcClient/Source/TestProcClient/Public/Tests/CommonTestFunctions.h +++ b/sdks/unreal/tests/TestProcClient/Source/TestProcClient/Public/Tests/CommonTestFunctions.h @@ -154,4 +154,4 @@ T* CreateTestHandler() Handler->AddToRoot(); Handler->Counter = MakeShared(); return Handler; -} \ No newline at end of file +} diff --git a/sdks/unreal/tests/test.rs b/sdks/unreal/tests/test.rs index 2700f10ec14..34d2563cc74 100644 --- a/sdks/unreal/tests/test.rs +++ b/sdks/unreal/tests/test.rs @@ -2,7 +2,6 @@ mod sdk_unreal_harness; use sdk_unreal_harness::{make_test_with_suite, TestSuite}; use serial_test::serial; -use std::env; const SDK_TEST_SUITE: TestSuite = TestSuite { module: "sdk-test", @@ -309,3 +308,15 @@ fn unreal_overlapping_subscriptions() { fn unreal_insert_result_okay() { make_test("InsertResultOkTest").run(); } + +#[test] +#[serial(Group6)] +fn unreal_websocket_v3_protocol_happy_path() { + make_test("WebsocketV3ProtocolHappyPathTest").run(); +} + +#[test] +#[serial(Group6)] +fn unreal_websocket_v3_inbound_ordering() { + make_test("WebsocketV3InboundOrderingTest").run(); +}