diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index b2f81b75ad..55700a46d0 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -360,6 +360,26 @@ jobs: SQLX_OFFLINE_DIR: .sqlx RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + # Run tests to validate zstd compression for traffic + - run: > + cargo test + --no-default-features + --features any,mysql,mysql-zstd-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled&compression=zstd:1 + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + + # Run tests to validate zlib compression for traffic + - run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?ssl-mode=disabled&compression=zlib:1 + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + # Run the `test-attr` test again to cover cleanup. - run: > cargo test @@ -426,6 +446,27 @@ jobs: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + # Run tests to validate zstd compression for traffic with tls + - if: ${{ matrix.tls != 'none' }} + run: > + cargo test + --no-default-features + --features any,mysql,mysql-zstd-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zstd:1 + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + + # Run tests to validate zlib compression for traffic with tls + - if: ${{ matrix.tls != 'none' }} + run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1 + RUSTFLAGS: --cfg mysql_${{ matrix.mysql }} + + mariadb: name: MariaDB runs-on: ubuntu-24.04 @@ -461,6 +502,16 @@ jobs: SQLX_OFFLINE_DIR: .sqlx RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + # Run tests to validate zlib compression for traffic + - run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root:password@localhost:3306/sqlx?compression=zlib:1 + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + # Run the `test-attr` test again to cover cleanup. - run: > cargo test @@ -514,3 +565,14 @@ jobs: env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + + + # Run tests to validate zlib compression for traffic with tls + - if: ${{ matrix.tls != 'none' }} + run: > + cargo test + --no-default-features + --features any,mysql,mysql-zlib-compression,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1 + RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" diff --git a/Cargo.lock b/Cargo.lock index 78e40f0c12..cc33848e1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1020,6 +1020,15 @@ version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -1373,6 +1382,17 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "flate2" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb" +dependencies = [ + "crc32fast", + "libz-sys", + "miniz_oxide", +] + [[package]] name = "float-cmp" version = "0.9.0" @@ -2160,6 +2180,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15d118bbf3771060e7311cc7bb0545b01d08a8b4a7de949198dec1fa0ca1c0f7" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -2266,6 +2297,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -3431,6 +3463,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simd-adler32" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" + [[package]] name = "simdutf8" version = "0.1.5" @@ -3903,6 +3941,7 @@ dependencies = [ "digest", "dotenvy", "either", + "flate2", "futures-channel", "futures-core", "futures-io", @@ -3931,6 +3970,7 @@ dependencies = [ "tracing", "uuid", "whoami", + "zstd", ] [[package]] @@ -5290,3 +5330,31 @@ dependencies = [ "quote", "syn 2.0.104", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index 00d5d656c1..92bd9cee77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -161,6 +161,10 @@ uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgre regexp = ["sqlx-sqlite?/regexp"] bstr = ["sqlx-core/bstr"] +# compression +mysql-zstd-compression = ["sqlx-mysql/zstd-compression"] +mysql-zlib-compression = ["sqlx-mysql/zlib-compression"] + [workspace.dependencies] # Core Crates sqlx-core = { version = "=0.9.0-alpha.1", path = "sqlx-core" } @@ -401,6 +405,12 @@ name = "mysql-rustsec" path = "tests/mysql/rustsec.rs" required-features = ["mysql"] +[[test]] +name = "mysql-compression" +path = "tests/mysql/compression.rs" +required-features = ["mysql"] + + # # PostgreSQL # diff --git a/README.md b/README.md index f1e53cdced..099fa0093b 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,10 @@ be removed in the future. - `mysql`: Add support for the MySQL/MariaDB database server. +- `mysql-zlib-compression`: Add zlib compression support for MySQL/MariaDB database server. + +- `mysql-zstd-compression`: Add std compression support for MySQL database server. + - `mssql`: Add support for the MSSQL database server. - `sqlite`: Add support for the self-contained [SQLite](https://sqlite.org/) database engine with SQLite bundled and statically-linked. diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index ee9512b61e..57eb7db9cd 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -14,6 +14,8 @@ json = ["sqlx-core/json", "serde"] any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde/derive"] migrate = ["sqlx-core/migrate"] +zstd-compression = ["zstd"] +zlib-compression = ["flate2"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] @@ -67,6 +69,10 @@ stringprep = "0.1.2" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } +# Compression +zstd = { version = "0.13.3", optional = true, default-features = false, features = ["zdict_builder"] } +flate2 = { version = "1.1.5", optional = true, default-features = false, features = ["rust_backend", "zlib"] } + dotenvy.workspace = true thiserror.workspace = true diff --git a/sqlx-mysql/src/connection/auth.rs b/sqlx-mysql/src/connection/auth.rs index 613f8e702f..9739bd57a1 100644 --- a/sqlx-mysql/src/connection/auth.rs +++ b/sqlx-mysql/src/connection/auth.rs @@ -53,7 +53,7 @@ impl AuthPlugin { 0x04 => { let payload = encrypt_rsa(stream, 0x02, password, nonce).await?; - stream.write_packet(&*payload)?; + stream.write_packet(&*payload).await?; stream.flush().await?; Ok(false) @@ -143,7 +143,7 @@ async fn encrypt_rsa<'s>( } // client sends a public key request - stream.write_packet(&[public_key_request_id][..])?; + stream.write_packet(&[public_key_request_id][..]).await?; stream.flush().await?; // server sends a public key response diff --git a/sqlx-mysql/src/connection/compression.rs b/sqlx-mysql/src/connection/compression.rs new file mode 100644 index 0000000000..fd547b35e7 --- /dev/null +++ b/sqlx-mysql/src/connection/compression.rs @@ -0,0 +1,594 @@ +use crate::protocol::Capabilities; +use crate::CompressionConfig; +use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; +use sqlx_core::net::{BufferedSocket, Socket}; +use sqlx_core::Error; +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] +use {crate::Compression, compressed_stream::CompressedStream}; + +pub(crate) struct CompressionMySqlStream> { + mode: CompressionMode, + pub(crate) socket: BufferedSocket, +} + +impl CompressionMySqlStream { + pub(crate) fn not_compressed(socket: BufferedSocket) -> Self { + let mode = CompressionMode::NotCompressed; + Self { mode, socket } + } + + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + fn compressed(socket: BufferedSocket, compression: CompressionConfig) -> Self { + let mode = CompressionMode::Compressed(CompressedStream::new(compression)); + Self { mode, socket } + } + + pub(crate) fn create( + socket: BufferedSocket, + #[cfg_attr( + not(all(feature = "zstd-compression", feature = "zlib-compression")), + allow(unused_variables) + )] + capabilities: &Capabilities, + compression_configs: &[CompressionConfig], + ) -> Self { + let supported_compression = compression_configs.iter().find(|c| { + let is_supported = match c.0 { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => capabilities.contains(Capabilities::COMPRESS), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => { + capabilities.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM) + } + #[cfg(not(any(feature = "zstd-compression", feature = "zlib-compression")))] + _ => false, + }; + if !is_supported { + tracing::warn!("server doesn't support '{:?}' compression", c.0); + } + is_supported + }); + match supported_compression { + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + Some(c) => CompressionMySqlStream::compressed(socket, *c), + _ => CompressionMySqlStream::not_compressed(socket), + } + } + + pub(crate) fn boxed(self) -> CompressionMySqlStream> { + CompressionMySqlStream { + socket: self.socket.boxed(), + mode: self.mode, + } + } + + pub(crate) async fn read_with<'de, T, C>( + &mut self, + byte_len: usize, + context: C, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + match self.mode { + CompressionMode::NotCompressed => self.socket.read_with(byte_len, context).await, + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + CompressionMode::Compressed(ref mut s) => { + s.read_with(byte_len, context, &mut self.socket).await + } + } + } + + pub(crate) async fn write_with<'en, 'stream, T>( + &mut self, + value: T, + context: (Capabilities, &'stream mut u8), + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, + { + match self.mode { + CompressionMode::NotCompressed => self.socket.write_with(value, context), + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + CompressionMode::Compressed(ref mut s) => { + s.write_with(value, context, &mut self.socket).await + } + } + } + + pub(crate) fn uncompressed_write_with<'en, 'stream, T>( + &mut self, + value: T, + context: (Capabilities, &'stream mut u8), + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, (Capabilities, &'stream mut u8)>, + { + match self.mode { + CompressionMode::NotCompressed => self.socket.write_with(value, context), + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + CompressionMode::Compressed(ref mut s) => { + s.uncompressed_write_with(value, context, &mut self.socket) + } + } + } +} + +enum CompressionMode { + NotCompressed, + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + Compressed(CompressedStream), +} + +#[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] +mod compressed_stream { + use crate::{Compression, CompressionConfig}; + use bytes::{Buf, BufMut, Bytes, BytesMut}; + #[cfg(feature = "zlib-compression")] + use flate2::{ + write::ZlibEncoder, Compression as ZlibCompression, Decompress as ZlibDecompressor, + FlushDecompress, Status, + }; + use sqlx_core::io::{ProtocolDecode, ProtocolEncode}; + use sqlx_core::net::{BufferedSocket, Socket}; + use sqlx_core::rt::yield_now; + use sqlx_core::Error; + use std::cmp::min; + use std::io::{Cursor, Write}; + #[cfg(feature = "zstd-compression")] + use zstd::stream::{ + raw::{Decoder as ZstdDecoder, InBuffer, Operation, OutBuffer}, + Encoder as ZstdEncoder, + }; + + pub(crate) struct CompressedStream { + compression_config: CompressionConfig, + sequence_id: u8, + packet_reader: Option, + } + + impl CompressedStream { + pub(crate) fn new(compression_config: CompressionConfig) -> Self { + Self { + sequence_id: 0, + packet_reader: None, + compression_config, + } + } + + pub(crate) async fn read_with<'de, T, C, S: Socket>( + &mut self, + byte_len: usize, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result + where + T: ProtocolDecode<'de, C>, + { + let mut result_buffer = BytesMut::with_capacity(byte_len); + while result_buffer.len() != byte_len { + let compressed_packet_reader = match self.packet_reader.as_mut() { + None => { + let packet_reader = + CompressedPacketReader::new(buffered_socket, &self.compression_config) + .await?; + self.sequence_id = packet_reader.sequence_id.wrapping_add(1); + self.packet_reader = Some(packet_reader); + self.packet_reader.as_mut().unwrap() + } + Some(p) => p, + }; + + let required_bytes_count = byte_len.saturating_sub(result_buffer.len()); + let chunk = compressed_packet_reader + .read(buffered_socket, required_bytes_count) + .await?; + result_buffer.put_slice(&chunk); + + if !compressed_packet_reader.is_available() { + self.packet_reader = None + } + } + + T::decode_with(result_buffer.freeze(), context) + } + + pub(crate) async fn write_with<'en, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + self.sequence_id = 0; + let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); + packet.encode_with(&mut uncompressed_payload, context)?; + + let mut uncompressed_chunks = uncompressed_payload.chunks(0xFF_FF_FF); + for uncompressed_chunk in uncompressed_chunks.by_ref() { + let mut compressed_payload = Vec::with_capacity(uncompressed_chunk.len() + 7); + Self::add_compressed_packet( + self.sequence_id, + &self.compression_config, + &mut compressed_payload, + uncompressed_chunk, + ) + .await?; + + buffered_socket.write_with(compressed_payload.as_slice(), ())?; + + self.sequence_id = self.sequence_id.wrapping_add(1); + } + + Ok(()) + } + + pub(crate) fn uncompressed_write_with<'en, T, C, S: Socket>( + &mut self, + packet: T, + context: C, + buffered_socket: &mut BufferedSocket, + ) -> Result<(), Error> + where + T: ProtocolEncode<'en, C>, + { + self.sequence_id = 0; + let mut uncompressed_payload = Vec::with_capacity(0xFF_FF_FF); + packet.encode_with(&mut uncompressed_payload, context)?; + + let mut uncompressed_chunks = uncompressed_payload.chunks(0xFF_FF_FF); + for uncompressed_chunk in uncompressed_chunks.by_ref() { + let mut header = Vec::with_capacity(7); + header.put_uint_le(uncompressed_chunk.len() as u64, 3); + header.put_u8(self.sequence_id); + header.put_uint_le(0, 3); + + buffered_socket.write_with(header.as_slice(), ())?; + buffered_socket.write_with(uncompressed_chunk, ())?; + + self.sequence_id = self.sequence_id.wrapping_add(1); + } + + Ok(()) + } + + async fn add_compressed_packet( + sequence_id: u8, + compression: &CompressionConfig, + compressed_chunk: &mut Vec, + uncompressed_chunk: &[u8], + ) -> Result<(), Error> { + compressed_chunk.extend_from_slice(&[0; 7]); + + let compressed_payload_length = + Self::compress_chunk(compression, compressed_chunk, uncompressed_chunk).await?; + + let mut header = &mut compressed_chunk[0..7]; + header.put_uint_le(compressed_payload_length as u64, 3); + header.put_u8(sequence_id); + header.put_uint_le(uncompressed_chunk.len() as u64, 3); + + Ok(()) + } + + async fn compress_chunk( + compression: &CompressionConfig, + output: &mut Vec, + uncompressed_chunk: &[u8], + ) -> Result { + let offset = output.len(); + let mut cursor = Cursor::new(output); + cursor.set_position(offset as u64); + + let mut encoder = Encoder::new(compression, cursor)?; + + for chunk in uncompressed_chunk.chunks(encoder.get_chunk_size()) { + encoder.write_all(chunk)?; + yield_now().await; + } + let cursor = encoder.finish()?; + Ok(cursor.get_ref().len().saturating_sub(offset)) + } + } + + enum Encoder<'en> { + #[cfg(feature = "zlib-compression")] + Zlib(ZlibEncoder>>, u8), + #[cfg(feature = "zstd-compression")] + Zstd(ZstdEncoder<'en, Cursor<&'en mut Vec>>, u8), + } + + impl<'en> Encoder<'en> { + fn new( + compression_config: &CompressionConfig, + cursor: Cursor<&'en mut Vec>, + ) -> Result, Error> { + let encoder = match compression_config { + #[cfg(feature = "zlib-compression")] + CompressionConfig(Compression::Zlib, level) => Encoder::Zlib( + ZlibEncoder::new(cursor, ZlibCompression::new(*level as u32)), + *level, + ), + #[cfg(feature = "zstd-compression")] + CompressionConfig(Compression::Zstd, level) => { + Encoder::Zstd(ZstdEncoder::new(cursor, *level as i32)?, *level) + } + }; + Ok(encoder) + } + + fn write_all(&mut self, buf: &'en [u8]) -> Result<(), Error> { + match self { + #[cfg(feature = "zlib-compression")] + Encoder::Zlib(encoder, _) => encoder.write_all(buf)?, + #[cfg(feature = "zstd-compression")] + Encoder::Zstd(encoder, _) => encoder.write_all(buf)?, + } + Ok(()) + } + + fn finish(self) -> Result>, Error> { + let cursor = match self { + #[cfg(feature = "zlib-compression")] + Encoder::Zlib(encoder, _) => encoder.finish()?, + #[cfg(feature = "zstd-compression")] + Encoder::Zstd(encoder, _) => encoder.finish()?, + }; + Ok(cursor) + } + + // Chunk size is chosen based on lzbench benchmarks: + // https://github.com/inikep/lzbench?tab=readme-ov-file#benchmarks + // The target is to keep runtime under 50 ms. + fn get_chunk_size(&self) -> usize { + match self { + #[cfg(feature = "zlib-compression")] + Encoder::Zlib(_, level) => match level { + 1 => 4 * 1024, + 2..=4 => 2 * 1024, + 5..=6 => 1024, + _ => 512, + }, + #[cfg(feature = "zstd-compression")] + Encoder::Zstd(_, level) => match level { + 1..=2 => 16 * 1024, + 3..=4 => 8 * 1024, + 5..=6 => 4 * 1024, + 7..=10 => 2 * 1024, + 11..=12 => 1024, + 13..=14 => 512, + 15..=16 => 256, + 17..=20 => 128, + _ => 64, + }, + } + } + } + + struct CompressedPacketReader { + sequence_id: u8, + remaining_bytes: usize, + is_compressed: bool, + + decoder: Decoder, + input_buffer: Bytes, + input_buffer_pos: usize, + output_buffer: BytesMut, + } + + impl CompressedPacketReader { + async fn new( + buffered_socket: &mut BufferedSocket, + compression_config: &CompressionConfig, + ) -> Result { + let mut header: Bytes = buffered_socket.read(7).await?; + #[allow(clippy::cast_possible_truncation)] + let compressed_payload_length = header.get_uint_le(3) as usize; + let sequence_id = header.get_u8(); + #[allow(clippy::cast_possible_truncation)] + let uncompressed_payload_length = header.get_uint_le(3) as usize; + let decoder = Decoder::new(compression_config)?; + + Ok(CompressedPacketReader { + sequence_id, + remaining_bytes: compressed_payload_length, + is_compressed: uncompressed_payload_length > 0, + decoder, + + input_buffer: Bytes::new(), + input_buffer_pos: 0, + output_buffer: BytesMut::with_capacity(uncompressed_payload_length), + }) + } + + fn is_available(&self) -> bool { + !self.output_buffer.is_empty() + || self.input_buffer_pos < self.input_buffer.len() + || self.remaining_bytes > 0 + } + + async fn read( + &mut self, + buffered_socket: &mut BufferedSocket, + bytes_count: usize, + ) -> Result { + let chunk = if self.is_compressed { + self.decompress(buffered_socket, bytes_count).await? + } else { + let available_bytes_count = min(self.remaining_bytes, bytes_count); + let result: Bytes = buffered_socket.read(available_bytes_count).await?; + self.remaining_bytes = self.remaining_bytes.saturating_sub(result.len()); + result + }; + + Ok(chunk) + } + + async fn decompress( + &mut self, + buffered_socket: &mut BufferedSocket, + output_bytes_count: usize, + ) -> Result { + if self.output_buffer.len() >= output_bytes_count { + return Ok(self.output_buffer.split_to(output_bytes_count).freeze()); + } + + while self.output_buffer.len() < output_bytes_count { + let mut is_refill_required = self.input_buffer_pos >= self.input_buffer.len(); + + if !is_refill_required { + let input = &self.input_buffer[self.input_buffer_pos..]; + let (consumed_bytes_total_count, produced_bytes_total_count) = + self.decoder.decompress(input, &mut self.output_buffer)?; + + self.input_buffer_pos += consumed_bytes_total_count; + + if produced_bytes_total_count == 0 { + is_refill_required = true; + } + } + + if is_refill_required { + if self.remaining_bytes == 0 { + break; + } + let available_bytes = min(self.remaining_bytes, self.decoder.get_chunk_size()); + + self.input_buffer = buffered_socket.read(available_bytes).await?; + self.input_buffer_pos = 0; + self.remaining_bytes = + self.remaining_bytes.saturating_sub(self.input_buffer.len()); + + if self.input_buffer.is_empty() { + return Err(err_protocol!("Compressed input ended unexpectedly")); + } + } + } + + let available_bytes = min(self.output_buffer.len(), output_bytes_count); + Ok(self.output_buffer.split_to(available_bytes).freeze()) + } + } + + enum Decoder { + #[cfg(feature = "zlib-compression")] + Zlib(ZlibDecompressor), + #[cfg(feature = "zstd-compression")] + Zstd(ZstdDecoder<'static>), + } + impl Decoder { + // Chunk size is chosen based on lzbench benchmarks: + // https://github.com/inikep/lzbench?tab=readme-ov-file#benchmarks + // The target is to keep runtime under 50 ms. + fn get_chunk_size(&self) -> usize { + match self { + #[cfg(feature = "zlib-compression")] + Decoder::Zlib(_) => 16 * 1024, + #[cfg(feature = "zstd-compression")] + Decoder::Zstd(_) => 32 * 1024, + } + } + + fn new(compression_config: &CompressionConfig) -> Result { + let decoder = match compression_config.0 { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => Decoder::Zlib(ZlibDecompressor::new(true)), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => Decoder::Zstd(ZstdDecoder::new()?), + }; + Ok(decoder) + } + + fn decompress( + &mut self, + input: &[u8], + output: &mut BytesMut, + ) -> Result<(usize, usize), Error> { + let mut produced_bytes_total_count = 0; + let mut consumed_bytes_total_count = 0; + + match self { + #[cfg(feature = "zlib-compression")] + Decoder::Zlib(decoder) => { + let mut output_buffer = [0u8; 16 * 1024]; + while consumed_bytes_total_count < input.len() { + let consumed_bytes_count_before = decoder.total_in(); + let produced_bytes_count_before = decoder.total_out(); + + let status = decoder + .decompress( + &input[consumed_bytes_total_count..], + &mut output_buffer, + FlushDecompress::None, + ) + .map_err(|e| err_protocol!("Decompression error: {}", e))?; + + #[allow(clippy::cast_possible_truncation)] + let consumed_bytes_count = + (decoder.total_in() - consumed_bytes_count_before) as usize; + #[allow(clippy::cast_possible_truncation)] + let produced_bytes_count = + (decoder.total_out() - produced_bytes_count_before) as usize; + + if produced_bytes_count > 0 { + output.extend_from_slice(&output_buffer[..produced_bytes_count]); + } + + consumed_bytes_total_count += consumed_bytes_count; + produced_bytes_total_count += produced_bytes_count; + + match status { + // Not enough input data to continue decompression + Status::BufError => break, + Status::StreamEnd => { + if consumed_bytes_total_count < input.len() { + return Err(err_protocol!("Unexpected stream end")); + } else { + break; + } + } + Status::Ok => {} + } + } + } + #[cfg(feature = "zstd-compression")] + Decoder::Zstd(decoder) => { + let mut input_chunk = input; + let mut output_buffer = [0u8; 16 * 1024]; + + while !input_chunk.is_empty() { + let mut in_buf = InBuffer::around(input_chunk); + let mut out_buf = OutBuffer::around(&mut output_buffer[..]); + + let result = decoder.run(&mut in_buf, &mut out_buf)?; + + let consumed_bytes_count = in_buf.pos(); + let produced_bytes_count = out_buf.pos(); + + input_chunk = &input_chunk[consumed_bytes_count..]; + + if produced_bytes_count > 0 { + output.extend_from_slice(&output_buffer[..produced_bytes_count]); + } + + consumed_bytes_total_count += consumed_bytes_count; + produced_bytes_total_count += produced_bytes_count; + + // No progress made; waiting for the next input chunk + if consumed_bytes_count == 0 && produced_bytes_count == 0 { + break; + } + + if result == 0 && !input_chunk.is_empty() { + return Err(err_protocol!("Unexpected stream end")); + } + } + } + }; + + Ok((consumed_bytes_total_count, produced_bytes_total_count)) + } + } +} diff --git a/sqlx-mysql/src/connection/establish.rs b/sqlx-mysql/src/connection/establish.rs index f61654d876..3fd6643a6b 100644 --- a/sqlx-mysql/src/connection/establish.rs +++ b/sqlx-mysql/src/connection/establish.rs @@ -1,6 +1,3 @@ -use bytes::buf::Buf; -use bytes::Bytes; - use crate::common::StatementCache; use crate::connection::{tls, MySqlConnectionInner, MySqlStream, MAX_PACKET_SIZE}; use crate::error::Error; @@ -10,6 +7,8 @@ use crate::protocol::connect::{ }; use crate::protocol::Capabilities; use crate::{MySqlConnectOptions, MySqlConnection, MySqlSslMode}; +use bytes::buf::Buf; +use bytes::Bytes; impl MySqlConnection { pub(crate) async fn establish(options: &MySqlConnectOptions) -> Result { @@ -105,14 +104,17 @@ impl<'a> DoHandshake<'a> { None }; - stream.write_packet(HandshakeResponse { - charset: super::INITIAL_CHARSET, - max_packet_size: MAX_PACKET_SIZE, - username: &options.username, - database: options.database.as_deref(), - auth_plugin: plugin, - auth_response: auth_response.as_deref(), - })?; + stream + .write_packet(HandshakeResponse { + charset: super::INITIAL_CHARSET, + max_packet_size: MAX_PACKET_SIZE, + username: &options.username, + database: options.database.as_deref(), + auth_plugin: plugin, + auth_response: auth_response.as_deref(), + compression_configs: options.get_compression(), + }) + .await?; stream.flush().await?; @@ -121,7 +123,7 @@ impl<'a> DoHandshake<'a> { match packet[0] { 0x00 => { let _ok = packet.ok()?; - + stream = stream.maybe_enable_compression(options); break; } @@ -141,7 +143,7 @@ impl<'a> DoHandshake<'a> { ) .await?; - stream.write_packet(AuthSwitchResponse(response))?; + stream.write_packet(AuthSwitchResponse(response)).await?; stream.flush().await?; } diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index 569ad32722..8d4a69db34 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -16,6 +16,7 @@ use crate::transaction::Transaction; use crate::{MySql, MySqlConnectOptions}; mod auth; +mod compression; mod establish; mod executor; mod stream; diff --git a/sqlx-mysql/src/connection/stream.rs b/sqlx-mysql/src/connection/stream.rs index ff931b2f46..d7cd0074f1 100644 --- a/sqlx-mysql/src/connection/stream.rs +++ b/sqlx-mysql/src/connection/stream.rs @@ -1,19 +1,21 @@ use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; -use bytes::{Buf, Bytes, BytesMut}; - +use crate::connection::compression::CompressionMySqlStream; use crate::error::Error; use crate::io::MySqlBufExt; use crate::io::{ProtocolDecode, ProtocolEncode}; use crate::net::{BufferedSocket, Socket}; +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] +use crate::options::Compression; use crate::protocol::response::{EofPacket, ErrPacket, OkPacket, Status}; use crate::protocol::{Capabilities, Packet}; use crate::{MySqlConnectOptions, MySqlDatabaseError}; +use bytes::{Buf, Bytes, BytesMut}; pub struct MySqlStream> { // Wrapping the socket in `Box` allows us to unsize in-place. - pub(crate) socket: BufferedSocket, + pub(crate) compression_stream: CompressionMySqlStream, pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, @@ -49,19 +51,27 @@ impl MySqlStream { capabilities |= Capabilities::CONNECT_WITH_DB; } + #[cfg(any(feature = "zstd-compression", feature = "zlib-compression"))] + options.compression_configs.iter().for_each(|c| match c.0 { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => capabilities |= Capabilities::COMPRESS, + #[cfg(feature = "zstd-compression")] + Compression::Zstd => capabilities |= Capabilities::ZSTD_COMPRESSION_ALGORITHM, + }); + Self { waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, - socket: BufferedSocket::new(socket), + compression_stream: CompressionMySqlStream::not_compressed(BufferedSocket::new(socket)), is_tls: false, } } pub(crate) async fn wait_until_ready(&mut self) -> Result<(), Error> { - if !self.socket.write_buffer().is_empty() { - self.socket.flush().await?; + if !self.write_buffer().is_empty() { + self.flush().await?; } while !self.waiting.is_empty() { @@ -103,24 +113,33 @@ impl MySqlStream { T: ProtocolEncode<'en, Capabilities>, { self.sequence_id = 0; - self.write_packet(payload)?; + self.write_packet(payload).await?; self.flush().await?; Ok(()) } - pub(crate) fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> + pub(crate) async fn write_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> where T: ProtocolEncode<'en, Capabilities>, { - self.socket + self.compression_stream .write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) + .await + } + + pub(crate) fn write_uncompressed_packet<'en, T>(&mut self, payload: T) -> Result<(), Error> + where + T: ProtocolEncode<'en, Capabilities>, + { + self.compression_stream + .uncompressed_write_with(Packet(payload), (self.capabilities, &mut self.sequence_id)) } async fn recv_packet_part(&mut self) -> Result { // https://dev.mysql.com/doc/dev/mysql-server/8.0.12/page_protocol_basic_packets.html // https://mariadb.com/kb/en/library/0-packet/#standard-packet - let mut header: Bytes = self.socket.read(4).await?; + let mut header: Bytes = self.compression_stream.read_with(4, ()).await?; // cannot overflow #[allow(clippy::cast_possible_truncation)] @@ -129,9 +148,7 @@ impl MySqlStream { self.sequence_id = sequence_id.wrapping_add(1); - let payload: Bytes = self.socket.read(packet_size).await?; - - // TODO: packet compression + let payload: Bytes = self.compression_stream.read_with(packet_size, ()).await?; Ok(payload) } @@ -207,7 +224,22 @@ impl MySqlStream { pub fn boxed_socket(self) -> MySqlStream { MySqlStream { - socket: self.socket.boxed(), + compression_stream: self.compression_stream.boxed(), + server_version: self.server_version, + capabilities: self.capabilities, + sequence_id: self.sequence_id, + waiting: self.waiting, + is_tls: self.is_tls, + } + } + + pub fn maybe_enable_compression(self, options: &MySqlConnectOptions) -> Self { + MySqlStream { + compression_stream: CompressionMySqlStream::create( + self.compression_stream.socket, + &self.capabilities, + options.get_compression(), + ), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, @@ -221,12 +253,12 @@ impl Deref for MySqlStream { type Target = BufferedSocket; fn deref(&self) -> &Self::Target { - &self.socket + &self.compression_stream.socket } } impl DerefMut for MySqlStream { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.socket + &mut self.compression_stream.socket } } diff --git a/sqlx-mysql/src/connection/tls.rs b/sqlx-mysql/src/connection/tls.rs index 9034fbd63a..582d7676bf 100644 --- a/sqlx-mysql/src/connection/tls.rs +++ b/sqlx-mysql/src/connection/tls.rs @@ -1,3 +1,4 @@ +use crate::connection::compression::CompressionMySqlStream; use crate::connection::{MySqlStream, Waiting}; use crate::error::Error; use crate::net::tls::TlsConfig; @@ -66,15 +67,17 @@ pub(super) async fn maybe_upgrade( }; // Request TLS upgrade - stream.write_packet(SslRequest { - max_packet_size: super::MAX_PACKET_SIZE, - charset: super::INITIAL_CHARSET, - })?; + stream + .write_packet(SslRequest { + max_packet_size: super::MAX_PACKET_SIZE, + charset: super::INITIAL_CHARSET, + }) + .await?; stream.flush().await?; tls::handshake( - stream.socket.into_inner(), + stream.compression_stream.socket.into_inner(), tls_config, MapStream { server_version: stream.server_version, @@ -91,7 +94,9 @@ impl WithSocket for MapStream { async fn with_socket(self, socket: S) -> Self::Output { MySqlStream { - socket: BufferedSocket::new(Box::new(socket)), + compression_stream: CompressionMySqlStream::not_compressed(BufferedSocket::new( + Box::new(socket), + )), server_version: self.server_version, capabilities: self.capabilities, sequence_id: self.sequence_id, diff --git a/sqlx-mysql/src/lib.rs b/sqlx-mysql/src/lib.rs index 7aa14256f3..da4b7ae715 100644 --- a/sqlx-mysql/src/lib.rs +++ b/sqlx-mysql/src/lib.rs @@ -42,7 +42,7 @@ pub use column::MySqlColumn; pub use connection::MySqlConnection; pub use database::MySql; pub use error::MySqlDatabaseError; -pub use options::{MySqlConnectOptions, MySqlSslMode}; +pub use options::{Compression, CompressionConfig, MySqlConnectOptions, MySqlSslMode}; pub use query_result::MySqlQueryResult; pub use row::MySqlRow; pub use statement::MySqlStatement; diff --git a/sqlx-mysql/src/options/mod.rs b/sqlx-mysql/src/options/mod.rs index 421bfb700e..83f655c942 100644 --- a/sqlx-mysql/src/options/mod.rs +++ b/sqlx-mysql/src/options/mod.rs @@ -1,3 +1,5 @@ +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] +use sqlx_core::Error; use std::path::{Path, PathBuf}; mod connect; @@ -80,6 +82,119 @@ pub struct MySqlConnectOptions { pub(crate) no_engine_substitution: bool, pub(crate) timezone: Option, pub(crate) set_names: bool, + pub(crate) compression_configs: Vec, +} + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct CompressionConfig( + pub(crate) Compression, + #[cfg_attr( + not(all(feature = "zlib-compression", feature = "zstd-compression")), + allow(dead_code) + )] + pub(crate) u8, +); + +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum Compression { + #[cfg(feature = "zlib-compression")] + Zlib, + #[cfg(feature = "zstd-compression")] + Zstd, +} + +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] +impl Compression { + /// Selects a default compression level optimized for both encoding speed and output size. + pub fn default(self) -> CompressionConfig { + match self { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => CompressionConfig(self, 6), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => CompressionConfig(self, 3), + } + } + + /// Optimize for the best speed of encoding. + pub fn fast(self) -> CompressionConfig { + CompressionConfig(self, 1) + } + + /// Optimize for maximum compression ratio. + /// + /// This mode favors smaller output size at the cost of significantly slower + /// compression speed. At high levels, compression itself may become the main + /// bottleneck rather than I/O or network transfer. + /// + /// Recommended only for offline or non-latency-sensitive workloads. + pub fn best(self) -> CompressionConfig { + match self { + #[cfg(feature = "zlib-compression")] + Compression::Zlib => CompressionConfig(self, 9), + #[cfg(feature = "zstd-compression")] + Compression::Zstd => CompressionConfig(self, 22), + } + } + + /// Sets the compression level for the current algorithm. + /// + /// Each compression method supports its own valid range of levels: + /// + /// - **Zstd:** `1` to `22` + /// - **Zlib:** `1` to `9` + /// + /// For **Zstd**, the configured level is applied on the server side. + /// + /// For **Zlib**, this setting affects only outgoing packets. Incoming data is + /// always decompressed using the server-defined compression level, which is + /// fixed at `6` and cannot be changed. + /// + /// If the provided level is valid for the selected algorithm, a new + /// [`CompressionConfig`] is returned. + /// If the level is out of range, an [`Error::Configuration`] is returned. + /// + /// # Returns + /// + /// - `Ok(CompressionConfig)` if the level is valid + /// - `Err(Error)` if the level is invalid + /// + /// # Examples + /// + /// ```rust + /// # use sqlx_mysql::Compression; + /// # #[cfg(feature = "zstd-compression")] + /// # { + /// let good = Compression::Zstd.level(5); + /// assert!(good.is_ok()); + /// # } + /// # #[cfg(feature = "zlib-compression")] + /// # { + /// let bad = Compression::Zlib.level(42); + /// assert!(bad.is_err()); + /// # } + /// ``` + pub fn level(self, value: u8) -> Result { + let range = match self { + #[cfg(feature = "zstd-compression")] + Compression::Zstd => 1..=22, + #[cfg(feature = "zlib-compression")] + Compression::Zlib => 1..=9, + }; + + range + .contains(&value) + .then_some(CompressionConfig(self, value)) + .ok_or_else(|| { + Error::Configuration( + format!( + "Illegal compression level for {self:?}: expected {}..={}, got {value}", + range.start(), + range.end() + ) + .into(), + ) + }) + } } impl Default for MySqlConnectOptions { @@ -111,6 +226,7 @@ impl MySqlConnectOptions { no_engine_substitution: true, timezone: Some(String::from("+00:00")), set_names: true, + compression_configs: vec![], } } @@ -414,6 +530,32 @@ impl MySqlConnectOptions { self.set_names = flag_val; self } + + /// Sets the compression configuration for the connection. + /// + /// Compression is disabled by default. + /// + /// The client will negotiate compression with the server using the provided + /// configurations, in the given order. The first compression algorithm + /// supported by the server will be selected. If none of the specified + /// algorithms are supported, the connection falls back to uncompressed mode. + /// + /// # Example + /// + /// ```rust + /// # use sqlx_mysql::{MySqlConnectOptions, Compression}; + /// let options = MySqlConnectOptions::new() + /// .compression(vec![ + ///# #[cfg(feature = "zlib-compression")] + /// Compression::Zlib.fast(), + ///# #[cfg(feature = "zstd-compression")] + /// Compression::Zstd.default(), + /// ]); + /// ``` + pub fn compression(mut self, compression: Vec) -> Self { + self.compression_configs = compression; + self + } } impl MySqlConnectOptions { @@ -526,4 +668,22 @@ impl MySqlConnectOptions { pub fn get_collation(&self) -> Option<&str> { self.collation.as_deref() } + + /// Get compression + /// + /// # Example + /// + /// ```rust + /// # #[cfg(feature = "zlib-compression")] + /// # { + /// # use sqlx_mysql::{Compression, CompressionConfig, MySqlConnectOptions}; + /// let options = MySqlConnectOptions::new() + /// .compression(vec![Compression::Zlib.fast()]); + /// + /// assert_eq!(options.get_compression(), &[Compression::Zlib.fast()]); + /// # } + /// ``` + pub fn get_compression(&self) -> &[CompressionConfig] { + &self.compression_configs + } } diff --git a/sqlx-mysql/src/options/parse.rs b/sqlx-mysql/src/options/parse.rs index e31ddc46d4..44f8e27f22 100644 --- a/sqlx-mysql/src/options/parse.rs +++ b/sqlx-mysql/src/options/parse.rs @@ -1,11 +1,11 @@ -use std::str::FromStr; - +use super::MySqlConnectOptions; +use crate::error::Error; +use crate::MySqlSslMode; +#[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] +use crate::{Compression, CompressionConfig}; use percent_encoding::{percent_decode_str, utf8_percent_encode, NON_ALPHANUMERIC}; use sqlx_core::Url; - -use crate::{error::Error, MySqlSslMode}; - -use super::MySqlConnectOptions; +use std::str::FromStr; impl MySqlConnectOptions { pub(crate) fn parse_from_url(url: &Url) -> Result { @@ -80,6 +80,36 @@ impl MySqlConnectOptions { options = options.timezone(Some(value.to_string())); } + #[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] + "compression" => { + let mut configs: Vec = vec![]; + for c in value.split(",") { + let (algorithm, level) = c + .split_once(":") + .ok_or_else(|| { + Error::Configuration( + format!( + "Invalid compression parameter. Expected algorithm:level, but got '{}'", + value + ).into(), + ) + })?; + let compression = match algorithm { + #[cfg(feature = "zlib-compression")] + "zlib" => Ok(Compression::Zlib), + #[cfg(feature = "zstd-compression")] + "zstd" => Ok(Compression::Zstd), + _ => Err(Error::Configuration( + format!("Unknown compression algorithm: {}", algorithm).into(), + )), + }?; + let compression_config = + compression.level(level.parse().map_err(Error::config)?)?; + configs.push(compression_config); + } + options = options.compression(configs); + } + _ => {} } } @@ -143,6 +173,23 @@ impl MySqlConnectOptions { .append_pair("socket", &socket.to_string_lossy()); } + #[cfg(any(feature = "zlib-compression", feature = "zstd-compression"))] + if !&self.compression_configs.is_empty() { + let values = self + .compression_configs + .iter() + .map(|c| match c { + #[cfg(feature = "zstd-compression")] + CompressionConfig(Compression::Zstd, level) => format!("zstd:{}", level), + #[cfg(feature = "zlib-compression")] + CompressionConfig(Compression::Zlib, level) => format!("zlib:{}", level), + }) + .collect::>() + .join(","); + + url.query_pairs_mut().append_pair("compression", &values); + } + url } } @@ -185,6 +232,44 @@ fn it_returns_the_parsed_url() { assert_eq!(expected_url, opts.build_url()); } +#[test] +#[cfg(feature = "zstd-compression")] +fn it_returns_the_build_url_with_zstd_compression_param() { + let url = "mysql://username:p@ssw0rd@hostname:3306/database"; + let opts = MySqlConnectOptions::from_str(url) + .unwrap() + .compression(vec![Compression::Zstd.fast()]); + + let mut expected_url = Url::parse(url).unwrap(); + let mut query_string = String::new(); + // MySqlConnectOptions defaults + query_string += "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; + query_string += "&compression=zstd%3A1"; + + expected_url.set_query(Some(&query_string)); + + assert_eq!(expected_url, opts.build_url()); +} + +#[test] +#[cfg(feature = "zlib-compression")] +fn it_returns_the_build_url_with_compression_params() { + let url = "mysql://username:p@ssw0rd@hostname:3306/database"; + let opts = MySqlConnectOptions::from_str(url) + .unwrap() + .compression(vec![Compression::Zlib.best()]); + + let mut expected_url = Url::parse(url).unwrap(); + let mut query_string = String::new(); + // MySqlConnectOptions defaults + query_string += "ssl-mode=PREFERRED&charset=utf8mb4&statement-cache-capacity=100"; + query_string += "&compression=zlib%3A9"; + + expected_url.set_query(Some(&query_string)); + + assert_eq!(expected_url, opts.build_url()); +} + #[test] fn it_parses_timezone() { let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?timezone=%2B08:00" @@ -197,3 +282,46 @@ fn it_parses_timezone() { .unwrap(); assert_eq!(opts.timezone.as_deref(), Some("+08:00")); } + +#[test] +#[cfg(feature = "zstd-compression")] +fn it_parses_compression() { + let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zstd:10" + .parse() + .unwrap(); + + assert_eq!( + opts.get_compression(), + &[Compression::Zstd.level(10).unwrap()] + ); +} + +#[test] +#[cfg(feature = "zlib-compression")] +fn it_parses_zlib_compression() { + let opts: MySqlConnectOptions = "mysql://user:password@hostname/database?compression=zlib:2" + .parse() + .unwrap(); + + assert_eq!( + opts.get_compression(), + &[Compression::Zlib.level(2).unwrap()] + ); +} + +#[test] +#[cfg(all(feature = "zlib-compression", feature = "zstd-compression"))] +fn it_parses_list_of_compression_algorithms() { + let opts: MySqlConnectOptions = + "mysql://user:password@hostname/database?compression=zlib:1,zstd:2" + .parse() + .unwrap(); + + assert_eq!( + opts.get_compression(), + &[ + Compression::Zlib.level(1).unwrap(), + Compression::Zstd.level(2).unwrap() + ] + ); +} diff --git a/sqlx-mysql/src/protocol/connect/handshake_response.rs b/sqlx-mysql/src/protocol/connect/handshake_response.rs index 6911419d98..a1999ab852 100644 --- a/sqlx-mysql/src/protocol/connect/handshake_response.rs +++ b/sqlx-mysql/src/protocol/connect/handshake_response.rs @@ -1,9 +1,11 @@ use crate::io::MySqlBufMutExt; use crate::io::{BufMutExt, ProtocolEncode}; +#[cfg(feature = "zstd-compression")] +use crate::options::Compression; use crate::protocol::auth::AuthPlugin; use crate::protocol::connect::ssl_request::SslRequest; use crate::protocol::Capabilities; - +use crate::CompressionConfig; // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse // https://mariadb.com/kb/en/connection/#client-handshake-response @@ -25,6 +27,10 @@ pub struct HandshakeResponse<'a> { /// Opaque authentication response pub auth_response: Option<&'a [u8]>, + + /// compression configurations + #[cfg_attr(not(feature = "zstd-compression"), allow(dead_code))] + pub compression_configs: &'a [CompressionConfig], } impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { @@ -77,6 +83,18 @@ impl ProtocolEncode<'_, Capabilities> for HandshakeResponse<'_> { } } + #[cfg(feature = "zstd-compression")] + if context.contains(Capabilities::ZSTD_COMPRESSION_ALGORITHM) { + let compression_config = self + .compression_configs + .iter() + .find(|c| c.0 == Compression::Zstd); + + if let Some(CompressionConfig(Compression::Zstd, level)) = compression_config { + buf.push(*level) + } + } + Ok(()) } } diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index 18db30b183..37c1e7ef42 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -63,7 +63,7 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.stream.sequence_id = 0; conn.inner .stream - .write_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) + .write_uncompressed_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; diff --git a/tests/mysql/compression.rs b/tests/mysql/compression.rs new file mode 100644 index 0000000000..f8a218aaeb --- /dev/null +++ b/tests/mysql/compression.rs @@ -0,0 +1,25 @@ +#[cfg(any(feature = "mysql-zstd-compression", feature = "mysql-zlib-compression"))] +mod compression_tests { + use sqlx::Row; + use sqlx_mysql::MySql; + use sqlx_test::new; + + #[sqlx_macros::test] + async fn it_connects_with_compression() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let rows = sqlx::raw_sql(r#"SHOW SESSION STATUS LIKE 'Compression'"#) + .fetch_all(&mut conn) + .await?; + + let result = rows + .first() + .map(|r| r.try_get::(1).unwrap_or_default()) + .unwrap_or_default(); + + assert!(!rows.is_empty()); + assert_eq!(result, "ON"); + + Ok(()) + } +} diff --git a/tests/mysql/rustsec.rs b/tests/mysql/rustsec.rs index 8d8db0c250..41ad56753c 100644 --- a/tests/mysql/rustsec.rs +++ b/tests/mysql/rustsec.rs @@ -1,4 +1,5 @@ use sqlx::{Error, MySql}; +use sqlx_mysql::MySqlDatabaseError; use std::io; use sqlx_test::new; @@ -29,8 +30,8 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { "CREATE TEMPORARY TABLE injection_target(id INTEGER PRIMARY KEY AUTO_INCREMENT, message TEXT);\n\ INSERT INTO injection_target(message) VALUES ('existing message');", ) - .execute(&mut conn) - .await?; + .execute(&mut conn) + .await?; // We can't concatenate a query string together like the other tests // because it would just demonstrate a regular old SQL injection. @@ -42,16 +43,22 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { if let Err(e) = res { // Connection rejected the query; we're happy. // - // Current observed behavior is that `mysqld` closes the connection before we're even done - // sending the message, giving us a "Broken pipe" error. + // If a packet exceeds `max_allowed_packet`, MySQL returns ER_NET_PACKET_TOO_LARGE + // and closes the connection. Depending on timing, the client may instead observe + // "Lost connection to MySQL server during query" or a local "Broken pipe" error. // - // As it turns out, MySQL has a tight limit on packet sizes (even after splitting) - // by default: https://dev.mysql.com/doc/refman/8.4/en/packet-too-large.html - if matches!(e, Error::Io(ref ioe) if ioe.kind() == io::ErrorKind::BrokenPipe) { - return Ok(()); + // See: https://dev.mysql.com/doc/refman/8.4/en/packet-too-large.html + match e { + Error::Database(ref dbe) => { + let err_net_packet_too_large = 1153; + return match dbe.try_downcast_ref::() { + Some(error) if error.number() == err_net_packet_too_large => Ok(()), + _ => panic!("unexpected error: {e:?}"), + }; + } + Error::Io(ref ioe) if ioe.kind() == io::ErrorKind::BrokenPipe => return Ok(()), + _ => panic!("unexpected error: {e:?}"), } - - panic!("unexpected error: {e:?}"); } let messages: Vec = diff --git a/tests/x.py b/tests/x.py index e1308f2fa4..34e06c474d 100755 --- a/tests/x.py +++ b/tests/x.py @@ -81,13 +81,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data environ["RUSTFLAGS"] = "--cfg sqlite_ipaddr" if platform.system() == "Linux": if os.environ.get("LD_LIBRARY_PATH"): - environ["LD_LIBRARY_PATH"]= os.environ.get("LD_LIBRARY_PATH") + ":"+ os.getcwd() + environ["LD_LIBRARY_PATH"] = os.environ.get("LD_LIBRARY_PATH") + ":" + os.getcwd() else: - environ["LD_LIBRARY_PATH"]=os.getcwd() - + environ["LD_LIBRARY_PATH"] = os.getcwd() if service is not None: - database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", cwd=dir_tests) + database_url = start_database(service, database="sqlite/sqlite.db" if service == "sqlite" else "sqlx", + cwd=dir_tests) if database_url_args: database_url += "?" + database_url_args @@ -209,16 +209,30 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data for version in ["8", "5_7"]: # Since docker mysql 5.7 using yaSSL(It only supports TLSv1.1), avoid running when using rustls. # https://github.com/docker-library/mysql/issues/567 - if not(version == "5_7" and tls == "rustls"): + if not (version == "5_7" and tls == "rustls"): run( f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}", service=f"mysql_{version}", tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version} zlib-compression", + database_url_args="compression=zlib:1", + service=f"mysql_{version}", + tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", + ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zstd-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version} zstd-compression", + database_url_args="compression=zstd:1", + service=f"mysql_{version}", + tag=f"mysql_{version}" if runtime == "async-std" else f"mysql_{version}_{runtime}", + ) ## +client-ssl - if tls != "none" and not(version == "5_7" and tls == "rustls"): + if tls != "none" and not (version == "5_7" and tls == "rustls"): run( f"cargo test --no-default-features --features any,mysql,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", comment=f"test mysql {version}_client_ssl no-password", @@ -226,6 +240,20 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data service=f"mysql_{version}_client_ssl", tag=f"mysql_{version}_client_ssl_no_password" if runtime == "async-std" else f"mysql_{version}_client_ssl_no_password_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version}_client_ssl no-password zlib-compression", + database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1", + service=f"mysql_{version}_client_ssl", + tag=f"mysql_{version}_client_ssl_no_password" if runtime == "async-std" else f"mysql_{version}_client_ssl_no_password_{runtime}", + ) + run( + f"cargo test --no-default-features --features any,mysql,mysql,mysql-zstd-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mysql {version}_client_ssl no-password zstd-compression", + database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zstd:1", + service=f"mysql_{version}_client_ssl", + tag=f"mysql_{version}_client_ssl_no_password" if runtime == "async-std" else f"mysql_{version}_client_ssl_no_password_{runtime}", + ) # # mariadb @@ -238,6 +266,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data service=f"mariadb_{version}", tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mariadb {version} zlib-compression", + database_url_args="compression=zlib:1", + service=f"mariadb_{version}", + tag=f"mariadb_{version}" if runtime == "async-std" else f"mariadb_{version}_{runtime}", + ) ## +client-ssl if tls != "none": @@ -248,6 +283,13 @@ def run(command, comment=None, env=None, service=None, tag=None, args=None, data service=f"mariadb_{version}_client_ssl", tag=f"mariadb_{version}_client_ssl_no_password" if runtime == "async-std" else f"mariadb_{version}_client_ssl_no_password_{runtime}", ) + run( + f"cargo test --no-default-features --features any,mysql,mysql-zlib-compression,macros,_unstable-all-types,runtime-{runtime},tls-{tls}", + comment=f"test mariadb {version}_client_ssl no-password zlib-compression", + database_url_args="sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt&compression=zlib:1", + service=f"mariadb_{version}_client_ssl", + tag=f"mariadb_{version}_client_ssl_no_password" if runtime == "async-std" else f"mariadb_{version}_client_ssl_no_password_{runtime}", + ) # TODO: Use [grcov] if available # ~/.cargo/bin/grcov tests/.cache/target/debug -s sqlx-core/ -t html --llvm --branch -o ./target/debug/coverage