From 796be47987202681516b0b8b378c5faf63b9c3fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sat, 14 Feb 2026 22:42:02 -0500 Subject: [PATCH 01/33] feat: add MSSQL (SQL Server) support via tiberius driver Introduces a new `sqlx-mssql` crate that wraps the tiberius TDS protocol driver (v0.12) behind sqlx's trait system, enabling SQL Server connectivity with the same ergonomic API as the existing MySQL, PostgreSQL, and SQLite drivers. Key implementation details: - Full Database/Connection/Executor/Row/Value trait implementations - Connection via tiberius::Client with runtime-agnostic socket bridging - Eager result collection (tiberius QueryStream borrows &mut Client) - Transaction support with MSSQL-specific savepoint syntax - Type mappings: bool, u8, i8, i16, i32, i64, f32, f64, String, Vec - Any driver integration and migration support - Test infrastructure with TestSupport for #[sqlx::test] - sp_describe_first_result_set for statement metadata and nullability - Fixes to existing MSSQL test files (imports, placeholders, lifetimes) Author: Pablo Carrera --- Cargo.lock | 122 +++++++++ Cargo.toml | 48 +++- sqlx-macros-core/Cargo.toml | 14 +- sqlx-macros/Cargo.toml | 1 + sqlx-mssql/Cargo.toml | 66 +++++ sqlx-mssql/src/any.rs | 220 ++++++++++++++++ sqlx-mssql/src/arguments.rs | 47 ++++ sqlx-mssql/src/column.rs | 32 +++ sqlx-mssql/src/connection/establish.rs | 44 ++++ sqlx-mssql/src/connection/executor.rs | 340 +++++++++++++++++++++++++ sqlx-mssql/src/connection/mod.rs | 96 +++++++ sqlx-mssql/src/database.rs | 56 ++++ sqlx-mssql/src/error.rs | 127 +++++++++ sqlx-mssql/src/io.rs | 72 ++++++ sqlx-mssql/src/lib.rs | 71 ++++++ sqlx-mssql/src/migrate.rs | 278 ++++++++++++++++++++ sqlx-mssql/src/options/connect.rs | 35 +++ sqlx-mssql/src/options/mod.rs | 201 +++++++++++++++ sqlx-mssql/src/options/parse.rs | 126 +++++++++ sqlx-mssql/src/query_result.rs | 30 +++ sqlx-mssql/src/row.rs | 54 ++++ sqlx-mssql/src/statement.rs | 57 +++++ sqlx-mssql/src/testing/mod.rs | 194 ++++++++++++++ sqlx-mssql/src/transaction.rs | 126 +++++++++ sqlx-mssql/src/type_checking.rs | 35 +++ sqlx-mssql/src/type_info.rs | 68 +++++ sqlx-mssql/src/types/bool.rs | 41 +++ sqlx-mssql/src/types/bytes.rs | 73 ++++++ sqlx-mssql/src/types/float.rs | 74 ++++++ sqlx-mssql/src/types/int.rs | 188 ++++++++++++++ sqlx-mssql/src/types/mod.rs | 29 +++ sqlx-mssql/src/types/str.rs | 73 ++++++ sqlx-mssql/src/value.rs | 109 ++++++++ src/any/mod.rs | 2 + src/lib.rs | 7 + tests/mssql/describe.rs | 6 +- tests/mssql/mssql.rs | 32 +-- 37 files changed, 3159 insertions(+), 35 deletions(-) create mode 100644 sqlx-mssql/Cargo.toml create mode 100644 sqlx-mssql/src/any.rs create mode 100644 sqlx-mssql/src/arguments.rs create mode 100644 sqlx-mssql/src/column.rs create mode 100644 sqlx-mssql/src/connection/establish.rs create mode 100644 sqlx-mssql/src/connection/executor.rs create mode 100644 sqlx-mssql/src/connection/mod.rs create mode 100644 sqlx-mssql/src/database.rs create mode 100644 sqlx-mssql/src/error.rs create mode 100644 sqlx-mssql/src/io.rs create mode 100644 sqlx-mssql/src/lib.rs create mode 100644 sqlx-mssql/src/migrate.rs create mode 100644 sqlx-mssql/src/options/connect.rs create mode 100644 sqlx-mssql/src/options/mod.rs create mode 100644 sqlx-mssql/src/options/parse.rs create mode 100644 sqlx-mssql/src/query_result.rs create mode 100644 sqlx-mssql/src/row.rs create mode 100644 sqlx-mssql/src/statement.rs create mode 100644 sqlx-mssql/src/testing/mod.rs create mode 100644 sqlx-mssql/src/transaction.rs create mode 100644 sqlx-mssql/src/type_checking.rs create mode 100644 sqlx-mssql/src/type_info.rs create mode 100644 sqlx-mssql/src/types/bool.rs create mode 100644 sqlx-mssql/src/types/bytes.rs create mode 100644 sqlx-mssql/src/types/float.rs create mode 100644 sqlx-mssql/src/types/int.rs create mode 100644 sqlx-mssql/src/types/mod.rs create mode 100644 sqlx-mssql/src/types/str.rs create mode 100644 sqlx-mssql/src/value.rs diff --git a/Cargo.lock b/Cargo.lock index fe01e11720..2a4075bff7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -373,6 +373,19 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "asynchronous-codec" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4057f2c32adbb2fc158e22fb38433c8e9bbf76b75a4732c7c0cbaf695fb65568" +dependencies = [ + "bytes", + "futures-sink", + "futures-util", + "memchr", + "pin-project-lite", +] + [[package]] name = "atoi" version = "2.0.0" @@ -950,6 +963,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "connection-string" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "510ca239cf13b7f8d16a2b48f263de7b4f8c566f0af58d901031473c76afb1e3" + [[package]] name = "console" version = "0.15.11" @@ -1261,6 +1280,35 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "enumflags2" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1027f7680c853e056ebcec683615fb6fbbc07dbaa13b4d5d9442b146ded4ecef" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.104", +] + [[package]] name = "env_filter" version = "0.1.3" @@ -2788,6 +2836,12 @@ dependencies = [ "termtree", ] +[[package]] +name = "pretty-hex" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa0831dd7cc608c38a5e323422a0077678fa5744aa2be4ad91c4ece8eec8d5" + [[package]] name = "prettyplease" version = "0.2.35" @@ -3519,6 +3573,7 @@ dependencies = [ "serde_json", "sqlx-core", "sqlx-macros", + "sqlx-mssql", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -3881,6 +3936,7 @@ dependencies = [ "sha2", "smol", "sqlx-core", + "sqlx-mssql", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -3890,6 +3946,35 @@ dependencies = [ "url", ] +[[package]] +name = "sqlx-mssql" +version = "0.9.0-alpha.1" +dependencies = [ + "async-std", + "atoi", + "bigdecimal", + "bytes", + "chrono", + "dotenvy", + "either", + "futures-core", + "futures-io", + "futures-util", + "log", + "percent-encoding", + "rust_decimal", + "serde", + "sqlx", + "sqlx-core", + "thiserror 2.0.17", + "tiberius", + "time", + "tokio", + "tokio-util", + "tracing", + "uuid", +] + [[package]] name = "sqlx-mysql" version = "0.9.0-alpha.1" @@ -4221,6 +4306,29 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "tiberius" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1446cb4198848d1562301a3340424b4f425ef79f35ef9ee034769a9dd92c10d" +dependencies = [ + "async-trait", + "asynchronous-codec", + "byteorder", + "bytes", + "connection-string", + "encoding_rs", + "enumflags2", + "futures-util", + "num-traits", + "once_cell", + "pin-project-lite", + "pretty-hex", + "thiserror 1.0.69", + "tracing", + "uuid", +] + [[package]] name = "time" version = "0.3.41" @@ -4329,6 +4437,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.23" diff --git a/Cargo.toml b/Cargo.toml index c88ab231e2..19220e1316 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "sqlx-test", "sqlx-cli", # "sqlx-bench", + "sqlx-mssql", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -63,14 +64,14 @@ rustdoc-args = ["--cfg", "docsrs"] default = ["any", "macros", "migrate", "json"] derive = ["sqlx-macros/derive"] -macros = ["derive", "sqlx-macros/macros", "sqlx-core/offline", "sqlx-mysql?/offline", "sqlx-postgres?/offline", "sqlx-sqlite?/offline"] -migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] +macros = ["derive", "sqlx-macros/macros", "sqlx-core/offline", "sqlx-mssql?/offline", "sqlx-mysql?/offline", "sqlx-postgres?/offline", "sqlx-sqlite?/offline"] +migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mssql?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] # Enable parsing of `sqlx.toml` for configuring macros and migrations. sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-macros?/sqlx-toml", "sqlx-sqlite?/sqlx-toml"] # intended mainly for CI and docs -all-databases = ["mysql", "sqlite", "postgres", "any"] +all-databases = ["mssql", "mysql", "sqlite", "postgres", "any"] _unstable-all-types = [ "bigdecimal", "rust_decimal", @@ -117,7 +118,8 @@ _rt-tokio = [] _sqlite = [] # database -any = ["sqlx-core/any", "sqlx-mysql?/any", "sqlx-postgres?/any", "sqlx-sqlite?/any"] +any = ["sqlx-core/any", "sqlx-mssql?/any", "sqlx-mysql?/any", "sqlx-postgres?/any", "sqlx-sqlite?/any"] +mssql = ["sqlx-mssql", "sqlx-macros?/mssql"] postgres = ["sqlx-postgres", "sqlx-macros?/postgres"] mysql = ["sqlx-mysql", "sqlx-macros?/mysql"] sqlite = ["sqlite-bundled", "sqlite-deserialize", "sqlite-load-extension", "sqlite-unlock-notify"] @@ -147,17 +149,17 @@ sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"] sqlite-unlock-notify = ["sqlx-sqlite/unlock-notify"] # types -json = ["sqlx-core/json", "sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] +json = ["sqlx-core/json", "sqlx-macros?/json", "sqlx-mssql?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] -bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] +bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mssql?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-macros?/bit-vec", "sqlx-postgres?/bit-vec"] -chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mssql?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] ipnet = ["sqlx-core/ipnet", "sqlx-macros?/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros?/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgres?/mac_address"] -rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] -time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] -uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] +rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mssql?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] +time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mssql?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] +uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mssql?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] regexp = ["sqlx-sqlite?/regexp"] bstr = ["sqlx-core/bstr"] @@ -168,6 +170,7 @@ sqlx-macros-core = { version = "=0.9.0-alpha.1", path = "sqlx-macros-core" } sqlx-macros = { version = "=0.9.0-alpha.1", path = "sqlx-macros" } # Driver crates +sqlx-mssql = { version = "=0.9.0-alpha.1", path = "sqlx-mssql" } sqlx-mysql = { version = "=0.9.0-alpha.1", path = "sqlx-mysql" } sqlx-postgres = { version = "=0.9.0-alpha.1", path = "sqlx-postgres" } sqlx-sqlite = { version = "=0.9.0-alpha.1", path = "sqlx-sqlite" } @@ -214,6 +217,7 @@ default-features = false sqlx-core = { workspace = true, features = ["migrate"] } sqlx-macros = { workspace = true, optional = true } +sqlx-mssql = { workspace = true, optional = true } sqlx-mysql = { workspace = true, optional = true } sqlx-postgres = { workspace = true, optional = true } sqlx-sqlite = { workspace = true, optional = true } @@ -454,3 +458,27 @@ required-features = ["postgres"] name = "postgres-rustsec" path = "tests/postgres/rustsec.rs" required-features = ["postgres", "macros", "migrate"] + +# +# MSSQL +# + +[[test]] +name = "mssql" +path = "tests/mssql/mssql.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-types" +path = "tests/mssql/types.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-describe" +path = "tests/mssql/describe.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-macros" +path = "tests/mssql/macros.rs" +required-features = ["mssql", "macros"] diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 8702555086..bafaa3bf79 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -32,6 +32,7 @@ migrate = ["sqlx-core/migrate"] sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-sqlite?/sqlx-toml"] # database +mssql = ["sqlx-mssql"] mysql = ["sqlx-mysql"] postgres = ["sqlx-postgres"] sqlite = ["_sqlite", "sqlx-sqlite/bundled"] @@ -41,20 +42,21 @@ sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled"] sqlite-load-extension = ["sqlx-sqlite/load-extension"] # type integrations -json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] +json = ["sqlx-core/json", "sqlx-mssql?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] -bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] +bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mssql?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] -chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +chrono = ["sqlx-core/chrono", "sqlx-mssql?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] ipnet = ["sqlx-core/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] -rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] -time = ["sqlx-core/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] -uuid = ["sqlx-core/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] +rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mssql?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] +time = ["sqlx-core/time", "sqlx-mssql?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] +uuid = ["sqlx-core/uuid", "sqlx-mssql?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] [dependencies] sqlx-core = { workspace = true, features = ["offline"] } +sqlx-mssql = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-mysql = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-postgres = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-sqlite = { workspace = true, features = ["offline", "migrate"], optional = true } diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 95954d72ef..51b7be7ac5 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -33,6 +33,7 @@ migrate = ["sqlx-macros-core/migrate"] sqlx-toml = ["sqlx-macros-core/sqlx-toml"] # database +mssql = ["sqlx-macros-core/mssql"] mysql = ["sqlx-macros-core/mysql"] postgres = ["sqlx-macros-core/postgres"] sqlite = ["sqlx-macros-core/sqlite"] diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml new file mode 100644 index 0000000000..61cb280d9f --- /dev/null +++ b/sqlx-mssql/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "sqlx-mssql" +documentation = "https://docs.rs/sqlx" +description = "MSSQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." +version.workspace = true +license.workspace = true +edition.workspace = true +authors.workspace = true +repository.workspace = true +rust-version.workspace = true + +[features] +json = ["sqlx-core/json", "serde"] +any = ["sqlx-core/any"] +offline = ["sqlx-core/offline", "serde"] +migrate = ["sqlx-core/migrate"] + +# Type Integration features +bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] +chrono = ["dep:chrono", "sqlx-core/chrono"] +rust_decimal = ["dep:rust_decimal", "sqlx-core/rust_decimal"] +time = ["dep:time", "sqlx-core/time"] +uuid = ["dep:uuid", "sqlx-core/uuid"] + +[dependencies] +sqlx-core = { workspace = true } + +# TDS protocol driver +tiberius = { version = "0.12", default-features = false, features = ["tds73"] } + +# Futures crates +futures-core = { version = "0.3.19", default-features = false } +futures-io = "0.3.24" +futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } + +# Runtime bridging +tokio = { workspace = true, optional = true } +tokio-util = { version = "0.7", features = ["compat"], optional = true } +async-std = { workspace = true, optional = true } + +# Type Integrations (versions inherited from `[workspace.dependencies]`) +bigdecimal = { workspace = true, optional = true } +chrono = { workspace = true, optional = true } +rust_decimal = { workspace = true, optional = true } +time = { workspace = true, optional = true } +uuid = { workspace = true, optional = true } + +# Misc +bytes = "1.1.0" +either = "1.6.1" +log = "0.4.18" +tracing = { version = "0.1.37", features = ["log"] } +percent-encoding = "2.1.0" +atoi = "2.0" + +dotenvy.workspace = true +thiserror.workspace = true + +serde = { version = "1.0.144", optional = true } + +[dev-dependencies] +# FIXME: https://github.com/rust-lang/cargo/issues/15622 +sqlx = { path = "..", default-features = false, features = ["mssql"] } + +[lints] +workspace = true diff --git a/sqlx-mssql/src/any.rs b/sqlx-mssql/src/any.rs new file mode 100644 index 0000000000..9a026fda73 --- /dev/null +++ b/sqlx-mssql/src/any.rs @@ -0,0 +1,220 @@ +use crate::{ + Mssql, MssqlColumn, MssqlConnectOptions, MssqlConnection, MssqlQueryResult, MssqlRow, + MssqlTransactionManager, MssqlTypeInfo, +}; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::{stream, FutureExt, StreamExt, TryStreamExt}; +use sqlx_core::any::{ + AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, + AnyStatement, AnyTypeInfo, AnyTypeInfoKind, +}; +use sqlx_core::connection::Connection; +use sqlx_core::database::Database; +use sqlx_core::executor::Executor; +use sqlx_core::sql_str::SqlStr; +use sqlx_core::transaction::TransactionManager; +use std::future; + +sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Mssql); + +impl AnyConnectionBackend for MssqlConnection { + fn name(&self) -> &str { + ::NAME + } + + fn close(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { + Connection::close(*self).boxed() + } + + fn close_hard(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { + Connection::close_hard(*self).boxed() + } + + fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + Connection::ping(self).boxed() + } + + fn begin(&mut self, statement: Option) -> BoxFuture<'_, sqlx_core::Result<()>> { + MssqlTransactionManager::begin(self, statement).boxed() + } + + fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + MssqlTransactionManager::commit(self).boxed() + } + + fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + MssqlTransactionManager::rollback(self).boxed() + } + + fn start_rollback(&mut self) { + MssqlTransactionManager::start_rollback(self) + } + + fn get_transaction_depth(&self) -> usize { + MssqlTransactionManager::get_transaction_depth(self) + } + + fn shrink_buffers(&mut self) { + Connection::shrink_buffers(self); + } + + fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + Connection::flush(self).boxed() + } + + fn should_flush(&self) -> bool { + Connection::should_flush(self) + } + + #[cfg(feature = "migrate")] + fn as_migrate( + &mut self, + ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> { + Ok(self) + } + + fn fetch_many( + &mut self, + query: SqlStr, + _persistent: bool, + arguments: Option, + ) -> BoxStream<'_, sqlx_core::Result>> { + let arguments = match arguments.map(AnyArguments::convert_into).transpose() { + Ok(arguments) => arguments, + Err(error) => { + return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() + } + }; + + Box::pin( + stream::once(async move { + let results = self.run(query.as_str(), arguments).await?; + Ok::<_, sqlx_core::Error>(results) + }) + .map_ok(|results| { + futures_util::stream::iter(results.into_iter().map(|res| { + Ok(match res { + Either::Left(result) => Either::Left(map_result(result)), + Either::Right(row) => Either::Right(AnyRow::try_from(&row)?), + }) + })) + }) + .try_flatten(), + ) + } + + fn fetch_optional( + &mut self, + query: SqlStr, + _persistent: bool, + arguments: Option, + ) -> BoxFuture<'_, sqlx_core::Result>> { + let arguments = arguments + .map(AnyArguments::convert_into) + .transpose() + .map_err(sqlx_core::Error::Encode); + + Box::pin(async move { + let arguments = arguments?; + let results = self.run(query.as_str(), arguments).await?; + + for result in results { + if let Either::Right(row) = result { + return Ok(Some(AnyRow::try_from(&row)?)); + } + } + + Ok(None) + }) + } + + fn prepare_with<'c, 'q: 'c>( + &'c mut self, + sql: SqlStr, + _parameters: &[AnyTypeInfo], + ) -> BoxFuture<'c, sqlx_core::Result> { + Box::pin(async move { + let statement = Executor::prepare_with(self, sql, &[]).await?; + let column_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) + }) + } + + #[cfg(feature = "offline")] + fn describe( + &mut self, + sql: SqlStr, + ) -> BoxFuture<'_, sqlx_core::Result>> { + Box::pin(async move { + let describe = Executor::describe(self, sql).await?; + describe.try_into_any() + }) + } +} + +impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo { + type Error = sqlx_core::Error; + + fn try_from(type_info: &'a MssqlTypeInfo) -> Result { + Ok(AnyTypeInfo { + kind: match type_info.name.as_str() { + "TINYINT" => AnyTypeInfoKind::SmallInt, + "SMALLINT" => AnyTypeInfoKind::SmallInt, + "INT" => AnyTypeInfoKind::Integer, + "BIGINT" => AnyTypeInfoKind::BigInt, + "REAL" => AnyTypeInfoKind::Real, + "FLOAT" => AnyTypeInfoKind::Double, + "VARBINARY" | "BINARY" | "IMAGE" => AnyTypeInfoKind::Blob, + "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" => { + AnyTypeInfoKind::Text + } + _ => { + return Err(sqlx_core::Error::AnyDriverError( + format!("Any driver does not support MSSQL type {type_info:?}").into(), + )) + } + }, + }) + } +} + +impl<'a> TryFrom<&'a MssqlColumn> for AnyColumn { + type Error = sqlx_core::Error; + + fn try_from(column: &'a MssqlColumn) -> Result { + let type_info = AnyTypeInfo::try_from(&column.type_info)?; + + Ok(AnyColumn { + ordinal: column.ordinal, + name: column.name.clone(), + type_info, + }) + } +} + +impl<'a> TryFrom<&'a MssqlRow> for AnyRow { + type Error = sqlx_core::Error; + + fn try_from(row: &'a MssqlRow) -> Result { + AnyRow::map_from(row, row.column_names.clone()) + } +} + +impl<'a> TryFrom<&'a AnyConnectOptions> for MssqlConnectOptions { + type Error = sqlx_core::Error; + + fn try_from(any_opts: &'a AnyConnectOptions) -> Result { + let mut opts = Self::parse_from_url(&any_opts.database_url)?; + opts.log_settings = any_opts.log_settings.clone(); + Ok(opts) + } +} + +fn map_result(result: MssqlQueryResult) -> AnyQueryResult { + AnyQueryResult { + rows_affected: result.rows_affected, + last_insert_id: None, + } +} diff --git a/sqlx-mssql/src/arguments.rs b/sqlx-mssql/src/arguments.rs new file mode 100644 index 0000000000..7294d09f32 --- /dev/null +++ b/sqlx-mssql/src/arguments.rs @@ -0,0 +1,47 @@ +use crate::database::MssqlArgumentValue; +use crate::encode::Encode; +use crate::types::Type; +use crate::Mssql; +pub(crate) use sqlx_core::arguments::*; +use sqlx_core::error::BoxDynError; + +/// Implementation of [`Arguments`] for MSSQL. +#[derive(Debug, Default, Clone)] +pub struct MssqlArguments { + pub(crate) values: Vec, +} + +impl MssqlArguments { + pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'q, Mssql> + Type, + { + let is_null = value.encode(&mut self.values)?; + if is_null.is_null() { + // If the encoder signaled null but didn't push a value, push a Null + if self.values.last().map_or(true, |v| !matches!(v, MssqlArgumentValue::Null)) { + self.values.push(MssqlArgumentValue::Null); + } + } + Ok(()) + } +} + +impl Arguments for MssqlArguments { + type Database = Mssql; + + fn reserve(&mut self, len: usize, _size: usize) { + self.values.reserve(len); + } + + fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'t, Self::Database> + Type, + { + self.add(value) + } + + fn len(&self) -> usize { + self.values.len() + } +} diff --git a/sqlx-mssql/src/column.rs b/sqlx-mssql/src/column.rs new file mode 100644 index 0000000000..aac9df3a4e --- /dev/null +++ b/sqlx-mssql/src/column.rs @@ -0,0 +1,32 @@ +use crate::ext::ustr::UStr; +use crate::{Mssql, MssqlTypeInfo}; +pub(crate) use sqlx_core::column::*; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct MssqlColumn { + pub(crate) ordinal: usize, + pub(crate) name: UStr, + pub(crate) type_info: MssqlTypeInfo, +} + +impl Column for MssqlColumn { + type Database = Mssql; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &self.name + } + + fn type_info(&self) -> &MssqlTypeInfo { + &self.type_info + } + + fn origin(&self) -> ColumnOrigin { + // tiberius doesn't expose table origin information + ColumnOrigin::Expression + } +} diff --git a/sqlx-mssql/src/connection/establish.rs b/sqlx-mssql/src/connection/establish.rs new file mode 100644 index 0000000000..5f1191c98b --- /dev/null +++ b/sqlx-mssql/src/connection/establish.rs @@ -0,0 +1,44 @@ +use crate::common::StatementCache; +use crate::connection::MssqlConnectionInner; +use crate::error::{tiberius_err, Error}; +use crate::io::SocketAdapter; +use crate::{MssqlConnectOptions, MssqlConnection}; +use sqlx_core::net::{Socket, WithSocket}; + +impl MssqlConnection { + pub(crate) async fn establish(options: &MssqlConnectOptions) -> Result { + let config = options.to_tiberius_config(); + let log_settings = options.log_settings.clone(); + let cache_capacity = options.statement_cache_capacity; + + let handler = EstablishHandler { config }; + + crate::net::connect_tcp(&options.host, options.port, handler) + .await? + .map(|client| MssqlConnection { + inner: Box::new(MssqlConnectionInner { + client, + transaction_depth: 0, + pending_rollback: false, + log_settings, + cache_statement: StatementCache::new(cache_capacity), + }), + }) + } +} + +struct EstablishHandler { + config: tiberius::Config, +} + +impl WithSocket for EstablishHandler { + type Output = Result>>, Error>; + + async fn with_socket(self, socket: S) -> Self::Output { + let boxed: Box = Box::new(socket); + let adapter = SocketAdapter::new(boxed); + tiberius::Client::connect(self.config, adapter) + .await + .map_err(tiberius_err) + } +} diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs new file mode 100644 index 0000000000..bad31eb13b --- /dev/null +++ b/sqlx-mssql/src/connection/executor.rs @@ -0,0 +1,340 @@ +use crate::database::MssqlArgumentValue; +use crate::error::{tiberius_err, Error}; +use crate::executor::{Execute, Executor}; +use crate::ext::ustr::UStr; +use crate::logger::QueryLogger; +use crate::statement::{MssqlStatement, MssqlStatementMetadata}; +use crate::type_info::{type_name_for_tiberius, MssqlTypeInfo}; +use crate::value::{column_data_to_mssql_data, MssqlData}; +use crate::HashMap; +use crate::{ + Mssql, MssqlArguments, MssqlColumn, MssqlConnection, MssqlQueryResult, MssqlRow, +}; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::TryStreamExt; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; +use std::sync::Arc; + +impl MssqlConnection { + /// Execute a query, eagerly collecting all results. + /// + /// We collect eagerly because `tiberius::QueryStream` borrows `&mut Client`, + /// which prevents us from holding it across yield points alongside `&mut self`. + pub(crate) async fn run( + &mut self, + sql: &str, + arguments: Option, + ) -> Result>, Error> { + // Resolve any pending rollback first + crate::transaction::resolve_pending_rollback(self).await?; + + let mut logger = QueryLogger::new( + AssertSqlSafe(sql).into_sql_str(), + self.inner.log_settings.clone(), + ); + + let mut results = Vec::new(); + + if let Some(args) = arguments { + // Parameterized query using tiberius::Query + let mut query = tiberius::Query::new(sql); + + for arg in &args.values { + match arg { + MssqlArgumentValue::Null => { + query.bind(Option::<&str>::None); + } + MssqlArgumentValue::Bool(v) => { + query.bind(*v); + } + MssqlArgumentValue::U8(v) => { + query.bind(*v); + } + MssqlArgumentValue::I16(v) => { + query.bind(*v); + } + MssqlArgumentValue::I32(v) => { + query.bind(*v); + } + MssqlArgumentValue::I64(v) => { + query.bind(*v); + } + MssqlArgumentValue::F32(v) => { + query.bind(*v); + } + MssqlArgumentValue::F64(v) => { + query.bind(*v); + } + MssqlArgumentValue::String(v) => { + query.bind(v.as_str()); + } + MssqlArgumentValue::Binary(v) => { + query.bind(v.as_slice()); + } + } + } + + let stream = query.query(&mut self.inner.client).await.map_err(tiberius_err)?; + collect_results(stream, &mut results, &mut logger).await?; + } else { + // Simple query (no parameters) + let stream = self + .inner + .client + .simple_query(sql) + .await + .map_err(tiberius_err)?; + collect_results(stream, &mut results, &mut logger).await?; + } + + Ok(results) + } +} + +/// Collect all results from a tiberius QueryStream into a Vec. +async fn collect_results<'a>( + mut stream: tiberius::QueryStream<'a>, + results: &mut Vec>, + logger: &mut QueryLogger, +) -> Result<(), Error> { + // Process all result sets + let mut columns: Option>> = None; + let mut column_names: Option>> = None; + let mut rows_affected: u64 = 0; + + while let Some(item) = stream.try_next().await.map_err(tiberius_err)? { + match item { + tiberius::QueryItem::Metadata(meta) => { + // Build column info from metadata + let cols: Vec = meta + .columns() + .iter() + .enumerate() + .map(|(ordinal, col)| { + let name = UStr::new(col.name()); + let type_info = + MssqlTypeInfo::new(type_name_for_tiberius(&col.column_type())); + MssqlColumn { + ordinal, + name, + type_info, + } + }) + .collect(); + + let names: HashMap = cols + .iter() + .enumerate() + .map(|(i, col)| (col.name.clone(), i)) + .collect(); + + columns = Some(Arc::new(cols)); + column_names = Some(Arc::new(names)); + } + tiberius::QueryItem::Row(row) => { + let cols = columns.as_ref().expect("row received before metadata"); + let names = column_names.as_ref().expect("row received before metadata"); + + // Convert tiberius row to MssqlRow by iterating over cells + let values: Vec = row + .into_iter() + .map(|data| column_data_to_mssql_data(&data)) + .collect(); + + rows_affected += 1; + logger.increment_rows_returned(); + results.push(Either::Right(MssqlRow { + values, + columns: Arc::clone(cols), + column_names: Arc::clone(names), + })); + } + } + } + + // Report query result with total rows + logger.increase_rows_affected(rows_affected); + results.push(Either::Left(MssqlQueryResult { rows_affected })); + + Ok(()) +} + +impl<'c> Executor<'c> for &'c mut MssqlConnection { + type Database = Mssql; + + fn fetch_many<'e, 'q, E>( + self, + mut query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + let arguments = query.take_arguments().map_err(Error::Encode); + let _persistent = query.persistent(); + let sql = query.sql(); + + Box::pin(futures_util::stream::once(async move { + let arguments = arguments?; + let results = self.run(sql.as_str(), arguments).await?; + Ok::<_, Error>(results) + }) + .map_ok(|results| futures_util::stream::iter(results.into_iter().map(Ok))) + .try_flatten()) + } + + fn fetch_optional<'e, 'q, E>( + self, + query: E, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + let mut s = self.fetch_many(query); + + Box::pin(async move { + while let Some(v) = s.try_next().await? { + if let Either::Right(r) = v { + return Ok(Some(r)); + } + } + + Ok(None) + }) + } + + fn prepare_with<'e>( + self, + sql: SqlStr, + _parameters: &'e [MssqlTypeInfo], + ) -> BoxFuture<'e, Result> + where + 'c: 'e, + { + Box::pin(async move { + // Use sp_describe_first_result_set to get column metadata + let describe_sql = format!( + "EXEC sp_describe_first_result_set @tsql = N'{}'", + sql.as_str().replace('\'', "''") + ); + + let mut columns = Vec::new(); + let mut column_names = HashMap::new(); + + let stream = self + .inner + .client + .simple_query(&describe_sql) + .await + .map_err(tiberius_err)?; + + let rows: Vec = stream.into_first_result().await.map_err(tiberius_err)?; + + for (ordinal, row) in rows.iter().enumerate() { + let name: &str = row.get("name").unwrap_or(""); + let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); + // Extract the base type name (before any parenthesized length/precision) + let base_type = type_name.split('(').next().unwrap_or(type_name).trim(); + let type_info = MssqlTypeInfo::new(base_type.to_uppercase()); + + let ustr_name = UStr::new(name); + column_names.insert(ustr_name.clone(), ordinal); + columns.push(MssqlColumn { + ordinal, + name: ustr_name, + type_info, + }); + } + + Ok(MssqlStatement { + sql, + metadata: MssqlStatementMetadata { + columns: Arc::new(columns), + column_names: Arc::new(column_names), + parameters: 0, + }, + }) + }) + } + + #[doc(hidden)] + #[cfg(feature = "offline")] + fn describe<'e>( + self, + sql: SqlStr, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + Box::pin(async move { + // Query sp_describe_first_result_set directly so we can extract nullable info + let describe_sql = format!( + "EXEC sp_describe_first_result_set @tsql = N'{}'", + sql.as_str().replace('\'', "''") + ); + + let stream = self + .inner + .client + .simple_query(&describe_sql) + .await + .map_err(tiberius_err)?; + + let rows: Vec = + stream.into_first_result().await.map_err(tiberius_err)?; + + let mut columns = Vec::new(); + let mut column_names = HashMap::new(); + let mut nullable = Vec::new(); + + for (ordinal, row) in rows.iter().enumerate() { + let name: &str = row.get("name").unwrap_or(""); + let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); + let base_type = type_name.split('(').next().unwrap_or(type_name).trim(); + let type_info = MssqlTypeInfo::new(base_type.to_uppercase()); + let is_nullable: Option = row.get("is_nullable"); + + let ustr_name = UStr::new(name); + column_names.insert(ustr_name.clone(), ordinal); + columns.push(MssqlColumn { + ordinal, + name: ustr_name, + type_info, + }); + nullable.push(is_nullable); + } + + // Count parameters using sp_describe_undeclared_parameters + let param_sql = format!( + "EXEC sp_describe_undeclared_parameters @tsql = N'{}'", + sql.as_str().replace('\'', "''") + ); + let param_count = match self + .inner + .client + .simple_query(¶m_sql) + .await + { + Ok(stream) => stream + .into_first_result() + .await + .map_err(tiberius_err)? + .len(), + Err(_) => 0, + }; + + Ok(crate::describe::Describe { + parameters: Some(Either::Right(param_count)), + columns, + nullable, + }) + }) + } +} diff --git a/sqlx-mssql/src/connection/mod.rs b/sqlx-mssql/src/connection/mod.rs new file mode 100644 index 0000000000..2e4f06bc79 --- /dev/null +++ b/sqlx-mssql/src/connection/mod.rs @@ -0,0 +1,96 @@ +use std::fmt::{self, Debug, Formatter}; + +pub(crate) use sqlx_core::connection::*; +use sqlx_core::net::Socket; +use sqlx_core::sql_str::SqlSafeStr; + +use crate::common::StatementCache; +use crate::error::Error; +use crate::executor::Executor; +use crate::io::SocketAdapter; +use crate::statement::MssqlStatementMetadata; +use crate::transaction::Transaction; +use crate::{Mssql, MssqlConnectOptions}; + +mod establish; +mod executor; + +/// A connection to a MSSQL database. +pub struct MssqlConnection { + pub(crate) inner: Box, +} + +pub(crate) struct MssqlConnectionInner { + pub(crate) client: tiberius::Client>>, + pub(crate) transaction_depth: usize, + pub(crate) pending_rollback: bool, + pub(crate) log_settings: LogSettings, + pub(crate) cache_statement: StatementCache, +} + +impl Debug for MssqlConnection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MssqlConnection").finish() + } +} + +impl Connection for MssqlConnection { + type Database = Mssql; + + type Options = MssqlConnectOptions; + + async fn close(self) -> Result<(), Error> { + // tiberius doesn't have an explicit close; dropping the client closes the connection. + drop(self); + Ok(()) + } + + async fn close_hard(self) -> Result<(), Error> { + drop(self); + Ok(()) + } + + async fn ping(&mut self) -> Result<(), Error> { + self.execute("SELECT 1").await?; + Ok(()) + } + + fn flush(&mut self) -> impl std::future::Future> + Send + '_ { + // No-op for MSSQL since tiberius handles buffering internally. + std::future::ready(Ok(())) + } + + fn cached_statements_size(&self) -> usize { + self.inner.cache_statement.len() + } + + async fn clear_cached_statements(&mut self) -> Result<(), Error> { + self.inner.cache_statement.clear(); + Ok(()) + } + + fn should_flush(&self) -> bool { + false + } + + fn begin( + &mut self, + ) -> impl std::future::Future, Error>> + Send + '_ + { + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl SqlSafeStr, + ) -> impl std::future::Future, Error>> + Send + '_ + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into_sql_str())) + } + + fn shrink_buffers(&mut self) { + // No-op for MSSQL + } +} diff --git a/sqlx-mssql/src/database.rs b/sqlx-mssql/src/database.rs new file mode 100644 index 0000000000..45a5ece196 --- /dev/null +++ b/sqlx-mssql/src/database.rs @@ -0,0 +1,56 @@ +use crate::value::{MssqlValue, MssqlValueRef}; +use crate::{ + MssqlArguments, MssqlColumn, MssqlConnection, MssqlQueryResult, MssqlRow, MssqlStatement, + MssqlTransactionManager, MssqlTypeInfo, +}; +pub(crate) use sqlx_core::database::{Database, HasStatementCache}; + +/// MSSQL (SQL Server) database driver. +#[derive(Debug)] +pub struct Mssql; + +impl Database for Mssql { + type Connection = MssqlConnection; + + type TransactionManager = MssqlTransactionManager; + + type Row = MssqlRow; + + type QueryResult = MssqlQueryResult; + + type Column = MssqlColumn; + + type TypeInfo = MssqlTypeInfo; + + type Value = MssqlValue; + type ValueRef<'r> = MssqlValueRef<'r>; + + type Arguments = MssqlArguments; + type ArgumentBuffer = Vec; + + type Statement = MssqlStatement; + + const NAME: &'static str = "MSSQL"; + + const URL_SCHEMES: &'static [&'static str] = &["mssql", "sqlserver"]; +} + +impl HasStatementCache for Mssql {} + +/// A single argument value for MSSQL queries. +/// +/// Unlike MySQL/Postgres which use a byte buffer, MSSQL arguments are stored +/// as typed enum variants because tiberius requires typed `bind()` calls. +#[derive(Debug, Clone)] +pub enum MssqlArgumentValue { + Null, + Bool(bool), + U8(u8), + I16(i16), + I32(i32), + I64(i64), + F32(f32), + F64(f64), + String(String), + Binary(Vec), +} diff --git a/sqlx-mssql/src/error.rs b/sqlx-mssql/src/error.rs new file mode 100644 index 0000000000..7f2e6b1eaa --- /dev/null +++ b/sqlx-mssql/src/error.rs @@ -0,0 +1,127 @@ +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt::{self, Debug, Display, Formatter}; + +pub(crate) use sqlx_core::error::*; + +/// An error returned from the MSSQL database. +pub struct MssqlDatabaseError { + pub(crate) number: u32, + pub(crate) state: u8, + pub(crate) class: u8, + pub(crate) message: String, + pub(crate) server: Option, + pub(crate) procedure: Option, +} + +impl MssqlDatabaseError { + /// The error number returned by SQL Server. + pub fn number(&self) -> u32 { + self.number + } + + /// The error state. + pub fn state(&self) -> u8 { + self.state + } + + /// The severity class of the error. + pub fn class(&self) -> u8 { + self.class + } + + /// The human-readable error message. + pub fn server(&self) -> Option<&str> { + self.server.as_deref() + } + + /// The stored procedure name, if applicable. + pub fn procedure(&self) -> Option<&str> { + self.procedure.as_deref() + } +} + +impl Debug for MssqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MssqlDatabaseError") + .field("number", &self.number) + .field("state", &self.state) + .field("class", &self.class) + .field("message", &self.message) + .finish() + } +} + +impl Display for MssqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "(number {}, state {}): {}", self.number, self.state, self.message) + } +} + +impl StdError for MssqlDatabaseError {} + +impl DatabaseError for MssqlDatabaseError { + #[inline] + fn message(&self) -> &str { + &self.message + } + + fn code(&self) -> Option> { + Some(Cow::Owned(self.number.to_string())) + } + + #[doc(hidden)] + fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) { + self + } + + #[doc(hidden)] + fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) { + self + } + + #[doc(hidden)] + fn into_error(self: Box) -> Box { + self + } + + fn kind(&self) -> ErrorKind { + match self.number { + // Cannot insert duplicate key + 2601 | 2627 => ErrorKind::UniqueViolation, + // Foreign key constraint violation + 547 => ErrorKind::ForeignKeyViolation, + // Cannot insert NULL + 515 => ErrorKind::NotNullViolation, + // Check constraint violation + 2628 => ErrorKind::CheckViolation, + _ => ErrorKind::Other, + } + } +} + +/// Convert a tiberius error into an sqlx Error. +pub(crate) fn tiberius_err(err: tiberius::error::Error) -> Error { + match err { + tiberius::error::Error::Server(token_error) => { + Error::Database(Box::new(MssqlDatabaseError { + number: token_error.code(), + state: token_error.state(), + class: token_error.class(), + message: token_error.message().to_string(), + server: { + let s = token_error.server(); + if s.is_empty() { None } else { Some(s.to_string()) } + }, + procedure: { + let s = token_error.procedure(); + if s.is_empty() { None } else { Some(s.to_string()) } + }, + })) + } + tiberius::error::Error::Io { kind, message } => { + Error::Io(std::io::Error::new(kind, message)) + } + other => Error::Protocol(other.to_string()), + } +} diff --git a/sqlx-mssql/src/io.rs b/sqlx-mssql/src/io.rs new file mode 100644 index 0000000000..88baca0c52 --- /dev/null +++ b/sqlx-mssql/src/io.rs @@ -0,0 +1,72 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use sqlx_core::net::Socket; + +/// Adapter that wraps an sqlx-core `Socket` to implement `futures_io::AsyncRead + AsyncWrite`, +/// which is what tiberius requires. +pub(crate) struct SocketAdapter { + inner: S, +} + +impl SocketAdapter { + pub fn new(socket: S) -> Self { + Self { inner: socket } + } +} + +impl futures_io::AsyncRead for SocketAdapter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match self.inner.try_read(&mut &mut *buf) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.inner.poll_read_ready(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } +} + +impl futures_io::AsyncWrite for SocketAdapter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match self.inner.try_write(buf) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.inner.poll_write_ready(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_shutdown(cx) + } +} + +// Implement Unpin since we only access the inner socket through &mut self +impl Unpin for SocketAdapter {} diff --git a/sqlx-mssql/src/lib.rs b/sqlx-mssql/src/lib.rs new file mode 100644 index 0000000000..3a71eb9e72 --- /dev/null +++ b/sqlx-mssql/src/lib.rs @@ -0,0 +1,71 @@ +//! **MSSQL** (SQL Server) database driver. +#![deny(clippy::cast_possible_truncation)] +#![deny(clippy::cast_possible_wrap)] +#![deny(clippy::cast_sign_loss)] + +#[macro_use] +extern crate sqlx_core; + +use crate::executor::Executor; + +pub(crate) use sqlx_core::driver_prelude::*; + +#[cfg(feature = "any")] +pub mod any; + +mod arguments; +mod column; +mod connection; +mod database; +mod error; +mod io; +mod options; +mod query_result; +mod row; +mod statement; +mod transaction; +mod type_checking; +mod type_info; +pub mod types; +mod value; + +#[cfg(feature = "migrate")] +mod migrate; + +#[cfg(feature = "migrate")] +mod testing; + +pub use arguments::MssqlArguments; +pub use column::MssqlColumn; +pub use connection::MssqlConnection; +pub use database::Mssql; +pub use error::MssqlDatabaseError; +pub use options::MssqlConnectOptions; +pub use query_result::MssqlQueryResult; +pub use row::MssqlRow; +pub use statement::MssqlStatement; +pub use transaction::MssqlTransactionManager; +pub use type_info::MssqlTypeInfo; +pub use value::{MssqlValue, MssqlValueRef}; + +/// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL. +pub type MssqlPool = crate::pool::Pool; + +/// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MSSQL. +pub type MssqlPoolOptions = crate::pool::PoolOptions; + +/// An alias for [`Executor<'_, Database = Mssql>`][Executor]. +pub trait MssqlExecutor<'c>: Executor<'c, Database = Mssql> {} +impl<'c, T: Executor<'c, Database = Mssql>> MssqlExecutor<'c> for T {} + +/// An alias for [`Transaction`][crate::transaction::Transaction], specialized for MSSQL. +pub type MssqlTransaction<'c> = crate::transaction::Transaction<'c, Mssql>; + +// NOTE: required due to the lack of lazy normalization +impl_into_arguments_for_arguments!(MssqlArguments); +impl_acquire!(Mssql, MssqlConnection); +impl_column_index_for_row!(MssqlRow); +impl_column_index_for_statement!(MssqlStatement); + +// required because some databases have a different handling of NULL +impl_encode_for_option!(Mssql); diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs new file mode 100644 index 0000000000..5b0a22cf2c --- /dev/null +++ b/sqlx-mssql/src/migrate.rs @@ -0,0 +1,278 @@ +use std::str::FromStr; +use std::time::Duration; +use std::time::Instant; + +use futures_core::future::BoxFuture; +pub(crate) use sqlx_core::migrate::*; +use sqlx_core::sql_str::AssertSqlSafe; + +use crate::connection::{ConnectOptions, Connection}; +use crate::error::Error; +use crate::executor::Executor; +use crate::query::query; +use crate::query_as::query_as; +use crate::query_scalar::query_scalar; +use crate::{Mssql, MssqlConnectOptions, MssqlConnection}; + +fn parse_for_maintenance(url: &str) -> Result<(MssqlConnectOptions, String), Error> { + let mut options = MssqlConnectOptions::from_str(url)?; + + let database = if let Some(database) = &options.database { + database.to_owned() + } else { + return Err(Error::Configuration( + "DATABASE_URL does not specify a database".into(), + )); + }; + + // switch us to master database for create/drop commands + options.database = Some("master".to_owned()); + + Ok((options, database)) +} + +impl MigrateDatabase for Mssql { + async fn create_database(url: &str) -> Result<(), Error> { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let _ = conn + .execute(AssertSqlSafe(format!( + "CREATE DATABASE [{database}]" + ))) + .await?; + + Ok(()) + } + + async fn database_exists(url: &str) -> Result { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let exists: bool = query_scalar( + "SELECT CASE WHEN DB_ID(@p1) IS NOT NULL THEN 1 ELSE 0 END", + ) + .bind(database) + .fetch_one(&mut conn) + .await?; + + Ok(exists) + } + + async fn drop_database(url: &str) -> Result<(), Error> { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + // Force close existing connections before dropping + let _ = conn + .execute(AssertSqlSafe(format!( + "IF DB_ID('{database}') IS NOT NULL \ + BEGIN \ + ALTER DATABASE [{database}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{database}]; \ + END" + ))) + .await?; + + Ok(()) + } +} + +impl Migrate for MssqlConnection { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + self.execute(AssertSqlSafe(format!( + r#"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema_name}') + EXEC('CREATE SCHEMA [{schema_name}]')"# + ))) + .await?; + + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + self.execute(AssertSqlSafe(format!( + r#" +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{table_name}') +CREATE TABLE {table_name} ( + version BIGINT PRIMARY KEY, + description NVARCHAR(MAX) NOT NULL, + installed_on DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(), + success BIT NOT NULL, + checksum VARBINARY(MAX) NOT NULL, + execution_time BIGINT NOT NULL +); + "# + ))) + .await?; + + Ok(()) + }) + } + + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async move { + let row: Option<(i64,)> = query_as(AssertSqlSafe(format!( + "SELECT TOP 1 version FROM {table_name} WHERE success = 0 ORDER BY version" + ))) + .fetch_optional(self) + .await?; + + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async move { + let rows: Vec<(i64, Vec)> = query_as(AssertSqlSafe(format!( + "SELECT version, checksum FROM {table_name} ORDER BY version" + ))) + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum)| AppliedMigration { + version, + checksum: checksum.into(), + }) + .collect(); + + Ok(migrations) + }) + } + + fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + // Use sp_getapplock for advisory locking in MSSQL + let _ = self + .execute( + "EXEC sp_getapplock @Resource = 'sqlx_migrations', @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1" + ) + .await?; + + Ok(()) + }) + } + + fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + let _ = self + .execute( + "EXEC sp_releaseapplock @Resource = 'sqlx_migrations', @LockOwner = 'Session'" + ) + .await?; + + Ok(()) + }) + } + + fn apply<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async move { + let mut tx = self.begin().await?; + let start = Instant::now(); + + let _ = query(AssertSqlSafe(format!( + r#" + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) + VALUES ( @p1, @p2, 0, @p3, -1 ) + "# + ))) + .bind(migration.version) + .bind(&*migration.description) + .bind(&*migration.checksum) + .execute(&mut *tx) + .await?; + + let _ = tx + .execute(migration.sql.clone()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + let _ = query(AssertSqlSafe(format!( + r#" + UPDATE {table_name} + SET success = 1 + WHERE version = @p1 + "# + ))) + .bind(migration.version) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + let elapsed = start.elapsed(); + + #[allow(clippy::cast_possible_truncation)] + let _ = query(AssertSqlSafe(format!( + r#" + UPDATE {table_name} + SET execution_time = @p1 + WHERE version = @p2 + "# + ))) + .bind(elapsed.as_nanos() as i64) + .bind(migration.version) + .execute(self) + .await?; + + Ok(elapsed) + }) + } + + fn revert<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async move { + let mut tx = self.begin().await?; + let start = Instant::now(); + + let _ = query(AssertSqlSafe(format!( + r#" + UPDATE {table_name} + SET success = 0 + WHERE version = @p1 + "# + ))) + .bind(migration.version) + .execute(&mut *tx) + .await?; + + tx.execute(migration.sql.clone()).await?; + + let _ = query(AssertSqlSafe(format!( + r#"DELETE FROM {table_name} WHERE version = @p1"# + ))) + .bind(migration.version) + .execute(&mut *tx) + .await?; + + tx.commit().await?; + + let elapsed = start.elapsed(); + + Ok(elapsed) + }) + } +} diff --git a/sqlx-mssql/src/options/connect.rs b/sqlx-mssql/src/options/connect.rs new file mode 100644 index 0000000000..f8c6dc04af --- /dev/null +++ b/sqlx-mssql/src/options/connect.rs @@ -0,0 +1,35 @@ +use crate::connection::ConnectOptions; +use crate::error::Error; +use crate::{MssqlConnectOptions, MssqlConnection}; +use log::LevelFilter; +use sqlx_core::Url; +use std::time::Duration; + +impl ConnectOptions for MssqlConnectOptions { + type Connection = MssqlConnection; + + fn from_url(url: &Url) -> Result { + Self::parse_from_url(url) + } + + fn to_url_lossy(&self) -> Url { + self.build_url() + } + + async fn connect(&self) -> Result + where + Self::Connection: Sized, + { + MssqlConnection::establish(self).await + } + + fn log_statements(mut self, level: LevelFilter) -> Self { + self.log_settings.log_statements(level); + self + } + + fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { + self.log_settings.log_slow_statements(level, duration); + self + } +} diff --git a/sqlx-mssql/src/options/mod.rs b/sqlx-mssql/src/options/mod.rs new file mode 100644 index 0000000000..36f064def6 --- /dev/null +++ b/sqlx-mssql/src/options/mod.rs @@ -0,0 +1,201 @@ +mod connect; +mod parse; + +use crate::connection::LogSettings; + +/// Options and flags which can be used to configure a MSSQL connection. +/// +/// A value of `MssqlConnectOptions` can be parsed from a connection URL, +/// as described below. +/// +/// The generic format of the connection URL: +/// +/// ```text +/// mssql://[user[:password]@]host[:port][/database][?properties] +/// ``` +/// +/// ## Properties +/// +/// |Parameter|Default|Description| +/// |---------|-------|-----------| +/// | `encrypt` | `false` | Whether to use TLS encryption. | +/// | `trust_server_certificate` | `false` | Whether to trust the server certificate without validation. | +/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. | +/// | `app_name` | `sqlx` | The application name sent to the server. | +/// | `instance` | `None` | The SQL Server instance name. | +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example() -> sqlx::Result<()> { +/// use sqlx::{Connection, ConnectOptions}; +/// use sqlx::mssql::{MssqlConnectOptions, MssqlConnection}; +/// +/// // URL connection string +/// let conn = MssqlConnection::connect("mssql://sa:password@localhost/master").await?; +/// +/// // Manually-constructed options +/// let conn = MssqlConnectOptions::new() +/// .host("localhost") +/// .username("sa") +/// .password("password") +/// .database("master") +/// .connect().await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct MssqlConnectOptions { + pub(crate) host: String, + pub(crate) port: u16, + pub(crate) username: String, + pub(crate) password: Option, + pub(crate) database: Option, + pub(crate) instance: Option, + pub(crate) encrypt: bool, + pub(crate) trust_server_certificate: bool, + pub(crate) statement_cache_capacity: usize, + pub(crate) app_name: String, + pub(crate) log_settings: LogSettings, +} + +impl Default for MssqlConnectOptions { + fn default() -> Self { + Self::new() + } +} + +impl MssqlConnectOptions { + /// Creates a new, default set of options ready for configuration. + pub fn new() -> Self { + Self { + port: 1433, + host: String::from("localhost"), + username: String::from("sa"), + password: None, + database: None, + instance: None, + encrypt: false, + trust_server_certificate: false, + statement_cache_capacity: 100, + app_name: String::from("sqlx"), + log_settings: Default::default(), + } + } + + /// Sets the name of the host to connect to. + pub fn host(mut self, host: &str) -> Self { + host.clone_into(&mut self.host); + self + } + + /// Sets the port to connect to at the server host. + /// + /// The default port for MSSQL is `1433`. + pub fn port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Sets the username to connect as. + pub fn username(mut self, username: &str) -> Self { + username.clone_into(&mut self.username); + self + } + + /// Sets the password to connect with. + pub fn password(mut self, password: &str) -> Self { + self.password = Some(password.to_owned()); + self + } + + /// Sets the database name. + pub fn database(mut self, database: &str) -> Self { + self.database = Some(database.to_owned()); + self + } + + /// Sets the SQL Server instance name. + pub fn instance(mut self, instance: &str) -> Self { + self.instance = Some(instance.to_owned()); + self + } + + /// Sets whether to use TLS encryption. + pub fn encrypt(mut self, encrypt: bool) -> Self { + self.encrypt = encrypt; + self + } + + /// Sets whether to trust the server certificate without validation. + pub fn trust_server_certificate(mut self, trust: bool) -> Self { + self.trust_server_certificate = trust; + self + } + + /// Sets the capacity of the connection's statement cache. + pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { + self.statement_cache_capacity = capacity; + self + } + + /// Sets the application name sent to the server. + pub fn app_name(mut self, app_name: &str) -> Self { + app_name.clone_into(&mut self.app_name); + self + } + + /// Get the current host. + pub fn get_host(&self) -> &str { + &self.host + } + + /// Get the server's port. + pub fn get_port(&self) -> u16 { + self.port + } + + /// Get the current username. + pub fn get_username(&self) -> &str { + &self.username + } + + /// Get the current database name. + pub fn get_database(&self) -> Option<&str> { + self.database.as_deref() + } + + /// Build a `tiberius::Config` from these options. + pub(crate) fn to_tiberius_config(&self) -> tiberius::Config { + let mut config = tiberius::Config::new(); + + config.host(&self.host); + config.port(self.port); + config.application_name(&self.app_name); + + if let Some(database) = &self.database { + config.database(database); + } + + if let Some(instance) = &self.instance { + config.instance_name(instance); + } + + config.authentication(tiberius::AuthMethod::sql_server( + &self.username, + self.password.as_deref().unwrap_or(""), + )); + + if self.trust_server_certificate { + config.trust_cert(); + } + + if self.encrypt { + config.encryption(tiberius::EncryptionLevel::Required); + } else { + config.encryption(tiberius::EncryptionLevel::NotSupported); + } + + config + } +} diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs new file mode 100644 index 0000000000..5dedf39661 --- /dev/null +++ b/sqlx-mssql/src/options/parse.rs @@ -0,0 +1,126 @@ +use std::str::FromStr; + +use percent_encoding::percent_decode_str; +use sqlx_core::Url; + +use crate::error::Error; + +use super::MssqlConnectOptions; + +impl MssqlConnectOptions { + pub(crate) fn parse_from_url(url: &Url) -> Result { + let mut options = Self::new(); + + if let Some(host) = url.host_str() { + options = options.host(host); + } + + if let Some(port) = url.port() { + options = options.port(port); + } + + let username = url.username(); + if !username.is_empty() { + options = options.username( + &percent_decode_str(username) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + if let Some(password) = url.password() { + options = options.password( + &percent_decode_str(password) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + let path = url.path().trim_start_matches('/'); + if !path.is_empty() { + options = options.database( + &percent_decode_str(path) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + for (key, value) in url.query_pairs().into_iter() { + match &*key { + "encrypt" => { + options = options + .encrypt(value.parse().map_err(Error::config)?); + } + + "trust_server_certificate" | "trustServerCertificate" => { + options = options + .trust_server_certificate(value.parse().map_err(Error::config)?); + } + + "instance" => { + options = options.instance(&value); + } + + "app_name" | "application-name" => { + options = options.app_name(&value); + } + + "statement-cache-capacity" => { + options = options + .statement_cache_capacity(value.parse().map_err(Error::config)?); + } + + _ => {} + } + } + + Ok(options) + } + + pub(crate) fn build_url(&self) -> Url { + let mut url = Url::parse(&format!( + "mssql://{}@{}:{}", + self.username, self.host, self.port + )) + .expect("BUG: generated un-parseable URL"); + + if let Some(password) = &self.password { + let _ = url.set_password(Some(password)); + } + + if let Some(database) = &self.database { + url.set_path(database); + } + + url + } +} + +impl FromStr for MssqlConnectOptions { + type Err = Error; + + fn from_str(s: &str) -> Result { + let url: Url = s.parse().map_err(Error::config)?; + Self::parse_from_url(&url) + } +} + +#[test] +fn it_parses_basic_mssql_url() { + let url = "mssql://sa:password@localhost:1433/master"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.host, "localhost"); + assert_eq!(opts.port, 1433); + assert_eq!(opts.username, "sa"); + assert_eq!(opts.password, Some("password".into())); + assert_eq!(opts.database, Some("master".into())); +} + +#[test] +fn it_parses_url_with_instance() { + let url = "mssql://sa:password@localhost/master?instance=SQLEXPRESS"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.instance, Some("SQLEXPRESS".into())); +} diff --git a/sqlx-mssql/src/query_result.rs b/sqlx-mssql/src/query_result.rs new file mode 100644 index 0000000000..de00dda5ca --- /dev/null +++ b/sqlx-mssql/src/query_result.rs @@ -0,0 +1,30 @@ +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct MssqlQueryResult { + pub(super) rows_affected: u64, +} + +impl MssqlQueryResult { + pub fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for MssqlQueryResult { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + } + } +} + +#[cfg(feature = "any")] +impl From for sqlx_core::any::AnyQueryResult { + fn from(done: MssqlQueryResult) -> Self { + sqlx_core::any::AnyQueryResult { + rows_affected: done.rows_affected(), + last_insert_id: None, + } + } +} diff --git a/sqlx-mssql/src/row.rs b/sqlx-mssql/src/row.rs new file mode 100644 index 0000000000..261e2ea016 --- /dev/null +++ b/sqlx-mssql/src/row.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +pub(crate) use sqlx_core::row::*; + +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::value::MssqlData; +use crate::HashMap; +use crate::{Mssql, MssqlColumn, MssqlValueRef}; + +/// Implementation of [`Row`] for MSSQL. +pub struct MssqlRow { + pub(crate) values: Vec, + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, +} + +impl Row for MssqlRow { + type Database = Mssql; + + fn columns(&self) -> &[MssqlColumn] { + &self.columns + } + + fn try_get_raw(&self, index: I) -> Result, Error> + where + I: ColumnIndex, + { + let index = index.index(self)?; + let column = &self.columns[index]; + let data = &self.values[index]; + + Ok(MssqlValueRef { + data, + type_info: column.type_info.clone(), + }) + } +} + +impl ColumnIndex for &'_ str { + fn index(&self, row: &MssqlRow) -> Result { + row.column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .copied() + } +} + +impl std::fmt::Debug for MssqlRow { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_row(self, f) + } +} diff --git a/sqlx-mssql/src/statement.rs b/sqlx-mssql/src/statement.rs new file mode 100644 index 0000000000..ad414dfd9c --- /dev/null +++ b/sqlx-mssql/src/statement.rs @@ -0,0 +1,57 @@ +use super::MssqlColumn; +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::HashMap; +use crate::{Mssql, MssqlArguments, MssqlTypeInfo}; +use either::Either; +use sqlx_core::sql_str::SqlStr; +use std::sync::Arc; + +pub(crate) use sqlx_core::statement::*; + +#[derive(Debug, Clone)] +pub struct MssqlStatement { + pub(crate) sql: SqlStr, + pub(crate) metadata: MssqlStatementMetadata, +} + +#[derive(Debug, Default, Clone)] +pub(crate) struct MssqlStatementMetadata { + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, + pub(crate) parameters: usize, +} + +impl Statement for MssqlStatement { + type Database = Mssql; + + fn into_sql(self) -> SqlStr { + self.sql + } + + fn sql(&self) -> &SqlStr { + &self.sql + } + + fn parameters(&self) -> Option> { + Some(Either::Right(self.metadata.parameters)) + } + + fn columns(&self) -> &[MssqlColumn] { + &self.metadata.columns + } + + impl_statement_query!(MssqlArguments); +} + +impl ColumnIndex for &'_ str { + fn index(&self, statement: &MssqlStatement) -> Result { + statement + .metadata + .column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .copied() + } +} diff --git a/sqlx-mssql/src/testing/mod.rs b/sqlx-mssql/src/testing/mod.rs new file mode 100644 index 0000000000..619728637c --- /dev/null +++ b/sqlx-mssql/src/testing/mod.rs @@ -0,0 +1,194 @@ +use std::future::Future; +use std::ops::Deref; +use std::str::FromStr; +use std::sync::OnceLock; +use std::time::Duration; + +use crate::error::Error; +use crate::executor::Executor; +use crate::pool::{Pool, PoolOptions}; +use crate::query::query; +use crate::{Mssql, MssqlConnectOptions, MssqlConnection}; +use sqlx_core::connection::Connection; +use sqlx_core::query_scalar::query_scalar; +use sqlx_core::sql_str::AssertSqlSafe; + +pub(crate) use sqlx_core::testing::*; + +// Using a blocking `OnceLock` here because the critical sections are short. +static MASTER_POOL: OnceLock> = OnceLock::new(); + +impl TestSupport for Mssql { + fn test_context( + args: &TestArgs, + ) -> impl Future, Error>> + Send + '_ { + test_context(args) + } + + async fn cleanup_test(db_name: &str) -> Result<(), Error> { + let mut conn = MASTER_POOL + .get() + .expect("cleanup_test() invoked outside `#[sqlx::test]`") + .acquire() + .await?; + + do_cleanup(&mut conn, db_name).await + } + + async fn cleanup_test_dbs() -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let mut conn = MssqlConnection::connect(&url).await?; + + let delete_db_names: Vec = + query_scalar("SELECT db_name FROM _sqlx_test_databases") + .fetch_all(&mut conn) + .await?; + + if delete_db_names.is_empty() { + return Ok(None); + } + + let mut deleted_count = 0usize; + + for db_name in &delete_db_names { + let drop_sql = format!( + "IF DB_ID('{db_name}') IS NOT NULL \ + BEGIN \ + ALTER DATABASE [{db_name}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{db_name}]; \ + END" + ); + + match conn.execute(AssertSqlSafe(drop_sql)).await { + Ok(_deleted) => { + deleted_count += 1; + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {db_name:?}: {dbe}") + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + if deleted_count == 0 { + return Ok(None); + } + + // Clean up the tracking table + for db_name in &delete_db_names { + let _ = query("DELETE FROM _sqlx_test_databases WHERE db_name = @p1") + .bind(db_name) + .execute(&mut conn) + .await; + } + + let _ = conn.close().await; + Ok(Some(deleted_count)) + } + + async fn snapshot(_conn: &mut Self::Connection) -> Result, Error> { + todo!() + } +} + +async fn test_context(args: &TestArgs) -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let master_opts = MssqlConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); + + let pool = PoolOptions::new() + .max_connections(20) + // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. + .after_release(|_conn, _| Box::pin(async move { Ok(false) })) + .connect_lazy_with(master_opts); + + let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) { + Ok(inserted) => inserted, + Err((existing, pool)) => { + assert_eq!( + existing.connect_options().host, + pool.connect_options().host, + "DATABASE_URL changed at runtime, host differs" + ); + + assert_eq!( + existing.connect_options().database, + pool.connect_options().database, + "DATABASE_URL changed at runtime, database differs" + ); + + existing + } + }; + + let mut conn = master_pool.acquire().await?; + + // Create tracking table if it doesn't exist + conn.execute( + r#" + IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '_sqlx_test_databases') + CREATE TABLE _sqlx_test_databases ( + db_name NVARCHAR(200) NOT NULL PRIMARY KEY, + test_path NVARCHAR(MAX) NOT NULL, + created_at DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME() + ); + "#, + ) + .await?; + + let db_name = Mssql::db_name(args); + do_cleanup(&mut conn, &db_name).await?; + + query("INSERT INTO _sqlx_test_databases(db_name, test_path) VALUES (@p1, @p2)") + .bind(&db_name) + .bind(args.test_path) + .execute(&mut *conn) + .await?; + + conn.execute(AssertSqlSafe(format!("CREATE DATABASE [{db_name}]"))) + .await?; + + eprintln!("created database {db_name}"); + + Ok(TestContext { + pool_opts: PoolOptions::new() + .max_connections(5) + .idle_timeout(Some(Duration::from_secs(1))) + .parent(master_pool.clone()), + connect_opts: master_pool + .connect_options() + .deref() + .clone() + .database(&db_name), + db_name, + }) +} + +async fn do_cleanup(conn: &mut MssqlConnection, db_name: &str) -> Result<(), Error> { + let drop_sql = format!( + "IF DB_ID('{db_name}') IS NOT NULL \ + BEGIN \ + ALTER DATABASE [{db_name}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{db_name}]; \ + END" + ); + conn.execute(AssertSqlSafe(drop_sql)).await?; + query("DELETE FROM _sqlx_test_databases WHERE db_name = @p1") + .bind(db_name) + .execute(&mut *conn) + .await?; + + Ok(()) +} + +fn once_lock_try_insert_polyfill(this: &OnceLock, value: T) -> Result<&T, (&T, T)> { + let mut value = Some(value); + let res = this.get_or_init(|| value.take().unwrap()); + match value { + None => Ok(res), + Some(value) => Err((res, value)), + } +} diff --git a/sqlx-mssql/src/transaction.rs b/sqlx-mssql/src/transaction.rs new file mode 100644 index 0000000000..9c35160070 --- /dev/null +++ b/sqlx-mssql/src/transaction.rs @@ -0,0 +1,126 @@ +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; + +use crate::error::{tiberius_err, Error}; +use crate::executor::Executor; +use crate::{Mssql, MssqlConnection}; + +pub(crate) use sqlx_core::transaction::*; + +/// Implementation of [`TransactionManager`] for MSSQL. +/// +/// MSSQL uses non-ANSI syntax for savepoints: +/// - depth 0 -> `BEGIN TRANSACTION` +/// - depth N -> `SAVE TRANSACTION _sqlx_savepoint_N` +/// - commit depth 1 -> `COMMIT` +/// - commit depth N -> no-op (savepoints auto-commit with parent) +/// - rollback depth 1 -> `ROLLBACK` +/// - rollback depth N -> `ROLLBACK TRANSACTION _sqlx_savepoint_N` +pub struct MssqlTransactionManager; + +impl TransactionManager for MssqlTransactionManager { + type Database = Mssql; + + async fn begin(conn: &mut MssqlConnection, statement: Option) -> Result<(), Error> { + let depth = conn.inner.transaction_depth; + + // Execute any pending rollback first + resolve_pending_rollback(conn).await?; + + let statement = match statement { + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => { + if depth == 0 { + SqlStr::from_static("BEGIN TRANSACTION") + } else { + AssertSqlSafe(format!("SAVE TRANSACTION _sqlx_savepoint_{}", depth)).into_sql_str() + } + } + }; + + conn.execute(statement).await?; + conn.inner.transaction_depth += 1; + + Ok(()) + } + + async fn commit(conn: &mut MssqlConnection) -> Result<(), Error> { + let depth = conn.inner.transaction_depth; + + if depth > 0 { + if depth == 1 { + // Only the outermost transaction actually commits + conn.execute("COMMIT").await?; + } + // Savepoints auto-commit with their parent transaction, so no-op for depth > 1 + conn.inner.transaction_depth = depth - 1; + } + + Ok(()) + } + + async fn rollback(conn: &mut MssqlConnection) -> Result<(), Error> { + let depth = conn.inner.transaction_depth; + + if depth > 0 { + if depth == 1 { + conn.execute("ROLLBACK").await?; + } else { + let savepoint = format!( + "ROLLBACK TRANSACTION _sqlx_savepoint_{}", + depth - 1 + ); + conn.execute(AssertSqlSafe(savepoint)).await?; + } + conn.inner.transaction_depth = depth - 1; + } + + Ok(()) + } + + fn start_rollback(conn: &mut MssqlConnection) { + let depth = conn.inner.transaction_depth; + if depth > 0 { + // We can't execute async SQL from a synchronous context (Drop), + // so we set a flag and execute the rollback on the next operation. + conn.inner.pending_rollback = true; + conn.inner.transaction_depth = depth - 1; + } + } + + fn get_transaction_depth(conn: &MssqlConnection) -> usize { + conn.inner.transaction_depth + } +} + +/// Execute pending rollback if one was triggered by `start_rollback`. +pub(crate) async fn resolve_pending_rollback( + conn: &mut MssqlConnection, +) -> Result<(), Error> { + if conn.inner.pending_rollback { + conn.inner.pending_rollback = false; + let depth = conn.inner.transaction_depth; + if depth == 0 { + // Rollback the entire transaction + conn.inner + .client + .simple_query("ROLLBACK") + .await + .map_err(tiberius_err)? + .into_results() + .await + .map_err(tiberius_err)?; + } else { + let savepoint = format!("ROLLBACK TRANSACTION _sqlx_savepoint_{}", depth); + conn.inner + .client + .simple_query(savepoint) + .await + .map_err(tiberius_err)? + .into_results() + .await + .map_err(tiberius_err)?; + } + } + Ok(()) +} diff --git a/sqlx-mssql/src/type_checking.rs b/sqlx-mssql/src/type_checking.rs new file mode 100644 index 0000000000..b3b78b4174 --- /dev/null +++ b/sqlx-mssql/src/type_checking.rs @@ -0,0 +1,35 @@ +// Type mappings used by the macros and `Debug` impls. + +#[allow(unused_imports)] +use sqlx_core as sqlx; + +use crate::Mssql; + +impl_type_checking!( + Mssql { + u8, + i8, + i16, + i32, + i64, + f32, + f64, + + // ordering is important here as otherwise we might infer strings to be binary + // NVARCHAR, VARCHAR, NCHAR, CHAR, NTEXT, TEXT + String, + + // VARBINARY, BINARY, IMAGE + Vec, + }, + ParamChecking::Weak, + feature-types: _info => None, + datetime-types: { + chrono: { }, + time: { }, + }, + numeric-types: { + bigdecimal: { }, + rust_decimal: { }, + }, +); diff --git a/sqlx-mssql/src/type_info.rs b/sqlx-mssql/src/type_info.rs new file mode 100644 index 0000000000..764710bc7e --- /dev/null +++ b/sqlx-mssql/src/type_info.rs @@ -0,0 +1,68 @@ +use std::fmt::{self, Display, Formatter}; + +pub(crate) use sqlx_core::type_info::*; + +/// Type information for a MSSQL type. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct MssqlTypeInfo { + pub(crate) name: String, +} + +impl MssqlTypeInfo { + pub(crate) fn new(name: impl Into) -> Self { + Self { name: name.into() } + } +} + +impl Display for MssqlTypeInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.pad(&self.name) + } +} + +impl TypeInfo for MssqlTypeInfo { + fn is_null(&self) -> bool { + false + } + + fn name(&self) -> &str { + &self.name + } +} + +/// Map a tiberius column type to a MSSQL type name string. +pub(crate) fn type_name_for_tiberius(col_type: &tiberius::ColumnType) -> &'static str { + match col_type { + tiberius::ColumnType::Null => "NULL", + tiberius::ColumnType::Bit => "BIT", + tiberius::ColumnType::Int1 => "TINYINT", + tiberius::ColumnType::Int2 => "SMALLINT", + tiberius::ColumnType::Int4 => "INT", + tiberius::ColumnType::Int8 => "BIGINT", + tiberius::ColumnType::Float4 => "REAL", + tiberius::ColumnType::Float8 => "FLOAT", + tiberius::ColumnType::Datetime | tiberius::ColumnType::Datetimen => "DATETIME", + tiberius::ColumnType::Datetime2 => "DATETIME2", + tiberius::ColumnType::Datetime4 | tiberius::ColumnType::DatetimeOffsetn => { + "DATETIMEOFFSET" + } + tiberius::ColumnType::Daten => "DATE", + tiberius::ColumnType::Timen => "TIME", + tiberius::ColumnType::Decimaln | tiberius::ColumnType::Numericn => "DECIMAL", + tiberius::ColumnType::Money | tiberius::ColumnType::Money4 => "MONEY", + tiberius::ColumnType::BigVarChar | tiberius::ColumnType::NVarchar => "NVARCHAR", + tiberius::ColumnType::BigChar | tiberius::ColumnType::NChar => "NCHAR", + tiberius::ColumnType::BigVarBin => "VARBINARY", + tiberius::ColumnType::BigBinary => "BINARY", + tiberius::ColumnType::Text | tiberius::ColumnType::NText => "NTEXT", + tiberius::ColumnType::Image => "IMAGE", + tiberius::ColumnType::Xml => "XML", + tiberius::ColumnType::Guid => "UNIQUEIDENTIFIER", + tiberius::ColumnType::Intn => "INT", + tiberius::ColumnType::Bitn => "BIT", + tiberius::ColumnType::Floatn => "FLOAT", + tiberius::ColumnType::SSVariant => "SQL_VARIANT", + _ => "UNKNOWN", + } +} diff --git a/sqlx-mssql/src/types/bool.rs b/sqlx-mssql/src/types/bool.rs new file mode 100644 index 0000000000..af0cae774f --- /dev/null +++ b/sqlx-mssql/src/types/bool.rs @@ -0,0 +1,41 @@ +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for bool { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("BIT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.name.as_str(), "BIT" | "TINYINT" | "INT" | "SMALLINT" | "BIGINT") + } +} + +impl Encode<'_, Mssql> for bool { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::Bool(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for bool { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::Bool(v) => Ok(*v), + MssqlData::U8(v) => Ok(*v != 0), + MssqlData::I16(v) => Ok(*v != 0), + MssqlData::I32(v) => Ok(*v != 0), + MssqlData::I64(v) => Ok(*v != 0), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected bool-compatible type, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/bytes.rs b/sqlx-mssql/src/types/bytes.rs new file mode 100644 index 0000000000..a1133e8dc6 --- /dev/null +++ b/sqlx-mssql/src/types/bytes.rs @@ -0,0 +1,73 @@ +use std::borrow::Cow; +use std::rc::Rc; +use std::sync::Arc; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn bytes_compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "VARBINARY" | "BINARY" | "IMAGE" + ) +} + +impl Type for [u8] { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("VARBINARY") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + bytes_compatible(ty) + } +} + +impl Encode<'_, Mssql> for &'_ [u8] { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::Binary(self.to_vec())); + Ok(IsNull::No) + } +} + +impl<'r> Decode<'r, Mssql> for &'r [u8] { + fn decode(value: MssqlValueRef<'r>) -> Result { + value.as_bytes() + } +} + +impl Type for Vec { + fn type_info() -> MssqlTypeInfo { + <[u8] as Type>::type_info() + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + <[u8] as Type>::compatible(ty) + } +} + +impl Encode<'_, Mssql> for Vec { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + <&[u8] as Encode>::encode(&**self, buf) + } +} + +impl Decode<'_, Mssql> for Vec { + fn decode(value: MssqlValueRef<'_>) -> Result { + <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) + } +} + +forward_encode_impl!(Arc<[u8]>, &[u8], Mssql); +forward_encode_impl!(Rc<[u8]>, &[u8], Mssql); +forward_encode_impl!(Box<[u8]>, &[u8], Mssql); +forward_encode_impl!(Cow<'_, [u8]>, &[u8], Mssql); diff --git a/sqlx-mssql/src/types/float.rs b/sqlx-mssql/src/types/float.rs new file mode 100644 index 0000000000..d0f88bac59 --- /dev/null +++ b/sqlx-mssql/src/types/float.rs @@ -0,0 +1,74 @@ +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn real_compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.name.as_str(), "REAL" | "FLOAT") +} + +impl Type for f32 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("REAL") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + real_compatible(ty) + } +} + +impl Encode<'_, Mssql> for f32 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::F32(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for f32 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::F32(v) => Ok(*v), + #[allow(clippy::cast_possible_truncation)] + MssqlData::F64(v) => Ok(*v as f32), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected float, got {:?}", value.data).into()), + } + } +} + +impl Type for f64 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("FLOAT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + real_compatible(ty) + } +} + +impl Encode<'_, Mssql> for f64 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::F64(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for f64 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::F32(v) => Ok(f64::from(*v)), + MssqlData::F64(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected float, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/int.rs b/sqlx-mssql/src/types/int.rs new file mode 100644 index 0000000000..92acd6916d --- /dev/null +++ b/sqlx-mssql/src/types/int.rs @@ -0,0 +1,188 @@ +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn int_compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "TINYINT" | "SMALLINT" | "INT" | "BIGINT" + ) +} + +// u8 - MSSQL's TINYINT is unsigned (0-255) +impl Type for u8 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TINYINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for u8 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::U8(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for u8 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(*v), + MssqlData::I16(v) => Ok((*v).try_into()?), + MssqlData::I32(v) => Ok((*v).try_into()?), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i8 - maps to TINYINT but only 0-127 range +impl Type for i8 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TINYINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i8 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + if *self < 0 { + return Err("MSSQL TINYINT is unsigned; cannot encode negative i8".into()); + } + #[allow(clippy::cast_sign_loss)] + buf.push(MssqlArgumentValue::U8(*self as u8)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i8 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok((*v).try_into()?), + MssqlData::I16(v) => Ok((*v).try_into()?), + MssqlData::I32(v) => Ok((*v).try_into()?), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i16 +impl Type for i16 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("SMALLINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i16 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::I16(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i16 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(i16::from(*v)), + MssqlData::I16(v) => Ok(*v), + MssqlData::I32(v) => Ok((*v).try_into()?), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i32 +impl Type for i32 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("INT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i32 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::I32(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i32 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(i32::from(*v)), + MssqlData::I16(v) => Ok(i32::from(*v)), + MssqlData::I32(v) => Ok(*v), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i64 +impl Type for i64 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("BIGINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i64 { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::I64(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i64 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(i64::from(*v)), + MssqlData::I16(v) => Ok(i64::from(*v)), + MssqlData::I32(v) => Ok(i64::from(*v)), + MssqlData::I64(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs new file mode 100644 index 0000000000..9bca3e1e90 --- /dev/null +++ b/sqlx-mssql/src/types/mod.rs @@ -0,0 +1,29 @@ +//! Conversions between Rust and **MSSQL** types. +//! +//! # Types +//! +//! | Rust type | MSSQL type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bool` | BIT | +//! | `u8` | TINYINT (unsigned, 0-255) | +//! | `i8` | TINYINT (0-127 only) | +//! | `i16` | SMALLINT | +//! | `i32` | INT | +//! | `i64` | BIGINT | +//! | `f32` | REAL | +//! | `f64` | FLOAT | +//! | `&str`, [`String`] | NVARCHAR | +//! | `&[u8]`, `Vec` | VARBINARY | +//! +//! # Nullable +//! +//! In addition, `Option` is supported where `T` implements `Type`. An `Option` represents +//! a potentially `NULL` value from MSSQL. + +pub(crate) use sqlx_core::types::*; + +mod bool; +mod bytes; +mod float; +mod int; +mod str; diff --git a/sqlx-mssql/src/types/str.rs b/sqlx-mssql/src/types/str.rs new file mode 100644 index 0000000000..4995160a9b --- /dev/null +++ b/sqlx-mssql/src/types/str.rs @@ -0,0 +1,73 @@ +use std::borrow::Cow; +use std::rc::Rc; +use std::sync::Arc; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn str_compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" + ) +} + +impl Type for str { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("NVARCHAR") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + str_compatible(ty) + } +} + +impl Encode<'_, Mssql> for &'_ str { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::String((*self).to_owned())); + Ok(IsNull::No) + } +} + +impl<'r> Decode<'r, Mssql> for &'r str { + fn decode(value: MssqlValueRef<'r>) -> Result { + value.as_str() + } +} + +impl Type for String { + fn type_info() -> MssqlTypeInfo { + >::type_info() + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + >::compatible(ty) + } +} + +impl Encode<'_, Mssql> for String { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + <&str as Encode>::encode(self.as_str(), buf) + } +} + +impl Decode<'_, Mssql> for String { + fn decode(value: MssqlValueRef<'_>) -> Result { + <&str as Decode>::decode(value).map(ToOwned::to_owned) + } +} + +forward_encode_impl!(Arc, &str, Mssql); +forward_encode_impl!(Rc, &str, Mssql); +forward_encode_impl!(Cow<'_, str>, &str, Mssql); +forward_encode_impl!(Box, &str, Mssql); diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs new file mode 100644 index 0000000000..ae9c4cdd97 --- /dev/null +++ b/sqlx-mssql/src/value.rs @@ -0,0 +1,109 @@ +use std::borrow::Cow; + +pub(crate) use sqlx_core::value::*; + +use crate::error::BoxDynError; +use crate::{Mssql, MssqlTypeInfo}; + +/// Internal storage for an MSSQL value, decoupled from tiberius lifetimes. +#[derive(Debug, Clone)] +pub(crate) enum MssqlData { + Null, + Bool(bool), + U8(u8), + I16(i16), + I32(i32), + I64(i64), + F32(f32), + F64(f64), + String(String), + Binary(Vec), +} + +/// Implementation of [`Value`] for MSSQL. +#[derive(Debug, Clone)] +pub struct MssqlValue { + pub(crate) data: MssqlData, + pub(crate) type_info: MssqlTypeInfo, +} + +/// Implementation of [`ValueRef`] for MSSQL. +#[derive(Debug, Clone)] +pub struct MssqlValueRef<'r> { + pub(crate) data: &'r MssqlData, + pub(crate) type_info: MssqlTypeInfo, +} + +impl<'r> MssqlValueRef<'r> { + pub(crate) fn as_str(&self) -> Result<&'r str, BoxDynError> { + match self.data { + MssqlData::String(ref s) => Ok(s.as_str()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected string, got {:?}", self.data).into()), + } + } + + pub(crate) fn as_bytes(&self) -> Result<&'r [u8], BoxDynError> { + match self.data { + MssqlData::Binary(ref b) => Ok(b.as_slice()), + MssqlData::String(ref s) => Ok(s.as_bytes()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected binary, got {:?}", self.data).into()), + } + } +} + +impl Value for MssqlValue { + type Database = Mssql; + + fn as_ref(&self) -> MssqlValueRef<'_> { + MssqlValueRef { + data: &self.data, + type_info: self.type_info.clone(), + } + } + + fn type_info(&self) -> Cow<'_, MssqlTypeInfo> { + Cow::Borrowed(&self.type_info) + } + + fn is_null(&self) -> bool { + matches!(self.data, MssqlData::Null) + } +} + +impl<'r> ValueRef<'r> for MssqlValueRef<'r> { + type Database = Mssql; + + fn to_owned(&self) -> MssqlValue { + MssqlValue { + data: self.data.clone(), + type_info: self.type_info.clone(), + } + } + + fn type_info(&self) -> Cow<'_, MssqlTypeInfo> { + Cow::Borrowed(&self.type_info) + } + + fn is_null(&self) -> bool { + matches!(self.data, MssqlData::Null) + } +} + +/// Convert a `tiberius::ColumnData` into our owned `MssqlData`. +pub(crate) fn column_data_to_mssql_data(data: &tiberius::ColumnData<'_>) -> MssqlData { + match data { + tiberius::ColumnData::U8(Some(v)) => MssqlData::U8(*v), + tiberius::ColumnData::I16(Some(v)) => MssqlData::I16(*v), + tiberius::ColumnData::I32(Some(v)) => MssqlData::I32(*v), + tiberius::ColumnData::I64(Some(v)) => MssqlData::I64(*v), + tiberius::ColumnData::F32(Some(v)) => MssqlData::F32(*v), + tiberius::ColumnData::F64(Some(v)) => MssqlData::F64(*v), + tiberius::ColumnData::Bit(Some(v)) => MssqlData::Bool(*v), + tiberius::ColumnData::String(Some(v)) => MssqlData::String(v.to_string()), + tiberius::ColumnData::Binary(Some(v)) => MssqlData::Binary(v.to_vec()), + // All None variants and unhandled types map to Null + _ => MssqlData::Null, + } +} diff --git a/src/any/mod.rs b/src/any/mod.rs index 434d255573..fa6a8f5498 100644 --- a/src/any/mod.rs +++ b/src/any/mod.rs @@ -37,6 +37,8 @@ pub fn install_default_drivers() { ONCE.call_once(|| { install_drivers(&[ + #[cfg(feature = "mssql")] + sqlx_mssql::any::DRIVER, #[cfg(feature = "mysql")] sqlx_mysql::any::DRIVER, #[cfg(feature = "postgres")] diff --git a/src/lib.rs b/src/lib.rs index 438463210d..e0cd7dd164 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,13 @@ pub use sqlx_core::error::{self, Error, Result}; #[cfg(feature = "migrate")] pub use sqlx_core::migrate; +#[cfg(feature = "mssql")] +#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] +#[doc(inline)] +pub use sqlx_mssql::{ + self as mssql, Mssql, MssqlConnection, MssqlExecutor, MssqlPool, MssqlTransaction, +}; + #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] #[doc(inline)] diff --git a/tests/mssql/describe.rs b/tests/mssql/describe.rs index 3717829b41..64f97102d9 100644 --- a/tests/mssql/describe.rs +++ b/tests/mssql/describe.rs @@ -1,12 +1,12 @@ use sqlx::mssql::Mssql; -use sqlx::{Column, Executor, TypeInfo}; +use sqlx::{Column, Executor, SqlSafeStr, TypeInfo}; use sqlx_test::new; #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT * FROM tweet").await?; + let d = conn.describe("SELECT * FROM tweet".into_sql_str()).await?; assert_eq!(d.columns()[0].name(), "id"); assert_eq!(d.columns()[1].name(), "text"); @@ -31,7 +31,7 @@ async fn it_describes_with_params() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("SELECT text FROM tweet WHERE id = @p1") + .describe("SELECT text FROM tweet WHERE id = @p1".into_sql_str()) .await?; assert_eq!(d.columns()[0].name(), "text"); diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 0986ef1bbd..2aa26d3e8e 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -1,7 +1,7 @@ use futures_util::TryStreamExt; use sqlx::mssql::{Mssql, MssqlPoolOptions}; -use sqlx::{Column, Connection, Executor, MssqlConnection, Row, Statement, TypeInfo}; -use sqlx_core::mssql::MssqlRow; +use sqlx::{Column, Connection, Executor, MssqlConnection, Row, SqlSafeStr, Statement, TypeInfo}; +use sqlx::mssql::MssqlRow; use sqlx_test::new; use std::sync::atomic::{AtomicI32, Ordering}; use std::time::Duration; @@ -195,9 +195,9 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { let mut tx = conn.begin().await?; - sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") .bind(10_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; tx.rollback().await?; @@ -214,7 +214,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") .bind(10_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; tx.commit().await?; @@ -232,7 +232,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") .bind(20_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; } @@ -262,7 +262,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES (@p1)") .bind(50_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; // begin once more @@ -271,7 +271,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES (@p1)") .bind(10_i32) - .execute(&mut tx2) + .execute(&mut *tx2) .await?; // never mind, rollback @@ -279,7 +279,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // did we really? let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") - .fetch_one(&mut tx) + .fetch_one(&mut *tx) .await?; assert_eq!(count, 1); @@ -305,10 +305,10 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { let tweet_id: i64 = sqlx::query_scalar( "INSERT INTO tweet ( id, text ) OUTPUT INSERTED.id VALUES ( 50, 'Hello, World' )", ) - .fetch_one(&mut tx) + .fetch_one(&mut *tx) .await?; - let statement = tx.prepare("SELECT * FROM tweet WHERE id = @p1").await?; + let statement = tx.prepare("SELECT * FROM tweet WHERE id = @p1".into_sql_str()).await?; assert_eq!(statement.column(0).name(), "id"); assert_eq!(statement.column(1).name(), "text"); @@ -320,7 +320,7 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { assert_eq!(statement.column(2).type_info().name(), "TINYINT"); assert_eq!(statement.column(3).type_info().name(), "BIGINT"); - let row = statement.query().bind(tweet_id).fetch_one(&mut tx).await?; + let row = statement.query().bind(tweet_id).fetch_one(&mut *tx).await?; let tweet_text: String = row.try_get("text")?; assert_eq!(tweet_text, "Hello, World"); @@ -359,7 +359,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { CREATE TABLE #conn_stats( id int primary key, before_acquire_calls int default 0, - after_release_calls int default 0 + after_release_calls int default 0 ); INSERT INTO #conn_stats(id) VALUES ({}); "#, @@ -367,7 +367,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + conn.execute(sqlx::AssertSqlSafe(statement)).await?; Ok(()) }) }) @@ -380,7 +380,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { // MSSQL doesn't support UPDATE ... RETURNING either sqlx::query( r#" - UPDATE #conn_stats + UPDATE #conn_stats SET before_acquire_calls = before_acquire_calls + 1 "#, ) @@ -404,7 +404,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { Box::pin(async move { sqlx::query( r#" - UPDATE #conn_stats + UPDATE #conn_stats SET after_release_calls = after_release_calls + 1 "#, ) From 6dc8265d5a507d2a1c48802da0d16a3ac8a337fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 01:23:52 -0500 Subject: [PATCH 02/33] feat: add chrono datetime Encode/Decode support to sqlx-mssql Enable native chrono type handling for MSSQL datetime columns, eliminating the need for CONVERT(VARCHAR) workarounds. Supports NaiveDateTime, NaiveDate, NaiveTime, and DateTime by converting between tiberius internal datetime structs and chrono types. Author: Pablo Carrera --- Cargo.lock | 1 + sqlx-mssql/Cargo.toml | 2 +- sqlx-mssql/src/connection/executor.rs | 12 +++ sqlx-mssql/src/database.rs | 6 ++ sqlx-mssql/src/type_checking.rs | 7 +- sqlx-mssql/src/types/chrono.rs | 137 ++++++++++++++++++++++++++ sqlx-mssql/src/types/mod.rs | 2 + sqlx-mssql/src/value.rs | 63 ++++++++++++ 8 files changed, 228 insertions(+), 2 deletions(-) create mode 100644 sqlx-mssql/src/types/chrono.rs diff --git a/Cargo.lock b/Cargo.lock index 2a4075bff7..ffe1e893ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4316,6 +4316,7 @@ dependencies = [ "asynchronous-codec", "byteorder", "bytes", + "chrono", "connection-string", "encoding_rs", "enumflags2", diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml index 61cb280d9f..90fdc7cdf8 100644 --- a/sqlx-mssql/Cargo.toml +++ b/sqlx-mssql/Cargo.toml @@ -17,7 +17,7 @@ migrate = ["sqlx-core/migrate"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] -chrono = ["dep:chrono", "sqlx-core/chrono"] +chrono = ["dep:chrono", "sqlx-core/chrono", "tiberius/chrono"] rust_decimal = ["dep:rust_decimal", "sqlx-core/rust_decimal"] time = ["dep:time", "sqlx-core/time"] uuid = ["dep:uuid", "sqlx-core/uuid"] diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index bad31eb13b..d6301a729b 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -73,6 +73,18 @@ impl MssqlConnection { MssqlArgumentValue::Binary(v) => { query.bind(v.as_slice()); } + #[cfg(feature = "chrono")] + MssqlArgumentValue::NaiveDateTime(v) => { + query.bind(*v); + } + #[cfg(feature = "chrono")] + MssqlArgumentValue::NaiveDate(v) => { + query.bind(*v); + } + #[cfg(feature = "chrono")] + MssqlArgumentValue::NaiveTime(v) => { + query.bind(*v); + } } } diff --git a/sqlx-mssql/src/database.rs b/sqlx-mssql/src/database.rs index 45a5ece196..d8cf8e05c5 100644 --- a/sqlx-mssql/src/database.rs +++ b/sqlx-mssql/src/database.rs @@ -53,4 +53,10 @@ pub enum MssqlArgumentValue { F64(f64), String(String), Binary(Vec), + #[cfg(feature = "chrono")] + NaiveDateTime(chrono::NaiveDateTime), + #[cfg(feature = "chrono")] + NaiveDate(chrono::NaiveDate), + #[cfg(feature = "chrono")] + NaiveTime(chrono::NaiveTime), } diff --git a/sqlx-mssql/src/type_checking.rs b/sqlx-mssql/src/type_checking.rs index b3b78b4174..78365e5ec3 100644 --- a/sqlx-mssql/src/type_checking.rs +++ b/sqlx-mssql/src/type_checking.rs @@ -25,7 +25,12 @@ impl_type_checking!( ParamChecking::Weak, feature-types: _info => None, datetime-types: { - chrono: { }, + chrono: { + sqlx::types::chrono::NaiveTime, + sqlx::types::chrono::NaiveDate, + sqlx::types::chrono::NaiveDateTime, + sqlx::types::chrono::DateTime, + }, time: { }, }, numeric-types: { diff --git a/sqlx-mssql/src/types/chrono.rs b/sqlx-mssql/src/types/chrono.rs new file mode 100644 index 0000000000..62ba995b12 --- /dev/null +++ b/sqlx-mssql/src/types/chrono.rs @@ -0,0 +1,137 @@ +use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +// ── NaiveDateTime ─────────────────────────────────────────────────────────── + +impl Type for NaiveDateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIME2") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "DATETIME2" | "DATETIME" | "SMALLDATETIME" + ) + } +} + +impl Encode<'_, Mssql> for NaiveDateTime { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::NaiveDateTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for NaiveDateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveDateTime(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetime, got {:?}", value.data).into()), + } + } +} + +// ── NaiveDate ─────────────────────────────────────────────────────────────── + +impl Type for NaiveDate { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATE") + } +} + +impl Encode<'_, Mssql> for NaiveDate { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::NaiveDate(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for NaiveDate { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveDate(v) => Ok(*v), + MssqlData::NaiveDateTime(v) => Ok(v.date()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected date, got {:?}", value.data).into()), + } + } +} + +// ── NaiveTime ─────────────────────────────────────────────────────────────── + +impl Type for NaiveTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TIME") + } +} + +impl Encode<'_, Mssql> for NaiveTime { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::NaiveTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for NaiveTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveTime(v) => Ok(*v), + MssqlData::NaiveDateTime(v) => Ok(v.time()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected time, got {:?}", value.data).into()), + } + } +} + +// ── DateTime ─────────────────────────────────────────────────────────── + +impl Type for DateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIME2") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "DATETIME2" | "DATETIMEOFFSET" + ) + } +} + +impl Encode<'_, Mssql> for DateTime { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::NaiveDateTime(self.naive_utc())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for DateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveDateTime(v) => Ok(v.and_utc()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetime, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs index 9bca3e1e90..bcec206b1e 100644 --- a/sqlx-mssql/src/types/mod.rs +++ b/sqlx-mssql/src/types/mod.rs @@ -24,6 +24,8 @@ pub(crate) use sqlx_core::types::*; mod bool; mod bytes; +#[cfg(feature = "chrono")] +mod chrono; mod float; mod int; mod str; diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index ae9c4cdd97..052e393b67 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -18,6 +18,12 @@ pub(crate) enum MssqlData { F64(f64), String(String), Binary(Vec), + #[cfg(feature = "chrono")] + NaiveDateTime(chrono::NaiveDateTime), + #[cfg(feature = "chrono")] + NaiveDate(chrono::NaiveDate), + #[cfg(feature = "chrono")] + NaiveTime(chrono::NaiveTime), } /// Implementation of [`Value`] for MSSQL. @@ -103,7 +109,64 @@ pub(crate) fn column_data_to_mssql_data(data: &tiberius::ColumnData<'_>) -> Mssq tiberius::ColumnData::Bit(Some(v)) => MssqlData::Bool(*v), tiberius::ColumnData::String(Some(v)) => MssqlData::String(v.to_string()), tiberius::ColumnData::Binary(Some(v)) => MssqlData::Binary(v.to_vec()), + + #[cfg(feature = "chrono")] + tiberius::ColumnData::DateTime2(Some(dt2)) => { + let date = chrono_date_from_days(dt2.date().days() as i64, 1); + let ns = dt2.time().increments() as i64 + * 10i64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time)) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::DateTime(Some(dt)) => { + let date = chrono_date_from_days(dt.days() as i64, 1900); + let ns = dt.seconds_fragments() as i64 * 1_000_000_000i64 / 300; + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time)) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::SmallDateTime(Some(dt)) => { + let date = chrono_date_from_days(dt.days() as i64, 1900); + let seconds = dt.seconds_fragments() as u32 * 60; + let time = + chrono::NaiveTime::from_num_seconds_from_midnight_opt(seconds, 0).unwrap(); + MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time)) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::Date(Some(d)) => { + MssqlData::NaiveDate(chrono_date_from_days(d.days() as i64, 1)) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::Time(Some(t)) => { + let ns = + t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + MssqlData::NaiveTime(time) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::DateTimeOffset(Some(dto)) => { + let date = chrono_date_from_days(dto.datetime2().date().days() as i64, 1); + let ns = dto.datetime2().time().increments() as i64 + * 10i64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + // Subtract the offset to convert to UTC + let naive = chrono::NaiveDateTime::new(date, time) + - chrono::Duration::minutes(dto.offset() as i64); + MssqlData::NaiveDateTime(naive) + } + // All None variants and unhandled types map to Null _ => MssqlData::Null, } } + +/// Convert days since `start_year`-01-01 to a `chrono::NaiveDate`. +#[cfg(feature = "chrono")] +fn chrono_date_from_days(days: i64, start_year: i32) -> chrono::NaiveDate { + chrono::NaiveDate::from_ymd_opt(start_year, 1, 1).unwrap() + chrono::Duration::days(days) +} From fd9e436c8b37f7ab09d5dc176e8fb9c9232e48a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 02:12:53 -0500 Subject: [PATCH 03/33] feat: add UUID and rust_decimal Encode/Decode support to sqlx-mssql Enable UNIQUEIDENTIFIER columns via uuid::Uuid and DECIMAL/NUMERIC/MONEY columns via rust_decimal::Decimal, both feature-gated behind their respective cargo features. Author: Pablo Carrera --- Cargo.lock | 1 + sqlx-mssql/Cargo.toml | 2 +- sqlx-mssql/src/connection/executor.rs | 19 +++++++++ sqlx-mssql/src/database.rs | 4 ++ sqlx-mssql/src/type_checking.rs | 7 +++- sqlx-mssql/src/types/mod.rs | 11 +++++ sqlx-mssql/src/types/rust_decimal.rs | 46 ++++++++++++++++++++ sqlx-mssql/src/types/uuid.rs | 60 +++++++++++++++++++++++++++ sqlx-mssql/src/value.rs | 15 +++++++ 9 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 sqlx-mssql/src/types/rust_decimal.rs create mode 100644 sqlx-mssql/src/types/uuid.rs diff --git a/Cargo.lock b/Cargo.lock index ffe1e893ee..1a18141114 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4325,6 +4325,7 @@ dependencies = [ "once_cell", "pin-project-lite", "pretty-hex", + "rust_decimal", "thiserror 1.0.69", "tracing", "uuid", diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml index 90fdc7cdf8..cbff784593 100644 --- a/sqlx-mssql/Cargo.toml +++ b/sqlx-mssql/Cargo.toml @@ -18,7 +18,7 @@ migrate = ["sqlx-core/migrate"] # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] chrono = ["dep:chrono", "sqlx-core/chrono", "tiberius/chrono"] -rust_decimal = ["dep:rust_decimal", "sqlx-core/rust_decimal"] +rust_decimal = ["dep:rust_decimal", "sqlx-core/rust_decimal", "tiberius/rust_decimal"] time = ["dep:time", "sqlx-core/time"] uuid = ["dep:uuid", "sqlx-core/uuid"] diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index d6301a729b..7d2c73cdf8 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -85,6 +85,25 @@ impl MssqlConnection { MssqlArgumentValue::NaiveTime(v) => { query.bind(*v); } + #[cfg(feature = "uuid")] + MssqlArgumentValue::Uuid(v) => { + query.bind(v); + } + #[cfg(feature = "rust_decimal")] + MssqlArgumentValue::Decimal(v) => { + let unpacked = v.unpack(); + let mut value = (((unpacked.hi as u128) << 64) + + ((unpacked.mid as u128) << 32) + + unpacked.lo as u128) + as i128; + if v.is_sign_negative() { + value = -value; + } + query.bind(tiberius::numeric::Numeric::new_with_scale( + value, + v.scale() as u8, + )); + } } } diff --git a/sqlx-mssql/src/database.rs b/sqlx-mssql/src/database.rs index d8cf8e05c5..8f57323ce4 100644 --- a/sqlx-mssql/src/database.rs +++ b/sqlx-mssql/src/database.rs @@ -59,4 +59,8 @@ pub enum MssqlArgumentValue { NaiveDate(chrono::NaiveDate), #[cfg(feature = "chrono")] NaiveTime(chrono::NaiveTime), + #[cfg(feature = "uuid")] + Uuid(uuid::Uuid), + #[cfg(feature = "rust_decimal")] + Decimal(rust_decimal::Decimal), } diff --git a/sqlx-mssql/src/type_checking.rs b/sqlx-mssql/src/type_checking.rs index 78365e5ec3..86dc7f621a 100644 --- a/sqlx-mssql/src/type_checking.rs +++ b/sqlx-mssql/src/type_checking.rs @@ -21,6 +21,9 @@ impl_type_checking!( // VARBINARY, BINARY, IMAGE Vec, + + #[cfg(feature = "uuid")] + sqlx::types::Uuid, }, ParamChecking::Weak, feature-types: _info => None, @@ -35,6 +38,8 @@ impl_type_checking!( }, numeric-types: { bigdecimal: { }, - rust_decimal: { }, + rust_decimal: { + sqlx::types::Decimal, + }, }, ); diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs index bcec206b1e..d487f902c0 100644 --- a/sqlx-mssql/src/types/mod.rs +++ b/sqlx-mssql/src/types/mod.rs @@ -15,6 +15,13 @@ //! | `&str`, [`String`] | NVARCHAR | //! | `&[u8]`, `Vec` | VARBINARY | //! +//! ### Feature-gated +//! +//! | Rust type | MSSQL type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `uuid::Uuid` | UNIQUEIDENTIFIER | +//! | `rust_decimal::Decimal` | DECIMAL, NUMERIC, MONEY | +//! //! # Nullable //! //! In addition, `Option` is supported where `T` implements `Type`. An `Option` represents @@ -28,4 +35,8 @@ mod bytes; mod chrono; mod float; mod int; +#[cfg(feature = "rust_decimal")] +mod rust_decimal; mod str; +#[cfg(feature = "uuid")] +mod uuid; diff --git a/sqlx-mssql/src/types/rust_decimal.rs b/sqlx-mssql/src/types/rust_decimal.rs new file mode 100644 index 0000000000..c167f28769 --- /dev/null +++ b/sqlx-mssql/src/types/rust_decimal.rs @@ -0,0 +1,46 @@ +use rust_decimal::Decimal; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for Decimal { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DECIMAL") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.name.as_str(), "DECIMAL" | "NUMERIC" | "MONEY") + } +} + +impl Encode<'_, Mssql> for Decimal { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::Decimal(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Decimal { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::Decimal(v) => Ok(*v), + MssqlData::I32(v) => Ok(Decimal::from(*v)), + MssqlData::I64(v) => Ok(Decimal::from(*v)), + MssqlData::F64(v) => Decimal::try_from(*v) + .map_err(|e| format!("failed to convert f64 to Decimal: {e}").into()), + MssqlData::String(ref s) => s + .parse::() + .map_err(|e| format!("failed to parse Decimal from string: {e}").into()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected DECIMAL, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/uuid.rs b/sqlx-mssql/src/types/uuid.rs new file mode 100644 index 0000000000..f36d259590 --- /dev/null +++ b/sqlx-mssql/src/types/uuid.rs @@ -0,0 +1,60 @@ +use uuid::Uuid; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for Uuid { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("UNIQUEIDENTIFIER") + } +} + +impl Encode<'_, Mssql> for Uuid { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::Uuid(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Uuid { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::Uuid(v) => Ok(*v), + MssqlData::String(ref s) => Ok(Uuid::parse_str(s)?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected UNIQUEIDENTIFIER, got {:?}", value.data).into()), + } + } +} + +impl Type for uuid::fmt::Hyphenated { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("UNIQUEIDENTIFIER") + } +} + +impl Encode<'_, Mssql> for uuid::fmt::Hyphenated { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + let uuid = Uuid::parse_str(&self.to_string())?; + buf.push(MssqlArgumentValue::Uuid(uuid)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for uuid::fmt::Hyphenated { + fn decode(value: MssqlValueRef<'_>) -> Result { + let uuid = Uuid::decode(value)?; + Ok(uuid.hyphenated()) + } +} diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index 052e393b67..67930cf244 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -24,6 +24,10 @@ pub(crate) enum MssqlData { NaiveDate(chrono::NaiveDate), #[cfg(feature = "chrono")] NaiveTime(chrono::NaiveTime), + #[cfg(feature = "uuid")] + Uuid(uuid::Uuid), + #[cfg(feature = "rust_decimal")] + Decimal(rust_decimal::Decimal), } /// Implementation of [`Value`] for MSSQL. @@ -160,6 +164,17 @@ pub(crate) fn column_data_to_mssql_data(data: &tiberius::ColumnData<'_>) -> Mssq MssqlData::NaiveDateTime(naive) } + #[cfg(feature = "uuid")] + tiberius::ColumnData::Guid(Some(v)) => MssqlData::Uuid(*v), + + #[cfg(feature = "rust_decimal")] + tiberius::ColumnData::Numeric(Some(n)) => { + MssqlData::Decimal(rust_decimal::Decimal::from_i128_with_scale( + n.value(), + n.scale() as u32, + )) + } + // All None variants and unhandled types map to Null _ => MssqlData::Null, } From 04f556dfb8f22e3b3c06bb50116bd2f77b6e317e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 02:35:06 -0500 Subject: [PATCH 04/33] feat: add time, bigdecimal, and JSON Encode/Decode support to sqlx-mssql Wire up tiberius/time and tiberius/bigdecimal feature flags and implement Type/Encode/Decode for time::{Date, Time, PrimitiveDateTime, OffsetDateTime}, bigdecimal::BigDecimal, and Json. Uses a ColumnDataWrapper newtype to bridge tiberius's IntoSql gap for time types and BigDecimal (version mismatch). JSON is stored as NVARCHAR since SQL Server has no native JSON column type. Author: Pablo Carrera --- Cargo.lock | 21 +++- sqlx-mssql/Cargo.toml | 4 +- sqlx-mssql/src/connection/executor.rs | 99 ++++++++++++++++++ sqlx-mssql/src/database.rs | 10 ++ sqlx-mssql/src/type_checking.rs | 14 ++- sqlx-mssql/src/types/bigdecimal.rs | 46 +++++++++ sqlx-mssql/src/types/json.rs | 42 ++++++++ sqlx-mssql/src/types/mod.rs | 12 +++ sqlx-mssql/src/types/time.rs | 140 ++++++++++++++++++++++++++ sqlx-mssql/src/value.rs | 84 ++++++++++++++++ 10 files changed, 464 insertions(+), 8 deletions(-) create mode 100644 sqlx-mssql/src/types/bigdecimal.rs create mode 100644 sqlx-mssql/src/types/json.rs create mode 100644 sqlx-mssql/src/types/time.rs diff --git a/Cargo.lock b/Cargo.lock index 1a18141114..ddbad9d857 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -537,6 +537,17 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +[[package]] +name = "bigdecimal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bigdecimal" version = "0.4.8" @@ -3619,7 +3630,7 @@ dependencies = [ "async-std", "async-task", "base64 0.22.1", - "bigdecimal", + "bigdecimal 0.4.8", "bit-vec", "bstr", "bytes", @@ -3952,7 +3963,7 @@ version = "0.9.0-alpha.1" dependencies = [ "async-std", "atoi", - "bigdecimal", + "bigdecimal 0.4.8", "bytes", "chrono", "dotenvy", @@ -3981,7 +3992,7 @@ version = "0.9.0-alpha.1" dependencies = [ "atoi", "base64 0.22.1", - "bigdecimal", + "bigdecimal 0.4.8", "bitflags 2.9.1", "byteorder", "bytes", @@ -4025,7 +4036,7 @@ version = "0.9.0-alpha.1" dependencies = [ "atoi", "base64 0.22.1", - "bigdecimal", + "bigdecimal 0.4.8", "bit-vec", "bitflags 2.9.1", "byteorder", @@ -4314,6 +4325,7 @@ checksum = "a1446cb4198848d1562301a3340424b4f425ef79f35ef9ee034769a9dd92c10d" dependencies = [ "async-trait", "asynchronous-codec", + "bigdecimal 0.3.1", "byteorder", "bytes", "chrono", @@ -4327,6 +4339,7 @@ dependencies = [ "pretty-hex", "rust_decimal", "thiserror 1.0.69", + "time", "tracing", "uuid", ] diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml index cbff784593..eb9f7f1b0d 100644 --- a/sqlx-mssql/Cargo.toml +++ b/sqlx-mssql/Cargo.toml @@ -16,10 +16,10 @@ offline = ["sqlx-core/offline", "serde"] migrate = ["sqlx-core/migrate"] # Type Integration features -bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal"] +bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal", "tiberius/bigdecimal"] chrono = ["dep:chrono", "sqlx-core/chrono", "tiberius/chrono"] rust_decimal = ["dep:rust_decimal", "sqlx-core/rust_decimal", "tiberius/rust_decimal"] -time = ["dep:time", "sqlx-core/time"] +time = ["dep:time", "sqlx-core/time", "tiberius/time"] uuid = ["dep:uuid", "sqlx-core/uuid"] [dependencies] diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 7d2c73cdf8..997414e6bc 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -17,6 +17,22 @@ use futures_util::TryStreamExt; use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; use std::sync::Arc; +/// Newtype wrapper to bridge `tiberius::ColumnData` into `tiberius::IntoSql`. +/// +/// tiberius implements `ToSql` but not `IntoSql` for some types (e.g. `time` +/// crate types, and `BigDecimal` due to version mismatch). `Query::bind()` +/// requires `IntoSql`, so this wrapper lets us construct `ColumnData` manually +/// and pass it to `bind()`. +#[cfg(any(feature = "time", feature = "bigdecimal"))] +struct ColumnDataWrapper<'a>(tiberius::ColumnData<'a>); + +#[cfg(any(feature = "time", feature = "bigdecimal"))] +impl<'a> tiberius::IntoSql<'a> for ColumnDataWrapper<'a> { + fn into_sql(self) -> tiberius::ColumnData<'a> { + self.0 + } +} + impl MssqlConnection { /// Execute a query, eagerly collecting all results. /// @@ -104,6 +120,89 @@ impl MssqlConnection { v.scale() as u8, )); } + #[cfg(feature = "time")] + MssqlArgumentValue::TimeDate(v) => { + let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); + let days = (*v - epoch).whole_days() as u32; + let cd = tiberius::ColumnData::Date(Some( + tiberius::time::Date::new(days), + )); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimeTime(v) => { + let (h, m, s, ns) = v.as_hms_nano(); + let total_ns = h as u64 * 3_600_000_000_000 + + m as u64 * 60_000_000_000 + + s as u64 * 1_000_000_000 + + ns as u64; + // Scale 7 = 100ns increments + let increments = total_ns / 100; + let cd = tiberius::ColumnData::Time(Some( + tiberius::time::Time::new(increments, 7), + )); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimePrimitiveDateTime(v) => { + let date = v.date(); + let time = v.time(); + let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); + let days = (date - epoch).whole_days() as u32; + let (h, m, s, ns) = time.as_hms_nano(); + let total_ns = h as u64 * 3_600_000_000_000 + + m as u64 * 60_000_000_000 + + s as u64 * 1_000_000_000 + + ns as u64; + let increments = total_ns / 100; + let cd = tiberius::ColumnData::DateTime2(Some( + tiberius::time::DateTime2::new( + tiberius::time::Date::new(days), + tiberius::time::Time::new(increments, 7), + ), + )); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimeOffsetDateTime(v) => { + let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); + let offset_minutes = v.offset().whole_seconds() / 60; + let date = v.date(); + let time = v.time(); + let days = (date - epoch).whole_days() as u32; + let (h, m, s, ns) = time.as_hms_nano(); + let total_ns = h as u64 * 3_600_000_000_000 + + m as u64 * 60_000_000_000 + + s as u64 * 1_000_000_000 + + ns as u64; + let increments = total_ns / 100; + let dt2 = tiberius::time::DateTime2::new( + tiberius::time::Date::new(days), + tiberius::time::Time::new(increments, 7), + ); + let cd = tiberius::ColumnData::DateTimeOffset(Some( + tiberius::time::DateTimeOffset::new( + dt2, + offset_minutes as i16, + ), + )); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "bigdecimal")] + MssqlArgumentValue::BigDecimal(v) => { + use bigdecimal::ToPrimitive; + // Convert BigDecimal to tiberius Numeric + let (bigint, exponent) = v.as_bigint_and_exponent(); + let scale = exponent.max(0) as u8; + // Convert to i128 for Numeric — panics if too large + let value: i128 = bigint + .to_i128() + .expect("BigDecimal value too large for SQL NUMERIC"); + let cd = tiberius::ColumnData::Numeric(Some( + tiberius::numeric::Numeric::new_with_scale(value, scale), + )); + query.bind(ColumnDataWrapper(cd)); + } } } diff --git a/sqlx-mssql/src/database.rs b/sqlx-mssql/src/database.rs index 8f57323ce4..379b35d824 100644 --- a/sqlx-mssql/src/database.rs +++ b/sqlx-mssql/src/database.rs @@ -63,4 +63,14 @@ pub enum MssqlArgumentValue { Uuid(uuid::Uuid), #[cfg(feature = "rust_decimal")] Decimal(rust_decimal::Decimal), + #[cfg(feature = "time")] + TimeDate(time::Date), + #[cfg(feature = "time")] + TimeTime(time::Time), + #[cfg(feature = "time")] + TimePrimitiveDateTime(time::PrimitiveDateTime), + #[cfg(feature = "time")] + TimeOffsetDateTime(time::OffsetDateTime), + #[cfg(feature = "bigdecimal")] + BigDecimal(bigdecimal::BigDecimal), } diff --git a/sqlx-mssql/src/type_checking.rs b/sqlx-mssql/src/type_checking.rs index 86dc7f621a..f3544207ee 100644 --- a/sqlx-mssql/src/type_checking.rs +++ b/sqlx-mssql/src/type_checking.rs @@ -24,6 +24,9 @@ impl_type_checking!( #[cfg(feature = "uuid")] sqlx::types::Uuid, + + #[cfg(feature = "json")] + sqlx::types::JsonValue, }, ParamChecking::Weak, feature-types: _info => None, @@ -34,10 +37,17 @@ impl_type_checking!( sqlx::types::chrono::NaiveDateTime, sqlx::types::chrono::DateTime, }, - time: { }, + time: { + sqlx::types::time::Time, + sqlx::types::time::Date, + sqlx::types::time::PrimitiveDateTime, + sqlx::types::time::OffsetDateTime, + }, }, numeric-types: { - bigdecimal: { }, + bigdecimal: { + sqlx::types::BigDecimal, + }, rust_decimal: { sqlx::types::Decimal, }, diff --git a/sqlx-mssql/src/types/bigdecimal.rs b/sqlx-mssql/src/types/bigdecimal.rs new file mode 100644 index 0000000000..330d978d8e --- /dev/null +++ b/sqlx-mssql/src/types/bigdecimal.rs @@ -0,0 +1,46 @@ +use bigdecimal::BigDecimal; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for BigDecimal { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DECIMAL") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.name.as_str(), "DECIMAL" | "NUMERIC" | "MONEY") + } +} + +impl Encode<'_, Mssql> for BigDecimal { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::BigDecimal(self.clone())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for BigDecimal { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::BigDecimal(ref v) => Ok(v.clone()), + MssqlData::I32(v) => Ok(BigDecimal::from(*v)), + MssqlData::I64(v) => Ok(BigDecimal::from(*v)), + MssqlData::F64(v) => bigdecimal::FromPrimitive::from_f64(*v) + .ok_or_else(|| format!("failed to convert f64 {v} to BigDecimal").into()), + MssqlData::String(ref s) => s + .parse::() + .map_err(|e| format!("failed to parse BigDecimal from string: {e}").into()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected DECIMAL, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/json.rs b/sqlx-mssql/src/types/json.rs new file mode 100644 index 0000000000..5dd5b76e09 --- /dev/null +++ b/sqlx-mssql/src/types/json.rs @@ -0,0 +1,42 @@ +use serde::{Deserialize, Serialize}; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{Json, Type}; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for Json { + fn type_info() -> MssqlTypeInfo { + // SQL Server has no native JSON type; JSON is stored as NVARCHAR + MssqlTypeInfo::new("NVARCHAR") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Encode<'_, Mssql> for Json +where + T: Serialize, +{ + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + let json_string = self.encode_to_string()?; + buf.push(MssqlArgumentValue::String(json_string)); + Ok(IsNull::No) + } +} + +impl<'r, T> Decode<'r, Mssql> for Json +where + T: Deserialize<'r> + 'r, +{ + fn decode(value: MssqlValueRef<'r>) -> Result { + Json::decode_from_string(value.as_str()?) + } +} diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs index d487f902c0..b830b018aa 100644 --- a/sqlx-mssql/src/types/mod.rs +++ b/sqlx-mssql/src/types/mod.rs @@ -21,6 +21,12 @@ //! |---------------------------------------|------------------------------------------------------| //! | `uuid::Uuid` | UNIQUEIDENTIFIER | //! | `rust_decimal::Decimal` | DECIMAL, NUMERIC, MONEY | +//! | `bigdecimal::BigDecimal` | DECIMAL, NUMERIC, MONEY | +//! | `time::Date` | DATE | +//! | `time::Time` | TIME | +//! | `time::PrimitiveDateTime` | DATETIME2, DATETIME, SMALLDATETIME | +//! | `time::OffsetDateTime` | DATETIMEOFFSET, DATETIME2 | +//! | `serde_json::Value` (`Json`) | NVARCHAR (JSON stored as string) | //! //! # Nullable //! @@ -29,14 +35,20 @@ pub(crate) use sqlx_core::types::*; +#[cfg(feature = "bigdecimal")] +mod bigdecimal; mod bool; mod bytes; #[cfg(feature = "chrono")] mod chrono; mod float; mod int; +#[cfg(feature = "json")] +mod json; #[cfg(feature = "rust_decimal")] mod rust_decimal; mod str; +#[cfg(feature = "time")] +mod time; #[cfg(feature = "uuid")] mod uuid; diff --git a/sqlx-mssql/src/types/time.rs b/sqlx-mssql/src/types/time.rs new file mode 100644 index 0000000000..d1b79ae251 --- /dev/null +++ b/sqlx-mssql/src/types/time.rs @@ -0,0 +1,140 @@ +use time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +// ── Date ─────────────────────────────────────────────────────────────────── + +impl Type for Date { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATE") + } +} + +impl Encode<'_, Mssql> for Date { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::TimeDate(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Date { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimeDate(v) => Ok(*v), + MssqlData::TimePrimitiveDateTime(v) => Ok(v.date()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected date, got {:?}", value.data).into()), + } + } +} + +// ── Time ─────────────────────────────────────────────────────────────────── + +impl Type for Time { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TIME") + } +} + +impl Encode<'_, Mssql> for Time { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::TimeTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Time { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimeTime(v) => Ok(*v), + MssqlData::TimePrimitiveDateTime(v) => Ok(v.time()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected time, got {:?}", value.data).into()), + } + } +} + +// ── PrimitiveDateTime ────────────────────────────────────────────────────── + +impl Type for PrimitiveDateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIME2") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "DATETIME2" | "DATETIME" | "SMALLDATETIME" + ) + } +} + +impl Encode<'_, Mssql> for PrimitiveDateTime { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::TimePrimitiveDateTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for PrimitiveDateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimePrimitiveDateTime(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetime, got {:?}", value.data).into()), + } + } +} + +// ── OffsetDateTime ───────────────────────────────────────────────────────── + +impl Type for OffsetDateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIMEOFFSET") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "DATETIMEOFFSET" | "DATETIME2" + ) + } +} + +impl Encode<'_, Mssql> for OffsetDateTime { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::TimeOffsetDateTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for OffsetDateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimeOffsetDateTime(v) => Ok(*v), + MssqlData::TimePrimitiveDateTime(v) => { + Ok(v.assume_utc()) + } + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetimeoffset, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index 67930cf244..c44834d2eb 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -28,6 +28,16 @@ pub(crate) enum MssqlData { Uuid(uuid::Uuid), #[cfg(feature = "rust_decimal")] Decimal(rust_decimal::Decimal), + #[cfg(feature = "time")] + TimeDate(time::Date), + #[cfg(feature = "time")] + TimeTime(time::Time), + #[cfg(feature = "time")] + TimePrimitiveDateTime(time::PrimitiveDateTime), + #[cfg(feature = "time")] + TimeOffsetDateTime(time::OffsetDateTime), + #[cfg(feature = "bigdecimal")] + BigDecimal(bigdecimal::BigDecimal), } /// Implementation of [`Value`] for MSSQL. @@ -175,11 +185,85 @@ pub(crate) fn column_data_to_mssql_data(data: &tiberius::ColumnData<'_>) -> Mssq )) } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::Date(Some(d)) => { + MssqlData::TimeDate(time_date_from_days(d.days() as u64, 1)) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::Time(Some(t)) => { + let ns = t.increments() as u64 + * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + MssqlData::TimeTime(time_from_sec_fragments(ns)) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::DateTime2(Some(dt2)) => { + let date = time_date_from_days(dt2.date().days() as u64, 1); + let ns = dt2.time().increments() as u64 + * 10u64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); + let time = time_from_sec_fragments(ns); + MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time)) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::DateTime(Some(dt)) => { + let date = time_date_from_days(dt.days() as u64, 1900); + let ns = dt.seconds_fragments() as u64 * 1_000_000_000u64 / 300; + let time = time_from_sec_fragments(ns); + MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time)) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::SmallDateTime(Some(dt)) => { + let date = time_date_from_days(dt.days() as u64, 1900); + let seconds = dt.seconds_fragments() as u64 * 60; + let time = time_from_sec_fragments(seconds * 1_000_000_000); + MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time)) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::DateTimeOffset(Some(dto)) => { + let date = time_date_from_days(dto.datetime2().date().days() as u64, 1); + let ns = dto.datetime2().time().increments() as u64 + * 10u64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); + let time = time_from_sec_fragments(ns); + let naive = time::PrimitiveDateTime::new(date, time); + let offset = time::UtcOffset::from_whole_seconds(dto.offset() as i32 * 60) + .expect("valid UTC offset from tiberius"); + MssqlData::TimeOffsetDateTime(naive.assume_offset(offset)) + } + + #[cfg(all(feature = "bigdecimal", not(feature = "rust_decimal")))] + tiberius::ColumnData::Numeric(Some(n)) => { + use bigdecimal::num_bigint::BigInt; + MssqlData::BigDecimal(bigdecimal::BigDecimal::new( + BigInt::from(n.value()), + n.scale() as i64, + )) + } + // All None variants and unhandled types map to Null _ => MssqlData::Null, } } +/// Convert days since `start_year`-01-01 to a `time::Date`. +#[cfg(feature = "time")] +fn time_date_from_days(days: u64, start_year: i32) -> time::Date { + let start = time::Date::from_ordinal_date(start_year, 1).expect("valid start date"); + start + .checked_add(time::Duration::days(days as i64)) + .expect("valid date from days offset") +} + +/// Convert nanoseconds-since-midnight to a `time::Time`. +#[cfg(feature = "time")] +fn time_from_sec_fragments(nanoseconds: u64) -> time::Time { + let hours = (nanoseconds / 3_600_000_000_000) as u8; + let remaining = nanoseconds % 3_600_000_000_000; + let minutes = (remaining / 60_000_000_000) as u8; + let remaining = remaining % 60_000_000_000; + let seconds = (remaining / 1_000_000_000) as u8; + let nanos = (remaining % 1_000_000_000) as u32; + time::Time::from_hms_nano(hours, minutes, seconds, nanos).expect("valid time") +} + /// Convert days since `start_year`-01-01 to a `chrono::NaiveDate`. #[cfg(feature = "chrono")] fn chrono_date_from_days(days: i64, start_year: i32) -> chrono::NaiveDate { From df5b9af7fc000e76ddb42219a5cc78b5d2b64055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 02:54:39 -0500 Subject: [PATCH 05/33] feat: add comprehensive test coverage for sqlx-mssql types, errors, derives, and test attributes Add tests for all implemented Encode/Decode types (uuid, chrono, time, rust_decimal, bigdecimal, json, bytes), error kind mapping, derive macros (weak enums, transparent types), and #[sqlx::test] attribute with migrations and fixtures. Author: Pablo Carrera --- Cargo.toml | 15 +++ tests/mssql/derives.rs | 110 ++++++++++++++++++++++ tests/mssql/error.rs | 80 ++++++++++++++++ tests/mssql/fixtures/comments.sql | 4 + tests/mssql/fixtures/posts.sql | 5 + tests/mssql/fixtures/users.sql | 3 + tests/mssql/migrations/1_user.sql | 4 + tests/mssql/migrations/2_post.sql | 7 ++ tests/mssql/migrations/3_comment.sql | 8 ++ tests/mssql/setup.sql | 25 +++++ tests/mssql/test-attr.rs | 100 ++++++++++++++++++++ tests/mssql/types.rs | 134 +++++++++++++++++++++++++++ 12 files changed, 495 insertions(+) create mode 100644 tests/mssql/derives.rs create mode 100644 tests/mssql/error.rs create mode 100644 tests/mssql/fixtures/comments.sql create mode 100644 tests/mssql/fixtures/posts.sql create mode 100644 tests/mssql/fixtures/users.sql create mode 100644 tests/mssql/migrations/1_user.sql create mode 100644 tests/mssql/migrations/2_post.sql create mode 100644 tests/mssql/migrations/3_comment.sql create mode 100644 tests/mssql/test-attr.rs diff --git a/Cargo.toml b/Cargo.toml index 19220e1316..58d2265417 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -482,3 +482,18 @@ required-features = ["mssql"] name = "mssql-macros" path = "tests/mssql/macros.rs" required-features = ["mssql", "macros"] + +[[test]] +name = "mssql-derives" +path = "tests/mssql/derives.rs" +required-features = ["mssql", "derive"] + +[[test]] +name = "mssql-error" +path = "tests/mssql/error.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-test-attr" +path = "tests/mssql/test-attr.rs" +required-features = ["mssql", "macros", "migrate"] diff --git a/tests/mssql/derives.rs b/tests/mssql/derives.rs new file mode 100644 index 0000000000..a9e91d439c --- /dev/null +++ b/tests/mssql/derives.rs @@ -0,0 +1,110 @@ +use sqlx::mssql::Mssql; +use sqlx_test::{new, test_type}; + +#[sqlx::test] +async fn test_derive_weak_enum() -> anyhow::Result<()> { + #[derive(sqlx::Type, Debug, PartialEq, Eq)] + #[repr(i16)] + enum WeakEnumI16 { + Foo = i16::MIN, + Bar = 0, + Baz = i16::MAX, + } + + #[derive(sqlx::Type, Debug, PartialEq, Eq)] + #[repr(i32)] + enum WeakEnumI32 { + Foo = i32::MIN, + Bar = 0, + Baz = i32::MAX, + } + + #[derive(sqlx::Type, Debug, PartialEq, Eq)] + #[repr(i64)] + enum WeakEnumI64 { + Foo = i64::MIN, + Bar = 0, + Baz = i64::MAX, + } + + #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] + struct WeakEnumRow { + i16: WeakEnumI16, + i32: WeakEnumI32, + i64: WeakEnumI64, + } + + let mut conn = new::().await?; + + sqlx::raw_sql( + r#" + CREATE TABLE #weak_enum ( + i16 SMALLINT, + i32 INT, + i64 BIGINT + ) + "#, + ) + .execute(&mut conn) + .await?; + + let rows_in = vec![ + WeakEnumRow { + i16: WeakEnumI16::Foo, + i32: WeakEnumI32::Foo, + i64: WeakEnumI64::Foo, + }, + WeakEnumRow { + i16: WeakEnumI16::Bar, + i32: WeakEnumI32::Bar, + i64: WeakEnumI64::Bar, + }, + WeakEnumRow { + i16: WeakEnumI16::Baz, + i32: WeakEnumI32::Baz, + i64: WeakEnumI64::Baz, + }, + ]; + + for row in &rows_in { + sqlx::query( + r#" + INSERT INTO #weak_enum(i16, i32, i64) + VALUES (@p1, @p2, @p3) + "#, + ) + .bind(&row.i16) + .bind(&row.i32) + .bind(&row.i64) + .execute(&mut conn) + .await?; + } + + let rows_out: Vec = sqlx::query_as("SELECT * FROM #weak_enum") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows_in, rows_out); + + Ok(()) +} + +#[derive(PartialEq, Eq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct TransparentTuple(i64); + +#[derive(PartialEq, Eq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct TransparentNamed { + field: i64, +} + +test_type!(transparent_tuple(Mssql, + "CAST(0 AS BIGINT)" == TransparentTuple(0), + "CAST(23523 AS BIGINT)" == TransparentTuple(23523) +)); + +test_type!(transparent_named(Mssql, + "CAST(0 AS BIGINT)" == TransparentNamed { field: 0 }, + "CAST(23523 AS BIGINT)" == TransparentNamed { field: 23523 }, +)); diff --git a/tests/mssql/error.rs b/tests/mssql/error.rs new file mode 100644 index 0000000000..6dc9e66931 --- /dev/null +++ b/tests/mssql/error.rs @@ -0,0 +1,80 @@ +use sqlx::error::ErrorKind; +use sqlx::mssql::Mssql; +use sqlx::Connection; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_fails_with_unique_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO tweet(id, text, owner_id) VALUES (1, 'Foo', 1)") + .execute(&mut *tx) + .await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO tweet(id, text, owner_id) VALUES (1, 'Bar', 1)") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::UniqueViolation); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_foreign_key_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO tweet_reply (tweet_id, text) VALUES (999, 'Reply!')") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::ForeignKeyViolation); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_not_null_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO tweet (id, text) VALUES (1, NULL)") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::NotNullViolation); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_check_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO products (product_no, name, price) VALUES (1, 'Product 1', 0)") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::CheckViolation); + + Ok(()) +} diff --git a/tests/mssql/fixtures/comments.sql b/tests/mssql/fixtures/comments.sql new file mode 100644 index 0000000000..93ea9941f6 --- /dev/null +++ b/tests/mssql/fixtures/comments.sql @@ -0,0 +1,4 @@ +INSERT INTO comment(comment_id, post_id, user_id, content, created_at) +VALUES (1, 1, 2, 'lol bet ur still bad, 1v1 me', DATEADD(MINUTE, -50, SYSUTCDATETIME())), + (2, 1, 1, 'you''re on!', DATEADD(MINUTE, -45, SYSUTCDATETIME())), + (3, 2, 1, 'lol you''re just mad you lost :P', DATEADD(MINUTE, -15, SYSUTCDATETIME())); diff --git a/tests/mssql/fixtures/posts.sql b/tests/mssql/fixtures/posts.sql new file mode 100644 index 0000000000..e75d0d9381 --- /dev/null +++ b/tests/mssql/fixtures/posts.sql @@ -0,0 +1,5 @@ +SET IDENTITY_INSERT post ON; +INSERT INTO post(post_id, user_id, content, created_at) +VALUES (1, 1, 'This new computer is lightning-fast!', DATEADD(HOUR, -1, SYSUTCDATETIME())), + (2, 2, '@alice is a haxxor :(', DATEADD(MINUTE, -30, SYSUTCDATETIME())); +SET IDENTITY_INSERT post OFF; diff --git a/tests/mssql/fixtures/users.sql b/tests/mssql/fixtures/users.sql new file mode 100644 index 0000000000..0d4270c282 --- /dev/null +++ b/tests/mssql/fixtures/users.sql @@ -0,0 +1,3 @@ +SET IDENTITY_INSERT [user] ON; +INSERT INTO [user](user_id, username) VALUES (1, 'alice'), (2, 'bob'); +SET IDENTITY_INSERT [user] OFF; diff --git a/tests/mssql/migrations/1_user.sql b/tests/mssql/migrations/1_user.sql new file mode 100644 index 0000000000..ed0bb63cdf --- /dev/null +++ b/tests/mssql/migrations/1_user.sql @@ -0,0 +1,4 @@ +CREATE TABLE [user] ( + user_id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, + username NVARCHAR(16) NOT NULL UNIQUE +); diff --git a/tests/mssql/migrations/2_post.sql b/tests/mssql/migrations/2_post.sql new file mode 100644 index 0000000000..cbdd07cdd4 --- /dev/null +++ b/tests/mssql/migrations/2_post.sql @@ -0,0 +1,7 @@ +CREATE TABLE post ( + post_id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, + user_id INT NOT NULL REFERENCES [user](user_id), + content NVARCHAR(MAX) NOT NULL, + created_at DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME() +); +CREATE INDEX post_created_at ON post (created_at DESC); diff --git a/tests/mssql/migrations/3_comment.sql b/tests/mssql/migrations/3_comment.sql new file mode 100644 index 0000000000..8f168a2e3d --- /dev/null +++ b/tests/mssql/migrations/3_comment.sql @@ -0,0 +1,8 @@ +CREATE TABLE comment ( + comment_id INT NOT NULL PRIMARY KEY, + post_id INT NOT NULL REFERENCES post(post_id), + user_id INT NOT NULL REFERENCES [user](user_id), + content NVARCHAR(MAX) NOT NULL, + created_at DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME() +); +CREATE INDEX comment_created_at ON comment (created_at DESC); diff --git a/tests/mssql/setup.sql b/tests/mssql/setup.sql index a033227b75..4a78cccfa7 100644 --- a/tests/mssql/setup.sql +++ b/tests/mssql/setup.sql @@ -18,3 +18,28 @@ IF OBJECT_ID('tweet') IS NULL ); END; GO + +IF OBJECT_ID('tweet_reply') IS NULL + BEGIN + CREATE TABLE tweet_reply + ( + id BIGINT NOT NULL IDENTITY(1,1) PRIMARY KEY, + tweet_id BIGINT NOT NULL, + text NVARCHAR(4000) NOT NULL, + owner_id BIGINT, + CONSTRAINT tweet_id_fk FOREIGN KEY (tweet_id) REFERENCES tweet(id) + ); + END; +GO + +IF OBJECT_ID('products') IS NULL + BEGIN + CREATE TABLE products + ( + product_no INT, + name NVARCHAR(200), + price DECIMAL(10,2), + CONSTRAINT chk_price CHECK (price > 0) + ); + END; +GO diff --git a/tests/mssql/test-attr.rs b/tests/mssql/test-attr.rs new file mode 100644 index 0000000000..25854a4cff --- /dev/null +++ b/tests/mssql/test-attr.rs @@ -0,0 +1,100 @@ +// The no-arg variant is covered by other tests already. + +use sqlx::MssqlPool; + +const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/mssql/migrations"); + +#[sqlx::test] +async fn it_gets_a_pool(pool: MssqlPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + let db_name: String = sqlx::query_scalar("SELECT DB_NAME()") + .fetch_one(&mut *conn) + .await?; + + assert!( + db_name.starts_with("_sqlx_test_"), + "db_name: {:?}", + db_name + ); + + Ok(()) +} + +// This should apply migrations and then `fixtures/users.sql` +#[sqlx::test(migrations = "tests/mssql/migrations", fixtures("users"))] +async fn it_gets_users(pool: MssqlPool) -> sqlx::Result<()> { + let usernames: Vec = + sqlx::query_scalar(r#"SELECT username FROM [user] ORDER BY username"#) + .fetch_all(&pool) + .await?; + + assert_eq!(usernames, ["alice", "bob"]); + + let post_count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM post") + .fetch_one(&pool) + .await?; + + assert_eq!(post_count, 0); + + let comment_count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM comment") + .fetch_one(&pool) + .await?; + + assert_eq!(comment_count, 0); + + Ok(()) +} + +#[sqlx::test(migrations = "tests/mssql/migrations", fixtures("users", "posts"))] +async fn it_gets_posts(pool: MssqlPool) -> sqlx::Result<()> { + let post_contents: Vec = + sqlx::query_scalar("SELECT content FROM post ORDER BY created_at") + .fetch_all(&pool) + .await?; + + assert_eq!( + post_contents, + [ + "This new computer is lightning-fast!", + "@alice is a haxxor :(" + ] + ); + + let comment_count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM comment") + .fetch_one(&pool) + .await?; + + assert_eq!(comment_count, 0); + + Ok(()) +} + +// Try `migrator` +#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))] +async fn it_gets_comments(pool: MssqlPool) -> sqlx::Result<()> { + let post_1_comments: Vec = + sqlx::query_scalar( + "SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at", + ) + .bind(&1) + .fetch_all(&pool) + .await?; + + assert_eq!( + post_1_comments, + ["lol bet ur still bad, 1v1 me", "you're on!"] + ); + + let post_2_comments: Vec = + sqlx::query_scalar( + "SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at", + ) + .bind(&2) + .fetch_all(&pool) + .await?; + + assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]); + + Ok(()) +} diff --git a/tests/mssql/types.rs b/tests/mssql/types.rs index 3e9ae7395b..097a0dbaa3 100644 --- a/tests/mssql/types.rs +++ b/tests/mssql/types.rs @@ -1,3 +1,5 @@ +extern crate time_ as time; + use sqlx::mssql::Mssql; use sqlx_test::test_type; @@ -48,3 +50,135 @@ test_type!(bool( "CAST(1 as BIT)" == true, "CAST(0 as BIT)" == false )); + +test_type!(bytes>(Mssql, + "CAST(0xDEADBEEF AS VARBINARY(MAX))" + == vec![0xDE_u8, 0xAD, 0xBE, 0xEF], + "CAST(0x AS VARBINARY(MAX))" + == Vec::::new(), + "CAST(0x0000000000000000 AS VARBINARY(MAX))" + == vec![0_u8; 8], +)); + +#[cfg(feature = "uuid")] +test_type!(uuid(Mssql, + "CAST('00000000-0000-0000-0000-000000000000' AS UNIQUEIDENTIFIER)" + == sqlx::types::Uuid::nil(), + "CAST('936da01f-9abd-4d9d-80c7-02af85c822a8' AS UNIQUEIDENTIFIER)" + == sqlx::types::Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(), +)); + +#[cfg(feature = "chrono")] +mod chrono { + use sqlx::mssql::Mssql; + use sqlx_test::test_type; + + type NaiveDate = sqlx::types::chrono::NaiveDate; + type NaiveTime = sqlx::types::chrono::NaiveTime; + type NaiveDateTime = sqlx::types::chrono::NaiveDateTime; + type DateTimeUtc = sqlx::types::chrono::DateTime; + + test_type!(chrono_naive_date(Mssql, + "CAST('2001-01-05' AS DATE)" + == NaiveDate::from_ymd_opt(2001, 1, 5).unwrap(), + "CAST('2050-11-23' AS DATE)" + == NaiveDate::from_ymd_opt(2050, 11, 23).unwrap(), + )); + + test_type!(chrono_naive_time(Mssql, + "CAST('05:10:20' AS TIME)" + == NaiveTime::from_hms_opt(5, 10, 20).unwrap(), + "CAST('00:00:00' AS TIME)" + == NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + )); + + test_type!(chrono_naive_date_time(Mssql, + "CAST('2019-01-02 05:10:20' AS DATETIME2)" + == NaiveDateTime::new( + NaiveDate::from_ymd_opt(2019, 1, 2).unwrap(), + NaiveTime::from_hms_opt(5, 10, 20).unwrap(), + ), + )); + + test_type!(chrono_date_time_utc(Mssql, + "CAST('2019-01-02 05:10:20.000 +00:00' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2019, 1, 2) + .unwrap() + .and_hms_opt(5, 10, 20) + .unwrap() + .and_utc(), + )); +} + +#[cfg(feature = "time")] +mod time_tests { + use sqlx::mssql::Mssql; + use sqlx_test::test_type; + + type TimeDate = sqlx::types::time::Date; + type TimeTime = sqlx::types::time::Time; + type TimePrimitiveDateTime = sqlx::types::time::PrimitiveDateTime; + type TimeOffsetDateTime = sqlx::types::time::OffsetDateTime; + + use time::macros::{date, time as time_macro, datetime}; + + test_type!(time_date(Mssql, + "CAST('2001-01-05' AS DATE)" + == date!(2001-01-05), + "CAST('2050-11-23' AS DATE)" + == date!(2050-11-23), + )); + + test_type!(time_time(Mssql, + "CAST('05:10:20' AS TIME)" + == time_macro!(05:10:20), + "CAST('00:00:00' AS TIME)" + == time_macro!(00:00:00), + )); + + test_type!(time_primitive_date_time(Mssql, + "CAST('2019-01-02 05:10:20' AS DATETIME2)" + == datetime!(2019-01-02 05:10:20), + )); + + test_type!(time_offset_date_time(Mssql, + "CAST('2019-01-02 05:10:20.000 +00:00' AS DATETIMEOFFSET)" + == datetime!(2019-01-02 05:10:20 UTC), + )); +} + +#[cfg(feature = "rust_decimal")] +test_type!(rust_decimal(Mssql, + "CAST('0' AS DECIMAL(10,2))" == sqlx::types::Decimal::ZERO, + "CAST('1.23' AS DECIMAL(10,2))" == sqlx::types::Decimal::new(123, 2), + "CAST('-1.23' AS DECIMAL(10,2))" == sqlx::types::Decimal::new(-123, 2), +)); + +#[cfg(feature = "bigdecimal")] +test_type!(bigdecimal(Mssql, + "CAST('0' AS DECIMAL(10,2))" == "0.00".parse::().unwrap(), + "CAST('1.23' AS DECIMAL(10,2))" == "1.23".parse::().unwrap(), + "CAST('-1.23' AS DECIMAL(10,2))" == "-1.23".parse::().unwrap(), +)); + +#[cfg(feature = "json")] +mod json_tests { + use sqlx::mssql::Mssql; + use sqlx::types::Json; + use sqlx_test::test_type; + + #[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)] + struct Friend { + name: String, + age: u32, + } + + test_type!(json>(Mssql, + "CAST('{\"name\":\"Joe\",\"age\":33}' AS NVARCHAR(MAX))" + == Json(Friend { name: "Joe".to_string(), age: 33 }), + )); + + test_type!(json_value(Mssql, + "CAST('null' AS NVARCHAR(MAX))" == serde_json::Value::Null, + )); +} From 809fbf5d94089b7645387ad56569b9b52fcba5a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 03:27:31 -0500 Subject: [PATCH 06/33] feat: add SSL mode enum, advisory locks, and MONEY type support to sqlx-mssql Add MssqlSslMode with 4 variants (Disabled, LoginOnly, Preferred, Required) mapping to all tiberius EncryptionLevel options. Add MssqlAdvisoryLock API using sp_getapplock/sp_releaseapplock for application-level distributed locking. Enable f64/Decimal/BigDecimal decoding from MONEY/SMALLMONEY columns. Author: Pablo Carrera --- Cargo.toml | 5 + sqlx-mssql/src/advisory_lock.rs | 192 +++++++++++++++++++++++++++++ sqlx-mssql/src/lib.rs | 4 + sqlx-mssql/src/options/mod.rs | 35 ++++-- sqlx-mssql/src/options/parse.rs | 80 ++++++++++++ sqlx-mssql/src/options/ssl_mode.rs | 18 +++ sqlx-mssql/src/types/float.rs | 2 +- sqlx-mssql/src/types/mod.rs | 4 +- tests/mssql/advisory-lock.rs | 85 +++++++++++++ tests/mssql/types.rs | 26 ++++ 10 files changed, 439 insertions(+), 12 deletions(-) create mode 100644 sqlx-mssql/src/advisory_lock.rs create mode 100644 sqlx-mssql/src/options/ssl_mode.rs create mode 100644 tests/mssql/advisory-lock.rs diff --git a/Cargo.toml b/Cargo.toml index 58d2265417..6b488a9e27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -497,3 +497,8 @@ required-features = ["mssql"] name = "mssql-test-attr" path = "tests/mssql/test-attr.rs" required-features = ["mssql", "macros", "migrate"] + +[[test]] +name = "mssql-advisory-lock" +path = "tests/mssql/advisory-lock.rs" +required-features = ["mssql"] diff --git a/sqlx-mssql/src/advisory_lock.rs b/sqlx-mssql/src/advisory_lock.rs new file mode 100644 index 0000000000..60857ec0ca --- /dev/null +++ b/sqlx-mssql/src/advisory_lock.rs @@ -0,0 +1,192 @@ +use crate::error::Error; +use crate::query_scalar::query_scalar; +use crate::MssqlConnection; + +/// The lock mode for a MSSQL advisory lock. +/// +/// Maps to the `@LockMode` parameter of `sp_getapplock`. +#[derive(Debug, Clone, Copy, Default)] +pub enum MssqlAdvisoryLockMode { + /// A shared lock, compatible with other `Shared` and `Update` locks. + Shared, + + /// An update lock, compatible with `Shared` but not with other `Update` or `Exclusive`. + Update, + + /// An exclusive lock, incompatible with all other lock modes. + #[default] + Exclusive, +} + +impl MssqlAdvisoryLockMode { + fn as_str(&self) -> &'static str { + match self { + MssqlAdvisoryLockMode::Shared => "Shared", + MssqlAdvisoryLockMode::Update => "Update", + MssqlAdvisoryLockMode::Exclusive => "Exclusive", + } + } +} + +/// A session-scoped advisory lock backed by SQL Server's `sp_getapplock` / +/// `sp_releaseapplock`. +/// +/// Advisory locks are cooperative: they don't block access to any database +/// object; instead, all participants must explicitly acquire the same named +/// lock. The lock is scoped to the database session (connection). +/// +/// Unlike the Postgres advisory-lock API, there is **no RAII drop guard**. +/// You must call [`release`][Self::release] explicitly when you are done with +/// the lock. +/// +/// # Resource Name +/// +/// SQL Server limits resource names to 255 characters. The name is passed as a +/// query parameter, so SQL injection is not possible. +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { +/// use sqlx::mssql::MssqlAdvisoryLock; +/// +/// let lock = MssqlAdvisoryLock::new("my_app_lock"); +/// lock.acquire(conn).await?; +/// +/// // ... do work under the lock ... +/// +/// lock.release(conn).await?; +/// # Ok(()) +/// # } +/// ``` +pub struct MssqlAdvisoryLock { + resource: String, + mode: MssqlAdvisoryLockMode, +} + +impl MssqlAdvisoryLock { + /// Create a new advisory lock with the given resource name and the default + /// [`Exclusive`][MssqlAdvisoryLockMode::Exclusive] mode. + pub fn new(resource: impl Into) -> Self { + Self { + resource: resource.into(), + mode: MssqlAdvisoryLockMode::default(), + } + } + + /// Create a new advisory lock with the given resource name and lock mode. + pub fn with_mode(resource: impl Into, mode: MssqlAdvisoryLockMode) -> Self { + Self { + resource: resource.into(), + mode, + } + } + + /// Returns the resource name of this lock. + pub fn resource(&self) -> &str { + &self.resource + } + + /// Returns the lock mode. + pub fn mode(&self) -> &MssqlAdvisoryLockMode { + &self.mode + } + + /// Acquire the lock, waiting indefinitely until it is available. + /// + /// # Errors + /// + /// Returns an error if `sp_getapplock` returns a negative status code + /// (e.g. lock request was cancelled or a deadlock was detected). + pub async fn acquire(&self, conn: &mut MssqlConnection) -> Result<(), Error> { + let mode = self.mode.as_str(); + let sql = format!( + "DECLARE @r INT; \ + EXEC @r = sp_getapplock @Resource = @p1, @LockMode = '{mode}', \ + @LockOwner = 'Session', @LockTimeout = -1; \ + SELECT @r;" + ); + + let status: i32 = query_scalar(sqlx_core::sql_str::AssertSqlSafe(sql)) + .bind(&self.resource) + .fetch_one(&mut *conn) + .await?; + + if status < 0 { + return Err(Error::Protocol(format!( + "sp_getapplock failed for resource '{}': status {status}{}", + self.resource, + applock_error_message(status), + ))); + } + + Ok(()) + } + + /// Try to acquire the lock without waiting. + /// + /// Returns `Ok(true)` if the lock was acquired, `Ok(false)` if it was not + /// available (timeout). + pub async fn try_acquire(&self, conn: &mut MssqlConnection) -> Result { + let mode = self.mode.as_str(); + let sql = format!( + "DECLARE @r INT; \ + EXEC @r = sp_getapplock @Resource = @p1, @LockMode = '{mode}', \ + @LockOwner = 'Session', @LockTimeout = 0; \ + SELECT @r;" + ); + + let status: i32 = query_scalar(sqlx_core::sql_str::AssertSqlSafe(sql)) + .bind(&self.resource) + .fetch_one(&mut *conn) + .await?; + + if status >= 0 { + // 0 = granted synchronously, 1 = granted after wait + Ok(true) + } else if status == -1 { + // -1 = timed out + Ok(false) + } else { + Err(Error::Protocol(format!( + "sp_getapplock failed for resource '{}': status {status}{}", + self.resource, + applock_error_message(status), + ))) + } + } + + /// Release the lock. + /// + /// Returns `Ok(true)` if the lock was successfully released, `Ok(false)` + /// if the lock was not held by this session. + pub async fn release(&self, conn: &mut MssqlConnection) -> Result { + let sql = "DECLARE @r INT; \ + EXEC @r = sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'; \ + SELECT @r;"; + + let status: i32 = query_scalar(sql) + .bind(&self.resource) + .fetch_one(&mut *conn) + .await?; + + match status { + 0 => Ok(true), + -999 => Ok(false), + _ => Err(Error::Protocol(format!( + "sp_releaseapplock failed for resource '{}': status {status}", + self.resource, + ))), + } + } +} + +fn applock_error_message(status: i32) -> &'static str { + match status { + -1 => " (timed out)", + -2 => " (lock request cancelled)", + -3 => " (deadlock victim)", + -999 => " (parameter validation or other call error)", + _ => "", + } +} diff --git a/sqlx-mssql/src/lib.rs b/sqlx-mssql/src/lib.rs index 3a71eb9e72..d4016d6b34 100644 --- a/sqlx-mssql/src/lib.rs +++ b/sqlx-mssql/src/lib.rs @@ -10,6 +10,8 @@ use crate::executor::Executor; pub(crate) use sqlx_core::driver_prelude::*; +pub mod advisory_lock; + #[cfg(feature = "any")] pub mod any; @@ -35,11 +37,13 @@ mod migrate; #[cfg(feature = "migrate")] mod testing; +pub use advisory_lock::{MssqlAdvisoryLock, MssqlAdvisoryLockMode}; pub use arguments::MssqlArguments; pub use column::MssqlColumn; pub use connection::MssqlConnection; pub use database::Mssql; pub use error::MssqlDatabaseError; +pub use options::ssl_mode::MssqlSslMode; pub use options::MssqlConnectOptions; pub use query_result::MssqlQueryResult; pub use row::MssqlRow; diff --git a/sqlx-mssql/src/options/mod.rs b/sqlx-mssql/src/options/mod.rs index 36f064def6..a69deada4f 100644 --- a/sqlx-mssql/src/options/mod.rs +++ b/sqlx-mssql/src/options/mod.rs @@ -1,7 +1,9 @@ mod connect; mod parse; +pub mod ssl_mode; use crate::connection::LogSettings; +use ssl_mode::MssqlSslMode; /// Options and flags which can be used to configure a MSSQL connection. /// @@ -18,7 +20,8 @@ use crate::connection::LogSettings; /// /// |Parameter|Default|Description| /// |---------|-------|-----------| -/// | `encrypt` | `false` | Whether to use TLS encryption. | +/// | `sslmode` / `ssl_mode` | `preferred` | SSL encryption mode: `disabled`, `login_only`, `preferred`, `required`. | +/// | `encrypt` | (none) | Legacy alias: `true` maps to `required`, `false` to `disabled`. | /// | `trust_server_certificate` | `false` | Whether to trust the server certificate without validation. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. | /// | `app_name` | `sqlx` | The application name sent to the server. | @@ -52,7 +55,7 @@ pub struct MssqlConnectOptions { pub(crate) password: Option, pub(crate) database: Option, pub(crate) instance: Option, - pub(crate) encrypt: bool, + pub(crate) ssl_mode: MssqlSslMode, pub(crate) trust_server_certificate: bool, pub(crate) statement_cache_capacity: usize, pub(crate) app_name: String, @@ -75,7 +78,7 @@ impl MssqlConnectOptions { password: None, database: None, instance: None, - encrypt: false, + ssl_mode: MssqlSslMode::default(), trust_server_certificate: false, statement_cache_capacity: 100, app_name: String::from("sqlx"), @@ -121,9 +124,22 @@ impl MssqlConnectOptions { self } + /// Sets the SSL encryption mode. + pub fn ssl_mode(mut self, mode: MssqlSslMode) -> Self { + self.ssl_mode = mode; + self + } + /// Sets whether to use TLS encryption. + /// + /// This is a legacy convenience method. + /// `true` maps to [`MssqlSslMode::Required`], `false` to [`MssqlSslMode::Disabled`]. pub fn encrypt(mut self, encrypt: bool) -> Self { - self.encrypt = encrypt; + self.ssl_mode = if encrypt { + MssqlSslMode::Required + } else { + MssqlSslMode::Disabled + }; self } @@ -190,11 +206,12 @@ impl MssqlConnectOptions { config.trust_cert(); } - if self.encrypt { - config.encryption(tiberius::EncryptionLevel::Required); - } else { - config.encryption(tiberius::EncryptionLevel::NotSupported); - } + config.encryption(match self.ssl_mode { + MssqlSslMode::Disabled => tiberius::EncryptionLevel::NotSupported, + MssqlSslMode::LoginOnly => tiberius::EncryptionLevel::Off, + MssqlSslMode::Preferred => tiberius::EncryptionLevel::On, + MssqlSslMode::Required => tiberius::EncryptionLevel::Required, + }); config } diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs index 5dedf39661..3ea75af7db 100644 --- a/sqlx-mssql/src/options/parse.rs +++ b/sqlx-mssql/src/options/parse.rs @@ -5,6 +5,7 @@ use sqlx_core::Url; use crate::error::Error; +use super::ssl_mode::MssqlSslMode; use super::MssqlConnectOptions; impl MssqlConnectOptions { @@ -47,6 +48,20 @@ impl MssqlConnectOptions { for (key, value) in url.query_pairs().into_iter() { match &*key { + "sslmode" | "ssl_mode" => { + options = options.ssl_mode(match &*value { + "disabled" => MssqlSslMode::Disabled, + "login_only" => MssqlSslMode::LoginOnly, + "preferred" => MssqlSslMode::Preferred, + "required" => MssqlSslMode::Required, + _ => { + return Err(Error::Configuration( + format!("unknown sslmode value: {value}").into(), + )) + } + }); + } + "encrypt" => { options = options .encrypt(value.parse().map_err(Error::config)?); @@ -92,6 +107,14 @@ impl MssqlConnectOptions { url.set_path(database); } + let sslmode = match self.ssl_mode { + MssqlSslMode::Disabled => "disabled", + MssqlSslMode::LoginOnly => "login_only", + MssqlSslMode::Preferred => "preferred", + MssqlSslMode::Required => "required", + }; + url.query_pairs_mut().append_pair("sslmode", sslmode); + url } } @@ -124,3 +147,60 @@ fn it_parses_url_with_instance() { assert_eq!(opts.instance, Some("SQLEXPRESS".into())); } + +#[test] +fn it_parses_sslmode_disabled() { + let url = "mssql://sa:password@localhost/master?sslmode=disabled"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Disabled)); +} + +#[test] +fn it_parses_sslmode_login_only() { + let url = "mssql://sa:password@localhost/master?ssl_mode=login_only"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::LoginOnly)); +} + +#[test] +fn it_parses_sslmode_preferred() { + let url = "mssql://sa:password@localhost/master?sslmode=preferred"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Preferred)); +} + +#[test] +fn it_parses_sslmode_required() { + let url = "mssql://sa:password@localhost/master?sslmode=required"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Required)); +} + +#[test] +fn it_parses_encrypt_true_as_required() { + let url = "mssql://sa:password@localhost/master?encrypt=true"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Required)); +} + +#[test] +fn it_parses_encrypt_false_as_disabled() { + let url = "mssql://sa:password@localhost/master?encrypt=false"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Disabled)); +} + +#[test] +fn it_rejects_invalid_sslmode() { + let url = "mssql://sa:password@localhost/master?sslmode=bogus"; + assert!(MssqlConnectOptions::from_str(url).is_err()); +} + +#[test] +fn it_roundtrips_sslmode_in_url() { + let url = "mssql://sa:password@localhost/master?sslmode=login_only"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + let built = opts.build_url(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert!(matches!(opts2.ssl_mode, MssqlSslMode::LoginOnly)); +} diff --git a/sqlx-mssql/src/options/ssl_mode.rs b/sqlx-mssql/src/options/ssl_mode.rs new file mode 100644 index 0000000000..09519bdcb3 --- /dev/null +++ b/sqlx-mssql/src/options/ssl_mode.rs @@ -0,0 +1,18 @@ +/// The SSL mode to use when connecting to MSSQL. +/// +/// Maps to the tiberius `EncryptionLevel` variants. +#[derive(Debug, Clone, Copy, Default)] +pub enum MssqlSslMode { + /// No encryption at all (`EncryptionLevel::NotSupported`). + Disabled, + + /// Only encrypt the login packet (`EncryptionLevel::Off`). + LoginOnly, + + /// Encrypt if the server supports it (`EncryptionLevel::On`). + #[default] + Preferred, + + /// Always encrypt; fail if the server doesn't support it (`EncryptionLevel::Required`). + Required, +} diff --git a/sqlx-mssql/src/types/float.rs b/sqlx-mssql/src/types/float.rs index d0f88bac59..66a9a7bb4b 100644 --- a/sqlx-mssql/src/types/float.rs +++ b/sqlx-mssql/src/types/float.rs @@ -7,7 +7,7 @@ use crate::value::MssqlData; use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn real_compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.name.as_str(), "REAL" | "FLOAT") + matches!(ty.name.as_str(), "REAL" | "FLOAT" | "MONEY" | "SMALLMONEY") } impl Type for f32 { diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs index b830b018aa..a0ca56c925 100644 --- a/sqlx-mssql/src/types/mod.rs +++ b/sqlx-mssql/src/types/mod.rs @@ -10,8 +10,8 @@ //! | `i16` | SMALLINT | //! | `i32` | INT | //! | `i64` | BIGINT | -//! | `f32` | REAL | -//! | `f64` | FLOAT | +//! | `f32` | REAL, FLOAT | +//! | `f64` | REAL, FLOAT, MONEY, SMALLMONEY | //! | `&str`, [`String`] | NVARCHAR | //! | `&[u8]`, `Vec` | VARBINARY | //! diff --git a/tests/mssql/advisory-lock.rs b/tests/mssql/advisory-lock.rs new file mode 100644 index 0000000000..1277b002fc --- /dev/null +++ b/tests/mssql/advisory-lock.rs @@ -0,0 +1,85 @@ +use sqlx::mssql::{Mssql, MssqlAdvisoryLock, MssqlAdvisoryLockMode}; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_acquires_and_releases() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_acquire_release"); + + lock.acquire(&mut conn).await?; + let released = lock.release(&mut conn).await?; + assert!(released, "lock should have been held and released"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_try_acquire_succeeds_when_free() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_try_free"); + + let acquired = lock.try_acquire(&mut conn).await?; + assert!(acquired, "lock should be free and acquired"); + + lock.release(&mut conn).await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_try_acquire_fails_when_held() -> anyhow::Result<()> { + let mut conn1 = new::().await?; + let mut conn2 = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_try_held"); + + // Conn1 holds the exclusive lock + lock.acquire(&mut conn1).await?; + + // Conn2 should fail to acquire it immediately + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(!acquired, "lock should not be available"); + + // Release from conn1 + lock.release(&mut conn1).await?; + + // Now conn2 should be able to acquire + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(acquired, "lock should now be free"); + + lock.release(&mut conn2).await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_supports_shared_locks() -> anyhow::Result<()> { + let mut conn1 = new::().await?; + let mut conn2 = new::().await?; + + let lock = MssqlAdvisoryLock::with_mode("sqlx_test_shared", MssqlAdvisoryLockMode::Shared); + + // Both connections should be able to acquire a shared lock + lock.acquire(&mut conn1).await?; + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(acquired, "shared lock should be acquirable by second connection"); + + lock.release(&mut conn1).await?; + lock.release(&mut conn2).await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_release_returns_false_when_not_held() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_not_held"); + + let released = lock.release(&mut conn).await?; + assert!(!released, "release should return false when lock is not held"); + + Ok(()) +} diff --git a/tests/mssql/types.rs b/tests/mssql/types.rs index 097a0dbaa3..de5d840dd6 100644 --- a/tests/mssql/types.rs +++ b/tests/mssql/types.rs @@ -36,6 +36,20 @@ test_type!(f64( "CAST(939399419.1225182 AS FLOAT)" == 939399419.1225182_f64 )); +test_type!(f64_money( + Mssql, + "CAST(922337203685477.5807 AS MONEY)" == 922337203685477.5807_f64, + "CAST(0 AS MONEY)" == 0.0_f64, + "CAST(-1234.5678 AS MONEY)" == -1234.5678_f64, +)); + +test_type!(f64_smallmoney( + Mssql, + "CAST(214748.3647 AS SMALLMONEY)" == 214748.3647_f64, + "CAST(0 AS SMALLMONEY)" == 0.0_f64, + "CAST(-1234.5678 AS SMALLMONEY)" == -1234.5678_f64, +)); + test_type!(str_nvarchar(Mssql, "CAST('this is foo' as NVARCHAR)" == "this is foo", )); @@ -154,6 +168,12 @@ test_type!(rust_decimal(Mssql, "CAST('-1.23' AS DECIMAL(10,2))" == sqlx::types::Decimal::new(-123, 2), )); +#[cfg(feature = "rust_decimal")] +test_type!(rust_decimal_money(Mssql, + "CAST(1234.5678 AS MONEY)" == sqlx::types::Decimal::new(12345678, 4), + "CAST(0 AS MONEY)" == sqlx::types::Decimal::ZERO, +)); + #[cfg(feature = "bigdecimal")] test_type!(bigdecimal(Mssql, "CAST('0' AS DECIMAL(10,2))" == "0.00".parse::().unwrap(), @@ -161,6 +181,12 @@ test_type!(bigdecimal(Mssql, "CAST('-1.23' AS DECIMAL(10,2))" == "-1.23".parse::().unwrap(), )); +#[cfg(feature = "bigdecimal")] +test_type!(bigdecimal_money(Mssql, + "CAST(1234.5678 AS MONEY)" == "1234.5678".parse::().unwrap(), + "CAST(0 AS MONEY)" == "0".parse::().unwrap(), +)); + #[cfg(feature = "json")] mod json_tests { use sqlx::mssql::Mssql; From 119107dd6f1b19db5416e0c1e4f18c660582ec9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 03:48:59 -0500 Subject: [PATCH 07/33] feat: add isolation levels, connection options, and bulk insert to sqlx-mssql Add MssqlIsolationLevel enum with begin_with_isolation() for typed transaction isolation control. Surface tiberius application_intent (read-only routing) and trust_cert_ca options in MssqlConnectOptions. Wrap tiberius BulkLoadRequest as MssqlBulkInsert for high-performance data loading via the TDS INSERT BULK protocol. Author: Pablo Carrera --- Cargo.toml | 10 ++++ sqlx-mssql/src/bulk_insert.rs | 51 ++++++++++++++++++ sqlx-mssql/src/connection/mod.rs | 72 +++++++++++++++++++++++-- sqlx-mssql/src/isolation_level.rs | 55 +++++++++++++++++++ sqlx-mssql/src/lib.rs | 7 +++ sqlx-mssql/src/options/mod.rs | 40 +++++++++++++- sqlx-mssql/src/options/parse.rs | 88 +++++++++++++++++++++++++++++++ tests/mssql/bulk-insert.rs | 79 +++++++++++++++++++++++++++ tests/mssql/isolation-level.rs | 62 ++++++++++++++++++++++ 9 files changed, 460 insertions(+), 4 deletions(-) create mode 100644 sqlx-mssql/src/bulk_insert.rs create mode 100644 sqlx-mssql/src/isolation_level.rs create mode 100644 tests/mssql/bulk-insert.rs create mode 100644 tests/mssql/isolation-level.rs diff --git a/Cargo.toml b/Cargo.toml index 6b488a9e27..a0de7fe7c5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -502,3 +502,13 @@ required-features = ["mssql", "macros", "migrate"] name = "mssql-advisory-lock" path = "tests/mssql/advisory-lock.rs" required-features = ["mssql"] + +[[test]] +name = "mssql-isolation-level" +path = "tests/mssql/isolation-level.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-bulk-insert" +path = "tests/mssql/bulk-insert.rs" +required-features = ["mssql"] diff --git a/sqlx-mssql/src/bulk_insert.rs b/sqlx-mssql/src/bulk_insert.rs new file mode 100644 index 0000000000..dac9a56943 --- /dev/null +++ b/sqlx-mssql/src/bulk_insert.rs @@ -0,0 +1,51 @@ +use crate::error::{tiberius_err, Error}; +use crate::io::SocketAdapter; +use sqlx_core::net::Socket; + +/// A bulk insert operation for high-performance data loading into SQL Server. +/// +/// Wraps the tiberius [`BulkLoadRequest`](tiberius::BulkLoadRequest) to provide +/// efficient bulk data insertion using the TDS `INSERT BULK` protocol. +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { +/// use sqlx::mssql::IntoRow; +/// +/// let mut bulk = conn.bulk_insert("#my_temp_table").await?; +/// bulk.send(("hello", 42i32).into_row()).await?; +/// bulk.send(("world", 99i32).into_row()).await?; +/// let total = bulk.finalize().await?; +/// assert_eq!(total, 2); +/// # Ok(()) +/// # } +/// ``` +pub struct MssqlBulkInsert<'c> { + inner: tiberius::BulkLoadRequest<'c, SocketAdapter>>, +} + +impl<'c> MssqlBulkInsert<'c> { + pub(crate) fn new( + inner: tiberius::BulkLoadRequest<'c, SocketAdapter>>, + ) -> Self { + Self { inner } + } + + /// Send a single row to the bulk insert operation. + /// + /// The row is a [`tiberius::TokenRow`] — use [`tiberius::IntoRow::into_row()`] + /// to convert tuples of up to 10 elements into a `TokenRow`. + pub async fn send(&mut self, row: tiberius::TokenRow<'c>) -> Result<(), Error> { + self.inner.send(row).await.map_err(tiberius_err) + } + + /// Finalize the bulk insert, flushing all buffered data to the server. + /// + /// Returns the total number of rows inserted. This **must** be called + /// after all rows have been sent — otherwise data will be lost. + pub async fn finalize(self) -> Result { + let result = self.inner.finalize().await.map_err(tiberius_err)?; + Ok(result.total()) + } +} diff --git a/sqlx-mssql/src/connection/mod.rs b/sqlx-mssql/src/connection/mod.rs index 2e4f06bc79..77a2656dc5 100644 --- a/sqlx-mssql/src/connection/mod.rs +++ b/sqlx-mssql/src/connection/mod.rs @@ -2,14 +2,16 @@ use std::fmt::{self, Debug, Formatter}; pub(crate) use sqlx_core::connection::*; use sqlx_core::net::Socket; -use sqlx_core::sql_str::SqlSafeStr; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; +use crate::bulk_insert::MssqlBulkInsert; use crate::common::StatementCache; -use crate::error::Error; +use crate::error::{tiberius_err, Error}; use crate::executor::Executor; use crate::io::SocketAdapter; +use crate::isolation_level::MssqlIsolationLevel; use crate::statement::MssqlStatementMetadata; -use crate::transaction::Transaction; +use crate::transaction::{resolve_pending_rollback, Transaction}; use crate::{Mssql, MssqlConnectOptions}; mod establish; @@ -94,3 +96,67 @@ impl Connection for MssqlConnection { // No-op for MSSQL } } + +impl MssqlConnection { + /// Begin a transaction with a specific isolation level. + /// + /// SQL Server requires `SET TRANSACTION ISOLATION LEVEL` to be issued + /// **before** `BEGIN TRANSACTION`. This method generates: + /// + /// ```sql + /// SET TRANSACTION ISOLATION LEVEL ; BEGIN TRANSACTION + /// ``` + /// + /// # Example + /// + /// ```rust,no_run + /// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { + /// use sqlx::mssql::MssqlIsolationLevel; + /// + /// let mut tx = conn.begin_with_isolation(MssqlIsolationLevel::Snapshot).await?; + /// // ... use tx ... + /// tx.commit().await?; + /// # Ok(()) + /// # } + /// ``` + pub fn begin_with_isolation( + &mut self, + level: MssqlIsolationLevel, + ) -> impl std::future::Future, Error>> + Send + '_ { + let sql = AssertSqlSafe(format!( + "SET TRANSACTION ISOLATION LEVEL {level}; BEGIN TRANSACTION" + )); + Transaction::begin(self, Some(sql.into_sql_str())) + } + + /// Start a bulk insert operation for high-performance data loading. + /// + /// The table must already exist. Tiberius executes `SELECT TOP 0 * FROM ` + /// to discover column metadata, then uses the TDS `INSERT BULK` protocol. + /// + /// # Example + /// + /// ```rust,no_run + /// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { + /// use sqlx::mssql::IntoRow; + /// + /// let mut bulk = conn.bulk_insert("#temp").await?; + /// bulk.send(("hello", 42i32).into_row()).await?; + /// let total = bulk.finalize().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn bulk_insert<'c>( + &'c mut self, + table: &'c str, + ) -> Result, Error> { + resolve_pending_rollback(self).await?; + let req = self + .inner + .client + .bulk_insert(table) + .await + .map_err(tiberius_err)?; + Ok(MssqlBulkInsert::new(req)) + } +} diff --git a/sqlx-mssql/src/isolation_level.rs b/sqlx-mssql/src/isolation_level.rs new file mode 100644 index 0000000000..1409f084de --- /dev/null +++ b/sqlx-mssql/src/isolation_level.rs @@ -0,0 +1,55 @@ +use std::fmt; + +/// SQL Server transaction isolation levels. +/// +/// SQL Server supports five isolation levels. The `SET TRANSACTION ISOLATION LEVEL` +/// statement must be issued **before** `BEGIN TRANSACTION`, unlike PostgreSQL which +/// accepts it inside the `BEGIN` block. +/// +/// See [SQL Server documentation](https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql) +/// for details on each level. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum MssqlIsolationLevel { + /// Allows dirty reads. Statements can read rows modified by other + /// transactions that have not yet been committed. + ReadUncommitted, + + /// The default isolation level. Statements cannot read data modified + /// by other transactions that have not been committed. + #[default] + ReadCommitted, + + /// Statements cannot read data modified by other transactions that + /// have not been committed, and no other transactions can modify + /// data read by the current transaction until it completes. + RepeatableRead, + + /// Uses row versioning to provide transaction-level read consistency. + /// Requires the `ALLOW_SNAPSHOT_ISOLATION` database option to be `ON`. + Snapshot, + + /// Statements cannot read data modified by other transactions that + /// have not been committed. No other transactions can modify data + /// read by the current transaction, and no other transactions can + /// insert new rows matching the current transaction's search conditions. + Serializable, +} + +impl MssqlIsolationLevel { + /// Returns the SQL Server syntax for this isolation level. + pub fn as_str(&self) -> &'static str { + match self { + Self::ReadUncommitted => "READ UNCOMMITTED", + Self::ReadCommitted => "READ COMMITTED", + Self::RepeatableRead => "REPEATABLE READ", + Self::Snapshot => "SNAPSHOT", + Self::Serializable => "SERIALIZABLE", + } + } +} + +impl fmt::Display for MssqlIsolationLevel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} diff --git a/sqlx-mssql/src/lib.rs b/sqlx-mssql/src/lib.rs index d4016d6b34..75f7708ec8 100644 --- a/sqlx-mssql/src/lib.rs +++ b/sqlx-mssql/src/lib.rs @@ -11,6 +11,8 @@ use crate::executor::Executor; pub(crate) use sqlx_core::driver_prelude::*; pub mod advisory_lock; +mod bulk_insert; +mod isolation_level; #[cfg(feature = "any")] pub mod any; @@ -39,10 +41,12 @@ mod testing; pub use advisory_lock::{MssqlAdvisoryLock, MssqlAdvisoryLockMode}; pub use arguments::MssqlArguments; +pub use bulk_insert::MssqlBulkInsert; pub use column::MssqlColumn; pub use connection::MssqlConnection; pub use database::Mssql; pub use error::MssqlDatabaseError; +pub use isolation_level::MssqlIsolationLevel; pub use options::ssl_mode::MssqlSslMode; pub use options::MssqlConnectOptions; pub use query_result::MssqlQueryResult; @@ -52,6 +56,9 @@ pub use transaction::MssqlTransactionManager; pub use type_info::MssqlTypeInfo; pub use value::{MssqlValue, MssqlValueRef}; +// Re-export tiberius types needed for bulk insert row construction. +pub use tiberius::{IntoRow, IntoSql, TokenRow}; + /// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL. pub type MssqlPool = crate::pool::Pool; diff --git a/sqlx-mssql/src/options/mod.rs b/sqlx-mssql/src/options/mod.rs index a69deada4f..16303c18a7 100644 --- a/sqlx-mssql/src/options/mod.rs +++ b/sqlx-mssql/src/options/mod.rs @@ -23,6 +23,8 @@ use ssl_mode::MssqlSslMode; /// | `sslmode` / `ssl_mode` | `preferred` | SSL encryption mode: `disabled`, `login_only`, `preferred`, `required`. | /// | `encrypt` | (none) | Legacy alias: `true` maps to `required`, `false` to `disabled`. | /// | `trust_server_certificate` | `false` | Whether to trust the server certificate without validation. | +/// | `trust_server_certificate_ca` | (none) | Path to a CA certificate file to validate the server certificate against. Mutually exclusive with `trust_server_certificate`. | +/// | `application_intent` | `read_write` | Application intent: `read_write` or `read_only`. `read_only` routes to Always On read replicas. | /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. | /// | `app_name` | `sqlx` | The application name sent to the server. | /// | `instance` | `None` | The SQL Server instance name. | @@ -57,6 +59,8 @@ pub struct MssqlConnectOptions { pub(crate) instance: Option, pub(crate) ssl_mode: MssqlSslMode, pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) application_intent_read_only: bool, pub(crate) statement_cache_capacity: usize, pub(crate) app_name: String, pub(crate) log_settings: LogSettings, @@ -80,6 +84,8 @@ impl MssqlConnectOptions { instance: None, ssl_mode: MssqlSslMode::default(), trust_server_certificate: false, + trust_server_certificate_ca: None, + application_intent_read_only: false, statement_cache_capacity: 100, app_name: String::from("sqlx"), log_settings: Default::default(), @@ -149,6 +155,31 @@ impl MssqlConnectOptions { self } + /// Sets a CA certificate file path to validate the server certificate against. + /// + /// Accepts `.pem`, `.crt`, or `.der` certificate files. + /// + /// This is mutually exclusive with [`trust_server_certificate`](Self::trust_server_certificate). + /// When a CA path is set, `trust_server_certificate` is ignored. + pub fn trust_server_certificate_ca(mut self, path: &str) -> Self { + self.trust_server_certificate_ca = Some(path.to_owned()); + self + } + + /// Sets the application intent to read-only. + /// + /// When `true`, sets `ApplicationIntent=ReadOnly` in the TDS login packet, + /// which routes connections to Always On Availability Group read replicas. + pub fn application_intent_read_only(mut self, read_only: bool) -> Self { + self.application_intent_read_only = read_only; + self + } + + /// Get whether the application intent is set to read-only. + pub fn get_application_intent_read_only(&self) -> bool { + self.application_intent_read_only + } + /// Sets the capacity of the connection's statement cache. pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { self.statement_cache_capacity = capacity; @@ -202,10 +233,17 @@ impl MssqlConnectOptions { self.password.as_deref().unwrap_or(""), )); - if self.trust_server_certificate { + if let Some(ca_path) = &self.trust_server_certificate_ca { + // trust_cert_ca and trust_cert are mutually exclusive in tiberius + config.trust_cert_ca(ca_path); + } else if self.trust_server_certificate { config.trust_cert(); } + if self.application_intent_read_only { + config.readonly(true); + } + config.encryption(match self.ssl_mode { MssqlSslMode::Disabled => tiberius::EncryptionLevel::NotSupported, MssqlSslMode::LoginOnly => tiberius::EncryptionLevel::Off, diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs index 3ea75af7db..6488069718 100644 --- a/sqlx-mssql/src/options/parse.rs +++ b/sqlx-mssql/src/options/parse.rs @@ -85,6 +85,26 @@ impl MssqlConnectOptions { .statement_cache_capacity(value.parse().map_err(Error::config)?); } + "application_intent" | "applicationIntent" => { + match &*value { + "read_only" | "ReadOnly" => { + options = options.application_intent_read_only(true); + } + "read_write" | "ReadWrite" => { + options = options.application_intent_read_only(false); + } + _ => { + return Err(Error::Configuration( + format!("unknown application_intent value: {value}").into(), + )) + } + } + } + + "trust_server_certificate_ca" | "trustServerCertificateCa" => { + options = options.trust_server_certificate_ca(&value); + } + _ => {} } } @@ -115,6 +135,16 @@ impl MssqlConnectOptions { }; url.query_pairs_mut().append_pair("sslmode", sslmode); + if self.application_intent_read_only { + url.query_pairs_mut() + .append_pair("application_intent", "read_only"); + } + + if let Some(ca_path) = &self.trust_server_certificate_ca { + url.query_pairs_mut() + .append_pair("trust_server_certificate_ca", ca_path); + } + url } } @@ -204,3 +234,61 @@ fn it_roundtrips_sslmode_in_url() { let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); assert!(matches!(opts2.ssl_mode, MssqlSslMode::LoginOnly)); } + +#[test] +fn it_parses_application_intent_read_only() { + let url = "mssql://sa:password@localhost/master?application_intent=read_only"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(opts.application_intent_read_only); +} + +#[test] +fn it_parses_application_intent_read_write() { + let url = "mssql://sa:password@localhost/master?application_intent=read_write"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(!opts.application_intent_read_only); +} + +#[test] +fn it_parses_application_intent_camel_case() { + let url = "mssql://sa:password@localhost/master?applicationIntent=ReadOnly"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(opts.application_intent_read_only); +} + +#[test] +fn it_rejects_invalid_application_intent() { + let url = "mssql://sa:password@localhost/master?application_intent=bogus"; + assert!(MssqlConnectOptions::from_str(url).is_err()); +} + +#[test] +fn it_parses_trust_server_certificate_ca() { + let url = "mssql://sa:password@localhost/master?trust_server_certificate_ca=/path/to/ca.pem"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert_eq!(opts.trust_server_certificate_ca, Some("/path/to/ca.pem".into())); +} + +#[test] +fn it_roundtrips_application_intent_in_url() { + let opts = MssqlConnectOptions::new() + .host("localhost") + .username("sa") + .password("password") + .application_intent_read_only(true); + let built = opts.build_url(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert!(opts2.application_intent_read_only); +} + +#[test] +fn it_roundtrips_trust_cert_ca_in_url() { + let opts = MssqlConnectOptions::new() + .host("localhost") + .username("sa") + .password("password") + .trust_server_certificate_ca("/etc/ssl/ca.pem"); + let built = opts.build_url(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert_eq!(opts2.trust_server_certificate_ca, Some("/etc/ssl/ca.pem".into())); +} diff --git a/tests/mssql/bulk-insert.rs b/tests/mssql/bulk-insert.rs new file mode 100644 index 0000000000..b233621570 --- /dev/null +++ b/tests/mssql/bulk-insert.rs @@ -0,0 +1,79 @@ +use sqlx::mssql::{IntoRow, Mssql}; +use sqlx::Row; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_bulk_inserts_rows() -> anyhow::Result<()> { + let mut conn = new::().await?; + + sqlx::query( + "CREATE TABLE #bulk_test (name NVARCHAR(50) NOT NULL, value INT NOT NULL)" + ) + .execute(&mut conn) + .await?; + + let mut bulk = conn.bulk_insert("#bulk_test").await?; + bulk.send(("hello", 1i32).into_row()).await?; + bulk.send(("world", 2i32).into_row()).await?; + bulk.send(("foo", 3i32).into_row()).await?; + let total = bulk.finalize().await?; + assert_eq!(total, 3); + + let rows = sqlx::query("SELECT name, value FROM #bulk_test ORDER BY value") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 3); + assert_eq!(rows[0].get::("name"), "hello"); + assert_eq!(rows[0].get::("value"), 1); + assert_eq!(rows[1].get::("name"), "world"); + assert_eq!(rows[1].get::("value"), 2); + assert_eq!(rows[2].get::("name"), "foo"); + assert_eq!(rows[2].get::("value"), 3); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_bulk_inserts_empty() -> anyhow::Result<()> { + let mut conn = new::().await?; + + sqlx::query("CREATE TABLE #bulk_empty (id INT NOT NULL)") + .execute(&mut conn) + .await?; + + let bulk = conn.bulk_insert("#bulk_empty").await?; + let total = bulk.finalize().await?; + assert_eq!(total, 0); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_bulk_inserts_various_types() -> anyhow::Result<()> { + let mut conn = new::().await?; + + sqlx::query( + "CREATE TABLE #bulk_types (id INT NOT NULL, label NVARCHAR(100) NOT NULL, score FLOAT NOT NULL)" + ) + .execute(&mut conn) + .await?; + + let mut bulk = conn.bulk_insert("#bulk_types").await?; + bulk.send((1i32, "alpha", 1.5f64).into_row()).await?; + bulk.send((2i32, "beta", 2.7f64).into_row()).await?; + let total = bulk.finalize().await?; + assert_eq!(total, 2); + + let rows = sqlx::query("SELECT id, label, score FROM #bulk_types ORDER BY id") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::("id"), 1); + assert_eq!(rows[0].get::("label"), "alpha"); + assert_eq!(rows[1].get::("id"), 2); + assert_eq!(rows[1].get::("label"), "beta"); + + Ok(()) +} diff --git a/tests/mssql/isolation-level.rs b/tests/mssql/isolation-level.rs new file mode 100644 index 0000000000..670b1f1e52 --- /dev/null +++ b/tests/mssql/isolation-level.rs @@ -0,0 +1,62 @@ +use sqlx::mssql::{Mssql, MssqlIsolationLevel}; +use sqlx::Row; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_begins_with_read_uncommitted() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::ReadUncommitted) + .await?; + + let row = sqlx::query("SELECT 1 AS val") + .fetch_one(&mut *tx) + .await?; + let val: i32 = row.get("val"); + assert_eq!(val, 1); + + tx.commit().await?; + Ok(()) +} + +#[sqlx_macros::test] +async fn it_begins_with_snapshot() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Enable snapshot isolation on the database first + sqlx::query("ALTER DATABASE CURRENT SET ALLOW_SNAPSHOT_ISOLATION ON") + .execute(&mut conn) + .await?; + + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Snapshot) + .await?; + + let row = sqlx::query("SELECT 1 AS val") + .fetch_one(&mut *tx) + .await?; + let val: i32 = row.get("val"); + assert_eq!(val, 1); + + tx.commit().await?; + Ok(()) +} + +#[sqlx_macros::test] +async fn it_begins_with_serializable() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Serializable) + .await?; + + let row = sqlx::query("SELECT 1 AS val") + .fetch_one(&mut *tx) + .await?; + let val: i32 = row.get("val"); + assert_eq!(val, 1); + + tx.commit().await?; + Ok(()) +} From e0e67f4b8c4e9df02f3115866b1f8b905cd67645 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 10:26:18 -0500 Subject: [PATCH 08/33] feat: add auth methods, QueryBuilder placeholder fix, and XML type to sqlx-mssql - Add Windows, Integrated, and AAD token authentication support with cfg-gated features matching tiberius (winauth, integrated-auth-gssapi) - Override format_placeholder to produce @p1, @p2, ... instead of ? so QueryBuilder works correctly with MSSQL parameterized queries - Add MssqlXml newtype wrapper for SQL Server XML columns with Type, Encode, and Decode implementations - Add URL parsing for auth and token query parameters with roundtrip tests Author: Pablo Carrera --- Cargo.lock | 326 ++++++++++++++++++++++++++++---- Cargo.toml | 5 + sqlx-mssql/Cargo.toml | 4 + sqlx-mssql/src/arguments.rs | 8 + sqlx-mssql/src/lib.rs | 1 + sqlx-mssql/src/options/mod.rs | 73 ++++++- sqlx-mssql/src/options/parse.rs | 75 ++++++++ sqlx-mssql/src/types/mod.rs | 1 + sqlx-mssql/src/types/xml.rs | 81 ++++++++ tests/mssql/query_builder.rs | 108 +++++++++++ 10 files changed, 639 insertions(+), 43 deletions(-) create mode 100644 sqlx-mssql/src/types/xml.rs create mode 100644 tests/mssql/query_builder.rs diff --git a/Cargo.lock b/Cargo.lock index ddbad9d857..9ac87ff8b3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,6 +64,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.19" @@ -401,6 +410,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -423,7 +443,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" dependencies = [ - "bindgen", + "bindgen 0.69.5", "cc", "cmake", "dunce", @@ -500,7 +520,7 @@ dependencies = [ "getrandom 0.2.16", "instant", "pin-project-lite", - "rand", + "rand 0.8.5", "tokio", ] @@ -561,6 +581,29 @@ dependencies = [ "num-traits", ] +[[package]] +name = "bindgen" +version = "0.59.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bd2a9a458e8f4304c52c43ebb0cfbd520289f8379a52e329a38afda99bf8eb8" +dependencies = [ + "bitflags 1.3.2", + "cexpr", + "clang-sys", + "clap 2.34.0", + "env_logger 0.9.3", + "lazy_static", + "lazycell", + "log", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "which", +] + [[package]] name = "bindgen" version = "0.69.5" @@ -860,6 +903,21 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "ansi_term", + "atty", + "bitflags 1.3.2", + "strsim 0.8.0", + "textwrap", + "unicode-width 0.1.14", + "vec_map", +] + [[package]] name = "clap" version = "4.5.40" @@ -879,7 +937,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex", - "strsim", + "strsim 0.11.1", "terminal_size", ] @@ -889,7 +947,7 @@ version = "4.5.54" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aad5b1b4de04fead402672b48897030eec1f3bfe1550776322f59f6d6e6a5677" dependencies = [ - "clap", + "clap 4.5.40", ] [[package]] @@ -1058,7 +1116,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap", + "clap 4.5.40", "criterion-plot", "futures", "is-terminal", @@ -1131,7 +1189,7 @@ dependencies = [ "crossterm_winapi", "libc", "mio 0.8.11", - "parking_lot", + "parking_lot 0.12.4", "signal-hook", "signal-hook-mio", "winapi", @@ -1182,7 +1240,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim", + "strsim 0.11.1", "syn 2.0.104", ] @@ -1330,6 +1388,19 @@ dependencies = [ "regex", ] +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + [[package]] name = "env_logger" version = "0.11.8" @@ -1554,7 +1625,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot", + "parking_lot 0.12.4", ] [[package]] @@ -1626,6 +1697,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.2.16" @@ -1735,6 +1817,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -1814,6 +1905,12 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "humantime" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135b12329e5e3ce057a9f972339ea52bc954fe1e9358ef27f95e89716fbc5424" + [[package]] name = "hyper" version = "0.14.32" @@ -2055,7 +2152,7 @@ version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ - "hermit-abi", + "hermit-abi 0.5.2", "libc", "windows-sys 0.59.0", ] @@ -2173,6 +2270,28 @@ version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" +[[package]] +name = "libgssapi" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "724dbcd1f871da9c67983537a47ac510c278656f6392418ad67c7a52720e54b2" +dependencies = [ + "bitflags 1.3.2", + "bytes", + "lazy_static", + "libgssapi-sys", + "parking_lot 0.11.2", +] + +[[package]] +name = "libgssapi-sys" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dd7d65e409c889f6c9d81ff079371d0d8fd88d7dca702ff187ef96fb0450fb7" +dependencies = [ + "bindgen 0.59.2", +] + [[package]] name = "libloading" version = "0.8.8" @@ -2206,7 +2325,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ - "bindgen", + "bindgen 0.69.5", "cc", "pkg-config", "vcpkg", @@ -2284,6 +2403,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e6bcd6433cff03a4bfc3d9834d504467db1f1cf6d0ea765d37d330249ed629d" + [[package]] name = "memchr" version = "2.7.5" @@ -2448,7 +2573,7 @@ dependencies = [ "num-integer", "num-iter", "num-traits", - "rand", + "rand 0.8.5", "smallvec", "zeroize", ] @@ -2588,6 +2713,17 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.4" @@ -2595,7 +2731,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70d58bf43669b5795d1576d0641cfb6fbb2057bf629506267a92807158584a13" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.11", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -2618,7 +2768,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" dependencies = [ "base64ct", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -2629,7 +2779,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" dependencies = [ "base64ct", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -2639,6 +2789,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -2760,7 +2916,7 @@ checksum = "b53a684391ad002dd6a596ceb6c74fd004fdce75f4be2e3f615068abbea5fd50" dependencies = [ "cfg-if", "concurrent-queue", - "hermit-abi", + "hermit-abi 0.5.2", "pin-project-lite", "rustix 1.0.7", "tracing", @@ -2946,6 +3102,19 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + [[package]] name = "rand" version = "0.8.5" @@ -2953,8 +3122,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", ] [[package]] @@ -2964,7 +3143,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", ] [[package]] @@ -2976,13 +3164,22 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -3026,6 +3223,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.13" @@ -3138,7 +3344,7 @@ dependencies = [ "num-traits", "pkcs1", "pkcs8", - "rand_core", + "rand_core 0.6.4", "signature", "spki", "subtle", @@ -3155,7 +3361,7 @@ dependencies = [ "borsh", "bytes", "num-traits", - "rand", + "rand 0.8.5", "rkyv", "serde", "serde_json", @@ -3495,7 +3701,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -3573,12 +3779,12 @@ dependencies = [ "async-std", "criterion", "dotenvy", - "env_logger", + "env_logger 0.11.8", "futures-util", "hex", "libsqlite3-sys", "paste", - "rand", + "rand 0.8.5", "rand_xoshiro", "serde", "serde_json", @@ -3605,7 +3811,7 @@ dependencies = [ "backoff", "cargo_metadata", "chrono", - "clap", + "clap 4.5.40", "clap_complete", "console", "dialoguer", @@ -3679,7 +3885,7 @@ name = "sqlx-example-mysql-todos" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.5.40", "sqlx", "tokio", ] @@ -3692,7 +3898,7 @@ dependencies = [ "argon2 0.4.1", "axum", "dotenvy", - "rand", + "rand 0.8.5", "regex", "serde", "serde_json", @@ -3733,7 +3939,7 @@ name = "sqlx-example-postgres-json" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.5.40", "dotenvy", "serde", "serde_json", @@ -3756,7 +3962,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", - "clap", + "clap 4.5.40", "dotenvy", "mockall", "sqlx", @@ -3769,7 +3975,7 @@ version = "0.9.0-alpha.1" dependencies = [ "color-eyre", "dotenvy", - "rand", + "rand 0.8.5", "rust_decimal", "sqlx", "sqlx-example-postgres-multi-database-accounts", @@ -3784,7 +3990,7 @@ version = "0.1.0" dependencies = [ "argon2 0.5.3", "password-hash 0.5.0", - "rand", + "rand 0.8.5", "serde", "sqlx", "thiserror 1.0.69", @@ -3810,7 +4016,7 @@ version = "0.9.0-alpha.1" dependencies = [ "color-eyre", "dotenvy", - "rand", + "rand 0.8.5", "rust_decimal", "sqlx", "sqlx-example-postgres-multi-tenant-accounts", @@ -3825,7 +4031,7 @@ version = "0.1.0" dependencies = [ "argon2 0.5.3", "password-hash 0.5.0", - "rand", + "rand 0.8.5", "serde", "sqlx", "thiserror 1.0.69", @@ -3885,7 +4091,7 @@ name = "sqlx-example-postgres-todos" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.5.40", "dotenvy", "sqlx", "tokio", @@ -3913,7 +4119,7 @@ name = "sqlx-example-sqlite-todos" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.5.40", "sqlx", "tokio", ] @@ -4014,7 +4220,7 @@ dependencies = [ "md-5", "memchr", "percent-encoding", - "rand", + "rand 0.8.5", "rsa", "rust_decimal", "serde", @@ -4059,7 +4265,7 @@ dependencies = [ "md-5", "memchr", "num-bigint", - "rand", + "rand 0.8.5", "rust_decimal", "serde", "serde_json", @@ -4108,7 +4314,7 @@ version = "0.1.0" dependencies = [ "anyhow", "dotenvy", - "env_logger", + "env_logger 0.11.8", "sqlx", ] @@ -4145,6 +4351,12 @@ dependencies = [ "unicode-properties", ] +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + [[package]] name = "strsim" version = "0.11.1" @@ -4268,6 +4480,15 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width 0.1.14", +] + [[package]] name = "thiserror" version = "1.0.69" @@ -4333,6 +4554,7 @@ dependencies = [ "encoding_rs", "enumflags2", "futures-util", + "libgssapi", "num-traits", "once_cell", "pin-project-lite", @@ -4342,6 +4564,7 @@ dependencies = [ "time", "tracing", "uuid", + "winauth", ] [[package]] @@ -4421,7 +4644,7 @@ dependencies = [ "io-uring", "libc", "mio 1.0.4", - "parking_lot", + "parking_lot 0.12.4", "pin-project-lite", "signal-hook-registry", "slab", @@ -4805,6 +5028,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "version_check" version = "0.9.5" @@ -4839,6 +5068,12 @@ dependencies = [ "try-lock", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -4996,6 +5231,19 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "winauth" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f820cd208ce9c6b050812dc2d724ba98c6c1e9db5ce9b3f58d925ae5723a5e6" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "md5", + "rand 0.7.3", + "winapi", +] + [[package]] name = "windows-core" version = "0.61.2" diff --git a/Cargo.toml b/Cargo.toml index a0de7fe7c5..8115fe643d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -508,6 +508,11 @@ name = "mssql-isolation-level" path = "tests/mssql/isolation-level.rs" required-features = ["mssql"] +[[test]] +name = "mssql-query-builder" +path = "tests/mssql/query_builder.rs" +required-features = ["mssql"] + [[test]] name = "mssql-bulk-insert" path = "tests/mssql/bulk-insert.rs" diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml index eb9f7f1b0d..607d977d79 100644 --- a/sqlx-mssql/Cargo.toml +++ b/sqlx-mssql/Cargo.toml @@ -15,6 +15,10 @@ any = ["sqlx-core/any"] offline = ["sqlx-core/offline", "serde"] migrate = ["sqlx-core/migrate"] +# Authentication features +winauth = ["tiberius/winauth"] +integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] + # Type Integration features bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal", "tiberius/bigdecimal"] chrono = ["dep:chrono", "sqlx-core/chrono", "tiberius/chrono"] diff --git a/sqlx-mssql/src/arguments.rs b/sqlx-mssql/src/arguments.rs index 7294d09f32..cc7bc66e42 100644 --- a/sqlx-mssql/src/arguments.rs +++ b/sqlx-mssql/src/arguments.rs @@ -1,3 +1,5 @@ +use std::fmt::{self, Write}; + use crate::database::MssqlArgumentValue; use crate::encode::Encode; use crate::types::Type; @@ -44,4 +46,10 @@ impl Arguments for MssqlArguments { fn len(&self) -> usize { self.values.len() } + + fn format_placeholder(&self, writer: &mut W) -> fmt::Result { + // MSSQL uses @p1, @p2, ... for parameterized queries. + // This is called after the bind is added, so len() is the correct 1-based index. + write!(writer, "@p{}", self.values.len()) + } } diff --git a/sqlx-mssql/src/lib.rs b/sqlx-mssql/src/lib.rs index 75f7708ec8..90855c4ffa 100644 --- a/sqlx-mssql/src/lib.rs +++ b/sqlx-mssql/src/lib.rs @@ -54,6 +54,7 @@ pub use row::MssqlRow; pub use statement::MssqlStatement; pub use transaction::MssqlTransactionManager; pub use type_info::MssqlTypeInfo; +pub use types::xml::MssqlXml; pub use value::{MssqlValue, MssqlValueRef}; // Re-export tiberius types needed for bulk insert row construction. diff --git a/sqlx-mssql/src/options/mod.rs b/sqlx-mssql/src/options/mod.rs index 16303c18a7..9fb549cc33 100644 --- a/sqlx-mssql/src/options/mod.rs +++ b/sqlx-mssql/src/options/mod.rs @@ -28,6 +28,8 @@ use ssl_mode::MssqlSslMode; /// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. | /// | `app_name` | `sqlx` | The application name sent to the server. | /// | `instance` | `None` | The SQL Server instance name. | +/// | `auth` | `sql_server` | Authentication method: `sql_server`, `windows` (cfg-gated), `integrated` (cfg-gated), `aad_token`. | +/// | `token` | (none) | Azure AD bearer token (used when `auth=aad_token`). | /// /// # Example /// @@ -64,6 +66,15 @@ pub struct MssqlConnectOptions { pub(crate) statement_cache_capacity: usize, pub(crate) app_name: String, pub(crate) log_settings: LogSettings, + /// When `true`, use Windows (NTLM) authentication instead of SQL Server auth. + /// The username can use `domain\user` syntax which tiberius parses internally. + #[cfg(all(windows, feature = "winauth"))] + pub(crate) windows_auth: bool, + /// When `true`, use integrated authentication (SSPI on Windows / Kerberos on Unix). + #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + pub(crate) integrated_auth: bool, + /// Azure AD bearer token for AAD authentication. + pub(crate) aad_token: Option, } impl Default for MssqlConnectOptions { @@ -89,6 +100,11 @@ impl MssqlConnectOptions { statement_cache_capacity: 100, app_name: String::from("sqlx"), log_settings: Default::default(), + #[cfg(all(windows, feature = "winauth"))] + windows_auth: false, + #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + integrated_auth: false, + aad_token: None, } } @@ -192,6 +208,31 @@ impl MssqlConnectOptions { self } + /// Sets whether to use Windows (NTLM) authentication. + /// + /// When enabled, the username can use `domain\user` syntax + /// which tiberius parses internally. + #[cfg(all(windows, feature = "winauth"))] + pub fn windows_auth(mut self, enabled: bool) -> Self { + self.windows_auth = enabled; + self + } + + /// Sets whether to use integrated authentication (SSPI on Windows / Kerberos on Unix). + #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + pub fn integrated_auth(mut self, enabled: bool) -> Self { + self.integrated_auth = enabled; + self + } + + /// Sets an Azure AD bearer token for authentication. + /// + /// When set, AAD token authentication takes precedence over other auth methods. + pub fn aad_token(mut self, token: &str) -> Self { + self.aad_token = Some(token.to_owned()); + self + } + /// Get the current host. pub fn get_host(&self) -> &str { &self.host @@ -228,10 +269,34 @@ impl MssqlConnectOptions { config.instance_name(instance); } - config.authentication(tiberius::AuthMethod::sql_server( - &self.username, - self.password.as_deref().unwrap_or(""), - )); + if let Some(token) = &self.aad_token { + config.authentication(tiberius::AuthMethod::aad_token(token)); + } else { + #[allow(unused_mut)] + let mut handled = false; + + #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + if !handled && self.integrated_auth { + config.authentication(tiberius::AuthMethod::Integrated); + handled = true; + } + + #[cfg(all(windows, feature = "winauth"))] + if !handled && self.windows_auth { + config.authentication(tiberius::AuthMethod::windows( + &self.username, + self.password.as_deref().unwrap_or(""), + )); + handled = true; + } + + if !handled { + config.authentication(tiberius::AuthMethod::sql_server( + &self.username, + self.password.as_deref().unwrap_or(""), + )); + } + } if let Some(ca_path) = &self.trust_server_certificate_ca { // trust_cert_ca and trust_cert are mutually exclusive in tiberius diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs index 6488069718..e481e71692 100644 --- a/sqlx-mssql/src/options/parse.rs +++ b/sqlx-mssql/src/options/parse.rs @@ -105,6 +105,32 @@ impl MssqlConnectOptions { options = options.trust_server_certificate_ca(&value); } + "auth" => { + match &*value { + "sql_server" => {} + #[cfg(all(windows, feature = "winauth"))] + "windows" => { + options.windows_auth = true; + } + #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + "integrated" => { + options.integrated_auth = true; + } + "aad_token" => { + // token value is set via the separate `token` parameter + } + _ => { + return Err(Error::Configuration( + format!("unknown auth value: {value}").into(), + )) + } + } + } + + "token" => { + options.aad_token = Some(value.into_owned()); + } + _ => {} } } @@ -145,6 +171,24 @@ impl MssqlConnectOptions { .append_pair("trust_server_certificate_ca", ca_path); } + if let Some(token) = &self.aad_token { + url.query_pairs_mut() + .append_pair("auth", "aad_token") + .append_pair("token", token); + } else { + #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + if self.integrated_auth { + url.query_pairs_mut() + .append_pair("auth", "integrated"); + } + + #[cfg(all(windows, feature = "winauth"))] + if self.windows_auth && !self.integrated_auth { + url.query_pairs_mut() + .append_pair("auth", "windows"); + } + } + url } } @@ -292,3 +336,34 @@ fn it_roundtrips_trust_cert_ca_in_url() { let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); assert_eq!(opts2.trust_server_certificate_ca, Some("/etc/ssl/ca.pem".into())); } + +#[test] +fn it_parses_aad_token_auth() { + let url = "mssql://sa@localhost/master?auth=aad_token&token=my-bearer-token"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert_eq!(opts.aad_token, Some("my-bearer-token".into())); +} + +#[test] +fn it_roundtrips_aad_token_in_url() { + let opts = MssqlConnectOptions::new() + .host("localhost") + .username("sa") + .aad_token("my-bearer-token"); + let built = opts.build_url(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert_eq!(opts2.aad_token, Some("my-bearer-token".into())); +} + +#[test] +fn it_parses_sql_server_auth_explicitly() { + let url = "mssql://sa:password@localhost/master?auth=sql_server"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert_eq!(opts.aad_token, None); +} + +#[test] +fn it_rejects_invalid_auth() { + let url = "mssql://sa:password@localhost/master?auth=bogus"; + assert!(MssqlConnectOptions::from_str(url).is_err()); +} diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs index a0ca56c925..cb82c2f267 100644 --- a/sqlx-mssql/src/types/mod.rs +++ b/sqlx-mssql/src/types/mod.rs @@ -52,3 +52,4 @@ mod str; mod time; #[cfg(feature = "uuid")] mod uuid; +pub mod xml; diff --git a/sqlx-mssql/src/types/xml.rs b/sqlx-mssql/src/types/xml.rs new file mode 100644 index 0000000000..65271c003a --- /dev/null +++ b/sqlx-mssql/src/types/xml.rs @@ -0,0 +1,81 @@ +use std::fmt; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +/// SQL Server `XML` column type. +/// +/// A newtype wrapper around [`String`] that maps to the MSSQL `XML` type. +/// This allows sqlx macros to distinguish `XML` columns from `NVARCHAR`. +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example() -> sqlx::Result<()> { +/// use sqlx::mssql::MssqlXml; +/// +/// let xml = MssqlXml::from("hello".to_owned()); +/// assert_eq!(xml.as_ref(), "hello"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MssqlXml(pub String); + +impl Type for MssqlXml { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("XML") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.name.as_str(), + "XML" | "NVARCHAR" | "VARCHAR" | "NTEXT" | "TEXT" + ) + } +} + +impl Encode<'_, Mssql> for MssqlXml { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::String(self.0.clone())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for MssqlXml { + fn decode(value: MssqlValueRef<'_>) -> Result { + let s = value.as_str()?; + Ok(MssqlXml(s.to_owned())) + } +} + +impl From for MssqlXml { + fn from(s: String) -> Self { + MssqlXml(s) + } +} + +impl From for String { + fn from(xml: MssqlXml) -> Self { + xml.0 + } +} + +impl AsRef for MssqlXml { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for MssqlXml { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} diff --git a/tests/mssql/query_builder.rs b/tests/mssql/query_builder.rs new file mode 100644 index 0000000000..2e938d4847 --- /dev/null +++ b/tests/mssql/query_builder.rs @@ -0,0 +1,108 @@ +use sqlx::mssql::Mssql; +use sqlx::query_builder::QueryBuilder; +use sqlx::Execute; + +#[test] +fn test_new() { + let qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); + assert_eq!(qb.sql(), "SELECT * FROM users"); +} + +#[test] +fn test_push() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); + let second_line = " WHERE last_name LIKE '[A-N]%';"; + qb.push(second_line); + + assert_eq!( + qb.sql(), + "SELECT * FROM users WHERE last_name LIKE '[A-N]%';".to_string(), + ); +} + +#[test] +#[should_panic] +fn test_push_panics_after_build_without_reset() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users;"); + + let _query = qb.build(); + + qb.push("SELECT * FROM users;"); +} + +#[test] +fn test_push_bind() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users WHERE id = "); + + qb.push_bind(42i32) + .push(" OR membership_level = ") + .push_bind(3i32); + + assert_eq!( + qb.sql(), + "SELECT * FROM users WHERE id = @p1 OR membership_level = @p2" + ); +} + +#[test] +fn test_build() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); + + qb.push(" WHERE id = ").push_bind(42i32); + let query = qb.build(); + + assert!(Execute::persistent(&query)); + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = @p1"); +} + +#[test] +fn test_reset() { + let mut qb: QueryBuilder = QueryBuilder::new(""); + + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } + + qb.reset(); + + assert_eq!(qb.sql(), ""); +} + +#[test] +fn test_query_builder_reuse() { + let mut qb: QueryBuilder = QueryBuilder::new(""); + + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + + qb.reset(); + + let query = qb.push("SELECT * FROM users WHERE id = 99").build(); + + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = 99"); +} + +#[test] +fn test_query_builder_with_args() { + let mut qb: QueryBuilder = QueryBuilder::new(""); + + let mut query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + + let args = query.take_arguments().unwrap().unwrap(); + + let mut qb: QueryBuilder = QueryBuilder::with_arguments(query.sql().as_str(), args); + let query = qb.push(" OR membership_level = ").push_bind(3i32).build(); + + assert_eq!( + query.sql(), + "SELECT * FROM users WHERE id = @p1 OR membership_level = @p2" + ); +} From ac12b9571c964f7df85d2389552d512e20444085 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 11:01:45 -0500 Subject: [PATCH 09/33] feat: add compile-time query macros and Any type mappings for MSSQL Register MSSQL in the macro system (cfg gate, impl_database_ext, FOSS_DRIVERS) so query!()/query_as!() work with MSSQL databases. Add missing Any driver type mappings for NULL, BIT, MONEY, SMALLMONEY, and DECIMAL/NUMERIC to prevent AnyDriverError at runtime. Author: Pablo Carrera --- sqlx-macros-core/src/database/impls.rs | 9 +++++++++ sqlx-macros-core/src/database/mod.rs | 2 +- sqlx-macros-core/src/lib.rs | 2 ++ sqlx-mssql/src/any.rs | 5 +++++ 4 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sqlx-macros-core/src/database/impls.rs b/sqlx-macros-core/src/database/impls.rs index 523b85cc14..d51f9745cf 100644 --- a/sqlx-macros-core/src/database/impls.rs +++ b/sqlx-macros-core/src/database/impls.rs @@ -46,6 +46,9 @@ mod sqlx { #[cfg(feature = "postgres")] pub use sqlx_postgres as postgres; + #[cfg(feature = "mssql")] + pub use sqlx_mssql as mssql; + #[cfg(feature = "_sqlite")] pub use sqlx_sqlite as sqlite; } @@ -63,6 +66,12 @@ impl_database_ext! { row: sqlx::postgres::PgRow, } +#[cfg(feature = "mssql")] +impl_database_ext! { + sqlx::mssql::Mssql, + row: sqlx::mssql::MssqlRow, +} + #[cfg(feature = "_sqlite")] impl_database_ext! { sqlx::sqlite::Sqlite, diff --git a/sqlx-macros-core/src/database/mod.rs b/sqlx-macros-core/src/database/mod.rs index 0885b3cca8..c108b70d50 100644 --- a/sqlx-macros-core/src/database/mod.rs +++ b/sqlx-macros-core/src/database/mod.rs @@ -10,7 +10,7 @@ use std::collections::hash_map; use std::collections::HashMap; use std::sync::{LazyLock, Mutex}; -#[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))] +#[cfg(any(feature = "postgres", feature = "mysql", feature = "mssql", feature = "_sqlite"))] mod impls; pub trait DatabaseExt: Database + TypeChecking { diff --git a/sqlx-macros-core/src/lib.rs b/sqlx-macros-core/src/lib.rs index 7722bf2e02..55a6f4be25 100644 --- a/sqlx-macros-core/src/lib.rs +++ b/sqlx-macros-core/src/lib.rs @@ -50,6 +50,8 @@ pub const FOSS_DRIVERS: &[QueryDriver] = &[ QueryDriver::new::(), #[cfg(feature = "postgres")] QueryDriver::new::(), + #[cfg(feature = "mssql")] + QueryDriver::new::(), #[cfg(feature = "_sqlite")] QueryDriver::new::(), ]; diff --git a/sqlx-mssql/src/any.rs b/sqlx-mssql/src/any.rs index 9a026fda73..f32b575d39 100644 --- a/sqlx-mssql/src/any.rs +++ b/sqlx-mssql/src/any.rs @@ -167,6 +167,11 @@ impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo { "REAL" => AnyTypeInfoKind::Real, "FLOAT" => AnyTypeInfoKind::Double, "VARBINARY" | "BINARY" | "IMAGE" => AnyTypeInfoKind::Blob, + "NULL" => AnyTypeInfoKind::Null, + "BIT" => AnyTypeInfoKind::Bool, + "MONEY" => AnyTypeInfoKind::Double, + "SMALLMONEY" => AnyTypeInfoKind::Real, + "DECIMAL" | "NUMERIC" => AnyTypeInfoKind::Text, "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" => { AnyTypeInfoKind::Text } From 68942d63f527f8f532130cf6c611c2fbe2e6d6dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 15 Feb 2026 12:31:26 -0500 Subject: [PATCH 10/33] feat: add Docker/CI, no_tx migrations, ColumnOrigin, Any types, and DECIMAL precision for MSSQL Wire MSSQL into docker-compose and CI workflow, support no_tx migrations following the Postgres pattern, extract column origin from sp_describe_first_result_set, map date/time/UUID types to Text in the Any driver, and preserve DECIMAL precision/scale by using a base_name() helper for type matching. Author: Pablo Carrera --- .github/workflows/sqlx.yml | 81 ++++++++++++++++++ sqlx-mssql/src/any.rs | 5 +- sqlx-mssql/src/column.rs | 4 +- sqlx-mssql/src/connection/executor.rs | 47 ++++++++-- sqlx-mssql/src/migrate.rs | 119 ++++++++++++++------------ sqlx-mssql/src/type_info.rs | 7 ++ sqlx-mssql/src/types/bigdecimal.rs | 2 +- sqlx-mssql/src/types/bool.rs | 2 +- sqlx-mssql/src/types/bytes.rs | 2 +- sqlx-mssql/src/types/chrono.rs | 4 +- sqlx-mssql/src/types/float.rs | 2 +- sqlx-mssql/src/types/int.rs | 2 +- sqlx-mssql/src/types/rust_decimal.rs | 2 +- sqlx-mssql/src/types/str.rs | 2 +- sqlx-mssql/src/types/time.rs | 4 +- sqlx-mssql/src/types/xml.rs | 2 +- tests/docker-compose.yml | 28 ++++++ 17 files changed, 242 insertions(+), 73 deletions(-) diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index b2f81b75ad..c8cc0035e9 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -104,6 +104,12 @@ jobs: -p sqlx-sqlite --all-features + - name: Test sqlx-mssql + run: > + cargo test + -p sqlx-mssql + --all-features + - name: Test sqlx-macros-core run: > cargo test @@ -514,3 +520,78 @@ jobs: env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + + mssql: + name: MSSQL + runs-on: ubuntu-24.04 + strategy: + matrix: + mssql: [ 2022, 2019 ] + runtime: [ async-global-executor, smol, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] + needs: check + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + + - uses: Swatinem/rust-cache@v2 + + - run: > + cargo build + --no-default-features + --features mssql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros,migrate + + - run: docker compose -f tests/docker-compose.yml run -d -p 1433:1433 --name mssql_${{ matrix.mssql }} mssql_${{ matrix.mssql }} + - run: sleep 60 + + # Create data dir for offline mode + - run: mkdir .sqlx + + - run: > + cargo test + --no-default-features + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mssql://sa:YourStrong!Passw0rd@localhost:1433/sqlx + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mssql_${{ matrix.mssql }} + + # Run the `test-attr` test again to cover cleanup. + - run: > + cargo test + --test mssql-test-attr + --no-default-features + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mssql://sa:YourStrong!Passw0rd@localhost:1433/sqlx + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mssql_${{ matrix.mssql }} + + # Remove test artifacts + - run: cargo clean -p sqlx + + # Build the macros-test in offline mode (omit DATABASE_URL) + - run: > + cargo build + --no-default-features + --test mssql-macros + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + SQLX_OFFLINE: true + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: -D warnings --cfg mssql_${{ matrix.mssql }} + + # Test macros in offline mode (still needs DATABASE_URL to run) + - run: > + cargo test + --no-default-features + --test mssql-macros + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mssql://sa:YourStrong!Passw0rd@localhost:1433/sqlx + SQLX_OFFLINE: true + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mssql_${{ matrix.mssql }} diff --git a/sqlx-mssql/src/any.rs b/sqlx-mssql/src/any.rs index f32b575d39..63cd7e3b62 100644 --- a/sqlx-mssql/src/any.rs +++ b/sqlx-mssql/src/any.rs @@ -159,7 +159,7 @@ impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo { fn try_from(type_info: &'a MssqlTypeInfo) -> Result { Ok(AnyTypeInfo { - kind: match type_info.name.as_str() { + kind: match type_info.base_name() { "TINYINT" => AnyTypeInfoKind::SmallInt, "SMALLINT" => AnyTypeInfoKind::SmallInt, "INT" => AnyTypeInfoKind::Integer, @@ -175,6 +175,9 @@ impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo { "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" => { AnyTypeInfoKind::Text } + "UNIQUEIDENTIFIER" => AnyTypeInfoKind::Text, + "DATE" | "TIME" | "DATETIME" | "DATETIME2" | "SMALLDATETIME" + | "DATETIMEOFFSET" => AnyTypeInfoKind::Text, _ => { return Err(sqlx_core::Error::AnyDriverError( format!("Any driver does not support MSSQL type {type_info:?}").into(), diff --git a/sqlx-mssql/src/column.rs b/sqlx-mssql/src/column.rs index aac9df3a4e..e721a78b6c 100644 --- a/sqlx-mssql/src/column.rs +++ b/sqlx-mssql/src/column.rs @@ -8,6 +8,7 @@ pub struct MssqlColumn { pub(crate) ordinal: usize, pub(crate) name: UStr, pub(crate) type_info: MssqlTypeInfo, + pub(crate) origin: ColumnOrigin, } impl Column for MssqlColumn { @@ -26,7 +27,6 @@ impl Column for MssqlColumn { } fn origin(&self) -> ColumnOrigin { - // tiberius doesn't expose table origin information - ColumnOrigin::Expression + self.origin.clone() } } diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 997414e6bc..a5ee25710d 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -14,6 +14,7 @@ use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::TryStreamExt; +use sqlx_core::column::{ColumnOrigin, TableColumn}; use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; use std::sync::Arc; @@ -250,6 +251,7 @@ async fn collect_results<'a>( ordinal, name, type_info, + origin: ColumnOrigin::Unknown, } }) .collect(); @@ -370,9 +372,25 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { for (ordinal, row) in rows.iter().enumerate() { let name: &str = row.get("name").unwrap_or(""); let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); - // Extract the base type name (before any parenthesized length/precision) - let base_type = type_name.split('(').next().unwrap_or(type_name).trim(); - let type_info = MssqlTypeInfo::new(base_type.to_uppercase()); + let type_info = MssqlTypeInfo::new(type_name.to_uppercase()); + + let source_table: Option<&str> = row.get("source_table"); + let source_schema: Option<&str> = row.get("source_schema"); + let source_column: Option<&str> = row.get("source_column"); + + let origin = match (source_table, source_column) { + (Some(table), Some(col)) if !table.is_empty() && !col.is_empty() => { + let table_str = match source_schema { + Some(s) if !s.is_empty() => format!("{s}.{table}"), + _ => table.to_string(), + }; + ColumnOrigin::Table(TableColumn { + table: table_str.into(), + name: col.into(), + }) + } + _ => ColumnOrigin::Expression, + }; let ustr_name = UStr::new(name); column_names.insert(ustr_name.clone(), ordinal); @@ -380,6 +398,7 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { ordinal, name: ustr_name, type_info, + origin, }); } @@ -427,16 +446,34 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { for (ordinal, row) in rows.iter().enumerate() { let name: &str = row.get("name").unwrap_or(""); let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); - let base_type = type_name.split('(').next().unwrap_or(type_name).trim(); - let type_info = MssqlTypeInfo::new(base_type.to_uppercase()); + let type_info = MssqlTypeInfo::new(type_name.to_uppercase()); let is_nullable: Option = row.get("is_nullable"); + let source_table: Option<&str> = row.get("source_table"); + let source_schema: Option<&str> = row.get("source_schema"); + let source_column: Option<&str> = row.get("source_column"); + + let origin = match (source_table, source_column) { + (Some(table), Some(col)) if !table.is_empty() && !col.is_empty() => { + let table_str = match source_schema { + Some(s) if !s.is_empty() => format!("{s}.{table}"), + _ => table.to_string(), + }; + ColumnOrigin::Table(TableColumn { + table: table_str.into(), + name: col.into(), + }) + } + _ => ColumnOrigin::Expression, + }; + let ustr_name = UStr::new(name); column_names.insert(ustr_name.clone(), ordinal); columns.push(MssqlColumn { ordinal, name: ustr_name, type_info, + origin, }); nullable.push(is_nullable); } diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index 5b0a22cf2c..c7836f57cf 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -187,39 +187,22 @@ CREATE TABLE {table_name} ( migration: &'e Migration, ) -> BoxFuture<'e, Result> { Box::pin(async move { - let mut tx = self.begin().await?; let start = Instant::now(); - let _ = query(AssertSqlSafe(format!( - r#" - INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) - VALUES ( @p1, @p2, 0, @p3, -1 ) - "# - ))) - .bind(migration.version) - .bind(&*migration.description) - .bind(&*migration.checksum) - .execute(&mut *tx) - .await?; - - let _ = tx - .execute(migration.sql.clone()) - .await - .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; - - let _ = query(AssertSqlSafe(format!( - r#" - UPDATE {table_name} - SET success = 1 - WHERE version = @p1 - "# - ))) - .bind(migration.version) - .execute(&mut *tx) - .await?; - - tx.commit().await?; - + if migration.no_tx { + execute_migration(self, table_name, migration).await?; + } else { + // Use a single transaction for the actual migration script and the essential + // bookkeeping so we never execute migrations twice. + // See https://github.com/launchbadge/sqlx/issues/1966. + let mut tx = self.begin().await?; + execute_migration(&mut tx, table_name, migration).await?; + tx.commit().await?; + } + + // Update `execution_time`. + // NOTE: The process may disconnect/die at this point, so the elapsed time value + // might be lost. We accept this small risk since this value is not super important. let elapsed = start.elapsed(); #[allow(clippy::cast_possible_truncation)] @@ -245,30 +228,15 @@ CREATE TABLE {table_name} ( migration: &'e Migration, ) -> BoxFuture<'e, Result> { Box::pin(async move { - let mut tx = self.begin().await?; let start = Instant::now(); - let _ = query(AssertSqlSafe(format!( - r#" - UPDATE {table_name} - SET success = 0 - WHERE version = @p1 - "# - ))) - .bind(migration.version) - .execute(&mut *tx) - .await?; - - tx.execute(migration.sql.clone()).await?; - - let _ = query(AssertSqlSafe(format!( - r#"DELETE FROM {table_name} WHERE version = @p1"# - ))) - .bind(migration.version) - .execute(&mut *tx) - .await?; - - tx.commit().await?; + if migration.no_tx { + revert_migration(self, table_name, migration).await?; + } else { + let mut tx = self.begin().await?; + revert_migration(&mut tx, table_name, migration).await?; + tx.commit().await?; + } let elapsed = start.elapsed(); @@ -276,3 +244,48 @@ CREATE TABLE {table_name} ( }) } } + +async fn execute_migration( + conn: &mut MssqlConnection, + table_name: &str, + migration: &Migration, +) -> Result<(), MigrateError> { + let _ = conn + .execute(migration.sql.clone()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + let _ = query(AssertSqlSafe(format!( + r#" + INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) + VALUES ( @p1, @p2, 1, @p3, -1 ) + "# + ))) + .bind(migration.version) + .bind(&*migration.description) + .bind(&*migration.checksum) + .execute(conn) + .await?; + + Ok(()) +} + +async fn revert_migration( + conn: &mut MssqlConnection, + table_name: &str, + migration: &Migration, +) -> Result<(), MigrateError> { + let _ = conn + .execute(migration.sql.clone()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + let _ = query(AssertSqlSafe(format!( + r#"DELETE FROM {table_name} WHERE version = @p1"# + ))) + .bind(migration.version) + .execute(conn) + .await?; + + Ok(()) +} diff --git a/sqlx-mssql/src/type_info.rs b/sqlx-mssql/src/type_info.rs index 764710bc7e..80a44223d8 100644 --- a/sqlx-mssql/src/type_info.rs +++ b/sqlx-mssql/src/type_info.rs @@ -13,6 +13,13 @@ impl MssqlTypeInfo { pub(crate) fn new(name: impl Into) -> Self { Self { name: name.into() } } + + /// Return the base type name without any parenthesized precision/scale. + /// + /// e.g. `"DECIMAL(10,2)"` → `"DECIMAL"`, `"NVARCHAR(4000)"` → `"NVARCHAR"` + pub(crate) fn base_name(&self) -> &str { + self.name.split('(').next().unwrap_or(&self.name).trim() + } } impl Display for MssqlTypeInfo { diff --git a/sqlx-mssql/src/types/bigdecimal.rs b/sqlx-mssql/src/types/bigdecimal.rs index 330d978d8e..175c9dbed2 100644 --- a/sqlx-mssql/src/types/bigdecimal.rs +++ b/sqlx-mssql/src/types/bigdecimal.rs @@ -14,7 +14,7 @@ impl Type for BigDecimal { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.name.as_str(), "DECIMAL" | "NUMERIC" | "MONEY") + matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY") } } diff --git a/sqlx-mssql/src/types/bool.rs b/sqlx-mssql/src/types/bool.rs index af0cae774f..171f961dc4 100644 --- a/sqlx-mssql/src/types/bool.rs +++ b/sqlx-mssql/src/types/bool.rs @@ -12,7 +12,7 @@ impl Type for bool { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.name.as_str(), "BIT" | "TINYINT" | "INT" | "SMALLINT" | "BIGINT") + matches!(ty.base_name(), "BIT" | "TINYINT" | "INT" | "SMALLINT" | "BIGINT") } } diff --git a/sqlx-mssql/src/types/bytes.rs b/sqlx-mssql/src/types/bytes.rs index a1133e8dc6..2c35ec41e6 100644 --- a/sqlx-mssql/src/types/bytes.rs +++ b/sqlx-mssql/src/types/bytes.rs @@ -11,7 +11,7 @@ use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn bytes_compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "VARBINARY" | "BINARY" | "IMAGE" ) } diff --git a/sqlx-mssql/src/types/chrono.rs b/sqlx-mssql/src/types/chrono.rs index 62ba995b12..bf34605e5e 100644 --- a/sqlx-mssql/src/types/chrono.rs +++ b/sqlx-mssql/src/types/chrono.rs @@ -17,7 +17,7 @@ impl Type for NaiveDateTime { fn compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "DATETIME2" | "DATETIME" | "SMALLDATETIME" ) } @@ -110,7 +110,7 @@ impl Type for DateTime { fn compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "DATETIME2" | "DATETIMEOFFSET" ) } diff --git a/sqlx-mssql/src/types/float.rs b/sqlx-mssql/src/types/float.rs index 66a9a7bb4b..df1a5534f3 100644 --- a/sqlx-mssql/src/types/float.rs +++ b/sqlx-mssql/src/types/float.rs @@ -7,7 +7,7 @@ use crate::value::MssqlData; use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn real_compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.name.as_str(), "REAL" | "FLOAT" | "MONEY" | "SMALLMONEY") + matches!(ty.base_name(), "REAL" | "FLOAT" | "MONEY" | "SMALLMONEY") } impl Type for f32 { diff --git a/sqlx-mssql/src/types/int.rs b/sqlx-mssql/src/types/int.rs index 92acd6916d..8de3b73749 100644 --- a/sqlx-mssql/src/types/int.rs +++ b/sqlx-mssql/src/types/int.rs @@ -8,7 +8,7 @@ use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn int_compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "TINYINT" | "SMALLINT" | "INT" | "BIGINT" ) } diff --git a/sqlx-mssql/src/types/rust_decimal.rs b/sqlx-mssql/src/types/rust_decimal.rs index c167f28769..5951f93f7e 100644 --- a/sqlx-mssql/src/types/rust_decimal.rs +++ b/sqlx-mssql/src/types/rust_decimal.rs @@ -14,7 +14,7 @@ impl Type for Decimal { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.name.as_str(), "DECIMAL" | "NUMERIC" | "MONEY") + matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY") } } diff --git a/sqlx-mssql/src/types/str.rs b/sqlx-mssql/src/types/str.rs index 4995160a9b..8694bc4ff8 100644 --- a/sqlx-mssql/src/types/str.rs +++ b/sqlx-mssql/src/types/str.rs @@ -11,7 +11,7 @@ use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn str_compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" ) } diff --git a/sqlx-mssql/src/types/time.rs b/sqlx-mssql/src/types/time.rs index d1b79ae251..0225fe1e66 100644 --- a/sqlx-mssql/src/types/time.rs +++ b/sqlx-mssql/src/types/time.rs @@ -75,7 +75,7 @@ impl Type for PrimitiveDateTime { fn compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "DATETIME2" | "DATETIME" | "SMALLDATETIME" ) } @@ -110,7 +110,7 @@ impl Type for OffsetDateTime { fn compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "DATETIMEOFFSET" | "DATETIME2" ) } diff --git a/sqlx-mssql/src/types/xml.rs b/sqlx-mssql/src/types/xml.rs index 65271c003a..82d409ea15 100644 --- a/sqlx-mssql/src/types/xml.rs +++ b/sqlx-mssql/src/types/xml.rs @@ -33,7 +33,7 @@ impl Type for MssqlXml { fn compatible(ty: &MssqlTypeInfo) -> bool { matches!( - ty.name.as_str(), + ty.base_name(), "XML" | "NVARCHAR" | "VARCHAR" | "NTEXT" | "TEXT" ) } diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index c2ccdabef6..a6fc025a8a 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -198,6 +198,34 @@ services: MARIADB_DATABASE: sqlx MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: 1 # + # Microsoft SQL Server 2022, 2019 + # + + mssql_2022: + build: + context: . + dockerfile: mssql/Dockerfile + args: + VERSION: 2022-latest + ports: + - 1433 + environment: + ACCEPT_EULA: "Y" + SA_PASSWORD: "YourStrong!Passw0rd" + + mssql_2019: + build: + context: . + dockerfile: mssql/Dockerfile + args: + VERSION: 2019-latest + ports: + - 1433 + environment: + ACCEPT_EULA: "Y" + SA_PASSWORD: "YourStrong!Passw0rd" + + # # PostgreSQL 17.x, 16.x, 15.x, 14.x, 13.x # https://www.postgresql.org/support/versioning/ # From 252bf142c953ca83d3c0263ca5a1f79d34bd398c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Mon, 16 Feb 2026 16:03:10 -0500 Subject: [PATCH 11/33] feat: add RAII advisory lock guard, DateTime support, and MSSQL examples Add MssqlAdvisoryLockGuard with Deref/DerefMut, release_now(), and leak() methods, mirroring the Postgres advisory lock guard pattern. The guard logs a warning on drop if not explicitly released, since MSSQL connections cannot queue deferred commands. Add chrono::DateTime Type/Encode/Decode for DATETIMEOFFSET columns, preserving timezone offset information instead of converting to UTC. Existing NaiveDateTime and DateTime decoders remain backward compatible by handling the new internal variant. Add examples/mssql/todos/ with a CLI app demonstrating basic CRUD operations using MSSQL parameter syntax and OUTPUT INSERTED. Author: Pablo Carrera --- Cargo.lock | 11 + Cargo.toml | 1 + examples/mssql/todos/Cargo.toml | 12 ++ .../20250101000000_create_todos.sql | 5 + examples/mssql/todos/src/main.rs | 83 ++++++++ sqlx-mssql/src/advisory_lock.rs | 188 +++++++++++++++++- sqlx-mssql/src/connection/executor.rs | 28 ++- sqlx-mssql/src/connection/mod.rs | 14 ++ sqlx-mssql/src/database.rs | 2 + sqlx-mssql/src/lib.rs | 2 +- sqlx-mssql/src/type_checking.rs | 1 + sqlx-mssql/src/types/chrono.rs | 45 ++++- sqlx-mssql/src/value.rs | 11 +- 13 files changed, 390 insertions(+), 13 deletions(-) create mode 100644 examples/mssql/todos/Cargo.toml create mode 100644 examples/mssql/todos/migrations/20250101000000_create_todos.sql create mode 100644 examples/mssql/todos/src/main.rs diff --git a/Cargo.lock b/Cargo.lock index 9ac87ff8b3..3afe2bf76e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3880,6 +3880,17 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "sqlx-example-mssql-todos" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap 4.5.40", + "dotenvy", + "sqlx", + "tokio", +] + [[package]] name = "sqlx-example-mysql-todos" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 8115fe643d..d2d66ed1e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", + "examples/mssql/todos", "examples/mysql/todos", "examples/postgres/axum-social-with-tests", "examples/postgres/chat", diff --git a/examples/mssql/todos/Cargo.toml b/examples/mssql/todos/Cargo.toml new file mode 100644 index 0000000000..f7298d42b7 --- /dev/null +++ b/examples/mssql/todos/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sqlx-example-mssql-todos" +version = "0.1.0" +edition = "2021" +workspace = "../../../" + +[dependencies] +anyhow = "1.0" +sqlx = { path = "../../../", features = [ "mssql", "runtime-tokio", "tls-native-tls" ] } +clap = { version = "4", features = ["derive"] } +tokio = { version = "1.20.0", features = ["rt", "macros"] } +dotenvy = "0.15.0" diff --git a/examples/mssql/todos/migrations/20250101000000_create_todos.sql b/examples/mssql/todos/migrations/20250101000000_create_todos.sql new file mode 100644 index 0000000000..aca157b7da --- /dev/null +++ b/examples/mssql/todos/migrations/20250101000000_create_todos.sql @@ -0,0 +1,5 @@ +CREATE TABLE todos ( + id BIGINT IDENTITY(1,1) PRIMARY KEY, + description NVARCHAR(MAX) NOT NULL, + done BIT NOT NULL DEFAULT 0 +); diff --git a/examples/mssql/todos/src/main.rs b/examples/mssql/todos/src/main.rs new file mode 100644 index 0000000000..c3c355b393 --- /dev/null +++ b/examples/mssql/todos/src/main.rs @@ -0,0 +1,83 @@ +use clap::{Parser, Subcommand}; +use sqlx::mssql::MssqlPool; +use sqlx::Row; +use std::env; + +#[derive(Parser)] +struct Args { + #[command(subcommand)] + cmd: Option, +} + +#[derive(Subcommand)] +enum Command { + Add { description: String }, + Done { id: i64 }, +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let pool = MssqlPool::connect(&env::var("DATABASE_URL")?).await?; + + match args.cmd { + Some(Command::Add { description }) => { + println!("Adding new todo with description '{description}'"); + let todo_id = add_todo(&pool, description).await?; + println!("Added new todo with id {todo_id}"); + } + Some(Command::Done { id }) => { + println!("Marking todo {id} as done"); + if complete_todo(&pool, id).await? { + println!("Todo {id} is marked as done"); + } else { + println!("Invalid id {id}"); + } + } + None => { + println!("Printing list of all todos"); + list_todos(&pool).await?; + } + } + + Ok(()) +} + +async fn add_todo(pool: &MssqlPool, description: String) -> anyhow::Result { + // MSSQL uses OUTPUT INSERTED instead of RETURNING + let rec = sqlx::query( + "INSERT INTO todos (description) OUTPUT INSERTED.id VALUES (@p1)", + ) + .bind(&description) + .fetch_one(pool) + .await?; + + Ok(rec.get::("id")) +} + +async fn complete_todo(pool: &MssqlPool, id: i64) -> anyhow::Result { + let rows_affected = sqlx::query("UPDATE todos SET done = 1 WHERE id = @p1") + .bind(id) + .execute(pool) + .await? + .rows_affected(); + + Ok(rows_affected > 0) +} + +async fn list_todos(pool: &MssqlPool) -> anyhow::Result<()> { + let recs = sqlx::query("SELECT id, description, done FROM todos ORDER BY id") + .fetch_all(pool) + .await?; + + for rec in recs { + println!( + "- [{}] {}: {}", + if rec.get::("done") { "x" } else { " " }, + rec.get::("id"), + rec.get::("description"), + ); + } + + Ok(()) +} diff --git a/sqlx-mssql/src/advisory_lock.rs b/sqlx-mssql/src/advisory_lock.rs index 60857ec0ca..a0a71d7b85 100644 --- a/sqlx-mssql/src/advisory_lock.rs +++ b/sqlx-mssql/src/advisory_lock.rs @@ -1,5 +1,8 @@ +use std::ops::{Deref, DerefMut}; + use crate::error::Error; use crate::query_scalar::query_scalar; +use crate::Either; use crate::MssqlConnection; /// The lock mode for a MSSQL advisory lock. @@ -35,9 +38,21 @@ impl MssqlAdvisoryLockMode { /// object; instead, all participants must explicitly acquire the same named /// lock. The lock is scoped to the database session (connection). /// -/// Unlike the Postgres advisory-lock API, there is **no RAII drop guard**. -/// You must call [`release`][Self::release] explicitly when you are done with -/// the lock. +/// # RAII Guard +/// +/// Use [`acquire_guard`][Self::acquire_guard] or +/// [`try_acquire_guard`][Self::try_acquire_guard] to get an +/// [`MssqlAdvisoryLockGuard`] that provides access to the underlying connection +/// and can release the lock via [`release_now()`][MssqlAdvisoryLockGuard::release_now]. +/// +/// Unlike PostgreSQL, MSSQL connections cannot queue commands for deferred +/// execution, so the lock **cannot** be released automatically on drop. +/// If the guard is dropped without calling `release_now()` or `leak()`, a +/// warning is logged. The lock will still be released when the connection +/// is closed or returned to the pool. +/// +/// For manual lock management without a guard, use [`acquire`][Self::acquire], +/// [`try_acquire`][Self::try_acquire], and [`release`][Self::release]. /// /// # Resource Name /// @@ -51,19 +66,43 @@ impl MssqlAdvisoryLockMode { /// use sqlx::mssql::MssqlAdvisoryLock; /// /// let lock = MssqlAdvisoryLock::new("my_app_lock"); -/// lock.acquire(conn).await?; /// -/// // ... do work under the lock ... +/// // Using the RAII guard (preferred): +/// let guard = lock.acquire_guard(conn).await?; +/// // ... do work under the lock, using `&mut *guard` as a connection ... +/// guard.release_now().await?; /// +/// // Or manual management: +/// lock.acquire(conn).await?; +/// // ... do work ... /// lock.release(conn).await?; /// # Ok(()) /// # } /// ``` +#[derive(Debug, Clone)] pub struct MssqlAdvisoryLock { resource: String, mode: MssqlAdvisoryLockMode, } +/// A wrapper for a connection that represents a held MSSQL advisory lock. +/// +/// Can be acquired by [`MssqlAdvisoryLock::acquire_guard()`] or +/// [`MssqlAdvisoryLock::try_acquire_guard()`]. +/// +/// ### Note: Release is NOT automatic on drop! +/// +/// Unlike PostgreSQL, MSSQL connections cannot queue commands for deferred +/// execution. If this guard is dropped without calling +/// [`release_now()`][Self::release_now], a warning is logged and the lock +/// remains held until the connection is closed or returned to the pool. +/// +/// Always prefer calling `.release_now().await` when you are done with the lock. +pub struct MssqlAdvisoryLockGuard> { + lock: MssqlAdvisoryLock, + conn: Option, +} + impl MssqlAdvisoryLock { /// Create a new advisory lock with the given resource name and the default /// [`Exclusive`][MssqlAdvisoryLockMode::Exclusive] mode. @@ -179,6 +218,145 @@ impl MssqlAdvisoryLock { ))), } } + + /// Acquire the lock and return an RAII guard that provides access to the + /// underlying connection. + /// + /// The guard does **not** release the lock on drop (see + /// [`MssqlAdvisoryLockGuard`] for details). Call + /// [`release_now()`][MssqlAdvisoryLockGuard::release_now] to release the + /// lock and recover the connection. + /// + /// A connection-like type is required to execute the call. Allowed types + /// include `MssqlConnection`, `PoolConnection`, and mutable + /// references to either. + pub async fn acquire_guard>( + &self, + mut conn: C, + ) -> Result, Error> { + self.acquire(conn.as_mut()).await?; + Ok(MssqlAdvisoryLockGuard::new(self.clone(), conn)) + } + + /// Try to acquire the lock without waiting, returning an RAII guard on + /// success. + /// + /// Returns `Ok(Left(guard))` if the lock was acquired, or + /// `Ok(Right(conn))` if it was not available. + pub async fn try_acquire_guard>( + &self, + mut conn: C, + ) -> Result, C>, Error> { + if self.try_acquire(conn.as_mut()).await? { + Ok(Either::Left(MssqlAdvisoryLockGuard::new( + self.clone(), + conn, + ))) + } else { + Ok(Either::Right(conn)) + } + } + + /// Execute `sp_releaseapplock` for this lock's resource on the given + /// connection. + /// + /// This is provided for manually releasing the lock from connections + /// returned by [`MssqlAdvisoryLockGuard::leak()`]. + /// + /// Returns `Ok((conn, true))` if released, `Ok((conn, false))` if the lock + /// was not held. + pub async fn force_release>( + &self, + mut conn: C, + ) -> Result<(C, bool), Error> { + let released = self.release(conn.as_mut()).await?; + Ok((conn, released)) + } +} + +const NONE_ERR: &str = "BUG: MssqlAdvisoryLockGuard.conn taken"; + +impl> MssqlAdvisoryLockGuard { + fn new(lock: MssqlAdvisoryLock, conn: C) -> Self { + MssqlAdvisoryLockGuard { + lock, + conn: Some(conn), + } + } + + /// Release the advisory lock immediately and return the connection. + /// + /// This is the preferred way to release the lock. An error should only be + /// returned if there is something wrong with the connection, in which case + /// the lock will be automatically released when the connection is closed. + pub async fn release_now(mut self) -> Result { + let (conn, released) = self + .lock + .force_release(self.conn.take().expect(NONE_ERR)) + .await?; + + if !released { + tracing::warn!( + resource = %self.lock.resource(), + "MssqlAdvisoryLockGuard: advisory lock was not held by the contained connection", + ); + } + + Ok(conn) + } + + /// Cancel the release of the advisory lock, keeping it held until the + /// connection is closed. + /// + /// To manually release the lock later, see + /// [`MssqlAdvisoryLock::force_release()`]. + pub fn leak(mut self) -> C { + self.conn.take().expect(NONE_ERR) + } +} + +impl + AsRef> Deref for MssqlAdvisoryLockGuard { + type Target = MssqlConnection; + + fn deref(&self) -> &Self::Target { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +impl + AsRef> DerefMut for MssqlAdvisoryLockGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} + +impl> AsRef for MssqlAdvisoryLockGuard +where + C: AsRef, +{ + fn as_ref(&self) -> &MssqlConnection { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +impl> AsMut for MssqlAdvisoryLockGuard { + fn as_mut(&mut self) -> &mut MssqlConnection { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} + +/// Logs a warning if dropped without calling `release_now()` or `leak()`. +/// +/// The lock remains held until the connection is closed or returned to the pool. +impl> Drop for MssqlAdvisoryLockGuard { + fn drop(&mut self) { + if self.conn.is_some() { + tracing::warn!( + resource = %self.lock.resource(), + "MssqlAdvisoryLockGuard dropped without calling release_now() or leak(). \ + The lock will be released when the connection is closed.", + ); + } + } } fn applock_error_message(status: i32) -> &'static str { diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index a5ee25710d..49bae2ee37 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -24,10 +24,10 @@ use std::sync::Arc; /// crate types, and `BigDecimal` due to version mismatch). `Query::bind()` /// requires `IntoSql`, so this wrapper lets us construct `ColumnData` manually /// and pass it to `bind()`. -#[cfg(any(feature = "time", feature = "bigdecimal"))] +#[cfg(any(feature = "chrono", feature = "time", feature = "bigdecimal"))] struct ColumnDataWrapper<'a>(tiberius::ColumnData<'a>); -#[cfg(any(feature = "time", feature = "bigdecimal"))] +#[cfg(any(feature = "chrono", feature = "time", feature = "bigdecimal"))] impl<'a> tiberius::IntoSql<'a> for ColumnDataWrapper<'a> { fn into_sql(self) -> tiberius::ColumnData<'a> { self.0 @@ -102,6 +102,30 @@ impl MssqlConnection { MssqlArgumentValue::NaiveTime(v) => { query.bind(*v); } + #[cfg(feature = "chrono")] + MssqlArgumentValue::DateTimeFixedOffset(v) => { + use chrono::Timelike as _; + let epoch = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap(); + let naive = v.naive_local(); + let days = (naive.date() - epoch).num_days() as u32; + let total_ns = naive.time().num_seconds_from_midnight() as u64 + * 1_000_000_000 + + naive.time().nanosecond() as u64 % 1_000_000_000; + let increments = total_ns / 100; + let offset_minutes = + v.offset().local_minus_utc() / 60; + let dt2 = tiberius::time::DateTime2::new( + tiberius::time::Date::new(days), + tiberius::time::Time::new(increments, 7), + ); + let cd = tiberius::ColumnData::DateTimeOffset(Some( + tiberius::time::DateTimeOffset::new( + dt2, + offset_minutes as i16, + ), + )); + query.bind(ColumnDataWrapper(cd)); + } #[cfg(feature = "uuid")] MssqlArgumentValue::Uuid(v) => { query.bind(v); diff --git a/sqlx-mssql/src/connection/mod.rs b/sqlx-mssql/src/connection/mod.rs index 77a2656dc5..d5ea091aa1 100644 --- a/sqlx-mssql/src/connection/mod.rs +++ b/sqlx-mssql/src/connection/mod.rs @@ -97,6 +97,20 @@ impl Connection for MssqlConnection { } } +// Implement `AsMut` so that `MssqlConnection` can be wrapped in +// a `MssqlAdvisoryLockGuard`. +impl AsMut for MssqlConnection { + fn as_mut(&mut self) -> &mut MssqlConnection { + self + } +} + +impl AsRef for MssqlConnection { + fn as_ref(&self) -> &MssqlConnection { + self + } +} + impl MssqlConnection { /// Begin a transaction with a specific isolation level. /// diff --git a/sqlx-mssql/src/database.rs b/sqlx-mssql/src/database.rs index 379b35d824..69fa61a469 100644 --- a/sqlx-mssql/src/database.rs +++ b/sqlx-mssql/src/database.rs @@ -59,6 +59,8 @@ pub enum MssqlArgumentValue { NaiveDate(chrono::NaiveDate), #[cfg(feature = "chrono")] NaiveTime(chrono::NaiveTime), + #[cfg(feature = "chrono")] + DateTimeFixedOffset(chrono::DateTime), #[cfg(feature = "uuid")] Uuid(uuid::Uuid), #[cfg(feature = "rust_decimal")] diff --git a/sqlx-mssql/src/lib.rs b/sqlx-mssql/src/lib.rs index 90855c4ffa..6cb0bf905a 100644 --- a/sqlx-mssql/src/lib.rs +++ b/sqlx-mssql/src/lib.rs @@ -39,7 +39,7 @@ mod migrate; #[cfg(feature = "migrate")] mod testing; -pub use advisory_lock::{MssqlAdvisoryLock, MssqlAdvisoryLockMode}; +pub use advisory_lock::{MssqlAdvisoryLock, MssqlAdvisoryLockGuard, MssqlAdvisoryLockMode}; pub use arguments::MssqlArguments; pub use bulk_insert::MssqlBulkInsert; pub use column::MssqlColumn; diff --git a/sqlx-mssql/src/type_checking.rs b/sqlx-mssql/src/type_checking.rs index f3544207ee..aa4cbcffe7 100644 --- a/sqlx-mssql/src/type_checking.rs +++ b/sqlx-mssql/src/type_checking.rs @@ -36,6 +36,7 @@ impl_type_checking!( sqlx::types::chrono::NaiveDate, sqlx::types::chrono::NaiveDateTime, sqlx::types::chrono::DateTime, + sqlx::types::chrono::DateTime, }, time: { sqlx::types::time::Time, diff --git a/sqlx-mssql/src/types/chrono.rs b/sqlx-mssql/src/types/chrono.rs index bf34605e5e..6a849622d3 100644 --- a/sqlx-mssql/src/types/chrono.rs +++ b/sqlx-mssql/src/types/chrono.rs @@ -1,4 +1,4 @@ -use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use crate::database::MssqlArgumentValue; use crate::decode::Decode; @@ -37,6 +37,7 @@ impl Decode<'_, Mssql> for NaiveDateTime { fn decode(value: MssqlValueRef<'_>) -> Result { match value.data { MssqlData::NaiveDateTime(v) => Ok(*v), + MssqlData::DateTimeFixedOffset(v) => Ok(v.naive_utc()), MssqlData::Null => Err("unexpected NULL".into()), _ => Err(format!("expected datetime, got {:?}", value.data).into()), } @@ -66,6 +67,7 @@ impl Decode<'_, Mssql> for NaiveDate { match value.data { MssqlData::NaiveDate(v) => Ok(*v), MssqlData::NaiveDateTime(v) => Ok(v.date()), + MssqlData::DateTimeFixedOffset(v) => Ok(v.naive_utc().date()), MssqlData::Null => Err("unexpected NULL".into()), _ => Err(format!("expected date, got {:?}", value.data).into()), } @@ -130,8 +132,49 @@ impl Decode<'_, Mssql> for DateTime { fn decode(value: MssqlValueRef<'_>) -> Result { match value.data { MssqlData::NaiveDateTime(v) => Ok(v.and_utc()), + MssqlData::DateTimeFixedOffset(v) => Ok(v.with_timezone(&Utc)), MssqlData::Null => Err("unexpected NULL".into()), _ => Err(format!("expected datetime, got {:?}", value.data).into()), } } } + +// ── DateTime ─────────────────────────────────────────────────── + +impl Type for DateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIMEOFFSET") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.base_name(), + "DATETIMEOFFSET" | "DATETIME2" + ) + } +} + +impl Encode<'_, Mssql> for DateTime { + fn encode_by_ref( + &self, + buf: &mut Vec, + ) -> Result { + buf.push(MssqlArgumentValue::DateTimeFixedOffset(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for DateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::DateTimeFixedOffset(v) => Ok(*v), + MssqlData::NaiveDateTime(v) => { + // Assume UTC if no offset information + let utc = v.and_utc(); + Ok(utc.with_timezone(&FixedOffset::east_opt(0).unwrap())) + } + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetimeoffset, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index c44834d2eb..2fffd1f647 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -24,6 +24,8 @@ pub(crate) enum MssqlData { NaiveDate(chrono::NaiveDate), #[cfg(feature = "chrono")] NaiveTime(chrono::NaiveTime), + #[cfg(feature = "chrono")] + DateTimeFixedOffset(chrono::DateTime), #[cfg(feature = "uuid")] Uuid(uuid::Uuid), #[cfg(feature = "rust_decimal")] @@ -168,10 +170,11 @@ pub(crate) fn column_data_to_mssql_data(data: &tiberius::ColumnData<'_>) -> Mssq * 10i64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); - // Subtract the offset to convert to UTC - let naive = chrono::NaiveDateTime::new(date, time) - - chrono::Duration::minutes(dto.offset() as i64); - MssqlData::NaiveDateTime(naive) + let naive = chrono::NaiveDateTime::new(date, time); + let offset_secs = dto.offset() as i32 * 60; + let fixed_offset = chrono::FixedOffset::east_opt(offset_secs) + .expect("valid offset from tiberius"); + MssqlData::DateTimeFixedOffset(naive.and_local_timezone(fixed_offset).unwrap()) } #[cfg(feature = "uuid")] From 926722e092d2aa1530e461c094b7641fade488b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Mon, 16 Feb 2026 16:24:26 -0500 Subject: [PATCH 12/33] feat: fix SMALLMONEY type mapping, expand test coverage, and document _persistent flag for MSSQL Fix Money4 (SMALLMONEY) being incorrectly mapped to "MONEY" instead of "SMALLMONEY", enabling the Any driver's existing SMALLMONEY handling. Add comprehensive type tests (integer edge cases, Unicode, large strings, binary, NULL, XML, DateTime, SMALLMONEY) and integration tests (multiple result sets, column metadata, error recovery, many parameters, isolation levels, advisory lock guard). Author: Pablo Carrera --- sqlx-mssql/src/any.rs | 3 + sqlx-mssql/src/connection/executor.rs | 3 + sqlx-mssql/src/type_info.rs | 3 +- tests/mssql/mssql.rs | 216 ++++++++++++++++++++++++++ tests/mssql/types.rs | 122 ++++++++++++++- 5 files changed, 343 insertions(+), 4 deletions(-) diff --git a/sqlx-mssql/src/any.rs b/sqlx-mssql/src/any.rs index 63cd7e3b62..a1f6ace02a 100644 --- a/sqlx-mssql/src/any.rs +++ b/sqlx-mssql/src/any.rs @@ -78,6 +78,8 @@ impl AnyConnectionBackend for MssqlConnection { fn fetch_many( &mut self, query: SqlStr, + // MSSQL always sends parameterized queries via tiberius (no server-side + // prepared statement caching), so the persistent flag has no effect. _persistent: bool, arguments: Option, ) -> BoxStream<'_, sqlx_core::Result>> { @@ -108,6 +110,7 @@ impl AnyConnectionBackend for MssqlConnection { fn fetch_optional( &mut self, query: SqlStr, + // See fetch_many: MSSQL has no server-side prepared statement caching. _persistent: bool, arguments: Option, ) -> BoxFuture<'_, sqlx_core::Result>> { diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 49bae2ee37..55fdb51587 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -331,6 +331,9 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { E: 'q, { let arguments = query.take_arguments().map_err(Error::Encode); + // MSSQL always sends parameterized queries via tiberius — there is no + // server-side prepared statement caching like PostgreSQL's, so this + // flag is intentionally unused. let _persistent = query.persistent(); let sql = query.sql(); diff --git a/sqlx-mssql/src/type_info.rs b/sqlx-mssql/src/type_info.rs index 80a44223d8..dbe612209a 100644 --- a/sqlx-mssql/src/type_info.rs +++ b/sqlx-mssql/src/type_info.rs @@ -57,7 +57,8 @@ pub(crate) fn type_name_for_tiberius(col_type: &tiberius::ColumnType) -> &'stati tiberius::ColumnType::Daten => "DATE", tiberius::ColumnType::Timen => "TIME", tiberius::ColumnType::Decimaln | tiberius::ColumnType::Numericn => "DECIMAL", - tiberius::ColumnType::Money | tiberius::ColumnType::Money4 => "MONEY", + tiberius::ColumnType::Money => "MONEY", + tiberius::ColumnType::Money4 => "SMALLMONEY", tiberius::ColumnType::BigVarChar | tiberius::ColumnType::NVarchar => "NVARCHAR", tiberius::ColumnType::BigChar | tiberius::ColumnType::NChar => "NCHAR", tiberius::ColumnType::BigVarBin => "VARBINARY", diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 2aa26d3e8e..dcbfeaccf2 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -5,6 +5,7 @@ use sqlx::mssql::MssqlRow; use sqlx_test::new; use std::sync::atomic::{AtomicI32, Ordering}; use std::time::Duration; +use sqlx::mssql::{MssqlAdvisoryLock, MssqlIsolationLevel}; #[sqlx_macros::test] async fn it_connects() -> anyhow::Result<()> { @@ -459,3 +460,218 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_can_query_multiple_result_sets() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // A batch that produces two result sets + let results = conn + .run("SELECT 1 AS a; SELECT 2 AS b, 3 AS c;", None) + .await?; + + // First result set: one row with column "a" + let mut rows_first = Vec::new(); + let mut rows_second = Vec::new(); + let mut result_count = 0; + + for item in &results { + match item { + either::Either::Left(_) => { + result_count += 1; + } + either::Either::Right(row) => { + if result_count == 0 { + rows_first.push(row); + } else { + rows_second.push(row); + } + } + } + } + + assert_eq!(rows_first.len(), 1); + assert_eq!(rows_first[0].try_get::("a")?, 1); + + assert_eq!(rows_second.len(), 1); + assert_eq!(rows_second[0].try_get::("b")?, 2); + assert_eq!(rows_second[0].try_get::("c")?, 3); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_inspect_column_metadata() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let statement = conn + .prepare("SELECT CAST(1 AS INT) AS int_col, CAST('hello' AS NVARCHAR(50)) AS str_col, CAST(NULL AS BIGINT) AS nullable_col".into_sql_str()) + .await?; + + assert_eq!(statement.column(0).name(), "int_col"); + assert_eq!(statement.column(1).name(), "str_col"); + assert_eq!(statement.column(2).name(), "nullable_col"); + + assert_eq!(statement.column(0).type_info().name(), "INT"); + // sp_describe_first_result_set returns "NVARCHAR(50)" for typed NVARCHAR + assert!(statement.column(1).type_info().name().starts_with("NVARCHAR")); + assert_eq!(statement.column(2).type_info().name(), "BIGINT"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_reuse_connection_after_error() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Cause an error + let res: Result<_, sqlx::Error> = + sqlx::query("SELECT * FROM this_table_does_not_exist_12345") + .execute(&mut conn) + .await; + assert!(res.is_err()); + + // Connection should still be usable + let val: (i32,) = sqlx::query_as("SELECT 42").fetch_one(&mut conn).await?; + assert_eq!(val.0, 42); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_bind_many_parameters() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Build a query with 100 parameters: SELECT @p1 + @p2 + ... + @p100 + let param_refs: Vec = (1..=100).map(|i| format!("@p{i}")).collect(); + let sql = format!("SELECT {}", param_refs.join(" + ")); + + let mut query = sqlx::query_scalar::<_, i32>(&sql); + for _ in 0..100 { + query = query.bind(1_i32); + } + + let result: i32 = query.fetch_one(&mut conn).await?; + assert_eq!(result, 100); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_handles_special_characters_in_strings() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Single quotes + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind("it's a test") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, "it's a test"); + + // Backslashes + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind(r"C:\Users\test") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, r"C:\Users\test"); + + // Unicode + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind("\u{1F600} hello \u{4E16}\u{754C}") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, "\u{1F600} hello \u{4E16}\u{754C}"); + + // Newlines and tabs + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind("line1\nline2\ttab") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, "line1\nline2\ttab"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_use_transaction_isolation_levels() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Start a transaction with READ UNCOMMITTED isolation + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::ReadUncommitted) + .await?; + + // Verify we can do work inside the transaction + let val: (i32,) = sqlx::query_as("SELECT 1").fetch_one(&mut *tx).await?; + assert_eq!(val.0, 1); + + tx.commit().await?; + + // Start a transaction with SERIALIZABLE isolation + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Serializable) + .await?; + + let val: (i32,) = sqlx::query_as("SELECT 2").fetch_one(&mut *tx).await?; + assert_eq!(val.0, 2); + + tx.rollback().await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_use_advisory_lock_guard() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Need a transaction context for sp_getapplock with Session owner + // Actually, Session-scoped locks work outside transactions too. + let lock = MssqlAdvisoryLock::new("sqlx_test_lock_guard"); + + // Acquire the lock via the RAII guard + let mut guard = lock.acquire_guard(&mut conn).await?; + + // Use the connection through the guard + let val: (i32,) = sqlx::query_as("SELECT 99") + .fetch_one(&mut *guard) + .await?; + assert_eq!(val.0, 99); + + // Release the lock and get the connection back + let conn = guard.release_now().await?; + + // Verify we can still use the connection + let val: (i32,) = sqlx::query_as("SELECT 100") + .fetch_one(conn) + .await?; + assert_eq!(val.0, 100); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_try_acquire_advisory_lock() -> anyhow::Result<()> { + let mut conn1 = new::().await?; + let mut conn2 = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_try_lock"); + + // Acquire on conn1 + lock.acquire(&mut conn1).await?; + + // Try to acquire on conn2 — should fail (return false) since it's exclusive + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(!acquired); + + // Release on conn1 + let released = lock.release(&mut conn1).await?; + assert!(released); + + // Now conn2 should be able to acquire + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(acquired); + + lock.release(&mut conn2).await?; + + Ok(()) +} diff --git a/tests/mssql/types.rs b/tests/mssql/types.rs index de5d840dd6..c15785eab3 100644 --- a/tests/mssql/types.rs +++ b/tests/mssql/types.rs @@ -20,11 +20,26 @@ test_type!(i8( "CAST(0 AS TINYINT)" == 0_i8 )); -test_type!(i16(Mssql, "CAST(21415 AS SMALLINT)" == 21415_i16)); +test_type!(i16( + Mssql, + "CAST(21415 AS SMALLINT)" == 21415_i16, + "CAST(-32768 AS SMALLINT)" == i16::MIN, + "CAST(32767 AS SMALLINT)" == i16::MAX, +)); -test_type!(i32(Mssql, "CAST(2141512 AS INT)" == 2141512_i32)); +test_type!(i32( + Mssql, + "CAST(2141512 AS INT)" == 2141512_i32, + "CAST(-2147483648 AS INT)" == i32::MIN, + "CAST(2147483647 AS INT)" == i32::MAX, +)); -test_type!(i64(Mssql, "CAST(32324324432 AS BIGINT)" == 32324324432_i64)); +test_type!(i64( + Mssql, + "CAST(32324324432 AS BIGINT)" == 32324324432_i64, + "CAST(-9223372036854775808 AS BIGINT)" == i64::MIN, + "CAST(9223372036854775807 AS BIGINT)" == i64::MAX, +)); test_type!(f32( Mssql, @@ -50,6 +65,12 @@ test_type!(f64_smallmoney( "CAST(-1234.5678 AS SMALLMONEY)" == -1234.5678_f64, )); +#[cfg(feature = "rust_decimal")] +test_type!(rust_decimal_smallmoney(Mssql, + "CAST(214748.3647 AS SMALLMONEY)" == sqlx::types::Decimal::new(2147483647, 4), + "CAST(0 AS SMALLMONEY)" == sqlx::types::Decimal::ZERO, +)); + test_type!(str_nvarchar(Mssql, "CAST('this is foo' as NVARCHAR)" == "this is foo", )); @@ -74,6 +95,58 @@ test_type!(bytes>(Mssql, == vec![0_u8; 8], )); +test_type!(bytes_single>(Mssql, + "CAST(0xFF AS VARBINARY(MAX))" == vec![0xFF_u8], +)); + +test_type!(bytes_large>(Mssql, + "CAST(REPLICATE(CAST(0xAB AS VARBINARY(MAX)), 10000) AS VARBINARY(MAX))" + == vec![0xAB_u8; 10000], +)); + +test_type!(str_nchar(Mssql, + "CAST('hello' AS NCHAR(5))" == "hello", +)); + +test_type!(str_varchar(Mssql, + "CAST('hello varchar' AS VARCHAR(50))" == "hello varchar", +)); + +test_type!(str_unicode(Mssql, + "CAST(N'\u{1F600}\u{1F680}\u{2764}' AS NVARCHAR(MAX))" == "\u{1F600}\u{1F680}\u{2764}", + "CAST(N'\u{4F60}\u{597D}\u{4E16}\u{754C}' AS NVARCHAR(MAX))" == "\u{4F60}\u{597D}\u{4E16}\u{754C}", +)); + +test_type!(str_nvarchar_max_large(Mssql, + "REPLICATE(CAST(N'x' AS NVARCHAR(MAX)), 10000)" + == "x".repeat(10000), +)); + +test_type!(null_bool>(Mssql, + "CAST(NULL AS BIT)" == None::, +)); + +test_type!(null_string>(Mssql, + "CAST(NULL AS NVARCHAR(100))" == None::, +)); + +test_type!(null_i64>(Mssql, + "CAST(NULL AS BIGINT)" == None::, +)); + +test_type!(null_f64>(Mssql, + "CAST(NULL AS FLOAT)" == None::, +)); + +test_type!(null_bytes>>(Mssql, + "CAST(NULL AS VARBINARY(MAX))" == None::>, +)); + +test_type!(xml(Mssql, + "CAST('hello' AS XML)" + == sqlx::mssql::MssqlXml::from("hello".to_owned()), +)); + #[cfg(feature = "uuid")] test_type!(uuid(Mssql, "CAST('00000000-0000-0000-0000-000000000000' AS UNIQUEIDENTIFIER)" @@ -91,6 +164,8 @@ mod chrono { type NaiveTime = sqlx::types::chrono::NaiveTime; type NaiveDateTime = sqlx::types::chrono::NaiveDateTime; type DateTimeUtc = sqlx::types::chrono::DateTime; + type DateTimeFixed = sqlx::types::chrono::DateTime; + type FixedOffset = sqlx::types::chrono::FixedOffset; test_type!(chrono_naive_date(Mssql, "CAST('2001-01-05' AS DATE)" @@ -122,6 +197,47 @@ mod chrono { .unwrap() .and_utc(), )); + + test_type!(chrono_date_time_fixed_utc(Mssql, + "CAST('2019-01-02 05:10:20.000 +00:00' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2019, 1, 2) + .unwrap() + .and_hms_opt(5, 10, 20) + .unwrap() + .and_local_timezone(FixedOffset::east_opt(0).unwrap()) + .unwrap(), + )); + + test_type!(chrono_date_time_fixed_positive(Mssql, + "CAST('2024-06-15 14:30:00.000 +05:30' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2024, 6, 15) + .unwrap() + .and_hms_opt(14, 30, 0) + .unwrap() + .and_local_timezone(FixedOffset::east_opt(5 * 3600 + 30 * 60).unwrap()) + .unwrap(), + )); + + test_type!(chrono_date_time_fixed_negative(Mssql, + "CAST('2024-12-25 08:00:00.000 -08:00' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2024, 12, 25) + .unwrap() + .and_hms_opt(8, 0, 0) + .unwrap() + .and_local_timezone(FixedOffset::west_opt(8 * 3600).unwrap()) + .unwrap(), + )); + + // Verify DateTime can decode from DATETIMEOFFSET with non-zero offset + // (the value should be converted to UTC) + test_type!(chrono_date_time_utc_from_offset(Mssql, + "CAST('2024-06-15 14:30:00.000 +05:30' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2024, 6, 15) + .unwrap() + .and_hms_opt(9, 0, 0) + .unwrap() + .and_utc(), + )); } #[cfg(feature = "time")] From 4c1827d932fd262809a945f229fecab52e7c7656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Wed, 18 Feb 2026 21:47:09 -0500 Subject: [PATCH 13/33] docs: mssql docs support markdown --- MSSQL_SUPPORT.md | 514 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 514 insertions(+) create mode 100644 MSSQL_SUPPORT.md diff --git a/MSSQL_SUPPORT.md b/MSSQL_SUPPORT.md new file mode 100644 index 0000000000..bbc9b4d09a --- /dev/null +++ b/MSSQL_SUPPORT.md @@ -0,0 +1,514 @@ +# MSSQL (SQL Server) Support for SQLx + +This document covers all MSSQL/SQL Server additions in the `feat/mssql-support` branch, built on top of the [Tiberius](https://github.com/prisma/tiberius) TDS driver. + +--- + +## Table of Contents + +- [Overview](#overview) +- [Getting Started](#getting-started) +- [Connection & Authentication](#connection--authentication) +- [SSL/TLS](#ssltls) +- [Type Mappings](#type-mappings) +- [Compile-Time Query Macros](#compile-time-query-macros) +- [Any Driver Support](#any-driver-support) +- [Migrations](#migrations) +- [Transactions & Isolation Levels](#transactions--isolation-levels) +- [Advisory Locks](#advisory-locks) +- [Bulk Insert](#bulk-insert) +- [QueryBuilder](#querybuilder) +- [XML Type](#xml-type) +- [Examples](#examples) +- [Docker & CI](#docker--ci) +- [Test Coverage](#test-coverage) +- [Feature Flags](#feature-flags) + +--- + +## Overview + +Full SQL Server support has been added to SQLx, bringing feature parity with PostgreSQL, MySQL, and SQLite where applicable. The implementation provides: + +- Complete type system mapping between Rust and SQL Server types +- Four authentication methods (SQL Server, Windows/NTLM, Integrated/GSSAPI, Azure AD) +- SSL/TLS with configurable modes +- Compile-time checked queries via macros +- Runtime-polymorphic `Any` driver support +- Database migrations with `sqlx migrate` +- RAII advisory locks via `sp_getapplock`/`sp_releaseapplock` +- Bulk insert via the TDS `INSERT BULK` protocol +- Transaction isolation levels +- Testing infrastructure with Docker Compose (MSSQL 2019 & 2022) + +**URL schemes:** `mssql://` and `sqlserver://` + +--- + +## Getting Started + +Add SQLx with the `mssql` feature to your `Cargo.toml`: + +```toml +[dependencies] +sqlx = { version = "0.8", features = ["mssql", "runtime-tokio"] } +``` + +Connect to a database: + +```rust +use sqlx::mssql::MssqlPool; + +let pool = MssqlPool::connect("mssql://sa:YourPassword@localhost/mydb").await?; + +let row: (i32,) = sqlx::query_as("SELECT @p1") + .bind(42i32) + .fetch_one(&pool) + .await?; +``` + +--- + +## Connection & Authentication + +**Connection string format:** + +``` +mssql://[user[:password]@]host[:port][/database][?properties] +``` + +**Connection options:** + +| Option | Default | Description | +|--------|---------|-------------| +| `host` | `localhost` | Database server hostname | +| `port` | `1433` | Port number | +| `username` | `sa` | Username | +| `password` | — | Password | +| `database` | — | Database name | +| `instance` | — | SQL Server named instance | +| `app_name` | `sqlx` | Application name sent to server | +| `statement-cache-capacity` | `100` | Max cached prepared statements | +| `application_intent` | `read_write` | `read_write` or `read_only` (Always On replicas) | + +### Authentication Methods + +**1. SQL Server Auth (default)** + +Standard username/password authentication. + +```rust +let pool = MssqlPool::connect("mssql://sa:password@localhost/mydb").await?; +``` + +**2. Windows/NTLM Auth** (feature: `winauth`) + +Supports `domain\user` syntax. + +```rust +let opts = MssqlConnectOptions::new() + .host("localhost") + .windows_auth(true); +``` + +**3. Integrated Auth / GSSAPI** (feature: `integrated-auth-gssapi`) + +Uses SSPI on Windows and Kerberos on Unix. + +```rust +let opts = MssqlConnectOptions::new() + .host("localhost") + .integrated_auth(true); +``` + +**4. Azure AD Token Auth** + +Pass a bearer token for Azure Active Directory authentication. + +```rust +let opts = MssqlConnectOptions::new() + .host("your-server.database.windows.net") + .aad_token("eyJ0eX..."); +``` + +--- + +## SSL/TLS + +Configurable encryption modes for the TDS connection. + +| Mode | Description | +|------|-------------| +| `Disabled` | No encryption | +| `LoginOnly` | Encrypt login packet only | +| `Preferred` (default) | Encrypt if server supports it | +| `Required` | Always encrypt, fail otherwise | + +**Connection string parameters:** + +| Parameter | Description | +|-----------|-------------| +| `sslmode` / `ssl_mode` | `disabled`, `login_only`, `preferred`, `required` | +| `encrypt` | Legacy alias: `true` = required, `false` = disabled | +| `trust_server_certificate` | Trust without validation (default: `false`) | +| `trust_server_certificate_ca` | Path to CA certificate file (`.pem`, `.crt`, `.der`) | + +``` +mssql://sa:password@localhost/mydb?sslmode=required&trust_server_certificate=true +``` + +--- + +## Type Mappings + +### Primitive Types + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `bool` | `BIT` | +| `u8` | `TINYINT` (0–255) | +| `i8` | `TINYINT` (0–127) | +| `i16` | `SMALLINT` | +| `i32` | `INT` | +| `i64` | `BIGINT` | +| `f32` | `REAL`, `FLOAT` | +| `f64` | `REAL`, `FLOAT`, `MONEY`, `SMALLMONEY` | +| `&str` / `String` | `NVARCHAR` | +| `&[u8]` / `Vec` | `VARBINARY` | + +### Feature-Gated Types + +#### `uuid` + +| Rust Type | SQL Server Type | +|-----------|----------------| +| `uuid::Uuid` | `UNIQUEIDENTIFIER` | + +#### `rust_decimal` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `rust_decimal::Decimal` | `DECIMAL`, `NUMERIC`, `MONEY`, `SMALLMONEY` | + +#### `bigdecimal` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `bigdecimal::BigDecimal` | `DECIMAL`, `NUMERIC`, `MONEY` | + +#### `chrono` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `chrono::NaiveDate` | `DATE` | +| `chrono::NaiveTime` | `TIME` | +| `chrono::NaiveDateTime` | `DATETIME2`, `DATETIME`, `SMALLDATETIME` | +| `chrono::DateTime` | `DATETIME2`, `DATETIMEOFFSET` | +| `chrono::DateTime` | `DATETIMEOFFSET`, `DATETIME2` | + +#### `time` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `time::Date` | `DATE` | +| `time::Time` | `TIME` | +| `time::PrimitiveDateTime` | `DATETIME2`, `DATETIME`, `SMALLDATETIME` | +| `time::OffsetDateTime` | `DATETIMEOFFSET`, `DATETIME2` | + +#### `json` + +| Rust Type | SQL Server Type | +|-----------|----------------| +| `serde_json::Value` / `Json` | `NVARCHAR` (stored as JSON string) | + +#### XML + +| Rust Type | SQL Server Type | +|-----------|----------------| +| `MssqlXml` | `XML` | + +### Nullable Types + +All types above support `Option` for nullable columns. + +--- + +## Compile-Time Query Macros + +The standard SQLx macros work with MSSQL when `DATABASE_URL` is set to an `mssql://` connection string: + +```rust +// Compile-time checked query +let row = sqlx::query!("SELECT @p1 AS value", 42i32) + .fetch_one(&pool) + .await?; + +// With custom return type +#[derive(sqlx::FromRow)] +struct User { + id: i32, + name: String, +} + +let user = sqlx::query_as!(User, "SELECT id, name FROM users WHERE id = @p1", 1i32) + .fetch_one(&pool) + .await?; + +// Scalar queries +let count = sqlx::query_scalar!("SELECT COUNT(*) FROM users") + .fetch_one(&pool) + .await?; +``` + +**Offline mode** is also supported — run `cargo sqlx prepare` to generate query metadata for CI builds without a live database. + +--- + +## Any Driver Support + +MSSQL is fully integrated with the `Any` runtime-polymorphic driver, enabled via the `any` feature flag. + +```rust +use sqlx::any::AnyPool; + +// Connects to whichever database the URL points to +let pool = AnyPool::connect("mssql://sa:password@localhost/mydb").await?; + +let rows = sqlx::query("SELECT 1 + 1 AS result") + .fetch_all(&pool) + .await?; +``` + +All standard operations work through `Any`: queries, transactions, ping, close, and prepared statements. + +--- + +## Migrations + +MSSQL supports the full `sqlx migrate` workflow. + +```bash +# Create a new migration +sqlx migrate add create_users_table + +# Run pending migrations +sqlx migrate run + +# Revert the last migration +sqlx migrate revert +``` + +**Programmatic usage:** + +```rust +sqlx::migrate!("./migrations") + .run(&pool) + .await?; +``` + +**Database lifecycle:** + +- `create_database(url)` — Creates a database via `CREATE DATABASE [name]` +- `database_exists(url)` — Checks existence via `DB_ID()` +- `drop_database(url)` — Drops with `ALTER DATABASE SET SINGLE_USER WITH ROLLBACK IMMEDIATE` for cleanup + +**No-transaction migrations** are supported for DDL operations that cannot run inside a transaction. + +--- + +## Transactions & Isolation Levels + +Standard transaction support with configurable isolation levels. + +```rust +let mut tx = pool.begin().await?; + +sqlx::query("INSERT INTO users (name) VALUES (@p1)") + .bind("Alice") + .execute(&mut *tx) + .await?; + +tx.commit().await?; +``` + +### Isolation Levels + +| Level | Description | +|-------|-------------| +| `ReadUncommitted` | Dirty reads allowed | +| `ReadCommitted` | Default SQL Server isolation | +| `RepeatableRead` | Prevents non-repeatable reads | +| `Snapshot` | Row versioning-based isolation | +| `Serializable` | Strictest isolation | + +```rust +use sqlx::mssql::MssqlIsolationLevel; + +let mut tx = pool + .begin_with_isolation(MssqlIsolationLevel::Snapshot) + .await?; +``` + +--- + +## Advisory Locks + +Application-level named locks using SQL Server's `sp_getapplock` and `sp_releaseapplock`, with an RAII guard pattern. + +### Lock Modes + +| Mode | Compatible With | +|------|----------------| +| `Shared` | Shared, Update | +| `Update` | Shared only | +| `Exclusive` (default) | None | + +### Usage + +```rust +use sqlx::mssql::{MssqlAdvisoryLock, MssqlAdvisoryLockMode}; + +// Create an exclusive lock +let lock = MssqlAdvisoryLock::new("my_resource"); + +// Or with a specific mode +let lock = MssqlAdvisoryLock::with_mode("my_resource", MssqlAdvisoryLockMode::Shared); + +// RAII guard (preferred) — lock released when guard is dropped +let guard = lock.acquire_guard(&mut conn).await?; +// ... do work while lock is held ... +let conn = guard.release_now().await?; // explicit release + +// Non-blocking attempt +if let Some(guard) = lock.try_acquire_guard(&mut conn).await? { + // lock acquired +} +``` + +--- + +## Bulk Insert + +High-performance data loading via the TDS `INSERT BULK` protocol. + +```rust +let mut bulk = conn.bulk_insert("my_table").await?; + +for item in &data { + bulk.send(tiberius::IntoRow::into_row(item)).await?; +} + +let rows_affected = bulk.finalize().await?; +``` + +Supports tuples up to 10 elements via `tiberius::IntoRow`. + +--- + +## QueryBuilder + +MSSQL uses `@p1`, `@p2`, etc. as parameter placeholders. The `QueryBuilder` handles this automatically: + +```rust +let mut qb = QueryBuilder::::new("SELECT * FROM users WHERE "); +qb.push("name = ").push_bind("Alice"); +qb.push(" AND age > ").push_bind(21); +// Produces: SELECT * FROM users WHERE name = @p1 AND age > @p2 +``` + +--- + +## XML Type + +A dedicated `MssqlXml` wrapper type distinguishes XML columns from regular strings. + +```rust +use sqlx::mssql::MssqlXml; + +let xml = MssqlXml::from("hello".to_string()); + +sqlx::query("INSERT INTO docs (content) VALUES (@p1)") + .bind(&xml) + .execute(&pool) + .await?; + +let result: MssqlXml = sqlx::query_scalar("SELECT content FROM docs") + .fetch_one(&pool) + .await?; +``` + +--- + +## Examples + +A full CRUD Todo application is available at `examples/mssql/todos/`, demonstrating: + +- Connection pooling +- Migrations +- Query execution +- Error handling + +--- + +## Docker & CI + +### Docker Compose + +The test suite includes Docker Compose configurations for MSSQL 2019 and 2022: + +```bash +docker compose -f tests/docker-compose.yml up mssql_2022 -d +``` + +**Services:** + +| Service | Image | Port | +|---------|-------|------| +| `mssql_2022` | `mcr.microsoft.com/mssql/server:2022-latest` | 1433 | +| `mssql_2019` | `mcr.microsoft.com/mssql/server:2019-latest` | 1433 | + +### CI Matrix + +The GitHub Actions workflow tests across: + +- **MSSQL versions:** 2019, 2022 +- **Async runtimes:** tokio, async-global-executor, smol +- **TLS backends:** native-tls, rustls-aws-lc-rs, rustls-ring, none + +--- + +## Test Coverage + +Comprehensive test suite in `tests/mssql/`: + +| Area | File | What's Tested | +|------|------|---------------| +| Core queries | `mssql.rs` | Connections, SELECT, INSERT, parameters, large result sets, error handling | +| Type round-trips | `types.rs` | All primitive and feature-gated types with boundary values, NULLs, Unicode, large data | +| Test attribute | `test-attr.rs` | `#[sqlx_macros::test]` macro with automatic test DB setup | +| Isolation levels | `isolation-level.rs` | All five isolation level configurations | +| Advisory locks | `advisory-lock.rs` | Acquire, release, guard pattern, all lock modes | +| Bulk insert | `bulk-insert.rs` | High-performance loading, multi-row operations | +| Derives | `derives.rs` | `#[derive(FromRow)]`, custom field mappings | +| Query builder | `query_builder.rs` | Dynamic query construction, parameter handling | +| Error handling | `error.rs` | Database error inspection, error details | +| Compile-time macros | `tests/mssql-macros/` | Online and offline macro verification | + +--- + +## Feature Flags + +| Feature | Description | +|---------|-------------| +| `mssql` | Enable the MSSQL driver | +| `any` | Enable runtime-polymorphic `Any` driver | +| `migrate` | Enable database migrations | +| `json` | JSON type support via `serde_json` | +| `uuid` | `uuid::Uuid` type support | +| `chrono` | `chrono` datetime types | +| `time` | `time` crate datetime types | +| `rust_decimal` | `rust_decimal::Decimal` support | +| `bigdecimal` | `bigdecimal::BigDecimal` support | +| `winauth` | Windows/NTLM authentication | +| `integrated-auth-gssapi` | Integrated auth (Kerberos on Unix, SSPI on Windows) | +| `offline` | Offline mode for compile-time macros | From 87cc6959f256f326e5db7316936d66db4121f3f0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 22 Feb 2026 23:05:51 -0500 Subject: [PATCH 14/33] fix: escape SQL identifiers, add SMALLMONEY compat, and harden migration lock - Escape `]` and `'` in database/schema names interpolated into DDL in migrate.rs and testing/mod.rs to prevent SQL injection - Add SMALLMONEY to compatible() for rust_decimal and bigdecimal types - Wrap sp_getapplock in DECLARE/THROW to surface lock failures as SQL errors - Add compatible() overrides for DATE, TIME, and UNIQUEIDENTIFIER types - Fix swapped Real/Double null type mappings in Any arguments - Replace panicking expects with proper Error returns in executor - Add migration test harness and sample migrations for MSSQL Author: Pablo Carrera --- Cargo.toml | 5 ++ sqlx-core/src/any/arguments.rs | 4 +- sqlx-mssql/issues/mssql-sp-return-value.md | 47 +++++++++++ sqlx-mssql/src/connection/executor.rs | 18 ++-- sqlx-mssql/src/migrate.rs | 23 ++++-- sqlx-mssql/src/testing/mod.rs | 14 ++-- sqlx-mssql/src/types/bigdecimal.rs | 2 +- sqlx-mssql/src/types/chrono.rs | 8 ++ sqlx-mssql/src/types/rust_decimal.rs | 2 +- sqlx-mssql/src/types/time.rs | 8 ++ sqlx-mssql/src/types/uuid.rs | 8 ++ tests/mssql/migrate.rs | 82 +++++++++++++++++++ .../20220721124650_add_table.down.sql | 1 + .../20220721124650_add_table.up.sql | 7 ++ .../20220721125033_modify_column.down.sql | 2 + .../20220721125033_modify_column.up.sql | 2 + .../20220721115250_add_test_table.sql | 7 ++ .../20220721115524_convert_type.sql | 34 ++++++++ 18 files changed, 250 insertions(+), 24 deletions(-) create mode 100644 sqlx-mssql/issues/mssql-sp-return-value.md create mode 100644 tests/mssql/migrate.rs create mode 100644 tests/mssql/migrations_reversible/20220721124650_add_table.down.sql create mode 100644 tests/mssql/migrations_reversible/20220721124650_add_table.up.sql create mode 100644 tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql create mode 100644 tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql create mode 100644 tests/mssql/migrations_simple/20220721115250_add_test_table.sql create mode 100644 tests/mssql/migrations_simple/20220721115524_convert_type.sql diff --git a/Cargo.toml b/Cargo.toml index d2d66ed1e6..f663e04f09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -518,3 +518,8 @@ required-features = ["mssql"] name = "mssql-bulk-insert" path = "tests/mssql/bulk-insert.rs" required-features = ["mssql"] + +[[test]] +name = "mssql-migrate" +path = "tests/mssql/migrate.rs" +required-features = ["mssql", "macros", "migrate"] diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs index 59d6f4d6e0..abb7098072 100644 --- a/sqlx-core/src/any/arguments.rs +++ b/sqlx-core/src/any/arguments.rs @@ -67,8 +67,8 @@ impl AnyArguments { AnyValueKind::Null(AnyTypeInfoKind::SmallInt) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::Integer) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::BigInt) => out.add(Option::::None), - AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::::None), - AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::Text) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::Blob) => out.add(Option::>::None), AnyValueKind::Bool(b) => out.add(b), diff --git a/sqlx-mssql/issues/mssql-sp-return-value.md b/sqlx-mssql/issues/mssql-sp-return-value.md new file mode 100644 index 0000000000..3392247b8e --- /dev/null +++ b/sqlx-mssql/issues/mssql-sp-return-value.md @@ -0,0 +1,47 @@ +# MSSQL: surface stored procedure return values through the executor + +## Context + +`sp_getapplock` communicates success/failure through its **return value**, not through SQL errors: + +| Return code | Meaning | +|---|---| +| `0` | Lock granted immediately | +| `1` | Lock granted after waiting | +| `-1` | Timed out | +| `-2` | Cancelled | +| `-3` | Deadlock victim | +| `-999` | Parameter validation error | + +## Current workaround + +We wrap the call in a `DECLARE @r / IF @r < 0 THROW` pattern so that a failed lock becomes a SQL error that `execute` can catch: + +```sql +DECLARE @r INT; +EXEC @r = sp_getapplock @Resource = 'sqlx_migrations', + @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1; +IF @r < 0 THROW 50000, 'Failed to acquire migration lock', 1; +``` + +This is sufficient for production use — the lock works correctly in all realistic scenarios, and failures are now surfaced as errors instead of being silently ignored. + +## Ideal long-term solution + +The proper fix is for the MSSQL executor to capture the TDS `RETURNSTATUS` token that SQL Server sends after stored procedure execution, and expose it through the driver's result types. + +### What would need to change + +1. **`collect_results` in `executor.rs`** — currently only handles `QueryItem::Metadata` and `QueryItem::Row`. The TDS return status token is not surfaced by tiberius's `QueryStream`. Investigate whether tiberius exposes this via `ExecuteResult` (from `.execute()`) or if it requires upstream changes. + +2. **`MssqlQueryResult`** — currently only holds `rows_affected: u64`. Would need an additional field like `return_status: Option` to carry the stored procedure return value. + +3. **`Migrate::lock` trait** — the signature is `Result<(), MigrateError>`, which is fine (we'd just check the return status and map negatives to `Err`). No trait changes needed. + +### Why this matters beyond migrations + +Any user calling stored procedures via `execute` today cannot inspect return values. This is a general limitation of the MSSQL driver, not specific to migrations. The `THROW` workaround only works when you control the SQL — it doesn't help when calling third-party procedures that use return codes for flow control. + +## Priority + +**Low** — the THROW workaround fully covers the migration lock case, and stored procedure return values are a niche use case. This is a correctness/completeness improvement, not a bug fix. diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 55fdb51587..7322feeaec 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -219,10 +219,12 @@ impl MssqlConnection { // Convert BigDecimal to tiberius Numeric let (bigint, exponent) = v.as_bigint_and_exponent(); let scale = exponent.max(0) as u8; - // Convert to i128 for Numeric — panics if too large - let value: i128 = bigint - .to_i128() - .expect("BigDecimal value too large for SQL NUMERIC"); + // Convert to i128 for Numeric + let value: i128 = bigint.to_i128().ok_or_else(|| { + Error::Encode( + format!("BigDecimal value too large for SQL NUMERIC: {v}").into(), + ) + })?; let cd = tiberius::ColumnData::Numeric(Some( tiberius::numeric::Numeric::new_with_scale(value, scale), )); @@ -290,8 +292,12 @@ async fn collect_results<'a>( column_names = Some(Arc::new(names)); } tiberius::QueryItem::Row(row) => { - let cols = columns.as_ref().expect("row received before metadata"); - let names = column_names.as_ref().expect("row received before metadata"); + let cols = columns.as_ref().ok_or_else(|| { + Error::Protocol("row received before metadata".into()) + })?; + let names = column_names.as_ref().ok_or_else(|| { + Error::Protocol("row received before metadata".into()) + })?; // Convert tiberius row to MssqlRow by iterating over cells let values: Vec = row diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index c7836f57cf..d593fe0038 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -36,9 +36,10 @@ impl MigrateDatabase for Mssql { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; + let escaped = database.replace(']', "]]"); let _ = conn .execute(AssertSqlSafe(format!( - "CREATE DATABASE [{database}]" + "CREATE DATABASE [{escaped}]" ))) .await?; @@ -64,12 +65,13 @@ impl MigrateDatabase for Mssql { let mut conn = options.connect().await?; // Force close existing connections before dropping + let escaped = database.replace('\'', "''").replace(']', "]]"); let _ = conn .execute(AssertSqlSafe(format!( - "IF DB_ID('{database}') IS NOT NULL \ + "IF DB_ID('{escaped}') IS NOT NULL \ BEGIN \ - ALTER DATABASE [{database}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{database}]; \ + ALTER DATABASE [{escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{escaped}]; \ END" ))) .await?; @@ -84,9 +86,10 @@ impl Migrate for MssqlConnection { schema_name: &'e str, ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { + let escaped = schema_name.replace('\'', "''").replace(']', "]]"); self.execute(AssertSqlSafe(format!( - r#"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{schema_name}') - EXEC('CREATE SCHEMA [{schema_name}]')"# + r#"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{escaped}') + EXEC('CREATE SCHEMA [{escaped}]')"# ))) .await?; @@ -158,10 +161,14 @@ CREATE TABLE {table_name} ( fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { Box::pin(async move { - // Use sp_getapplock for advisory locking in MSSQL + // sp_getapplock returns a status code (0/1 = success, negative = failure) + // but `execute` only surfaces SQL errors, not return values. + // We use THROW to convert a failed lock acquisition into a SQL error. let _ = self .execute( - "EXEC sp_getapplock @Resource = 'sqlx_migrations', @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1" + "DECLARE @r INT; \ + EXEC @r = sp_getapplock @Resource = 'sqlx_migrations', @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1; \ + IF @r < 0 THROW 50000, 'Failed to acquire migration lock', 1;" ) .await?; diff --git a/sqlx-mssql/src/testing/mod.rs b/sqlx-mssql/src/testing/mod.rs index 619728637c..05216e6fcd 100644 --- a/sqlx-mssql/src/testing/mod.rs +++ b/sqlx-mssql/src/testing/mod.rs @@ -52,11 +52,12 @@ impl TestSupport for Mssql { let mut deleted_count = 0usize; for db_name in &delete_db_names { + let escaped = db_name.replace('\'', "''").replace(']', "]]"); let drop_sql = format!( - "IF DB_ID('{db_name}') IS NOT NULL \ + "IF DB_ID('{escaped}') IS NOT NULL \ BEGIN \ - ALTER DATABASE [{db_name}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{db_name}]; \ + ALTER DATABASE [{escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{escaped}]; \ END" ); @@ -168,11 +169,12 @@ async fn test_context(args: &TestArgs) -> Result, Error> { } async fn do_cleanup(conn: &mut MssqlConnection, db_name: &str) -> Result<(), Error> { + let escaped = db_name.replace('\'', "''").replace(']', "]]"); let drop_sql = format!( - "IF DB_ID('{db_name}') IS NOT NULL \ + "IF DB_ID('{escaped}') IS NOT NULL \ BEGIN \ - ALTER DATABASE [{db_name}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{db_name}]; \ + ALTER DATABASE [{escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{escaped}]; \ END" ); conn.execute(AssertSqlSafe(drop_sql)).await?; diff --git a/sqlx-mssql/src/types/bigdecimal.rs b/sqlx-mssql/src/types/bigdecimal.rs index 175c9dbed2..b2fd93c655 100644 --- a/sqlx-mssql/src/types/bigdecimal.rs +++ b/sqlx-mssql/src/types/bigdecimal.rs @@ -14,7 +14,7 @@ impl Type for BigDecimal { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY") + matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY") } } diff --git a/sqlx-mssql/src/types/chrono.rs b/sqlx-mssql/src/types/chrono.rs index 6a849622d3..93932de833 100644 --- a/sqlx-mssql/src/types/chrono.rs +++ b/sqlx-mssql/src/types/chrono.rs @@ -50,6 +50,10 @@ impl Type for NaiveDate { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo::new("DATE") } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "DATE" + } } impl Encode<'_, Mssql> for NaiveDate { @@ -80,6 +84,10 @@ impl Type for NaiveTime { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo::new("TIME") } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "TIME" + } } impl Encode<'_, Mssql> for NaiveTime { diff --git a/sqlx-mssql/src/types/rust_decimal.rs b/sqlx-mssql/src/types/rust_decimal.rs index 5951f93f7e..e71d96aae1 100644 --- a/sqlx-mssql/src/types/rust_decimal.rs +++ b/sqlx-mssql/src/types/rust_decimal.rs @@ -14,7 +14,7 @@ impl Type for Decimal { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY") + matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY") } } diff --git a/sqlx-mssql/src/types/time.rs b/sqlx-mssql/src/types/time.rs index 0225fe1e66..ad420b3d87 100644 --- a/sqlx-mssql/src/types/time.rs +++ b/sqlx-mssql/src/types/time.rs @@ -14,6 +14,10 @@ impl Type for Date { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo::new("DATE") } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "DATE" + } } impl Encode<'_, Mssql> for Date { @@ -43,6 +47,10 @@ impl Type for Time { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo::new("TIME") } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "TIME" + } } impl Encode<'_, Mssql> for Time { diff --git a/sqlx-mssql/src/types/uuid.rs b/sqlx-mssql/src/types/uuid.rs index f36d259590..b06f50898a 100644 --- a/sqlx-mssql/src/types/uuid.rs +++ b/sqlx-mssql/src/types/uuid.rs @@ -12,6 +12,10 @@ impl Type for Uuid { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo::new("UNIQUEIDENTIFIER") } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "UNIQUEIDENTIFIER" + } } impl Encode<'_, Mssql> for Uuid { @@ -39,6 +43,10 @@ impl Type for uuid::fmt::Hyphenated { fn type_info() -> MssqlTypeInfo { MssqlTypeInfo::new("UNIQUEIDENTIFIER") } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "UNIQUEIDENTIFIER" + } } impl Encode<'_, Mssql> for uuid::fmt::Hyphenated { diff --git a/tests/mssql/migrate.rs b/tests/mssql/migrate.rs new file mode 100644 index 0000000000..63ba618151 --- /dev/null +++ b/tests/mssql/migrate.rs @@ -0,0 +1,82 @@ +use sqlx::migrate::Migrator; +use sqlx::mssql::{Mssql, MssqlConnection}; +use sqlx::pool::PoolConnection; +use sqlx::Executor; +use sqlx::Row; +use std::path::Path; + +#[sqlx::test(migrations = false)] +async fn simple(mut conn: PoolConnection) -> anyhow::Result<()> { + clean_up(&mut conn).await?; + + let migrator = Migrator::new(Path::new("tests/mssql/migrations_simple")).await?; + + // run migration + migrator.run(&mut conn).await?; + + // check outcome + let res: String = conn + .fetch_one("SELECT some_payload FROM migrations_simple_test") + .await? + .get(0); + assert_eq!(res, "110_suffix"); + + // running it a 2nd time should still work + migrator.run(&mut conn).await?; + + Ok(()) +} + +#[sqlx::test(migrations = false)] +async fn reversible(mut conn: PoolConnection) -> anyhow::Result<()> { + clean_up(&mut conn).await?; + + let migrator = Migrator::new(Path::new("tests/mssql/migrations_reversible")).await?; + + // run migration + migrator.run(&mut conn).await?; + + // check outcome + let res: i64 = conn + .fetch_one("SELECT some_payload FROM migrations_reversible_test") + .await? + .get(0); + assert_eq!(res, 101); + + // roll back nothing (last version) + migrator.undo(&mut conn, 20220721125033).await?; + + // check outcome + let res: i64 = conn + .fetch_one("SELECT some_payload FROM migrations_reversible_test") + .await? + .get(0); + assert_eq!(res, 101); + + // roll back one version + migrator.undo(&mut conn, 20220721124650).await?; + + // check outcome + let res: i64 = conn + .fetch_one("SELECT some_payload FROM migrations_reversible_test") + .await? + .get(0); + assert_eq!(res, 100); + + Ok(()) +} + +/// Ensure that we have a clean initial state. +async fn clean_up(conn: &mut MssqlConnection) -> anyhow::Result<()> { + conn.execute("IF OBJECT_ID('migrations_simple_test', 'U') IS NOT NULL DROP TABLE migrations_simple_test") + .await + .ok(); + conn.execute("IF OBJECT_ID('migrations_reversible_test', 'U') IS NOT NULL DROP TABLE migrations_reversible_test") + .await + .ok(); + conn.execute("IF OBJECT_ID('_sqlx_migrations', 'U') IS NOT NULL DROP TABLE _sqlx_migrations") + .await + .ok(); + + Ok(()) +} diff --git a/tests/mssql/migrations_reversible/20220721124650_add_table.down.sql b/tests/mssql/migrations_reversible/20220721124650_add_table.down.sql new file mode 100644 index 0000000000..5505859725 --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721124650_add_table.down.sql @@ -0,0 +1 @@ +DROP TABLE migrations_reversible_test; diff --git a/tests/mssql/migrations_reversible/20220721124650_add_table.up.sql b/tests/mssql/migrations_reversible/20220721124650_add_table.up.sql new file mode 100644 index 0000000000..9dfc757954 --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721124650_add_table.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE migrations_reversible_test ( + some_id BIGINT NOT NULL PRIMARY KEY, + some_payload BIGINT NOT NUll +); + +INSERT INTO migrations_reversible_test (some_id, some_payload) +VALUES (1, 100); diff --git a/tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql b/tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql new file mode 100644 index 0000000000..3f71737b8c --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql @@ -0,0 +1,2 @@ +UPDATE migrations_reversible_test +SET some_payload = some_payload - 1; diff --git a/tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql b/tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql new file mode 100644 index 0000000000..bbb176cf41 --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql @@ -0,0 +1,2 @@ +UPDATE migrations_reversible_test +SET some_payload = some_payload + 1; diff --git a/tests/mssql/migrations_simple/20220721115250_add_test_table.sql b/tests/mssql/migrations_simple/20220721115250_add_test_table.sql new file mode 100644 index 0000000000..d5ba291914 --- /dev/null +++ b/tests/mssql/migrations_simple/20220721115250_add_test_table.sql @@ -0,0 +1,7 @@ +CREATE TABLE migrations_simple_test ( + some_id BIGINT NOT NULL PRIMARY KEY, + some_payload BIGINT NOT NUll +); + +INSERT INTO migrations_simple_test (some_id, some_payload) +VALUES (1, 100); diff --git a/tests/mssql/migrations_simple/20220721115524_convert_type.sql b/tests/mssql/migrations_simple/20220721115524_convert_type.sql new file mode 100644 index 0000000000..c437c39d02 --- /dev/null +++ b/tests/mssql/migrations_simple/20220721115524_convert_type.sql @@ -0,0 +1,34 @@ +-- Perform a tricky conversion of the payload. +-- +-- This script will only succeed once and will fail if executed twice. + +-- set up temporary target column +ALTER TABLE migrations_simple_test +ADD some_payload_tmp NVARCHAR(MAX); + +-- perform conversion +-- This will fail if `some_payload` is already a string column due to the addition. +-- We add a suffix after the addition to ensure that the SQL database does not silently cast the string back to an +-- integer. +UPDATE migrations_simple_test +SET some_payload_tmp = CONCAT(CAST((some_payload + 10) AS VARCHAR(3)), '_suffix'); + +-- remove original column including the content +ALTER TABLE migrations_simple_test +DROP COLUMN some_payload; + +-- prepare new payload column (nullable, so we can copy over the data) +ALTER TABLE migrations_simple_test +ADD some_payload NVARCHAR(MAX); + +-- copy new values +UPDATE migrations_simple_test +SET some_payload = some_payload_tmp; + +-- "freeze" column: MSSQL uses sp_rename + re-add or ALTER COLUMN for NOT NULL +ALTER TABLE migrations_simple_test +ALTER COLUMN some_payload NVARCHAR(MAX) NOT NULL; + +-- clean up +ALTER TABLE migrations_simple_test +DROP COLUMN some_payload_tmp; From a75870773b968871bb60990b2c00079c8786c4d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 18:35:37 -0500 Subject: [PATCH 15/33] fix: add explicit parentheses for operator precedence in nanosecond calculation Closes #1 Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 7322feeaec..6e0087e092 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -110,7 +110,7 @@ impl MssqlConnection { let days = (naive.date() - epoch).num_days() as u32; let total_ns = naive.time().num_seconds_from_midnight() as u64 * 1_000_000_000 - + naive.time().nanosecond() as u64 % 1_000_000_000; + + (naive.time().nanosecond() as u64 % 1_000_000_000); let increments = total_ns / 100; let offset_minutes = v.offset().local_minus_utc() / 60; From 9e85668e5347c3bf28ab29cbd16685dbb2b10de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 18:56:38 -0500 Subject: [PATCH 16/33] refactor: extract time binding and document infallible unwrap Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 6e0087e092..ca9a55c18b 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -105,12 +105,14 @@ impl MssqlConnection { #[cfg(feature = "chrono")] MssqlArgumentValue::DateTimeFixedOffset(v) => { use chrono::Timelike as _; + // Year 1 is always a valid date let epoch = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap(); let naive = v.naive_local(); let days = (naive.date() - epoch).num_days() as u32; - let total_ns = naive.time().num_seconds_from_midnight() as u64 + let time = naive.time(); + let total_ns = time.num_seconds_from_midnight() as u64 * 1_000_000_000 - + (naive.time().nanosecond() as u64 % 1_000_000_000); + + (time.nanosecond() as u64 % 1_000_000_000); let increments = total_ns / 100; let offset_minutes = v.offset().local_minus_utc() / 60; From a904b91a50fce75c92b1e642a0005e7b528c1f34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 19:26:20 -0500 Subject: [PATCH 17/33] =?UTF-8?q?fix:=20replace=20unchecked=20i64=E2=86=92?= =?UTF-8?q?u32=20casts=20in=20date=20encoding=20with=20validated=20convers?= =?UTF-8?q?ion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dates before the TDS epoch (0001-01-01) produced negative day counts that silently wrapped via `as u32`, causing corrupt data or a panic inside `tiberius::time::Date::new()`. Add `days_since_epoch_to_u32()` helper that returns `Error::Encode` for negative or out-of-range values, and replace all four unsafe cast sites. Closes pabl-o-ce/sqlx#2 Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 65 +++++++++++++++++++++++++-- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index ca9a55c18b..f04599fcb8 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -34,6 +34,29 @@ impl<'a> tiberius::IntoSql<'a> for ColumnDataWrapper<'a> { } } +/// Maximum days-since-epoch (0001-01-01) that fits in the 3-byte TDS date +/// encoding. `tiberius::time::Date::new()` panics if `days > 0x00FFFFFF`. +#[cfg(any(feature = "chrono", feature = "time"))] +const MAX_DAYS: u32 = 0x00FF_FFFF; + +/// Convert a signed days-since-epoch count to `u32`, returning +/// `Error::Encode` if negative or exceeding the TDS 3-byte limit. +#[cfg(any(feature = "chrono", feature = "time"))] +fn days_since_epoch_to_u32(days: i64) -> Result { + u32::try_from(days) + .ok() + .filter(|&d| d <= MAX_DAYS) + .ok_or_else(|| { + Error::Encode( + format!( + "date out of range for SQL Server: {days} days since epoch \ + (must be 0..={MAX_DAYS})" + ) + .into(), + ) + }) +} + impl MssqlConnection { /// Execute a query, eagerly collecting all results. /// @@ -108,7 +131,7 @@ impl MssqlConnection { // Year 1 is always a valid date let epoch = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap(); let naive = v.naive_local(); - let days = (naive.date() - epoch).num_days() as u32; + let days = days_since_epoch_to_u32((naive.date() - epoch).num_days())?; let time = naive.time(); let total_ns = time.num_seconds_from_midnight() as u64 * 1_000_000_000 @@ -150,7 +173,7 @@ impl MssqlConnection { #[cfg(feature = "time")] MssqlArgumentValue::TimeDate(v) => { let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); - let days = (*v - epoch).whole_days() as u32; + let days = days_since_epoch_to_u32((*v - epoch).whole_days())?; let cd = tiberius::ColumnData::Date(Some( tiberius::time::Date::new(days), )); @@ -175,7 +198,7 @@ impl MssqlConnection { let date = v.date(); let time = v.time(); let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); - let days = (date - epoch).whole_days() as u32; + let days = days_since_epoch_to_u32((date - epoch).whole_days())?; let (h, m, s, ns) = time.as_hms_nano(); let total_ns = h as u64 * 3_600_000_000_000 + m as u64 * 60_000_000_000 @@ -196,7 +219,7 @@ impl MssqlConnection { let offset_minutes = v.offset().whole_seconds() / 60; let date = v.date(); let time = v.time(); - let days = (date - epoch).whole_days() as u32; + let days = days_since_epoch_to_u32((date - epoch).whole_days())?; let (h, m, s, ns) = time.as_hms_nano(); let total_ns = h as u64 * 3_600_000_000_000 + m as u64 * 60_000_000_000 @@ -540,3 +563,37 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { }) } } + +#[cfg(test)] +#[cfg(any(feature = "chrono", feature = "time"))] +mod tests { + use super::*; + + #[test] + fn days_since_epoch_zero() { + assert_eq!(days_since_epoch_to_u32(0).unwrap(), 0); + } + + #[test] + fn days_since_epoch_max_date() { + // 9999-12-31 is 3_652_058 days from 0001-01-01 + assert_eq!(days_since_epoch_to_u32(3_652_058).unwrap(), 3_652_058); + } + + #[test] + fn days_since_epoch_negative() { + let err = days_since_epoch_to_u32(-1).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn days_since_epoch_overflow() { + let err = days_since_epoch_to_u32(i64::from(MAX_DAYS) + 1).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn days_since_epoch_at_max() { + assert_eq!(days_since_epoch_to_u32(i64::from(MAX_DAYS)).unwrap(), MAX_DAYS); + } +} From 2f816bc60f3a503a940820c0da7435e415147e27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 19:53:13 -0500 Subject: [PATCH 18/33] =?UTF-8?q?fix:=20replace=20unchecked=20i32=E2=86=92?= =?UTF-8?q?i16=20cast=20in=20timezone=20offset=20encoding=20with=20validat?= =?UTF-8?q?ed=20conversion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds offset_minutes_to_i16 helper that validates the offset fits within SQL Server's -840..=840 minute range, returning Error::Encode instead of silently truncating. Follows the same pattern as days_since_epoch_to_u32. Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 53 ++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index f04599fcb8..b1946083f1 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -57,6 +57,26 @@ fn days_since_epoch_to_u32(days: i64) -> Result { }) } +/// Convert a signed offset-in-minutes to `i16`, returning +/// `Error::Encode` if outside the SQL Server range (-840..=840). +#[cfg(any(feature = "chrono", feature = "time"))] +fn offset_minutes_to_i16(offset_minutes: i32) -> Result { + const MIN_OFFSET: i32 = -840; + const MAX_OFFSET: i32 = 840; + if (MIN_OFFSET..=MAX_OFFSET).contains(&offset_minutes) { + // -840..=840 fits in i16, so this cast is infallible. + Ok(offset_minutes as i16) + } else { + Err(Error::Encode( + format!( + "timezone offset out of range for SQL Server: {offset_minutes} minutes \ + (must be {MIN_OFFSET}..={MAX_OFFSET})" + ) + .into(), + )) + } +} + impl MssqlConnection { /// Execute a query, eagerly collecting all results. /// @@ -146,7 +166,7 @@ impl MssqlConnection { let cd = tiberius::ColumnData::DateTimeOffset(Some( tiberius::time::DateTimeOffset::new( dt2, - offset_minutes as i16, + offset_minutes_to_i16(offset_minutes)?, ), )); query.bind(ColumnDataWrapper(cd)); @@ -233,7 +253,7 @@ impl MssqlConnection { let cd = tiberius::ColumnData::DateTimeOffset(Some( tiberius::time::DateTimeOffset::new( dt2, - offset_minutes as i16, + offset_minutes_to_i16(offset_minutes)?, ), )); query.bind(ColumnDataWrapper(cd)); @@ -596,4 +616,33 @@ mod tests { fn days_since_epoch_at_max() { assert_eq!(days_since_epoch_to_u32(i64::from(MAX_DAYS)).unwrap(), MAX_DAYS); } + + #[test] + fn offset_minutes_zero() { + assert_eq!(offset_minutes_to_i16(0).unwrap(), 0); + } + + #[test] + fn offset_minutes_positive_max() { + assert_eq!(offset_minutes_to_i16(840).unwrap(), 840); + } + + #[test] + fn offset_minutes_negative_max() { + assert_eq!(offset_minutes_to_i16(-840).unwrap(), -840); + } + + #[test] + fn offset_minutes_out_of_sql_range() { + let err = offset_minutes_to_i16(841).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + let err = offset_minutes_to_i16(-841).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn offset_minutes_i16_overflow() { + let err = offset_minutes_to_i16(i32::MAX).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } } From adace6d0997fc5f4de2b60d1d5e87a7fb92fd6e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 22:40:01 -0500 Subject: [PATCH 19/33] fix: escape table_name in MSSQL migration SQL to prevent injection Apply bracket-quoting for identifier contexts and single-quote escaping for string literal contexts across all 6 interpolation sites, consistent with existing escaping in create_database, drop_database, and create_schema_if_not_exists. Closes #4 Author: Pablo Carrera --- sqlx-mssql/src/migrate.rs | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index d593fe0038..0e334f1ed4 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -14,6 +14,11 @@ use crate::query_as::query_as; use crate::query_scalar::query_scalar; use crate::{Mssql, MssqlConnectOptions, MssqlConnection}; +/// Escape a table name for safe use as an MSSQL bracket-quoted identifier (`[...]`). +fn escape_table_name(table_name: &str) -> String { + format!("[{}]", table_name.replace(']', "]]")) +} + fn parse_for_maintenance(url: &str) -> Result<(MssqlConnectOptions, String), Error> { let mut options = MssqlConnectOptions::from_str(url)?; @@ -102,10 +107,12 @@ impl Migrate for MssqlConnection { table_name: &'e str, ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { + let lit = table_name.replace('\'', "''"); + let ident = escape_table_name(table_name); self.execute(AssertSqlSafe(format!( r#" -IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{table_name}') -CREATE TABLE {table_name} ( +IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{lit}') +CREATE TABLE {ident} ( version BIGINT PRIMARY KEY, description NVARCHAR(MAX) NOT NULL, installed_on DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(), @@ -126,8 +133,9 @@ CREATE TABLE {table_name} ( table_name: &'e str, ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { + let ident = escape_table_name(table_name); let row: Option<(i64,)> = query_as(AssertSqlSafe(format!( - "SELECT TOP 1 version FROM {table_name} WHERE success = 0 ORDER BY version" + "SELECT TOP 1 version FROM {ident} WHERE success = 0 ORDER BY version" ))) .fetch_optional(self) .await?; @@ -141,8 +149,9 @@ CREATE TABLE {table_name} ( table_name: &'e str, ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { + let ident = escape_table_name(table_name); let rows: Vec<(i64, Vec)> = query_as(AssertSqlSafe(format!( - "SELECT version, checksum FROM {table_name} ORDER BY version" + "SELECT version, checksum FROM {ident} ORDER BY version" ))) .fetch_all(self) .await?; @@ -212,10 +221,12 @@ CREATE TABLE {table_name} ( // might be lost. We accept this small risk since this value is not super important. let elapsed = start.elapsed(); + let ident = escape_table_name(table_name); + #[allow(clippy::cast_possible_truncation)] let _ = query(AssertSqlSafe(format!( r#" - UPDATE {table_name} + UPDATE {ident} SET execution_time = @p1 WHERE version = @p2 "# @@ -262,9 +273,10 @@ async fn execute_migration( .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + let ident = escape_table_name(table_name); let _ = query(AssertSqlSafe(format!( r#" - INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) + INSERT INTO {ident} ( version, description, success, checksum, execution_time ) VALUES ( @p1, @p2, 1, @p3, -1 ) "# ))) @@ -287,8 +299,9 @@ async fn revert_migration( .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + let ident = escape_table_name(table_name); let _ = query(AssertSqlSafe(format!( - r#"DELETE FROM {table_name} WHERE version = @p1"# + r#"DELETE FROM {ident} WHERE version = @p1"# ))) .bind(migration.version) .execute(conn) From 597166bd4753ab7e1e6d77265e7347a3f3142c92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 22:57:04 -0500 Subject: [PATCH 20/33] fix: use context-specific escaping in drop_database and create_schema_if_not_exists Split the single `escaped` variable into separate variables for each SQL quoting context, preventing cross-contamination between bracket-quoted identifiers ([...]) and single-quoted string literals ('...'). Closes #5 Author: Pablo Carrera --- sqlx-mssql/src/migrate.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index 0e334f1ed4..b71c9fcac5 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -70,13 +70,16 @@ impl MigrateDatabase for Mssql { let mut conn = options.connect().await?; // Force close existing connections before dropping - let escaped = database.replace('\'', "''").replace(']', "]]"); + // Use separate escaped values for the two different quoting contexts: + // bracket-quoted identifiers ([...]) vs single-quoted string literals ('...') + let bracket_escaped = database.replace(']', "]]"); + let quote_escaped = database.replace('\'', "''"); let _ = conn .execute(AssertSqlSafe(format!( - "IF DB_ID('{escaped}') IS NOT NULL \ + "IF DB_ID('{quote_escaped}') IS NOT NULL \ BEGIN \ - ALTER DATABASE [{escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{escaped}]; \ + ALTER DATABASE [{bracket_escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ + DROP DATABASE [{bracket_escaped}]; \ END" ))) .await?; @@ -91,10 +94,13 @@ impl Migrate for MssqlConnection { schema_name: &'e str, ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { - let escaped = schema_name.replace('\'', "''").replace(']', "]]"); + let quote_escaped = schema_name.replace('\'', "''"); + // Inside EXEC('...'), the identifier must be bracket-escaped AND + // single-quote-escaped (since it's nested inside a string literal). + let exec_escaped = schema_name.replace(']', "]]").replace('\'', "''"); self.execute(AssertSqlSafe(format!( - r#"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{escaped}') - EXEC('CREATE SCHEMA [{escaped}]')"# + r#"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{quote_escaped}') + EXEC('CREATE SCHEMA [{exec_escaped}]')"# ))) .await?; From 7b1d04da183f34fa132333f2e35599bd4df70c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 23:25:52 -0500 Subject: [PATCH 21/33] fix: replace panicking unwrap/expect calls in column_data_to_mssql_data with Result propagation Return Result instead of panicking on invalid data from tiberius. Fixes and_local_timezone().unwrap() panic on ambiguous/invalid DateTimeOffset, and guards against silent u8 truncation in time_from_sec_fragments with an upfront bounds check. All helper functions (chrono_date_from_days, time_date_from_days, time_from_sec_fragments) now return Result and use checked arithmetic. Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 2 +- sqlx-mssql/src/value.rs | 167 +++++++++++++++++--------- 2 files changed, 114 insertions(+), 55 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index b1946083f1..f11253953e 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -348,7 +348,7 @@ async fn collect_results<'a>( let values: Vec = row .into_iter() .map(|data| column_data_to_mssql_data(&data)) - .collect(); + .collect::, _>>()?; rows_affected += 1; logger.increment_rows_returned(); diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index 2fffd1f647..64dc9df11e 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -2,7 +2,7 @@ use std::borrow::Cow; pub(crate) use sqlx_core::value::*; -use crate::error::BoxDynError; +use crate::error::{BoxDynError, Error}; use crate::{Mssql, MssqlTypeInfo}; /// Internal storage for an MSSQL value, decoupled from tiberius lifetimes. @@ -114,161 +114,220 @@ impl<'r> ValueRef<'r> for MssqlValueRef<'r> { } /// Convert a `tiberius::ColumnData` into our owned `MssqlData`. -pub(crate) fn column_data_to_mssql_data(data: &tiberius::ColumnData<'_>) -> MssqlData { +pub(crate) fn column_data_to_mssql_data( + data: &tiberius::ColumnData<'_>, +) -> Result { match data { - tiberius::ColumnData::U8(Some(v)) => MssqlData::U8(*v), - tiberius::ColumnData::I16(Some(v)) => MssqlData::I16(*v), - tiberius::ColumnData::I32(Some(v)) => MssqlData::I32(*v), - tiberius::ColumnData::I64(Some(v)) => MssqlData::I64(*v), - tiberius::ColumnData::F32(Some(v)) => MssqlData::F32(*v), - tiberius::ColumnData::F64(Some(v)) => MssqlData::F64(*v), - tiberius::ColumnData::Bit(Some(v)) => MssqlData::Bool(*v), - tiberius::ColumnData::String(Some(v)) => MssqlData::String(v.to_string()), - tiberius::ColumnData::Binary(Some(v)) => MssqlData::Binary(v.to_vec()), + tiberius::ColumnData::U8(Some(v)) => Ok(MssqlData::U8(*v)), + tiberius::ColumnData::I16(Some(v)) => Ok(MssqlData::I16(*v)), + tiberius::ColumnData::I32(Some(v)) => Ok(MssqlData::I32(*v)), + tiberius::ColumnData::I64(Some(v)) => Ok(MssqlData::I64(*v)), + tiberius::ColumnData::F32(Some(v)) => Ok(MssqlData::F32(*v)), + tiberius::ColumnData::F64(Some(v)) => Ok(MssqlData::F64(*v)), + tiberius::ColumnData::Bit(Some(v)) => Ok(MssqlData::Bool(*v)), + tiberius::ColumnData::String(Some(v)) => Ok(MssqlData::String(v.to_string())), + tiberius::ColumnData::Binary(Some(v)) => Ok(MssqlData::Binary(v.to_vec())), #[cfg(feature = "chrono")] tiberius::ColumnData::DateTime2(Some(dt2)) => { - let date = chrono_date_from_days(dt2.date().days() as i64, 1); + let date = chrono_date_from_days(dt2.date().days() as i64, 1)?; let ns = dt2.time().increments() as i64 * 10i64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); + // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); - MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time)) + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) } #[cfg(feature = "chrono")] tiberius::ColumnData::DateTime(Some(dt)) => { - let date = chrono_date_from_days(dt.days() as i64, 1900); + let date = chrono_date_from_days(dt.days() as i64, 1900)?; let ns = dt.seconds_fragments() as i64 * 1_000_000_000i64 / 300; + // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); - MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time)) + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) } #[cfg(feature = "chrono")] tiberius::ColumnData::SmallDateTime(Some(dt)) => { - let date = chrono_date_from_days(dt.days() as i64, 1900); + let date = chrono_date_from_days(dt.days() as i64, 1900)?; let seconds = dt.seconds_fragments() as u32 * 60; - let time = - chrono::NaiveTime::from_num_seconds_from_midnight_opt(seconds, 0).unwrap(); - MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time)) + let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(seconds, 0) + .ok_or_else(|| { + Error::Protocol( + format!( + "invalid SmallDateTime seconds: {seconds} exceeds seconds-in-a-day" + ) + .into(), + ) + })?; + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) } #[cfg(feature = "chrono")] tiberius::ColumnData::Date(Some(d)) => { - MssqlData::NaiveDate(chrono_date_from_days(d.days() as i64, 1)) + Ok(MssqlData::NaiveDate(chrono_date_from_days(d.days() as i64, 1)?)) } #[cfg(feature = "chrono")] tiberius::ColumnData::Time(Some(t)) => { let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); - MssqlData::NaiveTime(time) + Ok(MssqlData::NaiveTime(time)) } #[cfg(feature = "chrono")] tiberius::ColumnData::DateTimeOffset(Some(dto)) => { - let date = chrono_date_from_days(dto.datetime2().date().days() as i64, 1); + let date = chrono_date_from_days(dto.datetime2().date().days() as i64, 1)?; let ns = dto.datetime2().time().increments() as i64 * 10i64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); + // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); let naive = chrono::NaiveDateTime::new(date, time); let offset_secs = dto.offset() as i32 * 60; - let fixed_offset = chrono::FixedOffset::east_opt(offset_secs) - .expect("valid offset from tiberius"); - MssqlData::DateTimeFixedOffset(naive.and_local_timezone(fixed_offset).unwrap()) + let fixed_offset = chrono::FixedOffset::east_opt(offset_secs).ok_or_else(|| { + Error::Protocol( + format!("invalid timezone offset: {offset_secs} seconds").into(), + ) + })?; + let dt = naive.and_local_timezone(fixed_offset).single().ok_or_else(|| { + Error::Protocol( + format!( + "ambiguous or invalid local time for offset {offset_secs}s" + ) + .into(), + ) + })?; + Ok(MssqlData::DateTimeFixedOffset(dt)) } #[cfg(feature = "uuid")] - tiberius::ColumnData::Guid(Some(v)) => MssqlData::Uuid(*v), + tiberius::ColumnData::Guid(Some(v)) => Ok(MssqlData::Uuid(*v)), #[cfg(feature = "rust_decimal")] tiberius::ColumnData::Numeric(Some(n)) => { - MssqlData::Decimal(rust_decimal::Decimal::from_i128_with_scale( + Ok(MssqlData::Decimal(rust_decimal::Decimal::from_i128_with_scale( n.value(), n.scale() as u32, - )) + ))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::Date(Some(d)) => { - MssqlData::TimeDate(time_date_from_days(d.days() as u64, 1)) + Ok(MssqlData::TimeDate(time_date_from_days(d.days() as u64, 1)?)) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::Time(Some(t)) => { let ns = t.increments() as u64 * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); - MssqlData::TimeTime(time_from_sec_fragments(ns)) + Ok(MssqlData::TimeTime(time_from_sec_fragments(ns)?)) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime2(Some(dt2)) => { - let date = time_date_from_days(dt2.date().days() as u64, 1); + let date = time_date_from_days(dt2.date().days() as u64, 1)?; let ns = dt2.time().increments() as u64 * 10u64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); - let time = time_from_sec_fragments(ns); - MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time)) + let time = time_from_sec_fragments(ns)?; + Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime(Some(dt)) => { - let date = time_date_from_days(dt.days() as u64, 1900); + let date = time_date_from_days(dt.days() as u64, 1900)?; let ns = dt.seconds_fragments() as u64 * 1_000_000_000u64 / 300; - let time = time_from_sec_fragments(ns); - MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time)) + let time = time_from_sec_fragments(ns)?; + Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::SmallDateTime(Some(dt)) => { - let date = time_date_from_days(dt.days() as u64, 1900); + let date = time_date_from_days(dt.days() as u64, 1900)?; let seconds = dt.seconds_fragments() as u64 * 60; - let time = time_from_sec_fragments(seconds * 1_000_000_000); - MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time)) + let time = time_from_sec_fragments(seconds * 1_000_000_000)?; + Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTimeOffset(Some(dto)) => { - let date = time_date_from_days(dto.datetime2().date().days() as u64, 1); + let date = time_date_from_days(dto.datetime2().date().days() as u64, 1)?; let ns = dto.datetime2().time().increments() as u64 * 10u64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); - let time = time_from_sec_fragments(ns); + let time = time_from_sec_fragments(ns)?; let naive = time::PrimitiveDateTime::new(date, time); - let offset = time::UtcOffset::from_whole_seconds(dto.offset() as i32 * 60) - .expect("valid UTC offset from tiberius"); - MssqlData::TimeOffsetDateTime(naive.assume_offset(offset)) + let offset_secs = dto.offset() as i32 * 60; + let offset = time::UtcOffset::from_whole_seconds(offset_secs).map_err(|_| { + Error::Protocol( + format!("invalid UTC offset: {offset_secs} seconds").into(), + ) + })?; + Ok(MssqlData::TimeOffsetDateTime(naive.assume_offset(offset))) } #[cfg(all(feature = "bigdecimal", not(feature = "rust_decimal")))] tiberius::ColumnData::Numeric(Some(n)) => { use bigdecimal::num_bigint::BigInt; - MssqlData::BigDecimal(bigdecimal::BigDecimal::new( + Ok(MssqlData::BigDecimal(bigdecimal::BigDecimal::new( BigInt::from(n.value()), n.scale() as i64, - )) + ))) } // All None variants and unhandled types map to Null - _ => MssqlData::Null, + _ => Ok(MssqlData::Null), } } /// Convert days since `start_year`-01-01 to a `time::Date`. #[cfg(feature = "time")] -fn time_date_from_days(days: u64, start_year: i32) -> time::Date { - let start = time::Date::from_ordinal_date(start_year, 1).expect("valid start date"); +fn time_date_from_days(days: u64, start_year: i32) -> Result { + let start = time::Date::from_ordinal_date(start_year, 1).map_err(|_| { + Error::Protocol(format!("invalid start year for date: {start_year}").into()) + })?; start .checked_add(time::Duration::days(days as i64)) - .expect("valid date from days offset") + .ok_or_else(|| { + Error::Protocol( + format!("date overflow: {days} days from {start_year}-01-01").into(), + ) + }) } /// Convert nanoseconds-since-midnight to a `time::Time`. #[cfg(feature = "time")] -fn time_from_sec_fragments(nanoseconds: u64) -> time::Time { +fn time_from_sec_fragments(nanoseconds: u64) -> Result { + const NANOS_PER_DAY: u64 = 86_400_000_000_000; + if nanoseconds >= NANOS_PER_DAY { + return Err(Error::Protocol( + format!( + "time nanoseconds out of range: {nanoseconds} (must be < {NANOS_PER_DAY})" + ) + .into(), + )); + } + // After the bounds check, hours is 0..=23, minutes 0..=59, seconds 0..=59, + // so the `as u8` casts and `from_hms_nano` are all infallible. let hours = (nanoseconds / 3_600_000_000_000) as u8; let remaining = nanoseconds % 3_600_000_000_000; let minutes = (remaining / 60_000_000_000) as u8; let remaining = remaining % 60_000_000_000; let seconds = (remaining / 1_000_000_000) as u8; let nanos = (remaining % 1_000_000_000) as u32; - time::Time::from_hms_nano(hours, minutes, seconds, nanos).expect("valid time") + time::Time::from_hms_nano(hours, minutes, seconds, nanos).map_err(|_| { + Error::Protocol( + format!("invalid time: {hours:02}:{minutes:02}:{seconds:02}.{nanos:09}") + .into(), + ) + }) } /// Convert days since `start_year`-01-01 to a `chrono::NaiveDate`. #[cfg(feature = "chrono")] -fn chrono_date_from_days(days: i64, start_year: i32) -> chrono::NaiveDate { - chrono::NaiveDate::from_ymd_opt(start_year, 1, 1).unwrap() + chrono::Duration::days(days) +fn chrono_date_from_days(days: i64, start_year: i32) -> Result { + let start = chrono::NaiveDate::from_ymd_opt(start_year, 1, 1).ok_or_else(|| { + Error::Protocol(format!("invalid start year for date: {start_year}").into()) + })?; + start + .checked_add_signed(chrono::Duration::days(days)) + .ok_or_else(|| { + Error::Protocol( + format!("date overflow: {days} days from {start_year}-01-01").into(), + ) + }) } From eea0c7dcfda88c447ad6e3b8593e2555eda05def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 3 Mar 2026 23:41:44 -0500 Subject: [PATCH 22/33] fix: eliminate client-side SQL escaping in DDL functions using parameterized queries + QUOTENAME() Replace fragile client-side .replace() escaping in create_database, drop_database, and create_schema_if_not_exists with parameterized queries (@p1) and SQL Server's built-in QUOTENAME() function for server-side identifier escaping via sp_executesql. Author: Pablo Carrera --- sqlx-mssql/src/migrate.rs | 57 ++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index b71c9fcac5..e600e43db8 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -41,12 +41,13 @@ impl MigrateDatabase for Mssql { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; - let escaped = database.replace(']', "]]"); - let _ = conn - .execute(AssertSqlSafe(format!( - "CREATE DATABASE [{escaped}]" - ))) - .await?; + query( + "DECLARE @sql NVARCHAR(MAX) = N'CREATE DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql;" + ) + .bind(database) + .execute(&mut conn) + .await?; Ok(()) } @@ -69,20 +70,19 @@ impl MigrateDatabase for Mssql { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; - // Force close existing connections before dropping - // Use separate escaped values for the two different quoting contexts: - // bracket-quoted identifiers ([...]) vs single-quoted string literals ('...') - let bracket_escaped = database.replace(']', "]]"); - let quote_escaped = database.replace('\'', "''"); - let _ = conn - .execute(AssertSqlSafe(format!( - "IF DB_ID('{quote_escaped}') IS NOT NULL \ - BEGIN \ - ALTER DATABASE [{bracket_escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{bracket_escaped}]; \ - END" - ))) - .await?; + query( + "IF DB_ID(@p1) IS NOT NULL \ + BEGIN \ + DECLARE @sql NVARCHAR(MAX); \ + SET @sql = N'ALTER DATABASE ' + QUOTENAME(@p1) + N' SET SINGLE_USER WITH ROLLBACK IMMEDIATE'; \ + EXEC sp_executesql @sql; \ + SET @sql = N'DROP DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END" + ) + .bind(database) + .execute(&mut conn) + .await?; Ok(()) } @@ -94,14 +94,15 @@ impl Migrate for MssqlConnection { schema_name: &'e str, ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { - let quote_escaped = schema_name.replace('\'', "''"); - // Inside EXEC('...'), the identifier must be bracket-escaped AND - // single-quote-escaped (since it's nested inside a string literal). - let exec_escaped = schema_name.replace(']', "]]").replace('\'', "''"); - self.execute(AssertSqlSafe(format!( - r#"IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = '{quote_escaped}') - EXEC('CREATE SCHEMA [{exec_escaped}]')"# - ))) + query( + "IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = @p1) \ + BEGIN \ + DECLARE @sql NVARCHAR(MAX) = N'CREATE SCHEMA ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END" + ) + .bind(schema_name) + .execute(&mut *self) .await?; Ok(()) From bd2756d81f3e8aa4b2d6d4fb75317aba007f6105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Wed, 4 Mar 2026 00:03:32 -0500 Subject: [PATCH 23/33] fix: replace silent NULL catch-all with explicit None enumeration in column_data_to_mssql_data The wildcard `_ => Ok(MssqlData::Null)` silently coerced unhandled `Some(...)` variants (e.g. Xml, Guid without uuid, Numeric without rust_decimal) to NULL, causing silent data loss. Enumerate all 18 `None` variants explicitly and error on unhandled `Some(...)` values. Author: Pablo Carrera --- sqlx-mssql/src/value.rs | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index 64dc9df11e..ffcc54e871 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -269,8 +269,34 @@ pub(crate) fn column_data_to_mssql_data( ))) } - // All None variants and unhandled types map to Null - _ => Ok(MssqlData::Null), + // All None variants represent SQL NULL + tiberius::ColumnData::U8(None) + | tiberius::ColumnData::I16(None) + | tiberius::ColumnData::I32(None) + | tiberius::ColumnData::I64(None) + | tiberius::ColumnData::F32(None) + | tiberius::ColumnData::F64(None) + | tiberius::ColumnData::Bit(None) + | tiberius::ColumnData::String(None) + | tiberius::ColumnData::Guid(None) + | tiberius::ColumnData::Binary(None) + | tiberius::ColumnData::Numeric(None) + | tiberius::ColumnData::Xml(None) + | tiberius::ColumnData::DateTime(None) + | tiberius::ColumnData::SmallDateTime(None) + | tiberius::ColumnData::DateTime2(None) + | tiberius::ColumnData::DateTimeOffset(None) + | tiberius::ColumnData::Date(None) + | tiberius::ColumnData::Time(None) => Ok(MssqlData::Null), + + // Unhandled Some(...) variant — real data the driver can't convert + other => { + let debug = format!("{other:?}"); + let truncated = if debug.len() > 200 { &debug[..200] } else { &debug }; + Err(Error::Protocol( + format!("unsupported tiberius ColumnData variant: {truncated}").into(), + )) + } } } From 192eb59f2410b1d7c66b3b6fbfd1cdb9f8916dcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Wed, 4 Mar 2026 00:22:15 -0500 Subject: [PATCH 24/33] fix: validate BigDecimal scale before u8 cast to prevent silent truncation Closes #9 Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index f11253953e..e2223b009a 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -263,7 +263,13 @@ impl MssqlConnection { use bigdecimal::ToPrimitive; // Convert BigDecimal to tiberius Numeric let (bigint, exponent) = v.as_bigint_and_exponent(); - let scale = exponent.max(0) as u8; + let scale = exponent.max(0); + if scale > 38 { + return Err(Error::Encode( + format!("BigDecimal scale {scale} exceeds SQL Server maximum of 38").into(), + )); + } + let scale = scale as u8; // Convert to i128 for Numeric let value: i128 = bigint.to_i128().ok_or_else(|| { Error::Encode( From 46cc1550f43114f6e75124d13aec8f5faa466e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Wed, 4 Mar 2026 00:47:46 -0500 Subject: [PATCH 25/33] fix: handle BigDecimal negative exponents and correct scale boundary check Extract bigdecimal_to_numeric() helper that normalizes negative exponents via with_scale(0) (matching tiberius's own to_sql! pattern) and tightens the scale limit from >38 to >37 to match tiberius's assert!(scale < 38). Add 10 unit tests covering negative exponents, boundary values, silent truncation at scale=256, and the off-by-one at scale=38. Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 154 +++++++++++++++++++++++--- 1 file changed, 138 insertions(+), 16 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index e2223b009a..edf8bd4dd0 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -77,6 +77,44 @@ fn offset_minutes_to_i16(offset_minutes: i32) -> Result { } } +/// Convert a `BigDecimal` into the `(i128, u8)` pair that +/// `tiberius::numeric::Numeric::new_with_scale` expects. +/// +/// Handles two edge cases: +/// - **Negative exponents** (e.g. `BigDecimal(9, -3)` = 9000): rescales to +/// exponent 0 so SQL Server receives the correct magnitude. +/// - **Scale > 37**: SQL Server NUMERIC max scale is 37, and tiberius +/// asserts `scale < 38`. Returns `Error::Encode` instead of panicking. +#[cfg(feature = "bigdecimal")] +fn bigdecimal_to_numeric(v: &bigdecimal::BigDecimal) -> Result<(i128, u8), Error> { + use bigdecimal::ToPrimitive; + + let (bigint, exponent) = v.as_bigint_and_exponent(); + let (bigint, exponent) = if exponent < 0 { + v.with_scale(0).into_bigint_and_exponent() + } else { + (bigint, exponent) + }; + + if exponent > 37 { + return Err(Error::Encode( + format!( + "BigDecimal scale {exponent} exceeds SQL Server maximum of 37" + ) + .into(), + )); + } + let scale = exponent as u8; + + let value: i128 = bigint.to_i128().ok_or_else(|| { + Error::Encode( + format!("BigDecimal value too large for SQL NUMERIC: {v}").into(), + ) + })?; + + Ok((value, scale)) +} + impl MssqlConnection { /// Execute a query, eagerly collecting all results. /// @@ -260,22 +298,7 @@ impl MssqlConnection { } #[cfg(feature = "bigdecimal")] MssqlArgumentValue::BigDecimal(v) => { - use bigdecimal::ToPrimitive; - // Convert BigDecimal to tiberius Numeric - let (bigint, exponent) = v.as_bigint_and_exponent(); - let scale = exponent.max(0); - if scale > 38 { - return Err(Error::Encode( - format!("BigDecimal scale {scale} exceeds SQL Server maximum of 38").into(), - )); - } - let scale = scale as u8; - // Convert to i128 for Numeric - let value: i128 = bigint.to_i128().ok_or_else(|| { - Error::Encode( - format!("BigDecimal value too large for SQL NUMERIC: {v}").into(), - ) - })?; + let (value, scale) = bigdecimal_to_numeric(&v)?; let cd = tiberius::ColumnData::Numeric(Some( tiberius::numeric::Numeric::new_with_scale(value, scale), )); @@ -652,3 +675,102 @@ mod tests { assert!(matches!(err, Error::Encode(_))); } } + +#[cfg(test)] +#[cfg(feature = "bigdecimal")] +mod bigdecimal_tests { + use super::*; + use std::str::FromStr; + + #[test] + fn positive_scale_simple() { + // 123.45 → bigint=12345, exponent=2 → scale=2 + let bd = bigdecimal::BigDecimal::from_str("123.45").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 12345); + assert_eq!(scale, 2); + } + + #[test] + fn zero_scale() { + // 42 → bigint=42, exponent=0 → scale=0 + let bd = bigdecimal::BigDecimal::from_str("42").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 42); + assert_eq!(scale, 0); + } + + #[test] + fn negative_exponent_rescales() { + // Explicitly construct BigDecimal(123, -3) = 123 * 10^3 = 123000. + // This is the internal form that triggers the negative-exponent path. + let bd = bigdecimal::BigDecimal::new(123.into(), -3); + let (bigint_raw, exp_raw) = bd.as_bigint_and_exponent(); + assert_eq!(exp_raw, -3, "precondition: exponent must be negative"); + assert_eq!(bigint_raw, 123.into(), "precondition: raw bigint is 123"); + + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + // After rescaling: 123000 with scale 0 + assert_eq!(value, 123000); + assert_eq!(scale, 0); + } + + #[test] + fn negative_exponent_large_magnitude() { + // 5e10 = 50_000_000_000 → internally (5, -10) + let bd = bigdecimal::BigDecimal::from_str("5e10").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 50_000_000_000); + assert_eq!(scale, 0); + } + + #[test] + fn scale_at_max_37() { + // Scale exactly 37 is the maximum tiberius allows + let bd = bigdecimal::BigDecimal::new(1.into(), 37); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 1); + assert_eq!(scale, 37); + } + + #[test] + fn scale_38_rejected() { + // Scale 38 triggers tiberius assert!(scale < 38); must be rejected + let bd = bigdecimal::BigDecimal::new(1.into(), 38); + let err = bigdecimal_to_numeric(&bd).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn scale_39_rejected() { + let bd = bigdecimal::BigDecimal::new(1.into(), 39); + let err = bigdecimal_to_numeric(&bd).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn scale_256_rejected_not_truncated() { + // The original bug: `as u8` would silently truncate 256 → 0. + // Must return an error, not scale=0. + let bd = bigdecimal::BigDecimal::new(1.into(), 256); + let err = bigdecimal_to_numeric(&bd).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn negative_value() { + // -99.9 → bigint=-999, scale=1 + let bd = bigdecimal::BigDecimal::from_str("-99.9").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, -999); + assert_eq!(scale, 1); + } + + #[test] + fn zero_value() { + let bd = bigdecimal::BigDecimal::from_str("0").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 0); + assert_eq!(scale, 0); + } +} From 470362a2633617759db433112eaadfbbd5272d64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Wed, 4 Mar 2026 01:05:34 -0500 Subject: [PATCH 26/33] fix: use parameterized queries for sp_describe stored procedure calls Replace string-interpolated SQL with tiberius::Query parameterized bindings in prepare_with/describe to eliminate Unicode homoglyph injection edge case. Also log sp_describe_undeclared_parameters errors instead of silently swallowing them. Closes #10 Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 41 +++++++++++++-------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index edf8bd4dd0..330e7cbc48 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -459,18 +459,16 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { { Box::pin(async move { // Use sp_describe_first_result_set to get column metadata - let describe_sql = format!( - "EXEC sp_describe_first_result_set @tsql = N'{}'", - sql.as_str().replace('\'', "''") + let mut describe_query = tiberius::Query::new( + "EXEC sp_describe_first_result_set @tsql = @p1", ); + describe_query.bind(sql.as_str()); let mut columns = Vec::new(); let mut column_names = HashMap::new(); - let stream = self - .inner - .client - .simple_query(&describe_sql) + let stream = describe_query + .query(&mut self.inner.client) .await .map_err(tiberius_err)?; @@ -531,15 +529,13 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { { Box::pin(async move { // Query sp_describe_first_result_set directly so we can extract nullable info - let describe_sql = format!( - "EXEC sp_describe_first_result_set @tsql = N'{}'", - sql.as_str().replace('\'', "''") + let mut describe_query = tiberius::Query::new( + "EXEC sp_describe_first_result_set @tsql = @p1", ); + describe_query.bind(sql.as_str()); - let stream = self - .inner - .client - .simple_query(&describe_sql) + let stream = describe_query + .query(&mut self.inner.client) .await .map_err(tiberius_err)?; @@ -586,14 +582,12 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { } // Count parameters using sp_describe_undeclared_parameters - let param_sql = format!( - "EXEC sp_describe_undeclared_parameters @tsql = N'{}'", - sql.as_str().replace('\'', "''") + let mut param_query = tiberius::Query::new( + "EXEC sp_describe_undeclared_parameters @tsql = @p1", ); - let param_count = match self - .inner - .client - .simple_query(¶m_sql) + param_query.bind(sql.as_str()); + let param_count = match param_query + .query(&mut self.inner.client) .await { Ok(stream) => stream @@ -601,7 +595,10 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { .await .map_err(tiberius_err)? .len(), - Err(_) => 0, + Err(e) => { + tracing::debug!("sp_describe_undeclared_parameters failed: {e}"); + 0 + } }; Ok(crate::describe::Describe { From cd02ddab479f81d20f6c289f5fb0afd802a6b4e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Wed, 4 Mar 2026 17:16:22 -0500 Subject: [PATCH 27/33] fix: correct Datetime4 mapping, guard rust_decimal scale, and remove SQL escaping in test harness - Map Datetime4 to SMALLDATETIME instead of DATETIMEOFFSET - Validate rust_decimal scale <= 37 before u8 cast to prevent silent truncation - Replace client-side .replace(']',"]]") escaping with parameterized QUOTENAME(@p1) + sp_executesql in testing/mod.rs Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 8 ++++- sqlx-mssql/src/testing/mod.rs | 52 ++++++++++++++++----------- sqlx-mssql/src/type_info.rs | 5 ++- 3 files changed, 41 insertions(+), 24 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 330e7cbc48..2d438fadb8 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -223,9 +223,15 @@ impl MssqlConnection { if v.is_sign_negative() { value = -value; } + let scale = v.scale(); + if scale > 37 { + return Err(Error::Encode( + format!("rust_decimal scale {scale} exceeds SQL Server maximum of 37").into(), + )); + } query.bind(tiberius::numeric::Numeric::new_with_scale( value, - v.scale() as u8, + scale as u8, )); } #[cfg(feature = "time")] diff --git a/sqlx-mssql/src/testing/mod.rs b/sqlx-mssql/src/testing/mod.rs index 05216e6fcd..a668d687e7 100644 --- a/sqlx-mssql/src/testing/mod.rs +++ b/sqlx-mssql/src/testing/mod.rs @@ -11,7 +11,6 @@ use crate::query::query; use crate::{Mssql, MssqlConnectOptions, MssqlConnection}; use sqlx_core::connection::Connection; use sqlx_core::query_scalar::query_scalar; -use sqlx_core::sql_str::AssertSqlSafe; pub(crate) use sqlx_core::testing::*; @@ -52,16 +51,20 @@ impl TestSupport for Mssql { let mut deleted_count = 0usize; for db_name in &delete_db_names { - let escaped = db_name.replace('\'', "''").replace(']', "]]"); - let drop_sql = format!( - "IF DB_ID('{escaped}') IS NOT NULL \ + match query( + "IF DB_ID(@p1) IS NOT NULL \ BEGIN \ - ALTER DATABASE [{escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{escaped}]; \ - END" - ); - - match conn.execute(AssertSqlSafe(drop_sql)).await { + DECLARE @sql NVARCHAR(MAX); \ + SET @sql = N'ALTER DATABASE ' + QUOTENAME(@p1) + N' SET SINGLE_USER WITH ROLLBACK IMMEDIATE'; \ + EXEC sp_executesql @sql; \ + SET @sql = N'DROP DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END", + ) + .bind(db_name) + .execute(&mut conn) + .await + { Ok(_deleted) => { deleted_count += 1; } @@ -149,8 +152,13 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .execute(&mut *conn) .await?; - conn.execute(AssertSqlSafe(format!("CREATE DATABASE [{db_name}]"))) - .await?; + query( + "DECLARE @sql NVARCHAR(MAX) = N'CREATE DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql;", + ) + .bind(&db_name) + .execute(&mut *conn) + .await?; eprintln!("created database {db_name}"); @@ -169,15 +177,19 @@ async fn test_context(args: &TestArgs) -> Result, Error> { } async fn do_cleanup(conn: &mut MssqlConnection, db_name: &str) -> Result<(), Error> { - let escaped = db_name.replace('\'', "''").replace(']', "]]"); - let drop_sql = format!( - "IF DB_ID('{escaped}') IS NOT NULL \ + query( + "IF DB_ID(@p1) IS NOT NULL \ BEGIN \ - ALTER DATABASE [{escaped}] SET SINGLE_USER WITH ROLLBACK IMMEDIATE; \ - DROP DATABASE [{escaped}]; \ - END" - ); - conn.execute(AssertSqlSafe(drop_sql)).await?; + DECLARE @sql NVARCHAR(MAX); \ + SET @sql = N'ALTER DATABASE ' + QUOTENAME(@p1) + N' SET SINGLE_USER WITH ROLLBACK IMMEDIATE'; \ + EXEC sp_executesql @sql; \ + SET @sql = N'DROP DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END", + ) + .bind(db_name) + .execute(&mut *conn) + .await?; query("DELETE FROM _sqlx_test_databases WHERE db_name = @p1") .bind(db_name) .execute(&mut *conn) diff --git a/sqlx-mssql/src/type_info.rs b/sqlx-mssql/src/type_info.rs index dbe612209a..76907aacbb 100644 --- a/sqlx-mssql/src/type_info.rs +++ b/sqlx-mssql/src/type_info.rs @@ -51,9 +51,8 @@ pub(crate) fn type_name_for_tiberius(col_type: &tiberius::ColumnType) -> &'stati tiberius::ColumnType::Float8 => "FLOAT", tiberius::ColumnType::Datetime | tiberius::ColumnType::Datetimen => "DATETIME", tiberius::ColumnType::Datetime2 => "DATETIME2", - tiberius::ColumnType::Datetime4 | tiberius::ColumnType::DatetimeOffsetn => { - "DATETIMEOFFSET" - } + tiberius::ColumnType::Datetime4 => "SMALLDATETIME", + tiberius::ColumnType::DatetimeOffsetn => "DATETIMEOFFSET", tiberius::ColumnType::Daten => "DATE", tiberius::ColumnType::Timen => "TIME", tiberius::ColumnType::Decimaln | tiberius::ColumnType::Numericn => "DECIMAL", From d51b828389b287917d3d720ada8fc8e4e0408742 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Thu, 5 Mar 2026 00:06:41 -0500 Subject: [PATCH 28/33] fix: resolve clippy deny violations, parameterize SQL, and clean up warnings - Add #[allow] annotations with SAFETY comments for all guarded integer casts that violated crate-level deny lints (cast_possible_truncation, cast_possible_wrap, cast_sign_loss) - Parameterize advisory lock mode via @p2 instead of string interpolation - Parameterize migrate table existence check via @p1 while preserving atomic single-statement DDL - Change time_date_from_days to accept i64 with i64::from() at call sites to eliminate cast_sign_loss under --features time - Remove unnecessary as u64 casts on Time::increments() calls - Remove useless .into() on Error::Protocol(format!(...)) calls - Extract build_columns_from_describe_rows helper to deduplicate prepare_with/describe column-building logic - Fix minor issues: needless_borrow, needless_lifetimes, doc-test reborrows, todo!() -> error, uuid encode simplification, unused atoi dep, incorrect doc comment, redundant .into_iter() Author: Pablo Carrera --- Cargo.lock | 1 - sqlx-mssql/Cargo.toml | 1 - sqlx-mssql/src/advisory_lock.rs | 40 ++++--- sqlx-mssql/src/arguments.rs | 2 +- sqlx-mssql/src/connection/executor.rs | 147 ++++++++++++-------------- sqlx-mssql/src/error.rs | 2 +- sqlx-mssql/src/migrate.rs | 29 ++--- sqlx-mssql/src/options/parse.rs | 2 +- sqlx-mssql/src/testing/mod.rs | 2 +- sqlx-mssql/src/types/uuid.rs | 3 +- sqlx-mssql/src/value.rs | 100 +++++++++--------- 11 files changed, 153 insertions(+), 176 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bcd187ade7..559c2211d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4179,7 +4179,6 @@ name = "sqlx-mssql" version = "0.9.0-alpha.1" dependencies = [ "async-std", - "atoi", "bigdecimal 0.4.8", "bytes", "chrono", diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml index 607d977d79..40f039df05 100644 --- a/sqlx-mssql/Cargo.toml +++ b/sqlx-mssql/Cargo.toml @@ -55,7 +55,6 @@ either = "1.6.1" log = "0.4.18" tracing = { version = "0.1.37", features = ["log"] } percent-encoding = "2.1.0" -atoi = "2.0" dotenvy.workspace = true thiserror.workspace = true diff --git a/sqlx-mssql/src/advisory_lock.rs b/sqlx-mssql/src/advisory_lock.rs index a0a71d7b85..1899565f07 100644 --- a/sqlx-mssql/src/advisory_lock.rs +++ b/sqlx-mssql/src/advisory_lock.rs @@ -68,12 +68,12 @@ impl MssqlAdvisoryLockMode { /// let lock = MssqlAdvisoryLock::new("my_app_lock"); /// /// // Using the RAII guard (preferred): -/// let guard = lock.acquire_guard(conn).await?; +/// let guard = lock.acquire_guard(&mut *conn).await?; /// // ... do work under the lock, using `&mut *guard` as a connection ... /// guard.release_now().await?; /// /// // Or manual management: -/// lock.acquire(conn).await?; +/// lock.acquire(&mut *conn).await?; /// // ... do work ... /// lock.release(conn).await?; /// # Ok(()) @@ -138,18 +138,16 @@ impl MssqlAdvisoryLock { /// Returns an error if `sp_getapplock` returns a negative status code /// (e.g. lock request was cancelled or a deadlock was detected). pub async fn acquire(&self, conn: &mut MssqlConnection) -> Result<(), Error> { - let mode = self.mode.as_str(); - let sql = format!( + let status: i32 = query_scalar( "DECLARE @r INT; \ - EXEC @r = sp_getapplock @Resource = @p1, @LockMode = '{mode}', \ + EXEC @r = sp_getapplock @Resource = @p1, @LockMode = @p2, \ @LockOwner = 'Session', @LockTimeout = -1; \ - SELECT @r;" - ); - - let status: i32 = query_scalar(sqlx_core::sql_str::AssertSqlSafe(sql)) - .bind(&self.resource) - .fetch_one(&mut *conn) - .await?; + SELECT @r;", + ) + .bind(&self.resource) + .bind(self.mode.as_str()) + .fetch_one(&mut *conn) + .await?; if status < 0 { return Err(Error::Protocol(format!( @@ -167,18 +165,16 @@ impl MssqlAdvisoryLock { /// Returns `Ok(true)` if the lock was acquired, `Ok(false)` if it was not /// available (timeout). pub async fn try_acquire(&self, conn: &mut MssqlConnection) -> Result { - let mode = self.mode.as_str(); - let sql = format!( + let status: i32 = query_scalar( "DECLARE @r INT; \ - EXEC @r = sp_getapplock @Resource = @p1, @LockMode = '{mode}', \ + EXEC @r = sp_getapplock @Resource = @p1, @LockMode = @p2, \ @LockOwner = 'Session', @LockTimeout = 0; \ - SELECT @r;" - ); - - let status: i32 = query_scalar(sqlx_core::sql_str::AssertSqlSafe(sql)) - .bind(&self.resource) - .fetch_one(&mut *conn) - .await?; + SELECT @r;", + ) + .bind(&self.resource) + .bind(self.mode.as_str()) + .fetch_one(&mut *conn) + .await?; if status >= 0 { // 0 = granted synchronously, 1 = granted after wait diff --git a/sqlx-mssql/src/arguments.rs b/sqlx-mssql/src/arguments.rs index cc7bc66e42..7c438f5c1b 100644 --- a/sqlx-mssql/src/arguments.rs +++ b/sqlx-mssql/src/arguments.rs @@ -21,7 +21,7 @@ impl MssqlArguments { let is_null = value.encode(&mut self.values)?; if is_null.is_null() { // If the encoder signaled null but didn't push a value, push a Null - if self.values.last().map_or(true, |v| !matches!(v, MssqlArgumentValue::Null)) { + if self.values.last().is_none_or(|v| !matches!(v, MssqlArgumentValue::Null)) { self.values.push(MssqlArgumentValue::Null); } } diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 2d438fadb8..1d5080349d 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -15,7 +15,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::TryStreamExt; use sqlx_core::column::{ColumnOrigin, TableColumn}; -use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr as _, SqlStr}; use std::sync::Arc; /// Newtype wrapper to bridge `tiberius::ColumnData` into `tiberius::IntoSql`. @@ -64,7 +64,8 @@ fn offset_minutes_to_i16(offset_minutes: i32) -> Result { const MIN_OFFSET: i32 = -840; const MAX_OFFSET: i32 = 840; if (MIN_OFFSET..=MAX_OFFSET).contains(&offset_minutes) { - // -840..=840 fits in i16, so this cast is infallible. + // SAFETY: range check above guarantees -840..=840, which fits in i16. + #[allow(clippy::cast_possible_truncation)] Ok(offset_minutes as i16) } else { Err(Error::Encode( @@ -104,6 +105,8 @@ fn bigdecimal_to_numeric(v: &bigdecimal::BigDecimal) -> Result<(i128, u8), Error .into(), )); } + // SAFETY: guarded by `exponent > 37` check above; 0..=37 fits in u8. + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let scale = exponent as u8; let value: i128 = bigint.to_i128().ok_or_else(|| { @@ -216,6 +219,8 @@ impl MssqlConnection { #[cfg(feature = "rust_decimal")] MssqlArgumentValue::Decimal(v) => { let unpacked = v.unpack(); + // SAFETY: rust_decimal mantissa is ≤96 bits (hi:mid:lo are u32s), fits in i128. + #[allow(clippy::cast_possible_wrap)] let mut value = (((unpacked.hi as u128) << 64) + ((unpacked.mid as u128) << 32) + unpacked.lo as u128) @@ -229,9 +234,12 @@ impl MssqlConnection { format!("rust_decimal scale {scale} exceeds SQL Server maximum of 37").into(), )); } + // SAFETY: guarded by `scale > 37` check above; 0..=37 fits in u8. + #[allow(clippy::cast_possible_truncation)] + let scale_u8 = scale as u8; query.bind(tiberius::numeric::Numeric::new_with_scale( value, - scale as u8, + scale_u8, )); } #[cfg(feature = "time")] @@ -304,7 +312,7 @@ impl MssqlConnection { } #[cfg(feature = "bigdecimal")] MssqlArgumentValue::BigDecimal(v) => { - let (value, scale) = bigdecimal_to_numeric(&v)?; + let (value, scale) = bigdecimal_to_numeric(v)?; let cd = tiberius::ColumnData::Numeric(Some( tiberius::numeric::Numeric::new_with_scale(value, scale), )); @@ -331,8 +339,8 @@ impl MssqlConnection { } /// Collect all results from a tiberius QueryStream into a Vec. -async fn collect_results<'a>( - mut stream: tiberius::QueryStream<'a>, +async fn collect_results( + mut stream: tiberius::QueryStream<'_>, results: &mut Vec>, logger: &mut QueryLogger, ) -> Result<(), Error> { @@ -403,6 +411,55 @@ async fn collect_results<'a>( Ok(()) } +/// Build column metadata from `sp_describe_first_result_set` result rows. +/// +/// Returns `(columns, column_names, nullable)` where `nullable` contains one +/// `Option` per column (extracted from the `is_nullable` field). +fn build_columns_from_describe_rows( + rows: &[tiberius::Row], +) -> (Vec, HashMap, Vec>) { + let mut columns = Vec::with_capacity(rows.len()); + let mut column_names = HashMap::with_capacity(rows.len()); + let mut nullable = Vec::with_capacity(rows.len()); + + for (ordinal, row) in rows.iter().enumerate() { + let name: &str = row.get("name").unwrap_or(""); + let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); + let type_info = MssqlTypeInfo::new(type_name.to_uppercase()); + let is_nullable: Option = row.get("is_nullable"); + + let source_table: Option<&str> = row.get("source_table"); + let source_schema: Option<&str> = row.get("source_schema"); + let source_column: Option<&str> = row.get("source_column"); + + let origin = match (source_table, source_column) { + (Some(table), Some(col)) if !table.is_empty() && !col.is_empty() => { + let table_str = match source_schema { + Some(s) if !s.is_empty() => format!("{s}.{table}"), + _ => table.to_string(), + }; + ColumnOrigin::Table(TableColumn { + table: table_str.into(), + name: col.into(), + }) + } + _ => ColumnOrigin::Expression, + }; + + let ustr_name = UStr::new(name); + column_names.insert(ustr_name.clone(), ordinal); + columns.push(MssqlColumn { + ordinal, + name: ustr_name, + type_info, + origin, + }); + nullable.push(is_nullable); + } + + (columns, column_names, nullable) +} + impl<'c> Executor<'c> for &'c mut MssqlConnection { type Database = Mssql; @@ -464,54 +521,18 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { 'c: 'e, { Box::pin(async move { - // Use sp_describe_first_result_set to get column metadata let mut describe_query = tiberius::Query::new( "EXEC sp_describe_first_result_set @tsql = @p1", ); describe_query.bind(sql.as_str()); - let mut columns = Vec::new(); - let mut column_names = HashMap::new(); - let stream = describe_query .query(&mut self.inner.client) .await .map_err(tiberius_err)?; let rows: Vec = stream.into_first_result().await.map_err(tiberius_err)?; - - for (ordinal, row) in rows.iter().enumerate() { - let name: &str = row.get("name").unwrap_or(""); - let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); - let type_info = MssqlTypeInfo::new(type_name.to_uppercase()); - - let source_table: Option<&str> = row.get("source_table"); - let source_schema: Option<&str> = row.get("source_schema"); - let source_column: Option<&str> = row.get("source_column"); - - let origin = match (source_table, source_column) { - (Some(table), Some(col)) if !table.is_empty() && !col.is_empty() => { - let table_str = match source_schema { - Some(s) if !s.is_empty() => format!("{s}.{table}"), - _ => table.to_string(), - }; - ColumnOrigin::Table(TableColumn { - table: table_str.into(), - name: col.into(), - }) - } - _ => ColumnOrigin::Expression, - }; - - let ustr_name = UStr::new(name); - column_names.insert(ustr_name.clone(), ordinal); - columns.push(MssqlColumn { - ordinal, - name: ustr_name, - type_info, - origin, - }); - } + let (columns, column_names, _nullable) = build_columns_from_describe_rows(&rows); Ok(MssqlStatement { sql, @@ -534,7 +555,6 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { 'c: 'e, { Box::pin(async move { - // Query sp_describe_first_result_set directly so we can extract nullable info let mut describe_query = tiberius::Query::new( "EXEC sp_describe_first_result_set @tsql = @p1", ); @@ -548,44 +568,7 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { let rows: Vec = stream.into_first_result().await.map_err(tiberius_err)?; - let mut columns = Vec::new(); - let mut column_names = HashMap::new(); - let mut nullable = Vec::new(); - - for (ordinal, row) in rows.iter().enumerate() { - let name: &str = row.get("name").unwrap_or(""); - let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); - let type_info = MssqlTypeInfo::new(type_name.to_uppercase()); - let is_nullable: Option = row.get("is_nullable"); - - let source_table: Option<&str> = row.get("source_table"); - let source_schema: Option<&str> = row.get("source_schema"); - let source_column: Option<&str> = row.get("source_column"); - - let origin = match (source_table, source_column) { - (Some(table), Some(col)) if !table.is_empty() && !col.is_empty() => { - let table_str = match source_schema { - Some(s) if !s.is_empty() => format!("{s}.{table}"), - _ => table.to_string(), - }; - ColumnOrigin::Table(TableColumn { - table: table_str.into(), - name: col.into(), - }) - } - _ => ColumnOrigin::Expression, - }; - - let ustr_name = UStr::new(name); - column_names.insert(ustr_name.clone(), ordinal); - columns.push(MssqlColumn { - ordinal, - name: ustr_name, - type_info, - origin, - }); - nullable.push(is_nullable); - } + let (columns, _column_names, nullable) = build_columns_from_describe_rows(&rows); // Count parameters using sp_describe_undeclared_parameters let mut param_query = tiberius::Query::new( diff --git a/sqlx-mssql/src/error.rs b/sqlx-mssql/src/error.rs index 7f2e6b1eaa..dcfcae3718 100644 --- a/sqlx-mssql/src/error.rs +++ b/sqlx-mssql/src/error.rs @@ -30,7 +30,7 @@ impl MssqlDatabaseError { self.class } - /// The human-readable error message. + /// The server name that generated the error, if available. pub fn server(&self) -> Option<&str> { self.server.as_deref() } diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index e600e43db8..d32eac092d 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -114,21 +114,24 @@ impl Migrate for MssqlConnection { table_name: &'e str, ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { - let lit = table_name.replace('\'', "''"); let ident = escape_table_name(table_name); - self.execute(AssertSqlSafe(format!( - r#" -IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '{lit}') -CREATE TABLE {ident} ( - version BIGINT PRIMARY KEY, - description NVARCHAR(MAX) NOT NULL, - installed_on DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(), - success BIT NOT NULL, - checksum VARBINARY(MAX) NOT NULL, - execution_time BIGINT NOT NULL -); - "# + // Atomic check-and-create: the IF NOT EXISTS and CREATE TABLE run + // in a single batch so concurrent migrators cannot race. + // The WHERE clause is parameterized; the identifier must use + // bracket-escaping because DDL identifiers can't be parameterized. + query(AssertSqlSafe(format!( + "IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = @p1) \ + CREATE TABLE {ident} ( \ + version BIGINT PRIMARY KEY, \ + description NVARCHAR(MAX) NOT NULL, \ + installed_on DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(), \ + success BIT NOT NULL, \ + checksum VARBINARY(MAX) NOT NULL, \ + execution_time BIGINT NOT NULL \ + );" ))) + .bind(table_name) + .execute(&mut *self) .await?; Ok(()) diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs index e481e71692..9ac76eb5e3 100644 --- a/sqlx-mssql/src/options/parse.rs +++ b/sqlx-mssql/src/options/parse.rs @@ -46,7 +46,7 @@ impl MssqlConnectOptions { ); } - for (key, value) in url.query_pairs().into_iter() { + for (key, value) in url.query_pairs() { match &*key { "sslmode" | "ssl_mode" => { options = options.ssl_mode(match &*value { diff --git a/sqlx-mssql/src/testing/mod.rs b/sqlx-mssql/src/testing/mod.rs index a668d687e7..d225139676 100644 --- a/sqlx-mssql/src/testing/mod.rs +++ b/sqlx-mssql/src/testing/mod.rs @@ -94,7 +94,7 @@ impl TestSupport for Mssql { } async fn snapshot(_conn: &mut Self::Connection) -> Result, Error> { - todo!() + Err(Error::Configuration("snapshots are not yet supported for MSSQL".into())) } } diff --git a/sqlx-mssql/src/types/uuid.rs b/sqlx-mssql/src/types/uuid.rs index b06f50898a..6d315b0ec6 100644 --- a/sqlx-mssql/src/types/uuid.rs +++ b/sqlx-mssql/src/types/uuid.rs @@ -54,8 +54,7 @@ impl Encode<'_, Mssql> for uuid::fmt::Hyphenated { &self, buf: &mut Vec, ) -> Result { - let uuid = Uuid::parse_str(&self.to_string())?; - buf.push(MssqlArgumentValue::Uuid(uuid)); + buf.push(MssqlArgumentValue::Uuid(*self.as_uuid())); Ok(IsNull::No) } } diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index ffcc54e871..0219afd1a4 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -131,6 +131,8 @@ pub(crate) fn column_data_to_mssql_data( #[cfg(feature = "chrono")] tiberius::ColumnData::DateTime2(Some(dt2)) => { let date = chrono_date_from_days(dt2.date().days() as i64, 1)?; + // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + #[allow(clippy::cast_possible_wrap)] let ns = dt2.time().increments() as i64 * 10i64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); // infallible: (0,0,0) is always valid @@ -153,12 +155,9 @@ pub(crate) fn column_data_to_mssql_data( let seconds = dt.seconds_fragments() as u32 * 60; let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(seconds, 0) .ok_or_else(|| { - Error::Protocol( - format!( - "invalid SmallDateTime seconds: {seconds} exceeds seconds-in-a-day" - ) - .into(), - ) + Error::Protocol(format!( + "invalid SmallDateTime seconds: {seconds} exceeds seconds-in-a-day" + )) })?; Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) } @@ -168,6 +167,8 @@ pub(crate) fn column_data_to_mssql_data( } #[cfg(feature = "chrono")] tiberius::ColumnData::Time(Some(t)) => { + // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + #[allow(clippy::cast_possible_wrap)] let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid @@ -178,6 +179,8 @@ pub(crate) fn column_data_to_mssql_data( #[cfg(feature = "chrono")] tiberius::ColumnData::DateTimeOffset(Some(dto)) => { let date = chrono_date_from_days(dto.datetime2().date().days() as i64, 1)?; + // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + #[allow(clippy::cast_possible_wrap)] let ns = dto.datetime2().time().increments() as i64 * 10i64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); // infallible: (0,0,0) is always valid @@ -186,17 +189,14 @@ pub(crate) fn column_data_to_mssql_data( let naive = chrono::NaiveDateTime::new(date, time); let offset_secs = dto.offset() as i32 * 60; let fixed_offset = chrono::FixedOffset::east_opt(offset_secs).ok_or_else(|| { - Error::Protocol( - format!("invalid timezone offset: {offset_secs} seconds").into(), - ) + Error::Protocol(format!( + "invalid timezone offset: {offset_secs} seconds" + )) })?; let dt = naive.and_local_timezone(fixed_offset).single().ok_or_else(|| { - Error::Protocol( - format!( - "ambiguous or invalid local time for offset {offset_secs}s" - ) - .into(), - ) + Error::Protocol(format!( + "ambiguous or invalid local time for offset {offset_secs}s" + )) })?; Ok(MssqlData::DateTimeFixedOffset(dt)) } @@ -214,48 +214,46 @@ pub(crate) fn column_data_to_mssql_data( #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::Date(Some(d)) => { - Ok(MssqlData::TimeDate(time_date_from_days(d.days() as u64, 1)?)) + Ok(MssqlData::TimeDate(time_date_from_days(i64::from(d.days()), 1)?)) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::Time(Some(t)) => { - let ns = t.increments() as u64 + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); Ok(MssqlData::TimeTime(time_from_sec_fragments(ns)?)) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime2(Some(dt2)) => { - let date = time_date_from_days(dt2.date().days() as u64, 1)?; - let ns = dt2.time().increments() as u64 + let date = time_date_from_days(i64::from(dt2.date().days()), 1)?; + let ns = dt2.time().increments() * 10u64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); let time = time_from_sec_fragments(ns)?; Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime(Some(dt)) => { - let date = time_date_from_days(dt.days() as u64, 1900)?; + let date = time_date_from_days(i64::from(dt.days()), 1900)?; let ns = dt.seconds_fragments() as u64 * 1_000_000_000u64 / 300; let time = time_from_sec_fragments(ns)?; Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::SmallDateTime(Some(dt)) => { - let date = time_date_from_days(dt.days() as u64, 1900)?; + let date = time_date_from_days(i64::from(dt.days()), 1900)?; let seconds = dt.seconds_fragments() as u64 * 60; let time = time_from_sec_fragments(seconds * 1_000_000_000)?; Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTimeOffset(Some(dto)) => { - let date = time_date_from_days(dto.datetime2().date().days() as u64, 1)?; - let ns = dto.datetime2().time().increments() as u64 + let date = time_date_from_days(i64::from(dto.datetime2().date().days()), 1)?; + let ns = dto.datetime2().time().increments() * 10u64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); let time = time_from_sec_fragments(ns)?; let naive = time::PrimitiveDateTime::new(date, time); let offset_secs = dto.offset() as i32 * 60; let offset = time::UtcOffset::from_whole_seconds(offset_secs).map_err(|_| { - Error::Protocol( - format!("invalid UTC offset: {offset_secs} seconds").into(), - ) + Error::Protocol(format!("invalid UTC offset: {offset_secs} seconds")) })?; Ok(MssqlData::TimeOffsetDateTime(naive.assume_offset(offset))) } @@ -293,25 +291,25 @@ pub(crate) fn column_data_to_mssql_data( other => { let debug = format!("{other:?}"); let truncated = if debug.len() > 200 { &debug[..200] } else { &debug }; - Err(Error::Protocol( - format!("unsupported tiberius ColumnData variant: {truncated}").into(), - )) + Err(Error::Protocol(format!( + "unsupported tiberius ColumnData variant: {truncated}" + ))) } } } /// Convert days since `start_year`-01-01 to a `time::Date`. #[cfg(feature = "time")] -fn time_date_from_days(days: u64, start_year: i32) -> Result { +fn time_date_from_days(days: i64, start_year: i32) -> Result { let start = time::Date::from_ordinal_date(start_year, 1).map_err(|_| { - Error::Protocol(format!("invalid start year for date: {start_year}").into()) + Error::Protocol(format!("invalid start year for date: {start_year}")) })?; start - .checked_add(time::Duration::days(days as i64)) + .checked_add(time::Duration::days(days)) .ok_or_else(|| { - Error::Protocol( - format!("date overflow: {days} days from {start_year}-01-01").into(), - ) + Error::Protocol(format!( + "date overflow: {days} days from {start_year}-01-01" + )) }) } @@ -320,26 +318,26 @@ fn time_date_from_days(days: u64, start_year: i32) -> Result fn time_from_sec_fragments(nanoseconds: u64) -> Result { const NANOS_PER_DAY: u64 = 86_400_000_000_000; if nanoseconds >= NANOS_PER_DAY { - return Err(Error::Protocol( - format!( - "time nanoseconds out of range: {nanoseconds} (must be < {NANOS_PER_DAY})" - ) - .into(), - )); + return Err(Error::Protocol(format!( + "time nanoseconds out of range: {nanoseconds} (must be < {NANOS_PER_DAY})" + ))); } - // After the bounds check, hours is 0..=23, minutes 0..=59, seconds 0..=59, - // so the `as u8` casts and `from_hms_nano` are all infallible. + // SAFETY: bounds check above guarantees nanoseconds < 86_400_000_000_000, + // so hours ≤ 23, minutes ≤ 59, seconds ≤ 59 — all fit in u8. + #[allow(clippy::cast_possible_truncation)] let hours = (nanoseconds / 3_600_000_000_000) as u8; let remaining = nanoseconds % 3_600_000_000_000; + #[allow(clippy::cast_possible_truncation)] let minutes = (remaining / 60_000_000_000) as u8; let remaining = remaining % 60_000_000_000; + #[allow(clippy::cast_possible_truncation)] let seconds = (remaining / 1_000_000_000) as u8; + #[allow(clippy::cast_possible_truncation)] let nanos = (remaining % 1_000_000_000) as u32; time::Time::from_hms_nano(hours, minutes, seconds, nanos).map_err(|_| { - Error::Protocol( - format!("invalid time: {hours:02}:{minutes:02}:{seconds:02}.{nanos:09}") - .into(), - ) + Error::Protocol(format!( + "invalid time: {hours:02}:{minutes:02}:{seconds:02}.{nanos:09}" + )) }) } @@ -347,13 +345,13 @@ fn time_from_sec_fragments(nanoseconds: u64) -> Result { #[cfg(feature = "chrono")] fn chrono_date_from_days(days: i64, start_year: i32) -> Result { let start = chrono::NaiveDate::from_ymd_opt(start_year, 1, 1).ok_or_else(|| { - Error::Protocol(format!("invalid start year for date: {start_year}").into()) + Error::Protocol(format!("invalid start year for date: {start_year}")) })?; start .checked_add_signed(chrono::Duration::days(days)) .ok_or_else(|| { - Error::Protocol( - format!("date overflow: {days} days from {start_year}-01-01").into(), - ) + Error::Protocol(format!( + "date overflow: {days} days from {start_year}-01-01" + )) }) } From 0e42f0ed79e7cbba7532f94eace6f9cfcd9c2bf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Thu, 5 Mar 2026 01:30:30 -0500 Subject: [PATCH 29/33] fix: safe UTF-8 truncation, handle XML, take ColumnData by value, use lossless casts - Fix panic on multi-byte UTF-8 in catch-all error truncation by using is_char_boundary to retreat to a valid boundary - Handle ColumnData::Xml(Some(...)) explicitly, converting to MssqlData::String via XmlData::into_string() - Take ColumnData by value instead of by reference, eliminating allocations for String/Binary/Xml variants (into_owned + move instead of to_string/to_vec) - Replace `as u64` with u64::from() for lossless widening casts in time encoding - Extract repeated .time() calls into local bindings for clarity Author: Pablo Carrera --- sqlx-mssql/src/connection/executor.rs | 30 ++++++------- sqlx-mssql/src/value.rs | 64 +++++++++++++++++---------- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index 1d5080349d..b6c8fdacc5 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -194,9 +194,9 @@ impl MssqlConnection { let naive = v.naive_local(); let days = days_since_epoch_to_u32((naive.date() - epoch).num_days())?; let time = naive.time(); - let total_ns = time.num_seconds_from_midnight() as u64 + let total_ns = u64::from(time.num_seconds_from_midnight()) * 1_000_000_000 - + (time.nanosecond() as u64 % 1_000_000_000); + + (u64::from(time.nanosecond()) % 1_000_000_000); let increments = total_ns / 100; let offset_minutes = v.offset().local_minus_utc() / 60; @@ -254,10 +254,10 @@ impl MssqlConnection { #[cfg(feature = "time")] MssqlArgumentValue::TimeTime(v) => { let (h, m, s, ns) = v.as_hms_nano(); - let total_ns = h as u64 * 3_600_000_000_000 - + m as u64 * 60_000_000_000 - + s as u64 * 1_000_000_000 - + ns as u64; + let total_ns = u64::from(h) * 3_600_000_000_000 + + u64::from(m) * 60_000_000_000 + + u64::from(s) * 1_000_000_000 + + u64::from(ns); // Scale 7 = 100ns increments let increments = total_ns / 100; let cd = tiberius::ColumnData::Time(Some( @@ -272,10 +272,10 @@ impl MssqlConnection { let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); let days = days_since_epoch_to_u32((date - epoch).whole_days())?; let (h, m, s, ns) = time.as_hms_nano(); - let total_ns = h as u64 * 3_600_000_000_000 - + m as u64 * 60_000_000_000 - + s as u64 * 1_000_000_000 - + ns as u64; + let total_ns = u64::from(h) * 3_600_000_000_000 + + u64::from(m) * 60_000_000_000 + + u64::from(s) * 1_000_000_000 + + u64::from(ns); let increments = total_ns / 100; let cd = tiberius::ColumnData::DateTime2(Some( tiberius::time::DateTime2::new( @@ -293,10 +293,10 @@ impl MssqlConnection { let time = v.time(); let days = days_since_epoch_to_u32((date - epoch).whole_days())?; let (h, m, s, ns) = time.as_hms_nano(); - let total_ns = h as u64 * 3_600_000_000_000 - + m as u64 * 60_000_000_000 - + s as u64 * 1_000_000_000 - + ns as u64; + let total_ns = u64::from(h) * 3_600_000_000_000 + + u64::from(m) * 60_000_000_000 + + u64::from(s) * 1_000_000_000 + + u64::from(ns); let increments = total_ns / 100; let dt2 = tiberius::time::DateTime2::new( tiberius::time::Date::new(days), @@ -390,7 +390,7 @@ async fn collect_results( // Convert tiberius row to MssqlRow by iterating over cells let values: Vec = row .into_iter() - .map(|data| column_data_to_mssql_data(&data)) + .map(column_data_to_mssql_data) .collect::, _>>()?; rows_affected += 1; diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index 0219afd1a4..395e004763 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -115,26 +115,30 @@ impl<'r> ValueRef<'r> for MssqlValueRef<'r> { /// Convert a `tiberius::ColumnData` into our owned `MssqlData`. pub(crate) fn column_data_to_mssql_data( - data: &tiberius::ColumnData<'_>, + data: tiberius::ColumnData<'_>, ) -> Result { match data { - tiberius::ColumnData::U8(Some(v)) => Ok(MssqlData::U8(*v)), - tiberius::ColumnData::I16(Some(v)) => Ok(MssqlData::I16(*v)), - tiberius::ColumnData::I32(Some(v)) => Ok(MssqlData::I32(*v)), - tiberius::ColumnData::I64(Some(v)) => Ok(MssqlData::I64(*v)), - tiberius::ColumnData::F32(Some(v)) => Ok(MssqlData::F32(*v)), - tiberius::ColumnData::F64(Some(v)) => Ok(MssqlData::F64(*v)), - tiberius::ColumnData::Bit(Some(v)) => Ok(MssqlData::Bool(*v)), - tiberius::ColumnData::String(Some(v)) => Ok(MssqlData::String(v.to_string())), - tiberius::ColumnData::Binary(Some(v)) => Ok(MssqlData::Binary(v.to_vec())), + tiberius::ColumnData::U8(Some(v)) => Ok(MssqlData::U8(v)), + tiberius::ColumnData::I16(Some(v)) => Ok(MssqlData::I16(v)), + tiberius::ColumnData::I32(Some(v)) => Ok(MssqlData::I32(v)), + tiberius::ColumnData::I64(Some(v)) => Ok(MssqlData::I64(v)), + tiberius::ColumnData::F32(Some(v)) => Ok(MssqlData::F32(v)), + tiberius::ColumnData::F64(Some(v)) => Ok(MssqlData::F64(v)), + tiberius::ColumnData::Bit(Some(v)) => Ok(MssqlData::Bool(v)), + tiberius::ColumnData::String(Some(v)) => Ok(MssqlData::String(v.into_owned())), + tiberius::ColumnData::Binary(Some(v)) => Ok(MssqlData::Binary(v.into_owned())), + tiberius::ColumnData::Xml(Some(xml)) => { + Ok(MssqlData::String(xml.into_owned().into_string())) + } #[cfg(feature = "chrono")] tiberius::ColumnData::DateTime2(Some(dt2)) => { let date = chrono_date_from_days(dt2.date().days() as i64, 1)?; // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + let t = dt2.time(); #[allow(clippy::cast_possible_wrap)] - let ns = dt2.time().increments() as i64 - * 10i64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); + let ns = t.increments() as i64 + * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); @@ -169,8 +173,8 @@ pub(crate) fn column_data_to_mssql_data( tiberius::ColumnData::Time(Some(t)) => { // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. #[allow(clippy::cast_possible_wrap)] - let ns = - t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() as i64 + * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); @@ -180,9 +184,10 @@ pub(crate) fn column_data_to_mssql_data( tiberius::ColumnData::DateTimeOffset(Some(dto)) => { let date = chrono_date_from_days(dto.datetime2().date().days() as i64, 1)?; // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + let t = dto.datetime2().time(); #[allow(clippy::cast_possible_wrap)] - let ns = dto.datetime2().time().increments() as i64 - * 10i64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); + let ns = t.increments() as i64 + * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); @@ -202,7 +207,7 @@ pub(crate) fn column_data_to_mssql_data( } #[cfg(feature = "uuid")] - tiberius::ColumnData::Guid(Some(v)) => Ok(MssqlData::Uuid(*v)), + tiberius::ColumnData::Guid(Some(v)) => Ok(MssqlData::Uuid(v)), #[cfg(feature = "rust_decimal")] tiberius::ColumnData::Numeric(Some(n)) => { @@ -225,8 +230,9 @@ pub(crate) fn column_data_to_mssql_data( #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime2(Some(dt2)) => { let date = time_date_from_days(i64::from(dt2.date().days()), 1)?; - let ns = dt2.time().increments() - * 10u64.pow(9u32.saturating_sub(dt2.time().scale() as u32)); + let t = dt2.time(); + let ns = t.increments() + * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); let time = time_from_sec_fragments(ns)?; Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) } @@ -247,8 +253,9 @@ pub(crate) fn column_data_to_mssql_data( #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTimeOffset(Some(dto)) => { let date = time_date_from_days(i64::from(dto.datetime2().date().days()), 1)?; - let ns = dto.datetime2().time().increments() - * 10u64.pow(9u32.saturating_sub(dto.datetime2().time().scale() as u32)); + let t = dto.datetime2().time(); + let ns = t.increments() + * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); let time = time_from_sec_fragments(ns)?; let naive = time::PrimitiveDateTime::new(date, time); let offset_secs = dto.offset() as i32 * 60; @@ -287,10 +294,21 @@ pub(crate) fn column_data_to_mssql_data( | tiberius::ColumnData::Date(None) | tiberius::ColumnData::Time(None) => Ok(MssqlData::Null), - // Unhandled Some(...) variant — real data the driver can't convert + // Unhandled Some(...) variant — real data the driver can't convert. + // Currently unreachable with all features enabled, but kept for forward + // compatibility when tiberius adds new variants. + #[allow(unreachable_patterns)] other => { let debug = format!("{other:?}"); - let truncated = if debug.len() > 200 { &debug[..200] } else { &debug }; + let truncated = if debug.len() > 200 { + let mut end = 200; + while !debug.is_char_boundary(end) { + end -= 1; + } + &debug[..end] + } else { + &debug + }; Err(Error::Protocol(format!( "unsupported tiberius ColumnData variant: {truncated}" ))) From ce69e2a67dd59aabdbf7f1be2244b7ab0f0577ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Tue, 10 Mar 2026 00:49:01 -0500 Subject: [PATCH 30/33] docs: rewrite MSSQL_SUPPORT.md as a complete developer guide Transform the feature overview into a self-contained guide covering connection pooling, query patterns, FromRow/derive macros, error handling, OUTPUT INSERTED, stored procedures, nested transactions, and pool callbacks. Restructure sections to follow a developer's learning path, add recommended feature flag sets, and correct begin_with_isolation usage and advisory lock drop behavior. Author: Pablo Carrera --- MSSQL_SUPPORT.md | 881 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 777 insertions(+), 104 deletions(-) diff --git a/MSSQL_SUPPORT.md b/MSSQL_SUPPORT.md index bbc9b4d09a..7decf7644d 100644 --- a/MSSQL_SUPPORT.md +++ b/MSSQL_SUPPORT.md @@ -1,6 +1,6 @@ # MSSQL (SQL Server) Support for SQLx -This document covers all MSSQL/SQL Server additions in the `feat/mssql-support` branch, built on top of the [Tiberius](https://github.com/prisma/tiberius) TDS driver. +A complete developer guide for using SQLx with Microsoft SQL Server, built on the [Tiberius](https://github.com/prisma/tiberius) TDS driver. --- @@ -8,21 +8,25 @@ This document covers all MSSQL/SQL Server additions in the `feat/mssql-support` - [Overview](#overview) - [Getting Started](#getting-started) +- [Feature Flags](#feature-flags) - [Connection & Authentication](#connection--authentication) +- [Connection Pooling](#connection-pooling) - [SSL/TLS](#ssltls) - [Type Mappings](#type-mappings) +- [Querying](#querying) - [Compile-Time Query Macros](#compile-time-query-macros) -- [Any Driver Support](#any-driver-support) -- [Migrations](#migrations) +- [FromRow & Derive Macros](#fromrow--derive-macros) +- [QueryBuilder](#querybuilder) - [Transactions & Isolation Levels](#transactions--isolation-levels) +- [Migrations](#migrations) - [Advisory Locks](#advisory-locks) - [Bulk Insert](#bulk-insert) -- [QueryBuilder](#querybuilder) - [XML Type](#xml-type) +- [Error Handling](#error-handling) +- [Any Driver Support](#any-driver-support) - [Examples](#examples) - [Docker & CI](#docker--ci) - [Test Coverage](#test-coverage) -- [Feature Flags](#feature-flags) --- @@ -34,11 +38,13 @@ Full SQL Server support has been added to SQLx, bringing feature parity with Pos - Four authentication methods (SQL Server, Windows/NTLM, Integrated/GSSAPI, Azure AD) - SSL/TLS with configurable modes - Compile-time checked queries via macros +- Connection pooling with callbacks - Runtime-polymorphic `Any` driver support - Database migrations with `sqlx migrate` - RAII advisory locks via `sp_getapplock`/`sp_releaseapplock` - Bulk insert via the TDS `INSERT BULK` protocol -- Transaction isolation levels +- Transaction isolation levels including `SNAPSHOT` +- Nested transactions via savepoints - Testing infrastructure with Docker Compose (MSSQL 2019 & 2022) **URL schemes:** `mssql://` and `sqlserver://` @@ -47,37 +53,125 @@ Full SQL Server support has been added to SQLx, bringing feature parity with Pos ## Getting Started -Add SQLx with the `mssql` feature to your `Cargo.toml`: +### Add SQLx to Your Project + +SQLx requires three choices in your feature flags: + +1. **Database driver** — `mssql` +2. **Async runtime** — one of `runtime-tokio` or `runtime-async-std` +3. **TLS backend** — one of `tls-native-tls`, `tls-rustls-aws-lc-rs`, `tls-rustls-ring`, or `tls-none` ```toml [dependencies] -sqlx = { version = "0.8", features = ["mssql", "runtime-tokio"] } +sqlx = { version = "0.9", features = [ + "mssql", # SQL Server driver + "runtime-tokio", # async runtime (or runtime-async-std) + "tls-native-tls", # TLS backend (see Feature Flags for options) +] } +tokio = { version = "1", features = ["full"] } ``` -Connect to a database: +> **Tip:** If you're unsure which TLS backend to pick, `tls-native-tls` is the safest default for SQL Server — it uses the platform's native TLS stack (SChannel on Windows, OpenSSL on Linux) and has the best compatibility with SQL Server's TLS implementation. + +### Minimal Example ```rust use sqlx::mssql::MssqlPool; -let pool = MssqlPool::connect("mssql://sa:YourPassword@localhost/mydb").await?; +#[tokio::main] +async fn main() -> Result<(), sqlx::Error> { + let pool = MssqlPool::connect("mssql://sa:YourStrong!Passw0rd@localhost/master").await?; -let row: (i32,) = sqlx::query_as("SELECT @p1") - .bind(42i32) - .fetch_one(&pool) - .await?; + let row: (i32,) = sqlx::query_as("SELECT @p1") + .bind(42i32) + .fetch_one(&pool) + .await?; + + println!("Got: {}", row.0); + Ok(()) +} +``` + +--- + +## Feature Flags + +### Required + +| Feature | Description | +|---------|-------------| +| `mssql` | Enable the MSSQL driver | + +### Async Runtime (pick one) + +| Feature | Description | +|---------|-------------| +| `runtime-tokio` | Use Tokio | +| `runtime-async-std` | Use async-std (via async-global-executor / smol) | + +### TLS Backend (pick one) + +| Feature | Description | +|---------|-------------| +| `tls-native-tls` | Platform-native TLS (recommended for SQL Server) | +| `tls-rustls-aws-lc-rs` | Rustls with AWS LC crypto | +| `tls-rustls-ring` | Rustls with ring crypto | +| `tls-none` | No TLS support | + +### Type Integrations + +| Feature | Description | +|---------|-------------| +| `json` | JSON type support via `serde_json` (stored as `NVARCHAR`) | +| `uuid` | `uuid::Uuid` ↔ `UNIQUEIDENTIFIER` | +| `chrono` | `chrono` datetime types | +| `time` | `time` crate datetime types | +| `rust_decimal` | `rust_decimal::Decimal` ↔ `DECIMAL`/`NUMERIC`/`MONEY` | +| `bigdecimal` | `bigdecimal::BigDecimal` ↔ `DECIMAL`/`NUMERIC`/`MONEY` | + +### Authentication + +| Feature | Description | +|---------|-------------| +| `winauth` | Windows/NTLM authentication | +| `integrated-auth-gssapi` | Integrated auth (Kerberos on Unix, SSPI on Windows) | + +### Functionality + +| Feature | Description | +|---------|-------------| +| `any` | Runtime-polymorphic `Any` driver | +| `migrate` | Database migrations | +| `offline` | Offline mode for compile-time macros (no live database needed in CI) | + +### Recommended Starter Set + +For most applications: + +```toml +sqlx = { version = "0.9", features = [ + "mssql", + "runtime-tokio", + "tls-native-tls", + "migrate", + "json", + "chrono", # or "time" + "uuid", + "rust_decimal", +] } ``` --- ## Connection & Authentication -**Connection string format:** +### Connection String Format ``` mssql://[user[:password]@]host[:port][/database][?properties] ``` -**Connection options:** +### Connection Options | Option | Default | Description | |--------|---------|-------------| @@ -91,6 +185,38 @@ mssql://[user[:password]@]host[:port][/database][?properties] | `statement-cache-capacity` | `100` | Max cached prepared statements | | `application_intent` | `read_write` | `read_write` or `read_only` (Always On replicas) | +### Programmatic Configuration + +Use `MssqlConnectOptions` for full control over connection settings: + +```rust +use sqlx::mssql::MssqlConnectOptions; + +let opts = MssqlConnectOptions::new() + .host("db.example.com") + .port(1433) + .username("app_user") + .password("s3cret") + .database("myapp") + .app_name("my-service") + .statement_cache_capacity(200) + .application_intent_read_only(false); + +let pool = MssqlPool::connect_with(opts).await?; +``` + +### URL-Based Configuration + +```rust +use sqlx::mssql::MssqlPool; + +let pool = MssqlPool::connect( + "mssql://app_user:s3cret@db.example.com:1433/myapp?app_name=my-service" +).await?; +``` + +Both approaches are equivalent. Use `MssqlConnectOptions` when you need to build connection parameters dynamically (e.g., from environment variables or a config file). + ### Authentication Methods **1. SQL Server Auth (default)** @@ -123,7 +249,7 @@ let opts = MssqlConnectOptions::new() **4. Azure AD Token Auth** -Pass a bearer token for Azure Active Directory authentication. +Pass a bearer token for Azure Active Directory authentication. This takes precedence over all other auth methods. ```rust let opts = MssqlConnectOptions::new() @@ -133,6 +259,113 @@ let opts = MssqlConnectOptions::new() --- +## Connection Pooling + +For production applications, always use a connection pool rather than individual connections. + +### Basic Pool + +```rust +use sqlx::mssql::MssqlPool; + +// Simple — uses default pool settings +let pool = MssqlPool::connect("mssql://sa:password@localhost/mydb").await?; +``` + +### Configuring the Pool + +```rust +use sqlx::mssql::{MssqlPool, MssqlPoolOptions}; +use std::time::Duration; + +let pool = MssqlPoolOptions::new() + .max_connections(20) + .min_connections(5) + .acquire_timeout(Duration::from_secs(10)) + .idle_timeout(Duration::from_secs(600)) + .max_lifetime(Duration::from_secs(1800)) + .test_before_acquire(true) + .connect("mssql://sa:password@localhost/mydb") + .await?; +``` + +### Pool Configuration Reference + +| Option | Default | Description | +|--------|---------|-------------| +| `max_connections` | `10` | Maximum number of connections in the pool | +| `min_connections` | `0` | Minimum idle connections maintained (best-effort) | +| `acquire_timeout` | `30s` | Max time to wait for a connection (includes all phases) | +| `idle_timeout` | `10min` | Close connections idle longer than this | +| `max_lifetime` | `30min` | Close connections older than this | +| `test_before_acquire` | `true` | Ping idle connections before returning them | +| `acquire_slow_threshold` | `2s` | Log a warning for acquires slower than this | + +### Eager vs Lazy Connection + +```rust +// connect() — opens at least one connection immediately, fails fast on bad credentials +let pool = MssqlPoolOptions::new() + .connect("mssql://sa:password@localhost/mydb") + .await?; + +// connect_lazy() — no connections opened until first use +// Useful in tests or when the database may not be available at startup +let pool = MssqlPoolOptions::new() + .connect_lazy("mssql://sa:password@localhost/mydb")?; +``` + +### Pool Callbacks + +Callbacks let you run logic at key points in a connection's lifecycle: + +```rust +let pool = MssqlPoolOptions::new() + .max_connections(10) + // Called after a new connection is established + .after_connect(|conn, _metadata| { + Box::pin(async move { + // e.g., SET session options + sqlx::query("SET ANSI_NULLS ON") + .execute(&mut *conn) + .await?; + Ok(()) + }) + }) + // Called before returning an idle connection from the pool + .before_acquire(|conn, _metadata| { + Box::pin(async move { + // Return Ok(true) to use this connection + // Return Ok(false) to close it and try another + Ok(true) + }) + }) + // Called when a connection is returned to the pool + .after_release(|conn, _metadata| { + Box::pin(async move { + // Return Ok(true) to keep in the pool + // Return Ok(false) to close it + Ok(true) + }) + }) + .connect("mssql://sa:password@localhost/mydb") + .await?; +``` + +Each callback receives a `PoolConnectionMetadata` with: +- `age` — time since the connection was first opened +- `idle_for` — time the connection has been idle (only meaningful in `before_acquire`) + +### Production Tuning Tips + +- Set `max_connections` based on your workload and SQL Server's `max worker threads` setting. A good starting point is 2× the number of CPU cores. +- Set `min_connections` to keep a warm pool and avoid cold-start latency. +- Keep `max_lifetime` at 30 minutes or less to cycle connections and pick up DNS changes. +- Use `after_connect` to set session-level options (e.g., `SET ANSI_NULLS ON`). +- Use `test_before_acquire(true)` (the default) in production. Disable only if latency is critical and you handle stale connections at the application level. + +--- + ## SSL/TLS Configurable encryption modes for the TDS connection. @@ -153,28 +386,42 @@ Configurable encryption modes for the TDS connection. | `trust_server_certificate` | Trust without validation (default: `false`) | | `trust_server_certificate_ca` | Path to CA certificate file (`.pem`, `.crt`, `.der`) | +> **Note:** `trust_server_certificate` and `trust_server_certificate_ca` are mutually exclusive. If both are set, the CA path takes precedence. + ``` mssql://sa:password@localhost/mydb?sslmode=required&trust_server_certificate=true ``` +**Programmatic configuration:** + +```rust +use sqlx::mssql::{MssqlConnectOptions, MssqlSslMode}; + +let opts = MssqlConnectOptions::new() + .host("db.example.com") + .ssl_mode(MssqlSslMode::Required) + .trust_server_certificate(false) + .trust_server_certificate_ca("/path/to/ca.pem"); +``` + --- ## Type Mappings ### Primitive Types -| Rust Type | SQL Server Type(s) | -|-----------|-------------------| -| `bool` | `BIT` | -| `u8` | `TINYINT` (0–255) | -| `i8` | `TINYINT` (0–127) | -| `i16` | `SMALLINT` | -| `i32` | `INT` | -| `i64` | `BIGINT` | -| `f32` | `REAL`, `FLOAT` | -| `f64` | `REAL`, `FLOAT`, `MONEY`, `SMALLMONEY` | -| `&str` / `String` | `NVARCHAR` | -| `&[u8]` / `Vec` | `VARBINARY` | +| Rust Type | SQL Server Type(s) | Notes | +|-----------|-------------------|-------| +| `bool` | `BIT` | | +| `u8` | `TINYINT` | Unsigned, full range 0–255 | +| `i8` | `TINYINT` | **Only 0–127** (SQL Server TINYINT is unsigned; values 128–255 don't fit in `i8`) | +| `i16` | `SMALLINT` | | +| `i32` | `INT` | | +| `i64` | `BIGINT` | | +| `f32` | `REAL`, `FLOAT` | | +| `f64` | `REAL`, `FLOAT`, `MONEY`, `SMALLMONEY` | | +| `&str` / `String` | `NVARCHAR` | | +| `&[u8]` / `Vec` | `VARBINARY` | | ### Feature-Gated Types @@ -219,7 +466,9 @@ mssql://sa:password@localhost/mydb?sslmode=required&trust_server_certificate=tru | Rust Type | SQL Server Type | |-----------|----------------| -| `serde_json::Value` / `Json` | `NVARCHAR` (stored as JSON string) | +| `serde_json::Value` / `Json` | `NVARCHAR` | + +> **Note:** SQL Server has no native JSON column type. JSON is stored as `NVARCHAR` text. You can still use SQL Server's built-in JSON functions (`JSON_VALUE`, `OPENJSON`, etc.) in your queries. #### XML @@ -231,6 +480,162 @@ mssql://sa:password@localhost/mydb?sslmode=required&trust_server_certificate=tru All types above support `Option` for nullable columns. +### Runtime Type Inspection + +Use `MssqlTypeInfo` to inspect column types at runtime: + +```rust +use sqlx::TypeInfo; + +let statement = conn.prepare("SELECT id, name FROM users".into_sql_str()).await?; +assert_eq!(statement.column(0).type_info().name(), "BIGINT"); +assert_eq!(statement.column(1).type_info().name(), "NVARCHAR"); +``` + +--- + +## Querying + +MSSQL uses `@p1`, `@p2`, `@p3`, ... as parameter placeholders (not `$1` or `?`). + +### Basic Queries + +```rust +use sqlx::Row; + +// Execute a statement (INSERT, UPDATE, DELETE) +let result = sqlx::query("UPDATE users SET active = 1 WHERE id = @p1") + .bind(42i32) + .execute(&pool) + .await?; +println!("Rows affected: {}", result.rows_affected()); + +// Fetch a single row +let row = sqlx::query("SELECT id, name FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +let name: String = row.get("name"); + +// Fetch all rows +let rows = sqlx::query("SELECT id, name FROM users") + .fetch_all(&pool) + .await?; + +// Fetch optional (returns None if no rows) +let maybe_row = sqlx::query("SELECT id FROM users WHERE email = @p1") + .bind("alice@example.com") + .fetch_optional(&pool) + .await?; +``` + +### Typed Queries with `query_as` + +```rust +let user: (i32, String) = sqlx::query_as("SELECT id, name FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; + +// Or with a named struct (see FromRow section) +let user: User = sqlx::query_as("SELECT id, name FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +``` + +### Scalar Queries + +```rust +let count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&pool) + .await?; +``` + +### Streaming with `fetch` + +For large result sets, use `fetch` to stream rows without loading them all into memory: + +```rust +use futures::TryStreamExt; + +let mut stream = sqlx::query("SELECT id, name FROM users") + .fetch(&pool); + +while let Some(row) = stream.try_next().await? { + let id: i32 = row.get("id"); + // process row... +} +``` + +### Row Access + +```rust +use sqlx::Row; + +let row = sqlx::query("SELECT id, name FROM users") + .fetch_one(&pool) + .await?; + +// By column name +let name: String = row.get("name"); + +// By column index (0-based) +let id: i32 = row.get(0); + +// Fallible access (returns Result) +let name: String = row.try_get("name")?; +``` + +### Custom Row Mapping + +```rust +let value = sqlx::query("SELECT 1 + @p1") + .bind(5_i32) + .try_map(|row: MssqlRow| row.try_get::(0)) + .fetch_one(&pool) + .await?; +``` + +### OUTPUT INSERTED (MSSQL's RETURNING) + +SQL Server does not support the `RETURNING` clause. Use `OUTPUT INSERTED` instead to get values from inserted/updated rows: + +```rust +// Get the auto-generated ID after INSERT +let id: i64 = sqlx::query_scalar( + "INSERT INTO users (name) OUTPUT INSERTED.id VALUES (@p1)" +) + .bind("Alice") + .fetch_one(&pool) + .await?; + +// Get multiple columns +let row = sqlx::query( + "INSERT INTO users (name, email) OUTPUT INSERTED.id, INSERTED.created_at VALUES (@p1, @p2)" +) + .bind("Alice") + .bind("alice@example.com") + .fetch_one(&pool) + .await?; +``` + +### Calling Stored Procedures + +Use `EXEC` to call stored procedures: + +```rust +let rows = sqlx::query("EXEC GetUsersByRole @p1") + .bind("admin") + .fetch_all(&pool) + .await?; + +// With output parameters, use a query that captures results +let result: (i32,) = sqlx::query_as("EXEC CountUsers") + .fetch_one(&pool) + .await?; +``` + --- ## Compile-Time Query Macros @@ -260,65 +665,158 @@ let count = sqlx::query_scalar!("SELECT COUNT(*) FROM users") .await?; ``` -**Offline mode** is also supported — run `cargo sqlx prepare` to generate query metadata for CI builds without a live database. +### Offline Mode + +For CI builds without a live database, use offline mode: + +```bash +# Generate query metadata (run with DATABASE_URL set) +cargo sqlx prepare + +# This creates a .sqlx/ directory with cached query metadata. +# Commit this directory to version control. +``` + +Then build without a database: + +```bash +SQLX_OFFLINE=true cargo build +``` + +Enable the `offline` feature flag to use this capability. --- -## Any Driver Support +## FromRow & Derive Macros -MSSQL is fully integrated with the `Any` runtime-polymorphic driver, enabled via the `any` feature flag. +### Basic FromRow + +Map query results directly to a struct: ```rust -use sqlx::any::AnyPool; +#[derive(sqlx::FromRow)] +struct User { + id: i32, + name: String, + email: Option, +} -// Connects to whichever database the URL points to -let pool = AnyPool::connect("mssql://sa:password@localhost/mydb").await?; +let user: User = sqlx::query_as("SELECT id, name, email FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +``` -let rows = sqlx::query("SELECT 1 + 1 AS result") - .fetch_all(&pool) +### Enum Types with `#[derive(Type)]` + +**Integer-repr enums** map to SQL integer columns: + +```rust +#[derive(sqlx::Type, Debug, PartialEq)] +#[repr(i32)] +enum Status { + Active = 1, + Inactive = 0, + Banned = -1, +} + +// Works with INT columns +let status: Status = sqlx::query_scalar("SELECT status FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) .await?; ``` -All standard operations work through `Any`: queries, transactions, ping, close, and prepared statements. +**Transparent wrappers** create newtypes over existing SQL types: ---- +```rust +#[derive(sqlx::Type, Debug, PartialEq)] +#[sqlx(transparent)] +struct UserId(i64); -## Migrations +let id: UserId = sqlx::query_scalar("SELECT id FROM users WHERE id = @p1") + .bind(1i64) + .fetch_one(&pool) + .await?; +``` -MSSQL supports the full `sqlx migrate` workflow. +### Combining FromRow and Type -```bash -# Create a new migration -sqlx migrate add create_users_table +```rust +#[derive(sqlx::Type, Debug, PartialEq)] +#[repr(i16)] +enum Priority { + Low = 0, + Medium = 1, + High = 2, +} -# Run pending migrations -sqlx migrate run +#[derive(sqlx::FromRow, Debug)] +struct Task { + id: i32, + title: String, + priority: Priority, +} -# Revert the last migration -sqlx migrate revert +let task: Task = sqlx::query_as("SELECT id, title, priority FROM tasks WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; ``` -**Programmatic usage:** +--- + +## QueryBuilder + +`QueryBuilder` generates MSSQL-style parameter placeholders (`@p1`, `@p2`, ...) automatically: ```rust -sqlx::migrate!("./migrations") - .run(&pool) +use sqlx::QueryBuilder; +use sqlx::mssql::Mssql; + +let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users WHERE "); +qb.push("name = ").push_bind("Alice"); +qb.push(" AND age > ").push_bind(21i32); +// Produces: SELECT * FROM users WHERE name = @p1 AND age > @p2 + +let users = qb.build_query_as::() + .fetch_all(&pool) .await?; ``` -**Database lifecycle:** +### Dynamic WHERE Clauses -- `create_database(url)` — Creates a database via `CREATE DATABASE [name]` -- `database_exists(url)` — Checks existence via `DB_ID()` -- `drop_database(url)` — Drops with `ALTER DATABASE SET SINGLE_USER WITH ROLLBACK IMMEDIATE` for cleanup +```rust +let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users WHERE 1=1"); -**No-transaction migrations** are supported for DDL operations that cannot run inside a transaction. +if let Some(name) = filter_name { + qb.push(" AND name = ").push_bind(name); +} +if let Some(min_age) = filter_min_age { + qb.push(" AND age >= ").push_bind(min_age); +} + +let results = qb.build_query_as::().fetch_all(&pool).await?; +``` + +### Reset and Rebuild + +```rust +let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); +let query = qb.build(); +// ... use query ... + +// Reset to build a new query with the same builder +qb.reset(); +qb.push("SELECT COUNT(*) FROM users"); +let count_query = qb.build(); +``` --- ## Transactions & Isolation Levels -Standard transaction support with configurable isolation levels. +### Basic Transactions ```rust let mut tx = pool.begin().await?; @@ -328,7 +826,43 @@ sqlx::query("INSERT INTO users (name) VALUES (@p1)") .execute(&mut *tx) .await?; +sqlx::query("INSERT INTO audit_log (action) VALUES (@p1)") + .bind("user_created") + .execute(&mut *tx) + .await?; + tx.commit().await?; +// Or: tx.rollback().await?; +``` + +### Nested Transactions (Savepoints) + +Calling `begin()` on an existing transaction creates a savepoint: + +```rust +let mut tx = pool.begin().await?; + +sqlx::query("INSERT INTO users (id, name) VALUES (@p1, @p2)") + .bind(1i32) + .bind("Alice") + .execute(&mut *tx) + .await?; + +// Nested transaction — creates a savepoint +let mut savepoint = tx.begin().await?; + +sqlx::query("INSERT INTO users (id, name) VALUES (@p1, @p2)") + .bind(2i32) + .bind("Bob") + .execute(&mut *savepoint) + .await?; + +// Roll back only the inner transaction +savepoint.rollback().await?; +// Bob's insert is undone, but Alice's remains + +tx.commit().await?; +// Alice is committed, Bob is not ``` ### Isolation Levels @@ -341,14 +875,64 @@ tx.commit().await?; | `Snapshot` | Row versioning-based isolation | | `Serializable` | Strictest isolation | +> **Important:** `begin_with_isolation` is a method on `MssqlConnection`, not on `Pool`. You must acquire a connection first: + ```rust use sqlx::mssql::MssqlIsolationLevel; -let mut tx = pool +let mut conn = pool.acquire().await?; +let mut tx = conn .begin_with_isolation(MssqlIsolationLevel::Snapshot) .await?; + +sqlx::query("SELECT * FROM accounts WHERE id = @p1") + .bind(1i32) + .fetch_one(&mut *tx) + .await?; + +tx.commit().await?; +``` + +> **Note:** `Snapshot` isolation requires the database to have `ALLOW_SNAPSHOT_ISOLATION` enabled: +> ```sql +> ALTER DATABASE [mydb] SET ALLOW_SNAPSHOT_ISOLATION ON; +> ``` + +--- + +## Migrations + +MSSQL supports the full `sqlx migrate` workflow. + +```bash +# Create a new migration +sqlx migrate add create_users_table + +# Run pending migrations +sqlx migrate run + +# Revert the last migration +sqlx migrate revert +``` + +**Programmatic usage:** + +```rust +sqlx::migrate!("./migrations") + .run(&pool) + .await?; ``` +**Database lifecycle:** + +- `create_database(url)` — Creates a database via `CREATE DATABASE [name]` +- `database_exists(url)` — Checks existence via `DB_ID()` +- `drop_database(url)` — Drops with `ALTER DATABASE SET SINGLE_USER WITH ROLLBACK IMMEDIATE` for cleanup + +**No-transaction migrations** are supported for DDL operations that cannot run inside a transaction. + +Migration files use standard SQL Server syntax. Use bracket-quoted identifiers (`[schema].[table]`) for schema-qualified objects. + --- ## Advisory Locks @@ -374,47 +958,52 @@ let lock = MssqlAdvisoryLock::new("my_resource"); // Or with a specific mode let lock = MssqlAdvisoryLock::with_mode("my_resource", MssqlAdvisoryLockMode::Shared); -// RAII guard (preferred) — lock released when guard is dropped +// RAII guard — acquire and release let guard = lock.acquire_guard(&mut conn).await?; // ... do work while lock is held ... let conn = guard.release_now().await?; // explicit release // Non-blocking attempt -if let Some(guard) = lock.try_acquire_guard(&mut conn).await? { - // lock acquired +match lock.try_acquire_guard(&mut conn).await? { + either::Either::Left(guard) => { + // Lock acquired + let conn = guard.release_now().await?; + } + either::Either::Right(conn) => { + // Lock not available + } } + +// Manual acquire/release (without guard) +lock.acquire(&mut conn).await?; +// ... do work ... +lock.release(&mut conn).await?; ``` +> **Warning:** Unlike PostgreSQL advisory locks, MSSQL advisory lock guards do **NOT** auto-release on drop. If you drop the guard without calling `release_now()` or `leak()`, a warning is logged and the lock remains held until the connection is closed or returned to the pool. Always call `release_now()` explicitly. + --- ## Bulk Insert -High-performance data loading via the TDS `INSERT BULK` protocol. +High-performance data loading via the TDS `INSERT BULK` protocol. The target table must already exist. ```rust +use sqlx::mssql::IntoRow; + let mut bulk = conn.bulk_insert("my_table").await?; -for item in &data { - bulk.send(tiberius::IntoRow::into_row(item)).await?; -} +bulk.send(("Alice", 30_i32).into_row()).await?; +bulk.send(("Bob", 25_i32).into_row()).await?; +bulk.send(("Carol", 28_i32).into_row()).await?; let rows_affected = bulk.finalize().await?; +assert_eq!(rows_affected, 3); ``` -Supports tuples up to 10 elements via `tiberius::IntoRow`. - ---- - -## QueryBuilder +> **Important:** You **must** call `finalize()` to flush buffered data. If the `MssqlBulkInsert` is dropped without calling `finalize()`, buffered rows are lost. -MSSQL uses `@p1`, `@p2`, etc. as parameter placeholders. The `QueryBuilder` handles this automatically: - -```rust -let mut qb = QueryBuilder::::new("SELECT * FROM users WHERE "); -qb.push("name = ").push_bind("Alice"); -qb.push(" AND age > ").push_bind(21); -// Produces: SELECT * FROM users WHERE name = @p1 AND age > @p2 -``` +Tuple elements map to table columns in order. Tuples up to **10 elements** are supported via `tiberius::IntoRow`. --- @@ -439,14 +1028,117 @@ let result: MssqlXml = sqlx::query_scalar("SELECT content FROM docs") --- -## Examples +## Error Handling + +### Error Types + +All SQLx operations return `sqlx::Error`. For database-specific errors, downcast to `MssqlDatabaseError`: + +```rust +use sqlx::error::ErrorKind; + +let result = sqlx::query("INSERT INTO users (id, name) VALUES (@p1, @p2)") + .bind(1i32) + .bind("Alice") + .execute(&pool) + .await; + +match result { + Ok(r) => println!("Inserted {} rows", r.rows_affected()), + Err(sqlx::Error::Database(db_err)) => { + // Classify the error + match db_err.kind() { + ErrorKind::UniqueViolation => { + println!("Duplicate key: {}", db_err.message()); + } + ErrorKind::ForeignKeyViolation => { + println!("Foreign key constraint failed"); + } + ErrorKind::NotNullViolation => { + println!("Required field is null"); + } + ErrorKind::CheckViolation => { + println!("Check constraint failed"); + } + _ => { + println!("Database error: {}", db_err.message()); + } + } + } + Err(e) => println!("Other error: {}", e), +} +``` + +### MssqlDatabaseError Fields -A full CRUD Todo application is available at `examples/mssql/todos/`, demonstrating: +When you need SQL Server-specific error details, downcast further: -- Connection pooling -- Migrations -- Query execution -- Error handling +```rust +use sqlx::mssql::MssqlDatabaseError; + +if let sqlx::Error::Database(db_err) = &err { + if let Some(mssql_err) = db_err.try_downcast_ref::() { + println!("Error number: {}", mssql_err.number()); // SQL Server error number + println!("State: {}", mssql_err.state()); // Error state + println!("Class: {}", mssql_err.class()); // Severity class + println!("Message: {}", mssql_err.message()); // Error message + println!("Server: {:?}", mssql_err.server()); // Server name (Option) + println!("Procedure: {:?}", mssql_err.procedure()); // Stored procedure name (Option) + } +} +``` + +### ErrorKind Mapping + +| SQL Server Error Number | ErrorKind | +|------------------------|-----------| +| 2601, 2627 | `UniqueViolation` | +| 547 | `ForeignKeyViolation` | +| 515 | `NotNullViolation` | +| 2628 | `CheckViolation` | +| All others | `Other` | + +### Connection Recovery + +Connections remain usable after query errors: + +```rust +// This query fails +let result = sqlx::query("SELECT * FROM nonexistent_table") + .execute(&mut conn) + .await; +assert!(result.is_err()); + +// Connection is still valid +let val: (i32,) = sqlx::query_as("SELECT 42") + .fetch_one(&mut conn) + .await?; +``` + +--- + +## Any Driver Support + +MSSQL is fully integrated with the `Any` runtime-polymorphic driver, enabled via the `any` feature flag. + +```rust +use sqlx::any::AnyPool; + +// Connects to whichever database the URL points to +let pool = AnyPool::connect("mssql://sa:password@localhost/mydb").await?; + +let rows = sqlx::query("SELECT 1 + 1 AS result") + .fetch_all(&pool) + .await?; +``` + +All standard operations work through `Any`: queries, transactions, ping, close, and prepared statements. + +--- + +## Examples + +A full CRUD Todo application is available at `examples/mssql/todos/`, demonstrating connection pooling, migrations, query execution, and error handling. --- @@ -493,22 +1185,3 @@ Comprehensive test suite in `tests/mssql/`: | Query builder | `query_builder.rs` | Dynamic query construction, parameter handling | | Error handling | `error.rs` | Database error inspection, error details | | Compile-time macros | `tests/mssql-macros/` | Online and offline macro verification | - ---- - -## Feature Flags - -| Feature | Description | -|---------|-------------| -| `mssql` | Enable the MSSQL driver | -| `any` | Enable runtime-polymorphic `Any` driver | -| `migrate` | Enable database migrations | -| `json` | JSON type support via `serde_json` | -| `uuid` | `uuid::Uuid` type support | -| `chrono` | `chrono` datetime types | -| `time` | `time` crate datetime types | -| `rust_decimal` | `rust_decimal::Decimal` support | -| `bigdecimal` | `bigdecimal::BigDecimal` support | -| `winauth` | Windows/NTLM authentication | -| `integrated-auth-gssapi` | Integrated auth (Kerberos on Unix, SSPI on Windows) | -| `offline` | Offline mode for compile-time macros | From 3ed869d0b8a5c5da83068bb5757455527831e57f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Mon, 23 Mar 2026 23:16:19 -0500 Subject: [PATCH 31/33] fix: address code review findings across MSSQL driver Return Result from build_url(), document expect() invariants for UTC offsets and epoch dates, safe UTF-8 truncation, lossless numeric casts, parameterized sp_describe calls, and resolve all clippy deny violations. Author: Pablo Carrera --- examples/mssql/todos/src/main.rs | 10 +- sqlx-macros-core/src/database/mod.rs | 7 +- sqlx-mssql/src/any.rs | 5 +- sqlx-mssql/src/arguments.rs | 6 +- sqlx-mssql/src/connection/executor.rs | 128 ++++++++++++-------------- sqlx-mssql/src/error.rs | 18 +++- sqlx-mssql/src/migrate.rs | 17 ++-- sqlx-mssql/src/options/connect.rs | 1 + sqlx-mssql/src/options/mod.rs | 20 +++- sqlx-mssql/src/options/parse.rs | 77 +++++++++------- sqlx-mssql/src/testing/mod.rs | 11 ++- sqlx-mssql/src/transaction.rs | 12 +-- sqlx-mssql/src/types/bigdecimal.rs | 10 +- sqlx-mssql/src/types/bool.rs | 10 +- sqlx-mssql/src/types/bytes.rs | 15 +-- sqlx-mssql/src/types/chrono.rs | 44 +++------ sqlx-mssql/src/types/float.rs | 10 +- sqlx-mssql/src/types/int.rs | 30 ++---- sqlx-mssql/src/types/json.rs | 5 +- sqlx-mssql/src/types/rust_decimal.rs | 10 +- sqlx-mssql/src/types/str.rs | 10 +- sqlx-mssql/src/types/time.rs | 34 ++----- sqlx-mssql/src/types/uuid.rs | 10 +- sqlx-mssql/src/types/xml.rs | 5 +- sqlx-mssql/src/value.rs | 106 +++++++++++---------- sqlx-test/src/lib.rs | 2 +- tests/any/any.rs | 2 +- tests/mssql/advisory-lock.rs | 10 +- tests/mssql/bulk-insert.rs | 8 +- tests/mssql/error.rs | 7 +- tests/mssql/isolation-level.rs | 12 +-- tests/mssql/migrate.rs | 8 +- tests/mssql/mssql.rs | 29 +++--- tests/mssql/test-attr.rs | 26 ++---- tests/mssql/types.rs | 2 +- 35 files changed, 323 insertions(+), 394 deletions(-) diff --git a/examples/mssql/todos/src/main.rs b/examples/mssql/todos/src/main.rs index c3c355b393..27bf395c04 100644 --- a/examples/mssql/todos/src/main.rs +++ b/examples/mssql/todos/src/main.rs @@ -45,12 +45,10 @@ async fn main() -> anyhow::Result<()> { async fn add_todo(pool: &MssqlPool, description: String) -> anyhow::Result { // MSSQL uses OUTPUT INSERTED instead of RETURNING - let rec = sqlx::query( - "INSERT INTO todos (description) OUTPUT INSERTED.id VALUES (@p1)", - ) - .bind(&description) - .fetch_one(pool) - .await?; + let rec = sqlx::query("INSERT INTO todos (description) OUTPUT INSERTED.id VALUES (@p1)") + .bind(&description) + .fetch_one(pool) + .await?; Ok(rec.get::("id")) } diff --git a/sqlx-macros-core/src/database/mod.rs b/sqlx-macros-core/src/database/mod.rs index c108b70d50..f6f864d95b 100644 --- a/sqlx-macros-core/src/database/mod.rs +++ b/sqlx-macros-core/src/database/mod.rs @@ -10,7 +10,12 @@ use std::collections::hash_map; use std::collections::HashMap; use std::sync::{LazyLock, Mutex}; -#[cfg(any(feature = "postgres", feature = "mysql", feature = "mssql", feature = "_sqlite"))] +#[cfg(any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "_sqlite" +))] mod impls; pub trait DatabaseExt: Database + TypeChecking { diff --git a/sqlx-mssql/src/any.rs b/sqlx-mssql/src/any.rs index a1f6ace02a..218bac33fc 100644 --- a/sqlx-mssql/src/any.rs +++ b/sqlx-mssql/src/any.rs @@ -179,8 +179,9 @@ impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo { AnyTypeInfoKind::Text } "UNIQUEIDENTIFIER" => AnyTypeInfoKind::Text, - "DATE" | "TIME" | "DATETIME" | "DATETIME2" | "SMALLDATETIME" - | "DATETIMEOFFSET" => AnyTypeInfoKind::Text, + "DATE" | "TIME" | "DATETIME" | "DATETIME2" | "SMALLDATETIME" | "DATETIMEOFFSET" => { + AnyTypeInfoKind::Text + } _ => { return Err(sqlx_core::Error::AnyDriverError( format!("Any driver does not support MSSQL type {type_info:?}").into(), diff --git a/sqlx-mssql/src/arguments.rs b/sqlx-mssql/src/arguments.rs index 7c438f5c1b..9b71e8e20a 100644 --- a/sqlx-mssql/src/arguments.rs +++ b/sqlx-mssql/src/arguments.rs @@ -21,7 +21,11 @@ impl MssqlArguments { let is_null = value.encode(&mut self.values)?; if is_null.is_null() { // If the encoder signaled null but didn't push a value, push a Null - if self.values.last().is_none_or(|v| !matches!(v, MssqlArgumentValue::Null)) { + if self + .values + .last() + .is_none_or(|v| !matches!(v, MssqlArgumentValue::Null)) + { self.values.push(MssqlArgumentValue::Null); } } diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs index b6c8fdacc5..254086f4da 100644 --- a/sqlx-mssql/src/connection/executor.rs +++ b/sqlx-mssql/src/connection/executor.rs @@ -7,9 +7,7 @@ use crate::statement::{MssqlStatement, MssqlStatementMetadata}; use crate::type_info::{type_name_for_tiberius, MssqlTypeInfo}; use crate::value::{column_data_to_mssql_data, MssqlData}; use crate::HashMap; -use crate::{ - Mssql, MssqlArguments, MssqlColumn, MssqlConnection, MssqlQueryResult, MssqlRow, -}; +use crate::{Mssql, MssqlArguments, MssqlColumn, MssqlConnection, MssqlQueryResult, MssqlRow}; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -99,10 +97,7 @@ fn bigdecimal_to_numeric(v: &bigdecimal::BigDecimal) -> Result<(i128, u8), Error if exponent > 37 { return Err(Error::Encode( - format!( - "BigDecimal scale {exponent} exceeds SQL Server maximum of 37" - ) - .into(), + format!("BigDecimal scale {exponent} exceeds SQL Server maximum of 37").into(), )); } // SAFETY: guarded by `exponent > 37` check above; 0..=37 fits in u8. @@ -110,9 +105,7 @@ fn bigdecimal_to_numeric(v: &bigdecimal::BigDecimal) -> Result<(i128, u8), Error let scale = exponent as u8; let value: i128 = bigint.to_i128().ok_or_else(|| { - Error::Encode( - format!("BigDecimal value too large for SQL NUMERIC: {v}").into(), - ) + Error::Encode(format!("BigDecimal value too large for SQL NUMERIC: {v}").into()) })?; Ok((value, scale)) @@ -189,17 +182,15 @@ impl MssqlConnection { #[cfg(feature = "chrono")] MssqlArgumentValue::DateTimeFixedOffset(v) => { use chrono::Timelike as _; - // Year 1 is always a valid date - let epoch = chrono::NaiveDate::from_ymd_opt(1, 1, 1).unwrap(); + let epoch = chrono::NaiveDate::from_ymd_opt(1, 1, 1) + .expect("epoch 0001-01-01 is always valid"); let naive = v.naive_local(); let days = days_since_epoch_to_u32((naive.date() - epoch).num_days())?; let time = naive.time(); - let total_ns = u64::from(time.num_seconds_from_midnight()) - * 1_000_000_000 + let total_ns = u64::from(time.num_seconds_from_midnight()) * 1_000_000_000 + (u64::from(time.nanosecond()) % 1_000_000_000); let increments = total_ns / 100; - let offset_minutes = - v.offset().local_minus_utc() / 60; + let offset_minutes = v.offset().local_minus_utc() / 60; let dt2 = tiberius::time::DateTime2::new( tiberius::time::Date::new(days), tiberius::time::Time::new(increments, 7), @@ -223,32 +214,30 @@ impl MssqlConnection { #[allow(clippy::cast_possible_wrap)] let mut value = (((unpacked.hi as u128) << 64) + ((unpacked.mid as u128) << 32) - + unpacked.lo as u128) - as i128; + + unpacked.lo as u128) as i128; if v.is_sign_negative() { value = -value; } let scale = v.scale(); if scale > 37 { return Err(Error::Encode( - format!("rust_decimal scale {scale} exceeds SQL Server maximum of 37").into(), + format!( + "rust_decimal scale {scale} exceeds SQL Server maximum of 37" + ) + .into(), )); } // SAFETY: guarded by `scale > 37` check above; 0..=37 fits in u8. #[allow(clippy::cast_possible_truncation)] let scale_u8 = scale as u8; - query.bind(tiberius::numeric::Numeric::new_with_scale( - value, - scale_u8, - )); + query.bind(tiberius::numeric::Numeric::new_with_scale(value, scale_u8)); } #[cfg(feature = "time")] MssqlArgumentValue::TimeDate(v) => { - let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); + let epoch = time::Date::from_ordinal_date(1, 1) + .expect("epoch 0001-01-01 is always valid"); let days = days_since_epoch_to_u32((*v - epoch).whole_days())?; - let cd = tiberius::ColumnData::Date(Some( - tiberius::time::Date::new(days), - )); + let cd = tiberius::ColumnData::Date(Some(tiberius::time::Date::new(days))); query.bind(ColumnDataWrapper(cd)); } #[cfg(feature = "time")] @@ -260,16 +249,17 @@ impl MssqlConnection { + u64::from(ns); // Scale 7 = 100ns increments let increments = total_ns / 100; - let cd = tiberius::ColumnData::Time(Some( - tiberius::time::Time::new(increments, 7), - )); + let cd = tiberius::ColumnData::Time(Some(tiberius::time::Time::new( + increments, 7, + ))); query.bind(ColumnDataWrapper(cd)); } #[cfg(feature = "time")] MssqlArgumentValue::TimePrimitiveDateTime(v) => { let date = v.date(); let time = v.time(); - let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); + let epoch = time::Date::from_ordinal_date(1, 1) + .expect("epoch 0001-01-01 is always valid"); let days = days_since_epoch_to_u32((date - epoch).whole_days())?; let (h, m, s, ns) = time.as_hms_nano(); let total_ns = u64::from(h) * 3_600_000_000_000 @@ -277,17 +267,17 @@ impl MssqlConnection { + u64::from(s) * 1_000_000_000 + u64::from(ns); let increments = total_ns / 100; - let cd = tiberius::ColumnData::DateTime2(Some( - tiberius::time::DateTime2::new( + let cd = + tiberius::ColumnData::DateTime2(Some(tiberius::time::DateTime2::new( tiberius::time::Date::new(days), tiberius::time::Time::new(increments, 7), - ), - )); + ))); query.bind(ColumnDataWrapper(cd)); } #[cfg(feature = "time")] MssqlArgumentValue::TimeOffsetDateTime(v) => { - let epoch = time::Date::from_ordinal_date(1, 1).unwrap(); + let epoch = time::Date::from_ordinal_date(1, 1) + .expect("epoch 0001-01-01 is always valid"); let offset_minutes = v.offset().whole_seconds() / 60; let date = v.date(); let time = v.time(); @@ -321,7 +311,10 @@ impl MssqlConnection { } } - let stream = query.query(&mut self.inner.client).await.map_err(tiberius_err)?; + let stream = query + .query(&mut self.inner.client) + .await + .map_err(tiberius_err)?; collect_results(stream, &mut results, &mut logger).await?; } else { // Simple query (no parameters) @@ -380,12 +373,12 @@ async fn collect_results( column_names = Some(Arc::new(names)); } tiberius::QueryItem::Row(row) => { - let cols = columns.as_ref().ok_or_else(|| { - Error::Protocol("row received before metadata".into()) - })?; - let names = column_names.as_ref().ok_or_else(|| { - Error::Protocol("row received before metadata".into()) - })?; + let cols = columns + .as_ref() + .ok_or_else(|| Error::Protocol("row received before metadata".into()))?; + let names = column_names + .as_ref() + .ok_or_else(|| Error::Protocol("row received before metadata".into()))?; // Convert tiberius row to MssqlRow by iterating over cells let values: Vec = row @@ -480,19 +473,18 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { let _persistent = query.persistent(); let sql = query.sql(); - Box::pin(futures_util::stream::once(async move { - let arguments = arguments?; - let results = self.run(sql.as_str(), arguments).await?; - Ok::<_, Error>(results) - }) - .map_ok(|results| futures_util::stream::iter(results.into_iter().map(Ok))) - .try_flatten()) + Box::pin( + futures_util::stream::once(async move { + let arguments = arguments?; + let results = self.run(sql.as_str(), arguments).await?; + Ok::<_, Error>(results) + }) + .map_ok(|results| futures_util::stream::iter(results.into_iter().map(Ok))) + .try_flatten(), + ) } - fn fetch_optional<'e, 'q, E>( - self, - query: E, - ) -> BoxFuture<'e, Result, Error>> + fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result, Error>> where 'c: 'e, E: Execute<'q, Self::Database>, @@ -521,9 +513,8 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { 'c: 'e, { Box::pin(async move { - let mut describe_query = tiberius::Query::new( - "EXEC sp_describe_first_result_set @tsql = @p1", - ); + let mut describe_query = + tiberius::Query::new("EXEC sp_describe_first_result_set @tsql = @p1"); describe_query.bind(sql.as_str()); let stream = describe_query @@ -531,7 +522,8 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { .await .map_err(tiberius_err)?; - let rows: Vec = stream.into_first_result().await.map_err(tiberius_err)?; + let rows: Vec = + stream.into_first_result().await.map_err(tiberius_err)?; let (columns, column_names, _nullable) = build_columns_from_describe_rows(&rows); Ok(MssqlStatement { @@ -555,9 +547,8 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { 'c: 'e, { Box::pin(async move { - let mut describe_query = tiberius::Query::new( - "EXEC sp_describe_first_result_set @tsql = @p1", - ); + let mut describe_query = + tiberius::Query::new("EXEC sp_describe_first_result_set @tsql = @p1"); describe_query.bind(sql.as_str()); let stream = describe_query @@ -571,14 +562,10 @@ impl<'c> Executor<'c> for &'c mut MssqlConnection { let (columns, _column_names, nullable) = build_columns_from_describe_rows(&rows); // Count parameters using sp_describe_undeclared_parameters - let mut param_query = tiberius::Query::new( - "EXEC sp_describe_undeclared_parameters @tsql = @p1", - ); + let mut param_query = + tiberius::Query::new("EXEC sp_describe_undeclared_parameters @tsql = @p1"); param_query.bind(sql.as_str()); - let param_count = match param_query - .query(&mut self.inner.client) - .await - { + let param_count = match param_query.query(&mut self.inner.client).await { Ok(stream) => stream .into_first_result() .await @@ -629,7 +616,10 @@ mod tests { #[test] fn days_since_epoch_at_max() { - assert_eq!(days_since_epoch_to_u32(i64::from(MAX_DAYS)).unwrap(), MAX_DAYS); + assert_eq!( + days_since_epoch_to_u32(i64::from(MAX_DAYS)).unwrap(), + MAX_DAYS + ); } #[test] diff --git a/sqlx-mssql/src/error.rs b/sqlx-mssql/src/error.rs index dcfcae3718..f61fd968a7 100644 --- a/sqlx-mssql/src/error.rs +++ b/sqlx-mssql/src/error.rs @@ -54,7 +54,11 @@ impl Debug for MssqlDatabaseError { impl Display for MssqlDatabaseError { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!(f, "(number {}, state {}): {}", self.number, self.state, self.message) + write!( + f, + "(number {}, state {}): {}", + self.number, self.state, self.message + ) } } @@ -111,11 +115,19 @@ pub(crate) fn tiberius_err(err: tiberius::error::Error) -> Error { message: token_error.message().to_string(), server: { let s = token_error.server(); - if s.is_empty() { None } else { Some(s.to_string()) } + if s.is_empty() { + None + } else { + Some(s.to_string()) + } }, procedure: { let s = token_error.procedure(); - if s.is_empty() { None } else { Some(s.to_string()) } + if s.is_empty() { + None + } else { + Some(s.to_string()) + } }, })) } diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index d32eac092d..a4623e11e3 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -43,7 +43,7 @@ impl MigrateDatabase for Mssql { query( "DECLARE @sql NVARCHAR(MAX) = N'CREATE DATABASE ' + QUOTENAME(@p1); \ - EXEC sp_executesql @sql;" + EXEC sp_executesql @sql;", ) .bind(database) .execute(&mut conn) @@ -56,12 +56,11 @@ impl MigrateDatabase for Mssql { let (options, database) = parse_for_maintenance(url)?; let mut conn = options.connect().await?; - let exists: bool = query_scalar( - "SELECT CASE WHEN DB_ID(@p1) IS NOT NULL THEN 1 ELSE 0 END", - ) - .bind(database) - .fetch_one(&mut conn) - .await?; + let exists: bool = + query_scalar("SELECT CASE WHEN DB_ID(@p1) IS NOT NULL THEN 1 ELSE 0 END") + .bind(database) + .fetch_one(&mut conn) + .await?; Ok(exists) } @@ -99,7 +98,7 @@ impl Migrate for MssqlConnection { BEGIN \ DECLARE @sql NVARCHAR(MAX) = N'CREATE SCHEMA ' + QUOTENAME(@p1); \ EXEC sp_executesql @sql; \ - END" + END", ) .bind(schema_name) .execute(&mut *self) @@ -199,7 +198,7 @@ impl Migrate for MssqlConnection { Box::pin(async move { let _ = self .execute( - "EXEC sp_releaseapplock @Resource = 'sqlx_migrations', @LockOwner = 'Session'" + "EXEC sp_releaseapplock @Resource = 'sqlx_migrations', @LockOwner = 'Session'", ) .await?; diff --git a/sqlx-mssql/src/options/connect.rs b/sqlx-mssql/src/options/connect.rs index f8c6dc04af..0a4cb94809 100644 --- a/sqlx-mssql/src/options/connect.rs +++ b/sqlx-mssql/src/options/connect.rs @@ -14,6 +14,7 @@ impl ConnectOptions for MssqlConnectOptions { fn to_url_lossy(&self) -> Url { self.build_url() + .expect("BUG: MssqlConnectOptions generated an un-parseable URL") } async fn connect(&self) -> Result diff --git a/sqlx-mssql/src/options/mod.rs b/sqlx-mssql/src/options/mod.rs index 9fb549cc33..e1126dbdcc 100644 --- a/sqlx-mssql/src/options/mod.rs +++ b/sqlx-mssql/src/options/mod.rs @@ -71,7 +71,10 @@ pub struct MssqlConnectOptions { #[cfg(all(windows, feature = "winauth"))] pub(crate) windows_auth: bool, /// When `true`, use integrated authentication (SSPI on Windows / Kerberos on Unix). - #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] pub(crate) integrated_auth: bool, /// Azure AD bearer token for AAD authentication. pub(crate) aad_token: Option, @@ -102,7 +105,10 @@ impl MssqlConnectOptions { log_settings: Default::default(), #[cfg(all(windows, feature = "winauth"))] windows_auth: false, - #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] integrated_auth: false, aad_token: None, } @@ -219,7 +225,10 @@ impl MssqlConnectOptions { } /// Sets whether to use integrated authentication (SSPI on Windows / Kerberos on Unix). - #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] pub fn integrated_auth(mut self, enabled: bool) -> Self { self.integrated_auth = enabled; self @@ -275,7 +284,10 @@ impl MssqlConnectOptions { #[allow(unused_mut)] let mut handled = false; - #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] if !handled && self.integrated_auth { config.authentication(tiberius::AuthMethod::Integrated); handled = true; diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs index 9ac76eb5e3..8eedb359b2 100644 --- a/sqlx-mssql/src/options/parse.rs +++ b/sqlx-mssql/src/options/parse.rs @@ -63,13 +63,12 @@ impl MssqlConnectOptions { } "encrypt" => { - options = options - .encrypt(value.parse().map_err(Error::config)?); + options = options.encrypt(value.parse().map_err(Error::config)?); } "trust_server_certificate" | "trustServerCertificate" => { - options = options - .trust_server_certificate(value.parse().map_err(Error::config)?); + options = + options.trust_server_certificate(value.parse().map_err(Error::config)?); } "instance" => { @@ -81,25 +80,23 @@ impl MssqlConnectOptions { } "statement-cache-capacity" => { - options = options - .statement_cache_capacity(value.parse().map_err(Error::config)?); + options = + options.statement_cache_capacity(value.parse().map_err(Error::config)?); } - "application_intent" | "applicationIntent" => { - match &*value { - "read_only" | "ReadOnly" => { - options = options.application_intent_read_only(true); - } - "read_write" | "ReadWrite" => { - options = options.application_intent_read_only(false); - } - _ => { - return Err(Error::Configuration( - format!("unknown application_intent value: {value}").into(), - )) - } + "application_intent" | "applicationIntent" => match &*value { + "read_only" | "ReadOnly" => { + options = options.application_intent_read_only(true); } - } + "read_write" | "ReadWrite" => { + options = options.application_intent_read_only(false); + } + _ => { + return Err(Error::Configuration( + format!("unknown application_intent value: {value}").into(), + )) + } + }, "trust_server_certificate_ca" | "trustServerCertificateCa" => { options = options.trust_server_certificate_ca(&value); @@ -112,7 +109,10 @@ impl MssqlConnectOptions { "windows" => { options.windows_auth = true; } - #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] "integrated" => { options.integrated_auth = true; } @@ -138,12 +138,12 @@ impl MssqlConnectOptions { Ok(options) } - pub(crate) fn build_url(&self) -> Url { + pub(crate) fn build_url(&self) -> Result { let mut url = Url::parse(&format!( "mssql://{}@{}:{}", self.username, self.host, self.port )) - .expect("BUG: generated un-parseable URL"); + .map_err(|e| Error::Configuration(e.to_string().into()))?; if let Some(password) = &self.password { let _ = url.set_password(Some(password)); @@ -176,20 +176,21 @@ impl MssqlConnectOptions { .append_pair("auth", "aad_token") .append_pair("token", token); } else { - #[cfg(any(all(windows, feature = "winauth"), all(unix, feature = "integrated-auth-gssapi")))] + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] if self.integrated_auth { - url.query_pairs_mut() - .append_pair("auth", "integrated"); + url.query_pairs_mut().append_pair("auth", "integrated"); } #[cfg(all(windows, feature = "winauth"))] if self.windows_auth && !self.integrated_auth { - url.query_pairs_mut() - .append_pair("auth", "windows"); + url.query_pairs_mut().append_pair("auth", "windows"); } } - url + Ok(url) } } @@ -274,7 +275,7 @@ fn it_rejects_invalid_sslmode() { fn it_roundtrips_sslmode_in_url() { let url = "mssql://sa:password@localhost/master?sslmode=login_only"; let opts = MssqlConnectOptions::from_str(url).unwrap(); - let built = opts.build_url(); + let built = opts.build_url().unwrap(); let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); assert!(matches!(opts2.ssl_mode, MssqlSslMode::LoginOnly)); } @@ -310,7 +311,10 @@ fn it_rejects_invalid_application_intent() { fn it_parses_trust_server_certificate_ca() { let url = "mssql://sa:password@localhost/master?trust_server_certificate_ca=/path/to/ca.pem"; let opts = MssqlConnectOptions::from_str(url).unwrap(); - assert_eq!(opts.trust_server_certificate_ca, Some("/path/to/ca.pem".into())); + assert_eq!( + opts.trust_server_certificate_ca, + Some("/path/to/ca.pem".into()) + ); } #[test] @@ -320,7 +324,7 @@ fn it_roundtrips_application_intent_in_url() { .username("sa") .password("password") .application_intent_read_only(true); - let built = opts.build_url(); + let built = opts.build_url().unwrap(); let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); assert!(opts2.application_intent_read_only); } @@ -332,9 +336,12 @@ fn it_roundtrips_trust_cert_ca_in_url() { .username("sa") .password("password") .trust_server_certificate_ca("/etc/ssl/ca.pem"); - let built = opts.build_url(); + let built = opts.build_url().unwrap(); let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); - assert_eq!(opts2.trust_server_certificate_ca, Some("/etc/ssl/ca.pem".into())); + assert_eq!( + opts2.trust_server_certificate_ca, + Some("/etc/ssl/ca.pem".into()) + ); } #[test] @@ -350,7 +357,7 @@ fn it_roundtrips_aad_token_in_url() { .host("localhost") .username("sa") .aad_token("my-bearer-token"); - let built = opts.build_url(); + let built = opts.build_url().unwrap(); let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); assert_eq!(opts2.aad_token, Some("my-bearer-token".into())); } diff --git a/sqlx-mssql/src/testing/mod.rs b/sqlx-mssql/src/testing/mod.rs index d225139676..54032da0ad 100644 --- a/sqlx-mssql/src/testing/mod.rs +++ b/sqlx-mssql/src/testing/mod.rs @@ -39,10 +39,9 @@ impl TestSupport for Mssql { let mut conn = MssqlConnection::connect(&url).await?; - let delete_db_names: Vec = - query_scalar("SELECT db_name FROM _sqlx_test_databases") - .fetch_all(&mut conn) - .await?; + let delete_db_names: Vec = query_scalar("SELECT db_name FROM _sqlx_test_databases") + .fetch_all(&mut conn) + .await?; if delete_db_names.is_empty() { return Ok(None); @@ -94,7 +93,9 @@ impl TestSupport for Mssql { } async fn snapshot(_conn: &mut Self::Connection) -> Result, Error> { - Err(Error::Configuration("snapshots are not yet supported for MSSQL".into())) + Err(Error::Configuration( + "snapshots are not yet supported for MSSQL".into(), + )) } } diff --git a/sqlx-mssql/src/transaction.rs b/sqlx-mssql/src/transaction.rs index 9c35160070..b74862235c 100644 --- a/sqlx-mssql/src/transaction.rs +++ b/sqlx-mssql/src/transaction.rs @@ -33,7 +33,8 @@ impl TransactionManager for MssqlTransactionManager { if depth == 0 { SqlStr::from_static("BEGIN TRANSACTION") } else { - AssertSqlSafe(format!("SAVE TRANSACTION _sqlx_savepoint_{}", depth)).into_sql_str() + AssertSqlSafe(format!("SAVE TRANSACTION _sqlx_savepoint_{}", depth)) + .into_sql_str() } } }; @@ -66,10 +67,7 @@ impl TransactionManager for MssqlTransactionManager { if depth == 1 { conn.execute("ROLLBACK").await?; } else { - let savepoint = format!( - "ROLLBACK TRANSACTION _sqlx_savepoint_{}", - depth - 1 - ); + let savepoint = format!("ROLLBACK TRANSACTION _sqlx_savepoint_{}", depth - 1); conn.execute(AssertSqlSafe(savepoint)).await?; } conn.inner.transaction_depth = depth - 1; @@ -94,9 +92,7 @@ impl TransactionManager for MssqlTransactionManager { } /// Execute pending rollback if one was triggered by `start_rollback`. -pub(crate) async fn resolve_pending_rollback( - conn: &mut MssqlConnection, -) -> Result<(), Error> { +pub(crate) async fn resolve_pending_rollback(conn: &mut MssqlConnection) -> Result<(), Error> { if conn.inner.pending_rollback { conn.inner.pending_rollback = false; let depth = conn.inner.transaction_depth; diff --git a/sqlx-mssql/src/types/bigdecimal.rs b/sqlx-mssql/src/types/bigdecimal.rs index b2fd93c655..c0a67851ef 100644 --- a/sqlx-mssql/src/types/bigdecimal.rs +++ b/sqlx-mssql/src/types/bigdecimal.rs @@ -14,15 +14,15 @@ impl Type for BigDecimal { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY") + matches!( + ty.base_name(), + "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY" + ) } } impl Encode<'_, Mssql> for BigDecimal { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::BigDecimal(self.clone())); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/types/bool.rs b/sqlx-mssql/src/types/bool.rs index 171f961dc4..a4eb3b0904 100644 --- a/sqlx-mssql/src/types/bool.rs +++ b/sqlx-mssql/src/types/bool.rs @@ -12,15 +12,15 @@ impl Type for bool { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.base_name(), "BIT" | "TINYINT" | "INT" | "SMALLINT" | "BIGINT") + matches!( + ty.base_name(), + "BIT" | "TINYINT" | "INT" | "SMALLINT" | "BIGINT" + ) } } impl Encode<'_, Mssql> for bool { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::Bool(*self)); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/types/bytes.rs b/sqlx-mssql/src/types/bytes.rs index 2c35ec41e6..f1a2d40322 100644 --- a/sqlx-mssql/src/types/bytes.rs +++ b/sqlx-mssql/src/types/bytes.rs @@ -10,10 +10,7 @@ use crate::types::Type; use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn bytes_compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "VARBINARY" | "BINARY" | "IMAGE" - ) + matches!(ty.base_name(), "VARBINARY" | "BINARY" | "IMAGE") } impl Type for [u8] { @@ -27,10 +24,7 @@ impl Type for [u8] { } impl Encode<'_, Mssql> for &'_ [u8] { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::Binary(self.to_vec())); Ok(IsNull::No) } @@ -53,10 +47,7 @@ impl Type for Vec { } impl Encode<'_, Mssql> for Vec { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&[u8] as Encode>::encode(&**self, buf) } } diff --git a/sqlx-mssql/src/types/chrono.rs b/sqlx-mssql/src/types/chrono.rs index 93932de833..d50a7a1587 100644 --- a/sqlx-mssql/src/types/chrono.rs +++ b/sqlx-mssql/src/types/chrono.rs @@ -16,18 +16,12 @@ impl Type for NaiveDateTime { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "DATETIME2" | "DATETIME" | "SMALLDATETIME" - ) + matches!(ty.base_name(), "DATETIME2" | "DATETIME" | "SMALLDATETIME") } } impl Encode<'_, Mssql> for NaiveDateTime { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::NaiveDateTime(*self)); Ok(IsNull::No) } @@ -57,10 +51,7 @@ impl Type for NaiveDate { } impl Encode<'_, Mssql> for NaiveDate { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::NaiveDate(*self)); Ok(IsNull::No) } @@ -91,10 +82,7 @@ impl Type for NaiveTime { } impl Encode<'_, Mssql> for NaiveTime { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::NaiveTime(*self)); Ok(IsNull::No) } @@ -119,18 +107,12 @@ impl Type for DateTime { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "DATETIME2" | "DATETIMEOFFSET" - ) + matches!(ty.base_name(), "DATETIME2" | "DATETIMEOFFSET") } } impl Encode<'_, Mssql> for DateTime { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::NaiveDateTime(self.naive_utc())); Ok(IsNull::No) } @@ -155,18 +137,12 @@ impl Type for DateTime { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "DATETIMEOFFSET" | "DATETIME2" - ) + matches!(ty.base_name(), "DATETIMEOFFSET" | "DATETIME2") } } impl Encode<'_, Mssql> for DateTime { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::DateTimeFixedOffset(*self)); Ok(IsNull::No) } @@ -179,7 +155,9 @@ impl Decode<'_, Mssql> for DateTime { MssqlData::NaiveDateTime(v) => { // Assume UTC if no offset information let utc = v.and_utc(); - Ok(utc.with_timezone(&FixedOffset::east_opt(0).unwrap())) + Ok(utc.with_timezone( + &FixedOffset::east_opt(0).expect("UTC offset 0 is always valid"), + )) } MssqlData::Null => Err("unexpected NULL".into()), _ => Err(format!("expected datetimeoffset, got {:?}", value.data).into()), diff --git a/sqlx-mssql/src/types/float.rs b/sqlx-mssql/src/types/float.rs index df1a5534f3..98114e3836 100644 --- a/sqlx-mssql/src/types/float.rs +++ b/sqlx-mssql/src/types/float.rs @@ -21,10 +21,7 @@ impl Type for f32 { } impl Encode<'_, Mssql> for f32 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::F32(*self)); Ok(IsNull::No) } @@ -53,10 +50,7 @@ impl Type for f64 { } impl Encode<'_, Mssql> for f64 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::F64(*self)); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/types/int.rs b/sqlx-mssql/src/types/int.rs index 8de3b73749..ffe4e9c2da 100644 --- a/sqlx-mssql/src/types/int.rs +++ b/sqlx-mssql/src/types/int.rs @@ -7,10 +7,7 @@ use crate::value::MssqlData; use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; fn int_compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "TINYINT" | "SMALLINT" | "INT" | "BIGINT" - ) + matches!(ty.base_name(), "TINYINT" | "SMALLINT" | "INT" | "BIGINT") } // u8 - MSSQL's TINYINT is unsigned (0-255) @@ -25,10 +22,7 @@ impl Type for u8 { } impl Encode<'_, Mssql> for u8 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::U8(*self)); Ok(IsNull::No) } @@ -59,10 +53,7 @@ impl Type for i8 { } impl Encode<'_, Mssql> for i8 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { if *self < 0 { return Err("MSSQL TINYINT is unsigned; cannot encode negative i8".into()); } @@ -97,10 +88,7 @@ impl Type for i16 { } impl Encode<'_, Mssql> for i16 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::I16(*self)); Ok(IsNull::No) } @@ -131,10 +119,7 @@ impl Type for i32 { } impl Encode<'_, Mssql> for i32 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::I32(*self)); Ok(IsNull::No) } @@ -165,10 +150,7 @@ impl Type for i64 { } impl Encode<'_, Mssql> for i64 { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::I64(*self)); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/types/json.rs b/sqlx-mssql/src/types/json.rs index 5dd5b76e09..4557de0b06 100644 --- a/sqlx-mssql/src/types/json.rs +++ b/sqlx-mssql/src/types/json.rs @@ -22,10 +22,7 @@ impl Encode<'_, Mssql> for Json where T: Serialize, { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { let json_string = self.encode_to_string()?; buf.push(MssqlArgumentValue::String(json_string)); Ok(IsNull::No) diff --git a/sqlx-mssql/src/types/rust_decimal.rs b/sqlx-mssql/src/types/rust_decimal.rs index e71d96aae1..c942f549f7 100644 --- a/sqlx-mssql/src/types/rust_decimal.rs +++ b/sqlx-mssql/src/types/rust_decimal.rs @@ -14,15 +14,15 @@ impl Type for Decimal { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!(ty.base_name(), "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY") + matches!( + ty.base_name(), + "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY" + ) } } impl Encode<'_, Mssql> for Decimal { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::Decimal(*self)); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/types/str.rs b/sqlx-mssql/src/types/str.rs index 8694bc4ff8..816a516167 100644 --- a/sqlx-mssql/src/types/str.rs +++ b/sqlx-mssql/src/types/str.rs @@ -27,10 +27,7 @@ impl Type for str { } impl Encode<'_, Mssql> for &'_ str { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::String((*self).to_owned())); Ok(IsNull::No) } @@ -53,10 +50,7 @@ impl Type for String { } impl Encode<'_, Mssql> for String { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { <&str as Encode>::encode(self.as_str(), buf) } } diff --git a/sqlx-mssql/src/types/time.rs b/sqlx-mssql/src/types/time.rs index ad420b3d87..e86f6a3061 100644 --- a/sqlx-mssql/src/types/time.rs +++ b/sqlx-mssql/src/types/time.rs @@ -21,10 +21,7 @@ impl Type for Date { } impl Encode<'_, Mssql> for Date { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::TimeDate(*self)); Ok(IsNull::No) } @@ -54,10 +51,7 @@ impl Type for Time { } impl Encode<'_, Mssql> for Time { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::TimeTime(*self)); Ok(IsNull::No) } @@ -82,18 +76,12 @@ impl Type for PrimitiveDateTime { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "DATETIME2" | "DATETIME" | "SMALLDATETIME" - ) + matches!(ty.base_name(), "DATETIME2" | "DATETIME" | "SMALLDATETIME") } } impl Encode<'_, Mssql> for PrimitiveDateTime { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::TimePrimitiveDateTime(*self)); Ok(IsNull::No) } @@ -117,18 +105,12 @@ impl Type for OffsetDateTime { } fn compatible(ty: &MssqlTypeInfo) -> bool { - matches!( - ty.base_name(), - "DATETIMEOFFSET" | "DATETIME2" - ) + matches!(ty.base_name(), "DATETIMEOFFSET" | "DATETIME2") } } impl Encode<'_, Mssql> for OffsetDateTime { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::TimeOffsetDateTime(*self)); Ok(IsNull::No) } @@ -138,9 +120,7 @@ impl Decode<'_, Mssql> for OffsetDateTime { fn decode(value: MssqlValueRef<'_>) -> Result { match value.data { MssqlData::TimeOffsetDateTime(v) => Ok(*v), - MssqlData::TimePrimitiveDateTime(v) => { - Ok(v.assume_utc()) - } + MssqlData::TimePrimitiveDateTime(v) => Ok(v.assume_utc()), MssqlData::Null => Err("unexpected NULL".into()), _ => Err(format!("expected datetimeoffset, got {:?}", value.data).into()), } diff --git a/sqlx-mssql/src/types/uuid.rs b/sqlx-mssql/src/types/uuid.rs index 6d315b0ec6..e382fe6a3f 100644 --- a/sqlx-mssql/src/types/uuid.rs +++ b/sqlx-mssql/src/types/uuid.rs @@ -19,10 +19,7 @@ impl Type for Uuid { } impl Encode<'_, Mssql> for Uuid { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::Uuid(*self)); Ok(IsNull::No) } @@ -50,10 +47,7 @@ impl Type for uuid::fmt::Hyphenated { } impl Encode<'_, Mssql> for uuid::fmt::Hyphenated { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::Uuid(*self.as_uuid())); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/types/xml.rs b/sqlx-mssql/src/types/xml.rs index 82d409ea15..ace1b70940 100644 --- a/sqlx-mssql/src/types/xml.rs +++ b/sqlx-mssql/src/types/xml.rs @@ -40,10 +40,7 @@ impl Type for MssqlXml { } impl Encode<'_, Mssql> for MssqlXml { - fn encode_by_ref( - &self, - buf: &mut Vec, - ) -> Result { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { buf.push(MssqlArgumentValue::String(self.0.clone())); Ok(IsNull::No) } diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs index 395e004763..ee60ad1414 100644 --- a/sqlx-mssql/src/value.rs +++ b/sqlx-mssql/src/value.rs @@ -30,15 +30,15 @@ pub(crate) enum MssqlData { Uuid(uuid::Uuid), #[cfg(feature = "rust_decimal")] Decimal(rust_decimal::Decimal), - #[cfg(feature = "time")] + #[cfg(all(feature = "time", not(feature = "chrono")))] TimeDate(time::Date), - #[cfg(feature = "time")] + #[cfg(all(feature = "time", not(feature = "chrono")))] TimeTime(time::Time), - #[cfg(feature = "time")] + #[cfg(all(feature = "time", not(feature = "chrono")))] TimePrimitiveDateTime(time::PrimitiveDateTime), - #[cfg(feature = "time")] + #[cfg(all(feature = "time", not(feature = "chrono")))] TimeOffsetDateTime(time::OffsetDateTime), - #[cfg(feature = "bigdecimal")] + #[cfg(all(feature = "bigdecimal", not(feature = "rust_decimal")))] BigDecimal(bigdecimal::BigDecimal), } @@ -137,12 +137,13 @@ pub(crate) fn column_data_to_mssql_data( // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. let t = dt2.time(); #[allow(clippy::cast_possible_wrap)] - let ns = t.increments() as i64 - * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); - Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new( + date, time, + ))) } #[cfg(feature = "chrono")] tiberius::ColumnData::DateTime(Some(dt)) => { @@ -151,7 +152,9 @@ pub(crate) fn column_data_to_mssql_data( // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); - Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new( + date, time, + ))) } #[cfg(feature = "chrono")] tiberius::ColumnData::SmallDateTime(Some(dt)) => { @@ -163,18 +166,20 @@ pub(crate) fn column_data_to_mssql_data( "invalid SmallDateTime seconds: {seconds} exceeds seconds-in-a-day" )) })?; - Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new(date, time))) + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new( + date, time, + ))) } #[cfg(feature = "chrono")] - tiberius::ColumnData::Date(Some(d)) => { - Ok(MssqlData::NaiveDate(chrono_date_from_days(d.days() as i64, 1)?)) - } + tiberius::ColumnData::Date(Some(d)) => Ok(MssqlData::NaiveDate(chrono_date_from_days( + d.days() as i64, + 1, + )?)), #[cfg(feature = "chrono")] tiberius::ColumnData::Time(Some(t)) => { // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. #[allow(clippy::cast_possible_wrap)] - let ns = t.increments() as i64 - * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); @@ -186,23 +191,23 @@ pub(crate) fn column_data_to_mssql_data( // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. let t = dto.datetime2().time(); #[allow(clippy::cast_possible_wrap)] - let ns = t.increments() as i64 - * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); // infallible: (0,0,0) is always valid let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + chrono::Duration::nanoseconds(ns); let naive = chrono::NaiveDateTime::new(date, time); let offset_secs = dto.offset() as i32 * 60; let fixed_offset = chrono::FixedOffset::east_opt(offset_secs).ok_or_else(|| { - Error::Protocol(format!( - "invalid timezone offset: {offset_secs} seconds" - )) - })?; - let dt = naive.and_local_timezone(fixed_offset).single().ok_or_else(|| { - Error::Protocol(format!( - "ambiguous or invalid local time for offset {offset_secs}s" - )) + Error::Protocol(format!("invalid timezone offset: {offset_secs} seconds")) })?; + let dt = naive + .and_local_timezone(fixed_offset) + .single() + .ok_or_else(|| { + Error::Protocol(format!( + "ambiguous or invalid local time for offset {offset_secs}s" + )) + })?; Ok(MssqlData::DateTimeFixedOffset(dt)) } @@ -210,52 +215,53 @@ pub(crate) fn column_data_to_mssql_data( tiberius::ColumnData::Guid(Some(v)) => Ok(MssqlData::Uuid(v)), #[cfg(feature = "rust_decimal")] - tiberius::ColumnData::Numeric(Some(n)) => { - Ok(MssqlData::Decimal(rust_decimal::Decimal::from_i128_with_scale( - n.value(), - n.scale() as u32, - ))) - } + tiberius::ColumnData::Numeric(Some(n)) => Ok(MssqlData::Decimal( + rust_decimal::Decimal::from_i128_with_scale(n.value(), n.scale() as u32), + )), #[cfg(all(feature = "time", not(feature = "chrono")))] - tiberius::ColumnData::Date(Some(d)) => { - Ok(MssqlData::TimeDate(time_date_from_days(i64::from(d.days()), 1)?)) - } + tiberius::ColumnData::Date(Some(d)) => Ok(MssqlData::TimeDate(time_date_from_days( + i64::from(d.days()), + 1, + )?)), #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::Time(Some(t)) => { - let ns = t.increments() - * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); Ok(MssqlData::TimeTime(time_from_sec_fragments(ns)?)) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime2(Some(dt2)) => { let date = time_date_from_days(i64::from(dt2.date().days()), 1)?; let t = dt2.time(); - let ns = t.increments() - * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); let time = time_from_sec_fragments(ns)?; - Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) + Ok(MssqlData::TimePrimitiveDateTime( + time::PrimitiveDateTime::new(date, time), + )) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTime(Some(dt)) => { let date = time_date_from_days(i64::from(dt.days()), 1900)?; let ns = dt.seconds_fragments() as u64 * 1_000_000_000u64 / 300; let time = time_from_sec_fragments(ns)?; - Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) + Ok(MssqlData::TimePrimitiveDateTime( + time::PrimitiveDateTime::new(date, time), + )) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::SmallDateTime(Some(dt)) => { let date = time_date_from_days(i64::from(dt.days()), 1900)?; let seconds = dt.seconds_fragments() as u64 * 60; let time = time_from_sec_fragments(seconds * 1_000_000_000)?; - Ok(MssqlData::TimePrimitiveDateTime(time::PrimitiveDateTime::new(date, time))) + Ok(MssqlData::TimePrimitiveDateTime( + time::PrimitiveDateTime::new(date, time), + )) } #[cfg(all(feature = "time", not(feature = "chrono")))] tiberius::ColumnData::DateTimeOffset(Some(dto)) => { let date = time_date_from_days(i64::from(dto.datetime2().date().days()), 1)?; let t = dto.datetime2().time(); - let ns = t.increments() - * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); let time = time_from_sec_fragments(ns)?; let naive = time::PrimitiveDateTime::new(date, time); let offset_secs = dto.offset() as i32 * 60; @@ -317,11 +323,10 @@ pub(crate) fn column_data_to_mssql_data( } /// Convert days since `start_year`-01-01 to a `time::Date`. -#[cfg(feature = "time")] +#[cfg(all(feature = "time", not(feature = "chrono")))] fn time_date_from_days(days: i64, start_year: i32) -> Result { - let start = time::Date::from_ordinal_date(start_year, 1).map_err(|_| { - Error::Protocol(format!("invalid start year for date: {start_year}")) - })?; + let start = time::Date::from_ordinal_date(start_year, 1) + .map_err(|_| Error::Protocol(format!("invalid start year for date: {start_year}")))?; start .checked_add(time::Duration::days(days)) .ok_or_else(|| { @@ -332,7 +337,7 @@ fn time_date_from_days(days: i64, start_year: i32) -> Result } /// Convert nanoseconds-since-midnight to a `time::Time`. -#[cfg(feature = "time")] +#[cfg(all(feature = "time", not(feature = "chrono")))] fn time_from_sec_fragments(nanoseconds: u64) -> Result { const NANOS_PER_DAY: u64 = 86_400_000_000_000; if nanoseconds >= NANOS_PER_DAY { @@ -362,9 +367,8 @@ fn time_from_sec_fragments(nanoseconds: u64) -> Result { /// Convert days since `start_year`-01-01 to a `chrono::NaiveDate`. #[cfg(feature = "chrono")] fn chrono_date_from_days(days: i64, start_year: i32) -> Result { - let start = chrono::NaiveDate::from_ymd_opt(start_year, 1, 1).ok_or_else(|| { - Error::Protocol(format!("invalid start year for date: {start_year}")) - })?; + let start = chrono::NaiveDate::from_ymd_opt(start_year, 1, 1) + .ok_or_else(|| Error::Protocol(format!("invalid start year for date: {start_year}")))?; start .checked_add_signed(chrono::Duration::days(days)) .ok_or_else(|| { diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 3744724c12..d7ffdaeb27 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -17,7 +17,7 @@ where let db_url = env::var("DATABASE_URL").map_err(|e| Error::Configuration(Box::new(e)))?; - Ok(DB::Connection::connect(&db_url).await?) + DB::Connection::connect(&db_url).await } // Make a new pool diff --git a/tests/any/any.rs b/tests/any/any.rs index 71c561cadb..0069dbf99d 100644 --- a/tests/any/any.rs +++ b/tests/any/any.rs @@ -152,7 +152,7 @@ async fn it_can_query_by_string_args() -> sqlx::Result<()> { let mut conn = new::().await?; let string = "Hello, world!".to_string(); - let ref tuple = ("Hello, world!".to_string(),); + let tuple = &("Hello, world!".to_string(),); #[cfg(feature = "postgres")] const SQL: &str = diff --git a/tests/mssql/advisory-lock.rs b/tests/mssql/advisory-lock.rs index 1277b002fc..c3dac4e924 100644 --- a/tests/mssql/advisory-lock.rs +++ b/tests/mssql/advisory-lock.rs @@ -64,7 +64,10 @@ async fn it_supports_shared_locks() -> anyhow::Result<()> { // Both connections should be able to acquire a shared lock lock.acquire(&mut conn1).await?; let acquired = lock.try_acquire(&mut conn2).await?; - assert!(acquired, "shared lock should be acquirable by second connection"); + assert!( + acquired, + "shared lock should be acquirable by second connection" + ); lock.release(&mut conn1).await?; lock.release(&mut conn2).await?; @@ -79,7 +82,10 @@ async fn it_release_returns_false_when_not_held() -> anyhow::Result<()> { let lock = MssqlAdvisoryLock::new("sqlx_test_not_held"); let released = lock.release(&mut conn).await?; - assert!(!released, "release should return false when lock is not held"); + assert!( + !released, + "release should return false when lock is not held" + ); Ok(()) } diff --git a/tests/mssql/bulk-insert.rs b/tests/mssql/bulk-insert.rs index b233621570..232f576752 100644 --- a/tests/mssql/bulk-insert.rs +++ b/tests/mssql/bulk-insert.rs @@ -6,11 +6,9 @@ use sqlx_test::new; async fn it_bulk_inserts_rows() -> anyhow::Result<()> { let mut conn = new::().await?; - sqlx::query( - "CREATE TABLE #bulk_test (name NVARCHAR(50) NOT NULL, value INT NOT NULL)" - ) - .execute(&mut conn) - .await?; + sqlx::query("CREATE TABLE #bulk_test (name NVARCHAR(50) NOT NULL, value INT NOT NULL)") + .execute(&mut conn) + .await?; let mut bulk = conn.bulk_insert("#bulk_test").await?; bulk.send(("hello", 1i32).into_row()).await?; diff --git a/tests/mssql/error.rs b/tests/mssql/error.rs index 6dc9e66931..e0dd1ed6fd 100644 --- a/tests/mssql/error.rs +++ b/tests/mssql/error.rs @@ -48,10 +48,9 @@ async fn it_fails_with_not_null_violation() -> anyhow::Result<()> { let mut conn = new::().await?; let mut tx = conn.begin().await?; - let res: Result<_, sqlx::Error> = - sqlx::query("INSERT INTO tweet (id, text) VALUES (1, NULL)") - .execute(&mut *tx) - .await; + let res: Result<_, sqlx::Error> = sqlx::query("INSERT INTO tweet (id, text) VALUES (1, NULL)") + .execute(&mut *tx) + .await; let err = res.unwrap_err(); let err = err.into_database_error().unwrap(); diff --git a/tests/mssql/isolation-level.rs b/tests/mssql/isolation-level.rs index 670b1f1e52..de6ac5ecb2 100644 --- a/tests/mssql/isolation-level.rs +++ b/tests/mssql/isolation-level.rs @@ -10,9 +10,7 @@ async fn it_begins_with_read_uncommitted() -> anyhow::Result<()> { .begin_with_isolation(MssqlIsolationLevel::ReadUncommitted) .await?; - let row = sqlx::query("SELECT 1 AS val") - .fetch_one(&mut *tx) - .await?; + let row = sqlx::query("SELECT 1 AS val").fetch_one(&mut *tx).await?; let val: i32 = row.get("val"); assert_eq!(val, 1); @@ -33,9 +31,7 @@ async fn it_begins_with_snapshot() -> anyhow::Result<()> { .begin_with_isolation(MssqlIsolationLevel::Snapshot) .await?; - let row = sqlx::query("SELECT 1 AS val") - .fetch_one(&mut *tx) - .await?; + let row = sqlx::query("SELECT 1 AS val").fetch_one(&mut *tx).await?; let val: i32 = row.get("val"); assert_eq!(val, 1); @@ -51,9 +47,7 @@ async fn it_begins_with_serializable() -> anyhow::Result<()> { .begin_with_isolation(MssqlIsolationLevel::Serializable) .await?; - let row = sqlx::query("SELECT 1 AS val") - .fetch_one(&mut *tx) - .await?; + let row = sqlx::query("SELECT 1 AS val").fetch_one(&mut *tx).await?; let val: i32 = row.get("val"); assert_eq!(val, 1); diff --git a/tests/mssql/migrate.rs b/tests/mssql/migrate.rs index 63ba618151..b76a2a4eb1 100644 --- a/tests/mssql/migrate.rs +++ b/tests/mssql/migrate.rs @@ -68,9 +68,11 @@ async fn reversible(mut conn: PoolConnection) -> anyhow::Result<()> { /// Ensure that we have a clean initial state. async fn clean_up(conn: &mut MssqlConnection) -> anyhow::Result<()> { - conn.execute("IF OBJECT_ID('migrations_simple_test', 'U') IS NOT NULL DROP TABLE migrations_simple_test") - .await - .ok(); + conn.execute( + "IF OBJECT_ID('migrations_simple_test', 'U') IS NOT NULL DROP TABLE migrations_simple_test", + ) + .await + .ok(); conn.execute("IF OBJECT_ID('migrations_reversible_test', 'U') IS NOT NULL DROP TABLE migrations_reversible_test") .await .ok(); diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index dcbfeaccf2..f2e1627515 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -1,11 +1,11 @@ use futures_util::TryStreamExt; +use sqlx::mssql::MssqlRow; use sqlx::mssql::{Mssql, MssqlPoolOptions}; +use sqlx::mssql::{MssqlAdvisoryLock, MssqlIsolationLevel}; use sqlx::{Column, Connection, Executor, MssqlConnection, Row, SqlSafeStr, Statement, TypeInfo}; -use sqlx::mssql::MssqlRow; use sqlx_test::new; use std::sync::atomic::{AtomicI32, Ordering}; use std::time::Duration; -use sqlx::mssql::{MssqlAdvisoryLock, MssqlIsolationLevel}; #[sqlx_macros::test] async fn it_connects() -> anyhow::Result<()> { @@ -309,7 +309,9 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { .fetch_one(&mut *tx) .await?; - let statement = tx.prepare("SELECT * FROM tweet WHERE id = @p1".into_sql_str()).await?; + let statement = tx + .prepare("SELECT * FROM tweet WHERE id = @p1".into_sql_str()) + .await?; assert_eq!(statement.column(0).name(), "id"); assert_eq!(statement.column(1).name(), "text"); @@ -514,7 +516,11 @@ async fn it_can_inspect_column_metadata() -> anyhow::Result<()> { assert_eq!(statement.column(0).type_info().name(), "INT"); // sp_describe_first_result_set returns "NVARCHAR(50)" for typed NVARCHAR - assert!(statement.column(1).type_info().name().starts_with("NVARCHAR")); + assert!(statement + .column(1) + .type_info() + .name() + .starts_with("NVARCHAR")); assert_eq!(statement.column(2).type_info().name(), "BIGINT"); Ok(()) @@ -525,10 +531,9 @@ async fn it_can_reuse_connection_after_error() -> anyhow::Result<()> { let mut conn = new::().await?; // Cause an error - let res: Result<_, sqlx::Error> = - sqlx::query("SELECT * FROM this_table_does_not_exist_12345") - .execute(&mut conn) - .await; + let res: Result<_, sqlx::Error> = sqlx::query("SELECT * FROM this_table_does_not_exist_12345") + .execute(&mut conn) + .await; assert!(res.is_err()); // Connection should still be usable @@ -632,18 +637,14 @@ async fn it_can_use_advisory_lock_guard() -> anyhow::Result<()> { let mut guard = lock.acquire_guard(&mut conn).await?; // Use the connection through the guard - let val: (i32,) = sqlx::query_as("SELECT 99") - .fetch_one(&mut *guard) - .await?; + let val: (i32,) = sqlx::query_as("SELECT 99").fetch_one(&mut *guard).await?; assert_eq!(val.0, 99); // Release the lock and get the connection back let conn = guard.release_now().await?; // Verify we can still use the connection - let val: (i32,) = sqlx::query_as("SELECT 100") - .fetch_one(conn) - .await?; + let val: (i32,) = sqlx::query_as("SELECT 100").fetch_one(conn).await?; assert_eq!(val.0, 100); Ok(()) diff --git a/tests/mssql/test-attr.rs b/tests/mssql/test-attr.rs index 25854a4cff..81b3d62660 100644 --- a/tests/mssql/test-attr.rs +++ b/tests/mssql/test-attr.rs @@ -12,11 +12,7 @@ async fn it_gets_a_pool(pool: MssqlPool) -> sqlx::Result<()> { .fetch_one(&mut *conn) .await?; - assert!( - db_name.starts_with("_sqlx_test_"), - "db_name: {:?}", - db_name - ); + assert!(db_name.starts_with("_sqlx_test_"), "db_name: {:?}", db_name); Ok(()) } @@ -74,12 +70,10 @@ async fn it_gets_posts(pool: MssqlPool) -> sqlx::Result<()> { #[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))] async fn it_gets_comments(pool: MssqlPool) -> sqlx::Result<()> { let post_1_comments: Vec = - sqlx::query_scalar( - "SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at", - ) - .bind(&1) - .fetch_all(&pool) - .await?; + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at") + .bind(&1) + .fetch_all(&pool) + .await?; assert_eq!( post_1_comments, @@ -87,12 +81,10 @@ async fn it_gets_comments(pool: MssqlPool) -> sqlx::Result<()> { ); let post_2_comments: Vec = - sqlx::query_scalar( - "SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at", - ) - .bind(&2) - .fetch_all(&pool) - .await?; + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at") + .bind(&2) + .fetch_all(&pool) + .await?; assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]); diff --git a/tests/mssql/types.rs b/tests/mssql/types.rs index c15785eab3..5818533f9e 100644 --- a/tests/mssql/types.rs +++ b/tests/mssql/types.rs @@ -250,7 +250,7 @@ mod time_tests { type TimePrimitiveDateTime = sqlx::types::time::PrimitiveDateTime; type TimeOffsetDateTime = sqlx::types::time::OffsetDateTime; - use time::macros::{date, time as time_macro, datetime}; + use time::macros::{date, datetime, time as time_macro}; test_type!(time_date(Mssql, "CAST('2001-01-05' AS DATE)" From 7f4790e94faca6158cd84359b579d27ac7162684 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Mon, 23 Mar 2026 23:35:31 -0500 Subject: [PATCH 32/33] docs: move MSSQL_SUPPORT.md into sqlx-mssql crate and fix test coverage table Relocate documentation to live alongside the driver code. Update the test coverage table to fix an incorrect path and add missing test file entries. Author: Pablo Carrera --- MSSQL_SUPPORT.md => sqlx-mssql/MSSQL_SUPPORT.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) rename MSSQL_SUPPORT.md => sqlx-mssql/MSSQL_SUPPORT.md (99%) diff --git a/MSSQL_SUPPORT.md b/sqlx-mssql/MSSQL_SUPPORT.md similarity index 99% rename from MSSQL_SUPPORT.md rename to sqlx-mssql/MSSQL_SUPPORT.md index 7decf7644d..e58c1cc66f 100644 --- a/MSSQL_SUPPORT.md +++ b/sqlx-mssql/MSSQL_SUPPORT.md @@ -1184,4 +1184,6 @@ Comprehensive test suite in `tests/mssql/`: | Derives | `derives.rs` | `#[derive(FromRow)]`, custom field mappings | | Query builder | `query_builder.rs` | Dynamic query construction, parameter handling | | Error handling | `error.rs` | Database error inspection, error details | -| Compile-time macros | `tests/mssql-macros/` | Online and offline macro verification | +| Compile-time macros | `macros.rs` | Online and offline macro verification | +| Describe | `describe.rs` | `sp_describe` column metadata and type inference | +| Migrations | `migrate.rs` | Migration lifecycle: create, run, revert | From be3a8e9431f92a0c4f96431984624a3b67776a33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=CE=94BL=C3=98=20=E1=84=83=CE=9E?= Date: Sun, 17 May 2026 17:35:00 -0500 Subject: [PATCH 33/33] feat(mssql): implement Migrate::skip for skip-migrations parity Mirrors the upstream skip-migrations addition from launchbadge/sqlx PR #3846 (commit 45ba990). The new `skip` method records a migration in the `_sqlx_migrations` table without executing its SQL body, marking it as successfully applied. This is used by `sqlx migrate override skip` in the CLI and is routed through the Any driver. The TSQL INSERT mirrors `execute_migration`'s parameter binding (`@p1, @p2, @p3`) and uses `1` for the `success` BIT column (MSSQL doesn't have a TRUE literal). The `escape_table_name` helper provides the same identifier-injection protection used by existing MSSQL migration writes. Author: Pablo Carrera --- sqlx-mssql/src/migrate.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs index a4623e11e3..631f8cc061 100644 --- a/sqlx-mssql/src/migrate.rs +++ b/sqlx-mssql/src/migrate.rs @@ -270,6 +270,30 @@ impl Migrate for MssqlConnection { Ok(elapsed) }) } + + fn skip<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + let ident = escape_table_name(table_name); + // language=TSQL + let _ = query(AssertSqlSafe(format!( + r#" + INSERT INTO {ident} ( version, description, success, checksum, execution_time ) + VALUES ( @p1, @p2, 1, @p3, -1 ) + "# + ))) + .bind(migration.version) + .bind(&*migration.description) + .bind(&*migration.checksum) + .execute(self) + .await?; + + Ok(()) + }) + } } async fn execute_migration(