diff --git a/crates/cli/src/subcommands/subscribe.rs b/crates/cli/src/subcommands/subscribe.rs index 5b22b4cd8a5..f8cbc61e0cf 100644 --- a/crates/cli/src/subcommands/subscribe.rs +++ b/crates/cli/src/subcommands/subscribe.rs @@ -1,20 +1,28 @@ use anyhow::Context; +use bytes::Bytes; use clap::{value_parser, Arg, ArgAction, ArgMatches}; use futures::{Sink, SinkExt, TryStream, TryStreamExt}; use http::header; use reqwest::Url; use serde_json::Value; -use spacetimedb_client_api_messages::websocket::v1 as ws_v1; +use spacetimedb_client_api_messages::websocket::{common as ws_common, v1 as ws_v1, v2 as ws_v2, v3 as ws_v3}; use spacetimedb_data_structures::map::HashMap; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::de::serde::{DeserializeWrapper, SeedWrapper}; +use spacetimedb_lib::de::DeserializeSeed as BsatnDeserializeSeed; +use spacetimedb_lib::sats::WithTypespace; use spacetimedb_lib::ser::serde::SerializeWrapper; +use spacetimedb_lib::{bsatn, AlgebraicType}; +use std::collections::VecDeque; use std::io; use std::time::Duration; use thiserror::Error; use tokio::io::AsyncWriteExt; +use tokio::net::TcpStream; use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::handshake::client::Request as WsRequest; use tokio_tungstenite::tungstenite::{Error as WsError, Message as WsMessage}; +use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use crate::api::ClientApi; use crate::common_args; @@ -72,56 +80,68 @@ pub fn cli() -> clap::Command { .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) } -fn parse_msg_json(msg: &WsMessage) -> Option> { - let WsMessage::Text(msg) = msg else { return None }; - serde_json::from_str::>>(msg) - .inspect_err(|e| eprintln!("couldn't parse message from server: {e}")) - .map(|wrapper| wrapper.0) - .ok() +#[derive(serde::Serialize, Debug)] +struct SubscriptionTable { + deletes: Vec, + inserts: Vec, } -fn reformat_update<'a>( - msg: &'a ws_v1::DatabaseUpdate, - schema: &RawModuleDefV9, -) -> anyhow::Result> { - msg.tables - .iter() - .map(|upd| { - let table_ty = schema.typespace.resolve( - schema - .type_ref_for_table_like(&upd.table_name) - .context("table not found in schema")?, - ); +/// Concrete websocket stream type returned by `tokio_tungstenite::connect_async`. +type SubscribeWebSocket = WebSocketStream>; - let reformat_row = |row: &str| -> anyhow::Result { - // TODO: can the following two calls be merged into a single call to reduce allocations? - let row = serde_json::from_str::(row)?; - let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?; - let row = table_ty.with_value(&row); - let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?; - Ok(row) - }; +/// Active websocket connection for `spacetime subscribe`. +/// +/// The command prefers the v3 transport so smoketests and normal CLI usage +/// exercise the coalesced server path, but it keeps the old v1 text transport +/// as a fallback for older servers. +enum SubscribeConnection { + /// v3 uses BSATN-encoded v2 messages, possibly coalesced in one websocket payload. + V3 { + ws: SubscribeWebSocket, + /// Decoded messages left over from a coalesced v3 websocket payload. + pending: VecDeque, + }, + /// v1 is the historical JSON text protocol. + V1 { ws: SubscribeWebSocket }, +} - let mut deletes = Vec::new(); - let mut inserts = Vec::new(); - for upd in &upd.updates { - for s in &upd.deletes { - deletes.push(reformat_row(s)?); - } - for s in &upd.inserts { - inserts.push(reformat_row(s)?); - } - } +impl SubscribeConnection { + /// Send the subscribe request using whichever protocol was negotiated. + async fn subscribe(&mut self, query_strings: Box<[Box]>) -> Result<(), Error> { + match self { + Self::V3 { ws, .. } => subscribe_v3(ws, query_strings).await, + Self::V1 { ws } => subscribe_v1(ws, query_strings).await, + } + } - Ok((&*upd.table_name, SubscriptionTable { deletes, inserts })) - }) - .collect() -} + /// Wait for the initial subscription result and optionally print it. + async fn await_initial_update(&mut self, module_def: Option<&RawModuleDefV9>) -> Result<(), Error> { + match self { + Self::V3 { ws, pending } => await_initial_update_v3(ws, pending, module_def).await, + Self::V1 { ws } => await_initial_update_v1(ws, module_def).await, + } + } -#[derive(serde::Serialize, Debug)] -struct SubscriptionTable { - deletes: Vec, - inserts: Vec, + /// Print transaction updates until the requested count is reached. + async fn consume_transaction_updates( + &mut self, + num: Option, + module_def: &RawModuleDefV9, + ) -> Result<(), Error> { + match self { + Self::V3 { ws, pending } => consume_transaction_updates_v3(ws, pending, num, module_def).await, + Self::V1 { ws } => consume_transaction_updates_v1(ws, num, module_def).await, + } + } + + /// Best-effort graceful websocket close. + async fn close(&mut self) { + match self { + Self::V3 { ws, .. } | Self::V1 { ws } => { + let _ = ws.close(None).await; + } + } + } } pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error> { @@ -160,36 +180,14 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error }; let api = ClientApi::new(conn); let module_def = api.module_def().await?; - - let mut url = Url::parse(&api.con.db_uri("subscribe"))?; - // Change the URI scheme from `http(s)` to `ws(s)`. - url.set_scheme(match url.scheme() { - "http" => "ws", - "https" => "wss", - unknown => unreachable!("Invalid URL scheme in `Connection::db_uri`: {unknown}"), - }) - .unwrap(); - if let Some(confirmed) = confirmed { - url.query_pairs_mut() - .append_pair("confirmed", if confirmed { "true" } else { "false" }); - } - - // Create the websocket request. - let mut req = url.into_client_request()?; - req.headers_mut().insert( - header::SEC_WEBSOCKET_PROTOCOL, - http::HeaderValue::from_static(ws_v1::TEXT_PROTOCOL), - ); - // Add the authorization header, if any. - if let Some(auth_header) = api.con.auth_header.to_header() { - req.headers_mut().insert(header::AUTHORIZATION, auth_header); - } - let mut ws = tokio_tungstenite::connect_async(req).await.map(|(ws, _)| ws)?; + let mut conn = connect_with_fallback(&api, confirmed).await?; let task = async { - subscribe(&mut ws, queries.iter().cloned().map(Into::into).collect()).await?; - await_initial_update(&mut ws, print_initial_update.then_some(&module_def)).await?; - consume_transaction_updates(&mut ws, num, &module_def).await + conn.subscribe(queries.iter().cloned().map(Into::into).collect()) + .await?; + conn.await_initial_update(print_initial_update.then_some(&module_def)) + .await?; + conn.consume_transaction_updates(num, &module_def).await }; let res = if let Some(timeout) = timeout { @@ -211,7 +209,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error // The error (if any) relevant to the user is already stored in `res`, // so we can ignore errors here -- graceful close is basically a // courtesy to the server. - let _ = ws.close(None).await; + conn.close().await; // The server closing the connection is not considered an error, // but any other error is. res.or_else(|e| { @@ -224,6 +222,90 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error .map_err(anyhow::Error::from) } +/// Connect using v3 when available, otherwise retry once with the v1 text protocol. +/// +/// Fallback is intentionally limited to connection setup and protocol +/// negotiation. After a v3 connection is accepted, malformed v3 data is a real +/// error and should not be hidden by silently reconnecting with v1. +async fn connect_with_fallback(api: &ClientApi, confirmed: Option) -> Result { + match connect_v3(api, confirmed).await { + Ok(conn) => Ok(conn), + Err(v3_error) => connect_v1(api, confirmed) + .await + .with_context(|| format!("v3 subscribe connection failed ({v3_error}); v1 fallback also failed")), + } +} + +/// Open a v3 subscribe websocket and validate that the server negotiated v3. +async fn connect_v3(api: &ClientApi, confirmed: Option) -> Result { + let req = subscribe_request(api, confirmed, ws_v3::BIN_PROTOCOL, true)?; + let (ws, response) = tokio_tungstenite::connect_async(req).await?; + if response + .headers() + .get(header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|value| value.to_str().ok()) + != Some(ws_v3::BIN_PROTOCOL) + { + return Err(Error::Protocol { + details: "server did not negotiate the v3 websocket protocol", + } + .into()); + } + Ok(SubscribeConnection::V3 { + ws, + pending: VecDeque::new(), + }) +} + +/// Open a v1 text subscribe websocket for compatibility with older servers. +async fn connect_v1(api: &ClientApi, confirmed: Option) -> Result { + let req = subscribe_request(api, confirmed, ws_v1::TEXT_PROTOCOL, false)?; + let (ws, _) = tokio_tungstenite::connect_async(req).await?; + Ok(SubscribeConnection::V1 { ws }) +} + +/// Build a subscribe websocket request for a specific subprotocol. +/// +/// `request_uncompressed` is used only for the CLI v3 path. The CLI decodes +/// enough v3 to print JSON output, but does not implement brotli or gzip +/// decoding, so v3 asks the server for uncompressed payloads. The v1 fallback +/// leaves the query string in its historical shape. +fn subscribe_request( + api: &ClientApi, + confirmed: Option, + protocol: &'static str, + request_uncompressed: bool, +) -> Result { + let mut url = Url::parse(&api.con.db_uri("subscribe"))?; + // Change the URI scheme from `http(s)` to `ws(s)`. + url.set_scheme(match url.scheme() { + "http" => "ws", + "https" => "wss", + unknown => unreachable!("Invalid URL scheme in `Connection::db_uri`: {unknown}"), + }) + .unwrap(); + { + let mut query = url.query_pairs_mut(); + if request_uncompressed { + // The CLI v3 path only needs enough support to print updates as + // JSON, so request uncompressed payloads and avoid brotli/gzip + // decoding here. The v1 fallback preserves the old URL shape. + query.append_pair("compression", "None"); + } + if let Some(confirmed) = confirmed { + query.append_pair("confirmed", if confirmed { "true" } else { "false" }); + } + } + + let mut req = url.into_client_request()?; + req.headers_mut() + .insert(header::SEC_WEBSOCKET_PROTOCOL, http::HeaderValue::from_static(protocol)); + if let Some(auth_header) = api.con.auth_header.to_header() { + req.headers_mut().insert(header::AUTHORIZATION, auth_header); + } + Ok(req) +} + #[derive(Debug, Error)] enum Error { #[error("error sending subscription queries")] @@ -247,6 +329,16 @@ enum Error { #[source] source: anyhow::Error, }, + #[error("error encoding BSATN websocket message: {source}")] + BsatnEncode { + #[source] + source: spacetimedb_lib::bsatn::EncodeError, + }, + #[error("error decoding BSATN websocket message: {source}")] + BsatnDecode { + #[source] + source: spacetimedb_lib::bsatn::DecodeError, + }, #[error(transparent)] Serde(#[from] serde_json::Error), #[error(transparent)] @@ -264,8 +356,8 @@ impl Error { } } -/// Send the subscribe message. -async fn subscribe(ws: &mut S, query_strings: Box<[Box]>) -> Result<(), Error> +/// Send a v1 JSON subscribe message. +async fn subscribe_v1(ws: &mut S, query_strings: Box<[Box]>) -> Result<(), Error> where S: Sink + Unpin, { @@ -279,9 +371,34 @@ where ws.send(msg.into()).await.map_err(|source| Error::Subscribe { source }) } -/// Await the initial [`ServerMessage::SubscriptionUpdate`]. +/// Send a v3 BSATN subscribe message. +async fn subscribe_v3(ws: &mut S, query_strings: Box<[Box]>) -> Result<(), Error> +where + S: Sink + Unpin, +{ + let msg = ws_v2::ClientMessage::Subscribe(ws_v2::Subscribe { + request_id: 0, + query_set_id: ws_v2::QuerySetId::new(0), + query_strings, + }); + let msg = bsatn::to_vec(&msg).map_err(|source| Error::BsatnEncode { source })?; + ws.send(WsMessage::Binary(msg.into())) + .await + .map_err(|source| Error::Subscribe { source }) +} + +/// Parse a v1 text websocket message as JSON. +fn parse_msg_json(msg: &WsMessage) -> Option> { + let WsMessage::Text(msg) = msg else { return None }; + serde_json::from_str::>>(msg) + .inspect_err(|e| eprintln!("couldn't parse message from server: {e}")) + .map(|wrapper| wrapper.0) + .ok() +} + +/// Await the initial v1 [`ws_v1::ServerMessage::InitialSubscription`]. /// If `module_def` is `Some`, print a JSON representation to stdout. -async fn await_initial_update(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error> +async fn await_initial_update_v1(ws: &mut S, module_def: Option<&RawModuleDefV9>) -> Result<(), Error> where S: TryStream + Unpin, { @@ -292,7 +409,7 @@ where match msg { ws_v1::ServerMessage::InitialSubscription(sub) => { if let Some(module_def) = module_def { - let output = format_output_json(&sub.database_update, module_def)?; + let output = format_output_json_v1(&sub.database_update, module_def)?; tokio::io::stdout().write_all(output.as_bytes()).await? } break; @@ -320,9 +437,49 @@ where Ok(()) } -/// Print `num` [`ServerMessage::TransactionUpdate`] messages as JSON. +/// Await the initial [`ws_v2::ServerMessage::SubscribeApplied`]. +/// If `module_def` is `Some`, print a JSON representation to stdout. +async fn await_initial_update_v3( + ws: &mut S, + pending: &mut VecDeque, + module_def: Option<&RawModuleDefV9>, +) -> Result<(), Error> +where + S: TryStream + Unpin, +{ + const RECV_TX_UPDATE: &str = "received transaction update before initial subscription update"; + + while let Some(msg) = next_server_message(ws, pending).await? { + match msg { + ws_v2::ServerMessage::SubscribeApplied(sub) => { + if let Some(module_def) = module_def { + let output = format_output_json_query_rows(&sub.rows, module_def)?; + tokio::io::stdout().write_all(output.as_bytes()).await? + } + break; + } + ws_v2::ServerMessage::SubscriptionError(error) => { + return Err(Error::SubscribeFailure { reason: error.error }); + } + ws_v2::ServerMessage::TransactionUpdate(_) => { + return Err(Error::Protocol { + details: RECV_TX_UPDATE, + }) + } + _ => continue, + } + } + + Ok(()) +} + +/// Print `num` v1 [`ws_v1::ServerMessage::TransactionUpdate`] messages as JSON. /// If `num` is `None`, keep going indefinitely. -async fn consume_transaction_updates(ws: &mut S, num: Option, module_def: &RawModuleDefV9) -> Result<(), Error> +async fn consume_transaction_updates_v1( + ws: &mut S, + num: Option, + module_def: &RawModuleDefV9, +) -> Result<(), Error> where S: TryStream + Unpin, { @@ -354,7 +511,50 @@ where status: ws_v1::UpdateStatus::Committed(update), .. }) => { - let output = format_output_json(&update, module_def)?; + let output = format_output_json_v1(&update, module_def)?; + stdout.write_all(output.as_bytes()).await?; + num_received += 1; + } + _ => continue, + } + } +} + +/// Print `num` [`ws_v2::ServerMessage::TransactionUpdate`] messages as JSON. +/// If `num` is `None`, keep going indefinitely. +async fn consume_transaction_updates_v3( + ws: &mut S, + pending: &mut VecDeque, + num: Option, + module_def: &RawModuleDefV9, +) -> Result<(), Error> +where + S: TryStream + Unpin, +{ + let mut stdout = tokio::io::stdout(); + let mut num_received = 0; + loop { + if num.is_some_and(|n| num_received >= n) { + return Ok(()); + } + let Some(msg) = next_server_message(ws, pending).await? else { + eprintln!("disconnected by server"); + return Err(Error::Websocket { + source: WsError::ConnectionClosed, + }); + }; + + match msg { + ws_v2::ServerMessage::SubscribeApplied(_) => { + return Err(Error::Protocol { + details: "received a second initial subscription update", + }) + } + ws_v2::ServerMessage::SubscriptionError(error) => { + return Err(Error::SubscribeFailure { reason: error.error }); + } + ws_v2::ServerMessage::TransactionUpdate(update) => { + let output = format_output_json_transaction_update(&update, module_def)?; stdout.write_all(output.as_bytes()).await?; num_received += 1; } @@ -363,12 +563,207 @@ where } } -fn format_output_json( +/// Return the next decoded server message from a v3 websocket stream. +/// +/// A v3 websocket payload can contain multiple consecutive BSATN-encoded v2 +/// server messages, so decoded surplus messages are queued for the next call. +/// Non-binary messages are ignored because v3 server data is binary-only. +async fn next_server_message( + ws: &mut S, + pending: &mut VecDeque, +) -> Result, Error> +where + S: TryStream + Unpin, +{ + loop { + if let Some(msg) = pending.pop_front() { + return Ok(Some(msg)); + } + + let Some(msg) = ws.try_next().await.map_err(|source| Error::Websocket { source })? else { + return Ok(None); + }; + let WsMessage::Binary(msg) = msg else { continue }; + decode_server_payload(msg, pending)?; + } +} + +/// Decode one uncompressed v3 websocket payload into queued v2 server messages. +/// +/// The server prefixes each binary payload with a compression tag. This CLI path +/// requests `compression=None`, so any compressed tag is treated as a protocol +/// error rather than decoded here. +fn decode_server_payload(msg: Bytes, pending: &mut VecDeque) -> Result<(), Error> { + let Some((&tag, mut remaining)) = msg.as_ref().split_first() else { + return Err(Error::Protocol { + details: "received empty v3 websocket payload", + }); + }; + if tag != ws_common::SERVER_MSG_COMPRESSION_TAG_NONE { + return Err(Error::Protocol { + details: "compressed v3 subscribe payload is not supported by this CLI path", + }); + } + if remaining.is_empty() { + return Err(Error::Protocol { + details: "received v3 websocket payload without a server message", + }); + } + + while !remaining.is_empty() { + let msg = bsatn::from_reader(&mut remaining).map_err(|source| Error::BsatnDecode { source })?; + pending.push_back(msg); + } + + Ok(()) +} + +/// Format a v1 database update using the legacy JSON row representation. +fn format_output_json_v1( msg: &ws_v1::DatabaseUpdate, schema: &RawModuleDefV9, ) -> Result { - let formatted = reformat_update(msg, schema).map_err(|source| Error::Reformat { source })?; - let output = serde_json::to_string(&formatted)? + "\n"; + let formatted = reformat_update_v1(msg, schema).map_err(|source| Error::Reformat { source })?; + format_output_json_from_tables(&formatted) +} + +/// Format initial v3 subscription rows using the CLI's existing JSON output shape. +fn format_output_json_query_rows(msg: &ws_v2::QueryRows, schema: &RawModuleDefV9) -> Result { + let formatted = reformat_query_rows(msg, schema).map_err(|source| Error::Reformat { source })?; + format_output_json_from_tables(&formatted) +} +/// Format a v3 transaction update using the CLI's existing JSON output shape. +fn format_output_json_transaction_update( + msg: &ws_v2::TransactionUpdate, + schema: &RawModuleDefV9, +) -> Result { + let formatted = reformat_transaction_update(msg, schema).map_err(|source| Error::Reformat { source })?; + format_output_json_from_tables(&formatted) +} + +/// Serialize the normalized table update map as one JSON object per output line. +fn format_output_json_from_tables(formatted: &HashMap<&str, SubscriptionTable>) -> Result { + let output = serde_json::to_string(formatted)? + "\n"; Ok(output) } + +/// Convert a v1 JSON-format database update to the normalized table output map. +fn reformat_update_v1<'a>( + msg: &'a ws_v1::DatabaseUpdate, + schema: &RawModuleDefV9, +) -> anyhow::Result> { + msg.tables + .iter() + .map(|upd| { + let table_ty = schema.typespace.resolve( + schema + .type_ref_for_table_like(&upd.table_name) + .context("table not found in schema")?, + ); + + let reformat_row = |row: &str| -> anyhow::Result { + // TODO: can the following two calls be merged into a single call to reduce allocations? + let row = serde_json::from_str::(row)?; + let row = serde::de::DeserializeSeed::deserialize(SeedWrapper(table_ty), row)?; + let row = table_ty.with_value(&row); + let row = serde_json::to_value(SerializeWrapper::from_ref(&row))?; + Ok(row) + }; + + let mut deletes = Vec::new(); + let mut inserts = Vec::new(); + for upd in &upd.updates { + for s in &upd.deletes { + deletes.push(reformat_row(s)?); + } + for s in &upd.inserts { + inserts.push(reformat_row(s)?); + } + } + + Ok((&*upd.table_name, SubscriptionTable { deletes, inserts })) + }) + .collect() +} + +/// Convert v3 initial subscription rows to the normalized table output map. +fn reformat_query_rows<'a>( + msg: &'a ws_v2::QueryRows, + schema: &RawModuleDefV9, +) -> anyhow::Result> { + let mut formatted = HashMap::default(); + + for table in &msg.tables { + let table_ty = schema.typespace.resolve( + schema + .type_ref_for_table_like(&table.table) + .context("table not found in schema")?, + ); + let table_output = formatted.entry(&*table.table).or_insert_with(|| SubscriptionTable { + deletes: Vec::new(), + inserts: Vec::new(), + }); + table_output.inserts.extend(reformat_bsatn_rows(&table.rows, table_ty)?); + } + + Ok(formatted) +} + +/// Convert a v3 transaction update to the normalized table output map. +fn reformat_transaction_update<'a>( + msg: &'a ws_v2::TransactionUpdate, + schema: &RawModuleDefV9, +) -> anyhow::Result> { + let mut formatted = HashMap::default(); + + for query_set in &msg.query_sets { + for table in &query_set.tables { + let table_ty = schema.typespace.resolve( + schema + .type_ref_for_table_like(&table.table_name) + .context("table not found in schema")?, + ); + let table_output = formatted + .entry(&*table.table_name) + .or_insert_with(|| SubscriptionTable { + deletes: Vec::new(), + inserts: Vec::new(), + }); + for rows in &table.rows { + match rows { + ws_v2::TableUpdateRows::PersistentTable(rows) => { + table_output + .deletes + .extend(reformat_bsatn_rows(&rows.deletes, table_ty)?); + table_output + .inserts + .extend(reformat_bsatn_rows(&rows.inserts, table_ty)?); + } + ws_v2::TableUpdateRows::EventTable(rows) => { + table_output + .inserts + .extend(reformat_bsatn_rows(&rows.events, table_ty)?); + } + } + } + } + } + + Ok(formatted) +} + +/// Decode BSATN row-list entries and re-encode them as schema-aware JSON values. +fn reformat_bsatn_rows( + rows: &ws_common::BsatnRowList, + table_ty: WithTypespace<'_, AlgebraicType>, +) -> anyhow::Result> { + rows.into_iter() + .map(|row| { + let mut row = row.as_ref(); + let row = BsatnDeserializeSeed::deserialize(table_ty, bsatn::Deserializer::new(&mut row))?; + let row = table_ty.with_value(&row); + Ok(serde_json::to_value(SerializeWrapper::from_ref(&row))?) + }) + .collect() +} diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index ce7c127be98..f29527bee42 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -23,7 +23,7 @@ use prometheus::{Histogram, IntGauge}; use scopeguard::{defer, ScopeGuard}; use serde::Deserialize; use spacetimedb::client::messages::{ - serialize, serialize_v2, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, + serialize, serialize_v3, IdentityTokenMessage, InUseSerializeBuffer, SerializeBuffer, SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ @@ -39,7 +39,7 @@ use spacetimedb::Identity; use spacetimedb_client_api_messages::websocket::v1 as ws_v1; use spacetimedb_client_api_messages::websocket::v2 as ws_v2; use spacetimedb_client_api_messages::websocket::v3 as ws_v3; -use spacetimedb_datastore::execution_context::WorkloadType; +use spacetimedb_lib::bsatn; use spacetimedb_lib::connection_id::{ConnectionId, ConnectionIdForUrl}; use tokio::sync::{mpsc, watch}; use tokio::task::JoinHandle; @@ -1290,13 +1290,143 @@ enum OutboundWsMessage { Message(OutboundMessage), } -/// Task that reads [`OutboundWsMessage`]s from `messages`, encodes them via -/// [`ws_encode_message`], and sends the resuling [`Frame`]s to `outgoing_frames`. +/// Controls how many binary protocol messages may be packed into a single +/// websocket payload. +/// +/// Protocol v2 requires one [`ws_v2::ServerMessage`] per websocket message. +/// Protocol v3 keeps the v2 message schema but permits multiple consecutive +/// v2 messages in a single websocket message. +#[derive(Clone, Copy, PartialEq, Eq)] +enum BinaryPayloadMode { + /// Flush after each binary server message. + Single, + /// Flush once after all available binary server messages are collected. + Coalesced, +} + +/// A binary websocket message plus the logical row count it contributes to +/// payload-level send metrics. +struct V2OutboundMessage { + message: ws_v2::ServerMessage, + num_rows: Option, +} + +/// Convert an outbound message into the binary websocket schema. +/// +/// v2 connections should only receive v2 server messages. +/// v1 messages are dropped. +/// +/// TODO: For better type safety, [`ClientConnectionReceiver`] should be made +/// generic over the protocol version. +fn v2_outbound_message(message: OutboundWsMessage) -> Option { + let message = match message { + OutboundWsMessage::Error(message) => { + log::error!("dropping v1 error message on v2 connection: {:?}", message); + return None; + } + OutboundWsMessage::Message(message) => message, + }; + + let num_rows = message.num_rows(); + match message { + OutboundMessage::V2(message) => Some(V2OutboundMessage { message, num_rows }), + OutboundMessage::V1(message) => { + log::error!("dropping v1 message on v2 connection: {:?}", message); + None + } + } +} + +/// Return the uncompressed payload size of `message`. +/// +/// v2 sends exactly one BSATN-encoded v2 server message per websocket payload. +/// v3 sends one or more of the same encoded messages in a coalesced payload. +fn message_size(message: &ws_v2::ServerMessage) -> usize { + bsatn::to_len(message).expect("should be able to measure bsatn-encoded v2 server message") +} + +/// Return whether appending the next message would cross the v3 coalescing cap. +/// +/// An empty payload is always allowed to accept one message, even when that +/// message alone is larger than the cap. +fn v3_payload_would_exceed_limit(total_bytes: usize, message_bytes: usize) -> bool { + total_bytes != 0 && total_bytes.saturating_add(message_bytes) > V3_MAX_UNCOMPRESSED_PAYLOAD_SIZE +} + +/// Return whether a binary websocket payload is large enough to encode on Rayon. +fn is_large_payload(num_bytes: usize) -> bool { + num_bytes >= V3_MAX_UNCOMPRESSED_PAYLOAD_SIZE +} + +/// Encoding receive batch size. +/// +/// This is deliberately tied to the client connection receive limit so the +/// websocket encoder can consume the batches produced by +/// [`ClientConnectionReceiver::recv_many`] without immediately re-batching +/// them to a different size. +const ENCODE_BATCH_SIZE: usize = ClientConnectionReceiver::DEFAULT_RECV_MANY_LIMIT; + +/// Target maximum uncompressed v3 payload body size. +/// +/// The v3 binary body is a sequence of BSATN-encoded v2 server messages. The +/// one-byte compression tag is not counted here. This is a target, not a hard +/// rejection limit. One logical server message may exceed it, in which case the +/// message is sent by itself. +const V3_MAX_UNCOMPRESSED_PAYLOAD_SIZE: usize = 512 * 1024; + +/// Tracks serialize buffers that may be reusable once their frames have been +/// copied to the wire. +struct SerializeBufferPool { + config: ClientConfig, + available: ArrayQueue, + in_use: Vec, +} + +impl SerializeBufferPool { + const CAPACITY: usize = 16; + + fn new(config: ClientConfig) -> Self { + Self { + config, + available: ArrayQueue::new(Self::CAPACITY), + in_use: Vec::with_capacity(Self::CAPACITY), + } + } + + fn get(&mut self) -> SerializeBuffer { + self.reclaim(); + self.available + .pop() + .unwrap_or_else(|| SerializeBuffer::new(self.config)) + } + + fn hold(&mut self, in_use: InUseSerializeBuffer) { + if self.in_use.len() < Self::CAPACITY { + self.in_use.push(in_use); + } + } + + fn reclaim(&mut self) { + let mut i = 0; + while i < self.in_use.len() { + if self.in_use[i].is_unique() { + let in_use = self.in_use.swap_remove(i); + let buf = in_use.try_reclaim().expect("buffer should be unique"); + let _ = self.available.push(buf); + } else { + i += 1; + } + } + } +} + +/// Task that reads [`OutboundWsMessage`]s from `messages`, encodes them, and +/// sends the resulting [`Frame`]s to `outgoing_frames`. /// /// Meant to be [`tokio::spawn`]ed. /// /// The function also takes care of reusing serialization buffers and reporting -/// metrics via [`SendMetrics`].. +/// metrics via [`SendMetrics`]. async fn ws_encode_task( metrics: SendMetrics, config: ClientConfig, @@ -1304,91 +1434,264 @@ async fn ws_encode_task( outgoing_frames: mpsc::UnboundedSender, bsatn_rlb_pool: BsatnRowListBuilderPool, ) { - // Serialize buffers can be reclaimed once all frames of a message are - // copied to the wire. Since we don't know when that will happen, we prepare - // for a few messages to be in-flight, i.e. encoded but not yet sent. - const BUF_POOL_CAPACITY: usize = 16; - let buf_pool = ArrayQueue::new(BUF_POOL_CAPACITY); - let mut in_use_bufs: Vec> = Vec::with_capacity(BUF_POOL_CAPACITY); - - 'send: while let Some(message) = messages.recv().await { - // Drop serialize buffers with no external referent, - // returning them to the pool. - in_use_bufs.retain(|in_use| !in_use.is_unique()); - // Get a serialize buffer from the pool, - // or create a fresh one. - let buf = buf_pool.pop().unwrap_or_else(|| SerializeBuffer::new(config)); - - let in_use_buf = match message { - OutboundWsMessage::Error(message) => { - if config.version != WsVersion::V1 { - log::error!( - "dropping v1 error message sent to a binary websocket client: {:?}", - message - ); - continue; + let mut encoder = WsEncoder { + config, + buffers: SerializeBufferPool::new(config), + metrics: &metrics, + outgoing_frames: &outgoing_frames, + bsatn_rlb_pool: &bsatn_rlb_pool, + binary_server_messages: Vec::new(), + }; + let mut message_batch = Vec::new(); + while messages.recv_many(&mut message_batch, ENCODE_BATCH_SIZE).await != 0 { + log::trace!("encoding batch of {} websocket messages", message_batch.len()); + // `encode_batch` drains `message_batch` on success. If forwarding to + // the websocket send loop fails, the receiver is gone, so the encode + // task can terminate. + if encoder.encode_batch(&mut message_batch).await.is_err() { + break; + } + } +} + +/// Stateful websocket encoder for one client connection. +/// +/// The encoder owns reusable scratch storage: +/// +/// - [`SerializeBufferPool`] reuses byte buffers once encoded frames have been +/// copied to the socket task. +/// - `binary_server_messages` reuses the vector allocation used to assemble +/// v2/v3 binary websocket payloads. +struct WsEncoder<'a> { + config: ClientConfig, + buffers: SerializeBufferPool, + metrics: &'a SendMetrics, + outgoing_frames: &'a mpsc::UnboundedSender, + bsatn_rlb_pool: &'a BsatnRowListBuilderPool, + binary_server_messages: Vec, +} + +impl WsEncoder<'_> { + /// Encode a drained batch according to the websocket version negotiated by + /// the client. + async fn encode_batch( + &mut self, + message_batch: &mut Vec, + ) -> Result<(), mpsc::error::SendError> { + match self.config.version { + WsVersion::V1 => self.encode_v1_batch(message_batch).await, + WsVersion::V2 => self.encode_v2_batch(message_batch).await, + WsVersion::V3 => self.encode_v3_batch(message_batch).await, + } + } + + /// Encode a batch for the original v1 websocket protocols. + /// + /// v1 text/binary messages are encoded one logical message at a time. This + /// path also handles reducer errors, which still use the v1 message schema. + async fn encode_v1_batch( + &mut self, + message_batch: &mut Vec, + ) -> Result<(), mpsc::error::SendError> { + for message in message_batch.drain(..) { + match message { + OutboundWsMessage::Error(message) => { + self.encode_and_forward_v1_message(None, message).await?; } - let Ok(in_use) = ws_forward_frames( - &metrics, - &outgoing_frames, - None, - None, - ws_encode_message(config, buf, message, false, &bsatn_rlb_pool).await, - ) else { - break 'send; - }; - in_use - } - OutboundWsMessage::Message(message) => { - let workload = message.workload(); - let num_rows = message.num_rows(); - match message { - OutboundMessage::V2(server_message) => { - if config.version == WsVersion::V1 { + OutboundWsMessage::Message(message) => { + let num_rows = message.num_rows(); + match message { + OutboundMessage::V2(_) => { log::error!("dropping v2 message on v1 connection"); continue; } - - let Ok(in_use) = ws_forward_frames( - &metrics, - &outgoing_frames, - workload, - num_rows, - ws_encode_binary_message(config, buf, server_message, false, &bsatn_rlb_pool).await, - ) else { - break 'send; - }; - in_use - } - OutboundMessage::V1(message) => { - if config.version != WsVersion::V1 { - log::error!("dropping v1 message for a binary websocket connection: {:?}", message); - continue; + OutboundMessage::V1(message) => { + self.encode_and_forward_v1_message(num_rows, message).await?; } + } + } + } + } + Ok(()) + } + + /// Encode a batch for protocol v2. + /// + /// v2 uses the binary server-message schema, but each logical server + /// message must still be sent as its own websocket message. + async fn encode_v2_batch( + &mut self, + message_batch: &mut Vec, + ) -> Result<(), mpsc::error::SendError> { + self.encode_binary_batch(message_batch, BinaryPayloadMode::Single).await + } + + /// Encode a batch for protocol v3. + /// + /// v3 uses the same binary server-message schema as v2, but coalesces all + /// messages currently available from the encoder input into one websocket + /// payload. + async fn encode_v3_batch( + &mut self, + message_batch: &mut Vec, + ) -> Result<(), mpsc::error::SendError> { + self.encode_binary_batch(message_batch, BinaryPayloadMode::Coalesced) + .await + } + + /// Encode binary websocket payloads from a batch of outbound messages. + /// + /// `mode` is the only protocol-specific choice here: + /// + /// - [`BinaryPayloadMode::Single`] preserves the v2 wire format by + /// flushing after each message. + /// - [`BinaryPayloadMode::Coalesced`] uses the v3 wire format by flushing + /// after the whole batch has been accumulated. + async fn encode_binary_batch( + &mut self, + message_batch: &mut Vec, + mode: BinaryPayloadMode, + ) -> Result<(), mpsc::error::SendError> { + self.binary_server_messages.clear(); + self.binary_server_messages.reserve(match mode { + BinaryPayloadMode::Single => 1, + BinaryPayloadMode::Coalesced => message_batch.len(), + }); + let mut total_rows = None; + let mut total_bytes = 0; + + for message in message_batch.drain(..) { + // Drop messages that are not valid for a binary websocket + // connection. The conversion logs the protocol mismatch. + let Some(v2_message) = v2_outbound_message(message) else { + continue; + }; - let is_large = num_rows.is_some_and(|n| n > 1024); - - let Ok(in_use) = ws_forward_frames( - &metrics, - &outgoing_frames, - workload, - num_rows, - ws_encode_message(config, buf, message, is_large, &bsatn_rlb_pool).await, - ) else { - break 'send; - }; - in_use + let message = v2_message.message; + let message_rows = v2_message.num_rows; + + let message_bytes = message_size(&message); + match mode { + BinaryPayloadMode::Coalesced => { + if v3_payload_would_exceed_limit(total_bytes, message_bytes) { + // v3 payload boundary: adding this message would cross the + // target byte limit, so flush the payload accumulated so far. + self.flush_binary_payload(&mut total_rows, &mut total_bytes).await?; } + self.append_binary_message(message, message_rows, message_bytes, &mut total_rows, &mut total_bytes); + } + BinaryPayloadMode::Single => { + // v2 payload boundary: exactly one binary server message per websocket message. + self.append_binary_message(message, message_rows, message_bytes, &mut total_rows, &mut total_bytes); + self.flush_binary_payload(&mut total_rows, &mut total_bytes).await?; } } - }; + } - if in_use_bufs.len() < BUF_POOL_CAPACITY { - in_use_bufs.push(scopeguard::guard(in_use_buf, |in_use| { - let buf = in_use.try_reclaim().expect("buffer should be unique"); - let _ = buf_pool.push(buf); - })); + // Final v3 payload boundary: flush the remaining coalesced messages. + // This is a no-op for v2 because `Single` mode flushes inside the loop. + self.flush_binary_payload(&mut total_rows, &mut total_bytes).await + } + + /// Append one v2 server message to the binary websocket payload currently being accumulated. + fn append_binary_message( + &mut self, + message: ws_v2::ServerMessage, + message_rows: Option, + message_bytes: usize, + total_rows: &mut Option, + total_bytes: &mut usize, + ) { + if let Some(message_rows) = message_rows { + // Payload metrics are emitted at websocket-payload granularity. + // In v3, one payload can contain several logical messages, so row + // counts are accumulated across the coalesced payload. + *total_rows.get_or_insert(0) += message_rows; } + self.binary_server_messages.push(message); + *total_bytes += message_bytes; + } + + /// Encode and forward the accumulated binary payload, then reset its counters. + async fn flush_binary_payload( + &mut self, + total_rows: &mut Option, + total_bytes: &mut usize, + ) -> Result<(), mpsc::error::SendError> { + if self.binary_server_messages.is_empty() { + return Ok(()); + } + let is_large = is_large_payload(*total_bytes); + self.encode_and_forward_binary_messages(total_rows.take(), is_large) + .await?; + *total_bytes = 0; + Ok(()) + } + + /// Encode and forward one v1 websocket message. + /// + /// v1 can produce either text or binary payloads depending on the client's + /// requested protocol, so it uses [`ws_encode_message`] rather than the + /// binary-only v2/v3 path. + async fn encode_and_forward_v1_message( + &mut self, + num_rows: Option, + message: impl ToProtocol + Send + 'static, + ) -> Result<(), mpsc::error::SendError> { + let config = self.config; + let bsatn_rlb_pool = self.bsatn_rlb_pool; + self.encode_and_forward_message(|buf| ws_encode_message(config, buf, message, true, bsatn_rlb_pool, num_rows)) + .await + } + + /// Encode and forward the currently accumulated binary server messages. + /// + /// This method is shared by v2 and v3. v2 calls it with exactly one + /// `binary_server_messages` entry; v3 calls it with the whole coalesced + /// batch. The actual bytes are produced by [`serialize_v3`], whose core + /// implementation also backs `serialize_v2`. + async fn encode_and_forward_binary_messages( + &mut self, + num_rows: Option, + is_large: bool, + ) -> Result<(), mpsc::error::SendError> { + let buf = self.buffers.get(); + // `spawn_rayon` requires a `'static` closure, so the message Vec cannot + // be borrowed from `self`. Move it into the closure and return the + // drained Vec afterward so its allocation is reused by the next batch. + let messages = std::mem::take(&mut self.binary_server_messages); + let compression = self.config.compression; + let bsatn_rlb_pool = self.bsatn_rlb_pool.clone(); + let (messages, timing, in_use, data) = maybe_spawn_encode(is_large, move || { + let mut messages = messages; + let (timing, in_use, data) = + time_encode(|| serialize_v3(&bsatn_rlb_pool, buf, messages.drain(..), compression)); + (messages, timing, in_use, data) + }) + .await; + self.binary_server_messages = messages; + let encoded = ws_encode_binary_frames(timing, in_use, data, num_rows); + let in_use = ws_forward_frames(self.metrics, self.outgoing_frames, encoded); + let in_use = in_use?; + self.buffers.hold(in_use); + Ok(()) + } + + /// Encode one websocket payload using a reusable serialization buffer, + /// forward its frames, then retain the buffer for later reuse. + async fn encode_and_forward_message( + &mut self, + encode: Encode, + ) -> Result<(), mpsc::error::SendError> + where + Encode: FnOnce(SerializeBuffer) -> Fut, + Fut: Future, + Frames: IntoIterator, + { + let buf = self.buffers.get(); + let in_use = ws_forward_frames(self.metrics, self.outgoing_frames, encode(buf).await)?; + self.buffers.hold(in_use); + Ok(()) } } @@ -1397,25 +1700,27 @@ async fn ws_encode_task( fn ws_forward_frames( metrics: &SendMetrics, outgoing_frames: &mpsc::UnboundedSender, - workload: Option, - num_rows: Option, - encoded: (EncodeMetrics, InUseSerializeBuffer, impl IntoIterator), + encoded: ( + EncodedPayloadMetrics, + InUseSerializeBuffer, + impl IntoIterator, + ), ) -> Result> { let (stats, in_use, frames) = encoded; - metrics.report(workload, num_rows, stats); + metrics.report(stats); frames.into_iter().try_for_each(|frame| outgoing_frames.send(frame))?; Ok(in_use) } -/// Some stats about serialization and compression. -/// -/// Returned by [`ws_encode_message`]. -struct EncodeMetrics { +/// Metrics for one encoded websocket payload. +struct EncodedPayloadMetrics { /// Time it took to serialize and (potentially) compress a message. /// Does not include scheduling overhead. timing: Duration, /// Length in bytes of the serialized and (potentially) compressed message. encoded_len: usize, + /// Number of logical rows included in the payload, if known. + num_rows: Option, } /// Encodes `message` into zero or more WebSocket [`Frame`]s. @@ -1432,7 +1737,7 @@ struct EncodeMetrics { /// of payload each, according to the rules laid out in [RFC6455], Section /// 5.4 Fragmentation. /// -/// Returns [`EncodeMetrics`], the [`InUseSerializeBuffer`] that was passed in +/// Returns [`EncodedPayloadMetrics`], the [`InUseSerializeBuffer`] that was passed in /// as `buf` for later reuse, and the [`Frame`]s. /// /// NOTE: When sending, the frames of a single message MUST NOT be interleaved @@ -1446,62 +1751,75 @@ async fn ws_encode_message( message: impl ToProtocol + Send + 'static, is_large_message: bool, bsatn_rlb_pool: &BsatnRowListBuilderPool, -) -> (EncodeMetrics, InUseSerializeBuffer, impl Iterator) { - const FRAGMENT_SIZE: usize = 4096; - - fn serialize_and_compress( - bsatn_rlb_pool: &BsatnRowListBuilderPool, - serialize_buf: SerializeBuffer, - message: impl ToProtocol + Send + 'static, - config: ClientConfig, - ) -> (Duration, InUseSerializeBuffer, DataMessage) { - let start = Instant::now(); - let (msg_alloc, msg_data) = serialize(bsatn_rlb_pool, serialize_buf, message, config); - (start.elapsed(), msg_alloc, msg_data) - } - let (timing, msg_alloc, msg_data) = if is_large_message { - let bsatn_rlb_pool = bsatn_rlb_pool.clone(); - spawn_rayon(move || serialize_and_compress(&bsatn_rlb_pool, buf, message, config)).await - } else { - serialize_and_compress(bsatn_rlb_pool, buf, message, config) - }; - - let metrics = EncodeMetrics { - timing, - encoded_len: msg_data.len(), - }; + num_rows: Option, +) -> (EncodedPayloadMetrics, InUseSerializeBuffer, impl Iterator) { + let bsatn_rlb_pool = bsatn_rlb_pool.clone(); + // Serialization/compression can dominate large subscription or query + // responses, so large payloads are offloaded to Rayon. + let (timing, in_use, msg_data) = maybe_spawn_encode(is_large_message, move || { + time_encode(|| serialize(&bsatn_rlb_pool, buf, message, config)) + }) + .await; + let encoded_len = msg_data.len(); let (data, ty) = match msg_data { DataMessage::Text(text) => (bytestring_to_utf8bytes(text).into(), Data::Text), DataMessage::Binary(bin) => (bin, Data::Binary), }; - let frames = fragment(data, ty, FRAGMENT_SIZE); + ws_encode_frames(timing, in_use, encoded_len, data, ty, num_rows) +} - (metrics, msg_alloc, frames) +/// Run `encode` on Rayon when the payload is expected to be large. +/// +/// Small payloads stay on the async task to avoid Rayon scheduling overhead. +async fn maybe_spawn_encode(is_large: bool, encode: impl FnOnce() -> T + Send + 'static) -> T { + if is_large { + spawn_rayon(encode).await + } else { + encode() + } } -async fn ws_encode_binary_message( - config: ClientConfig, - buf: SerializeBuffer, - message: ws_v2::ServerMessage, - is_large_message: bool, - bsatn_rlb_pool: &BsatnRowListBuilderPool, -) -> (EncodeMetrics, InUseSerializeBuffer, impl Iterator + use<>) { +/// Measure serialization/compression time for one websocket payload. +fn time_encode(encode: impl FnOnce() -> (InUseSerializeBuffer, T)) -> (Duration, InUseSerializeBuffer, T) { let start = Instant::now(); - let compression = config.compression; + let (in_use, data) = encode(); + (start.elapsed(), in_use, data) +} - let (in_use, data) = if is_large_message { - let bsatn_rlb_pool = bsatn_rlb_pool.clone(); - spawn_rayon(move || serialize_v2(&bsatn_rlb_pool, buf, message, compression)).await - } else { - serialize_v2(bsatn_rlb_pool, buf, message, compression) - }; +/// Build binary websocket frames and payload metrics for encoded bytes. +fn ws_encode_binary_frames( + timing: Duration, + in_use: InUseSerializeBuffer, + data: Bytes, + num_rows: Option, +) -> ( + EncodedPayloadMetrics, + InUseSerializeBuffer, + impl Iterator + use<>, +) { + ws_encode_frames(timing, in_use, data.len(), data, Data::Binary, num_rows) +} - let metrics = EncodeMetrics { - timing: start.elapsed(), - encoded_len: data.len(), +/// Build websocket frames and payload metrics for already-serialized bytes. +fn ws_encode_frames( + timing: Duration, + in_use: InUseSerializeBuffer, + encoded_len: usize, + data: Bytes, + ty: Data, + num_rows: Option, +) -> ( + EncodedPayloadMetrics, + InUseSerializeBuffer, + impl Iterator + use<>, +) { + let metrics = EncodedPayloadMetrics { + timing, + encoded_len, + num_rows, }; - let frames = fragment(data, Data::Binary, 4096); + let frames = fragment(data, ty, 4096); (metrics, in_use, frames) } @@ -1545,33 +1863,32 @@ impl ClientMessage { } } +/// Cached metric handles for the websocket send path. struct SendMetrics { - database: Identity, encode_timing: Histogram, + payload_size: Histogram, + payload_num_rows: Histogram, } impl SendMetrics { + /// Resolve metric handles for one database once per websocket send loop. fn new(database: Identity) -> Self { Self { encode_timing: WORKER_METRICS.websocket_serialize_secs.with_label_values(&database), - database, + payload_size: WORKER_METRICS.websocket_sent_msg_size.with_label_values(&database), + payload_num_rows: WORKER_METRICS.websocket_sent_num_rows.with_label_values(&database), } } - fn report(&self, workload: Option, num_rows: Option, encode: EncodeMetrics) { + /// Report one encoded websocket payload. + fn report(&self, encode: EncodedPayloadMetrics) { self.encode_timing.observe(encode.timing.as_secs_f64()); + self.payload_size.observe(encode.encoded_len as f64); - // These metrics should be updated together, - // or not at all. - if let (Some(workload), Some(num_rows)) = (workload, num_rows) { - WORKER_METRICS - .websocket_sent_num_rows - .with_label_values(&self.database, &workload) - .observe(num_rows as f64); - WORKER_METRICS - .websocket_sent_msg_size - .with_label_values(&self.database, &workload) - .observe(encode.encoded_len as f64); + if let Some(num_rows) = encode.num_rows { + // Some websocket payloads, such as control or error messages, do + // not correspond to a known logical row count. + self.payload_num_rows.observe(num_rows as f64); } } } diff --git a/crates/core/src/client/messages.rs b/crates/core/src/client/messages.rs index 798596b5bca..123c1c75d4a 100644 --- a/crates/core/src/client/messages.rs +++ b/crates/core/src/client/messages.rs @@ -196,20 +196,46 @@ pub fn serialize( /// conditional compression when configured. pub fn serialize_v2( bsatn_rlb_pool: &BsatnRowListBuilderPool, - mut buffer: SerializeBuffer, + buffer: SerializeBuffer, msg: ws_v2::ServerMessage, compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + serialize_v2_messages(bsatn_rlb_pool, buffer, std::iter::once(msg), compression) +} + +/// Serialize one or more [`ws_v2::ServerMessage`]s into a v3 websocket payload. +/// +/// Protocol v3 keeps the v2 message schema, but allows the uncompressed payload +/// body to contain consecutive BSATN-encoded server messages. +pub fn serialize_v3( + bsatn_rlb_pool: &BsatnRowListBuilderPool, + buffer: SerializeBuffer, + msgs: impl IntoIterator, + compression: ws_v1::Compression, +) -> (InUseSerializeBuffer, Bytes) { + serialize_v2_messages(bsatn_rlb_pool, buffer, msgs, compression) +} + +fn serialize_v2_messages( + bsatn_rlb_pool: &BsatnRowListBuilderPool, + mut buffer: SerializeBuffer, + msgs: impl IntoIterator, + compression: ws_v1::Compression, ) -> (InUseSerializeBuffer, Bytes) { let srv_msg = buffer.write_with_tag(ws_common::SERVER_MSG_COMPRESSION_TAG_NONE, |w| { - bsatn::to_writer(w.into_inner(), &msg).expect("should be able to bsatn encode v2 message"); + let out = w.into_inner(); + for msg in msgs { + write_v2_server_message(bsatn_rlb_pool, out, msg); + } }); let srv_msg_len = srv_msg.len(); + finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) +} - // At this point, we no longer have a use for `msg`, - // so try to reclaim its buffers. +fn write_v2_server_message(bsatn_rlb_pool: &BsatnRowListBuilderPool, out: &mut BytesMut, msg: ws_v2::ServerMessage) { + bsatn::to_writer(out, &msg).expect("should be able to bsatn encode v2 message"); + // At this point, we no longer have a use for `msg`, so try to reclaim its buffers. msg.consume_each_list(&mut |buffer| bsatn_rlb_pool.try_put(buffer)); - - finalize_binary_serialize_buffer(buffer, srv_msg_len, compression) } #[derive(Debug, From)] diff --git a/crates/core/src/worker_metrics/mod.rs b/crates/core/src/worker_metrics/mod.rs index 14ce95c5ccf..c1847fa6d1c 100644 --- a/crates/core/src/worker_metrics/mod.rs +++ b/crates/core/src/worker_metrics/mod.rs @@ -204,8 +204,8 @@ metrics_group!( pub tokio_mean_polls_per_park: GaugeVec, #[name = spacetime_websocket_sent_msg_size_bytes] - #[help = "The size of messages sent to connected sessions"] - #[labels(db: Identity, workload: WorkloadType)] + #[help = "The size of websocket payloads sent to connected sessions"] + #[labels(db: Identity)] // Prometheus histograms have default buckets, // which broadly speaking, // are tailored to measure the response time of a network service. @@ -219,8 +219,8 @@ metrics_group!( pub websocket_sent_msg_size: HistogramVec, #[name = spacetime_websocket_sent_num_rows] - #[help = "The number of rows sent to connected sessions"] - #[labels(db: Identity, workload: WorkloadType)] + #[help = "The number of rows sent in websocket payloads"] + #[labels(db: Identity)] // Prometheus histograms have default buckets, // which broadly speaking, // are tailored to measure the response time of a network service.