diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index 8d69cff731..d99e2c8c6f 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -117,6 +117,12 @@ jobs: -p sqlx-sqlite --all-features + - name: Test sqlx-mssql + run: > + cargo test + -p sqlx-mssql + --all-features + - name: Test sqlx-macros-core run: > cargo test @@ -716,3 +722,78 @@ jobs: env: DATABASE_URL: mysql://root@localhost:3306/sqlx?sslmode=verify_ca&ssl-ca=.%2Ftests%2Fcerts%2Fca.crt&ssl-key=.%2Ftests%2Fcerts%2Fkeys%2Fclient.key&ssl-cert=.%2Ftests%2Fcerts%2Fclient.crt RUSTFLAGS: --cfg mariadb="${{ matrix.mariadb }}" + + mssql: + name: MSSQL + runs-on: ubuntu-24.04 + strategy: + matrix: + mssql: [ 2022, 2019 ] + runtime: [ async-global-executor, smol, tokio ] + tls: [ native-tls, rustls-aws-lc-rs, rustls-ring, none ] + needs: check + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + + - name: Setup Rust + run: rustup show active-toolchain || rustup toolchain install + + - uses: Swatinem/rust-cache@v2 + + - run: > + cargo build + --no-default-features + --features mssql,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }},macros,migrate + + - run: docker compose -f tests/docker-compose.yml run -d -p 1433:1433 --name mssql_${{ matrix.mssql }} mssql_${{ matrix.mssql }} + - run: sleep 60 + + # Create data dir for offline mode + - run: mkdir .sqlx + + - run: > + cargo test + --no-default-features + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mssql://sa:YourStrong!Passw0rd@localhost:1433/sqlx + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mssql_${{ matrix.mssql }} + + # Run the `test-attr` test again to cover cleanup. + - run: > + cargo test + --test mssql-test-attr + --no-default-features + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mssql://sa:YourStrong!Passw0rd@localhost:1433/sqlx + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mssql_${{ matrix.mssql }} + + # Remove test artifacts + - run: cargo clean -p sqlx + + # Build the macros-test in offline mode (omit DATABASE_URL) + - run: > + cargo build + --no-default-features + --test mssql-macros + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + SQLX_OFFLINE: true + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: -D warnings --cfg mssql_${{ matrix.mssql }} + + # Test macros in offline mode (still needs DATABASE_URL to run) + - run: > + cargo test + --no-default-features + --test mssql-macros + --features any,mssql,macros,migrate,_unstable-all-types,runtime-${{ matrix.runtime }},tls-${{ matrix.tls }} + env: + DATABASE_URL: mssql://sa:YourStrong!Passw0rd@localhost:1433/sqlx + SQLX_OFFLINE: true + SQLX_OFFLINE_DIR: .sqlx + RUSTFLAGS: --cfg mssql_${{ matrix.mssql }} diff --git a/Cargo.lock b/Cargo.lock index 66d1bd7a60..67ae6e0d91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,6 +64,15 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + [[package]] name = "anstream" version = "0.6.21" @@ -372,6 +381,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "asynchronous-codec" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4057f2c32adbb2fc158e22fb38433c8e9bbf76b75a4732c7c0cbaf695fb65568" +dependencies = [ + "bytes", + "futures-sink", + "futures-util", + "memchr", + "pin-project-lite", +] + [[package]] name = "atoi" version = "2.0.0" @@ -387,6 +409,17 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi 0.1.19", + "libc", + "winapi", +] + [[package]] name = "autocfg" version = "1.5.0" @@ -490,7 +523,7 @@ dependencies = [ "getrandom 0.2.17", "instant", "pin-project-lite", - "rand", + "rand 0.8.5", "tokio", ] @@ -527,6 +560,17 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "bigdecimal" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6773ddc0eafc0e509fb60e48dff7f450f8e674a0686ae8605e8d9901bd5eefa" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "bigdecimal" version = "0.4.0" @@ -539,6 +583,29 @@ dependencies = [ "num-traits", ] +[[package]] +name = "bindgen" +version = "0.59.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bd2a9a458e8f4304c52c43ebb0cfbd520289f8379a52e329a38afda99bf8eb8" +dependencies = [ + "bitflags 1.3.2", + "cexpr", + "clang-sys", + "clap 2.34.0", + "env_logger 0.9.3", + "lazy_static", + "lazycell", + "log", + "peeking_take_while", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "which", +] + [[package]] name = "bindgen" version = "0.69.5" @@ -838,6 +905,21 @@ dependencies = [ "libloading", ] +[[package]] +name = "clap" +version = "2.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c" +dependencies = [ + "ansi_term", + "atty", + "bitflags 1.3.2", + "strsim 0.8.0", + "textwrap", + "unicode-width", + "vec_map", +] + [[package]] name = "clap" version = "4.4.7" @@ -867,7 +949,7 @@ version = "4.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f6b5c519bab3ea61843a7923d074b04245624bb84a64a8c150f5deb014e388b" dependencies = [ - "clap", + "clap 4.4.7", ] [[package]] @@ -952,6 +1034,12 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "connection-string" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "510ca239cf13b7f8d16a2b48f263de7b4f8c566f0af58d901031473c76afb1e3" + [[package]] name = "console" version = "0.15.0" @@ -1022,7 +1110,7 @@ dependencies = [ "anes", "cast", "ciborium", - "clap", + "clap 4.4.7", "criterion-plot", "futures", "is-terminal", @@ -1095,8 +1183,8 @@ dependencies = [ "bitflags 2.4.0", "crossterm_winapi", "libc", - "mio", - "parking_lot", + "mio 0.8.11", + "parking_lot 0.12.5", "signal-hook", "signal-hook-mio", "winapi", @@ -1256,6 +1344,35 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "enumflags2" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1027f7680c853e056ebcec683615fb6fbbc07dbaa13b4d5d9442b146ded4ecef" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67c78a4d8fdf9953a5c9d458f9efe940fd97a0cab0941c075a813ac594733827" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + [[package]] name = "env_filter" version = "0.1.4" @@ -1266,6 +1383,19 @@ dependencies = [ "regex", ] +[[package]] +name = "env_logger" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" +dependencies = [ + "atty", + "humantime", + "log", + "regex", + "termcolor", +] + [[package]] name = "env_logger" version = "0.11.0" @@ -1508,7 +1638,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot", + "parking_lot 0.12.5", ] [[package]] @@ -1579,6 +1709,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -1587,7 +1728,7 @@ checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", ] [[package]] @@ -1695,6 +1836,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + [[package]] name = "hermit-abi" version = "0.5.2" @@ -1727,11 +1877,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.12" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc627f471c528ff0c4a49e1d5e60450c8f6461dd6d10ba9dcd3a61d3dff7728d" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.61.2", + "windows-sys 0.59.0", ] [[package]] @@ -2019,7 +2169,7 @@ version = "0.4.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ - "hermit-abi", + "hermit-abi 0.5.2", "libc", "windows-sys 0.61.2", ] @@ -2115,6 +2265,28 @@ version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" +[[package]] +name = "libgssapi" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "724dbcd1f871da9c67983537a47ac510c278656f6392418ad67c7a52720e54b2" +dependencies = [ + "bitflags 1.3.2", + "bytes", + "lazy_static", + "libgssapi-sys", + "parking_lot 0.11.2", +] + +[[package]] +name = "libgssapi-sys" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1dd7d65e409c889f6c9d81ff079371d0d8fd88d7dca702ff187ef96fb0450fb7" +dependencies = [ + "bindgen 0.59.2", +] + [[package]] name = "libloading" version = "0.8.9" @@ -2149,7 +2321,7 @@ version = "0.30.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" dependencies = [ - "bindgen", + "bindgen 0.69.5", "cc", "pkg-config", "vcpkg", @@ -2231,6 +2403,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e6bcd6433cff03a4bfc3d9834d504467db1f1cf6d0ea765d37d330249ed629d" + [[package]] name = "memchr" version = "2.5.0" @@ -2275,10 +2453,21 @@ checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.48.0", ] +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.61.2", +] + [[package]] name = "mockall" version = "0.11.0" @@ -2385,7 +2574,7 @@ dependencies = [ "num-integer", "num-iter", "num-traits", - "rand", + "rand 0.8.5", "smallvec", "zeroize", ] @@ -2426,16 +2615,6 @@ dependencies = [ "libm", ] -[[package]] -name = "num_cpus" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" -dependencies = [ - "hermit-abi", - "libc", -] - [[package]] name = "object" version = "0.32.2" @@ -2535,6 +2714,17 @@ version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" +[[package]] +name = "parking_lot" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core 0.8.6", +] + [[package]] name = "parking_lot" version = "0.12.5" @@ -2542,7 +2732,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" dependencies = [ "lock_api", - "parking_lot_core", + "parking_lot_core 0.9.12", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall 0.2.16", + "smallvec", + "winapi", ] [[package]] @@ -2565,7 +2769,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" dependencies = [ "base64ct", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -2576,7 +2780,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" dependencies = [ "base64ct", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -2586,6 +2790,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0744126afe1a6dd7f394cb50a716dbe086cb06e255e53d8d0185d82828358fb5" +[[package]] +name = "peeking_take_while" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -2713,7 +2923,7 @@ checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" dependencies = [ "cfg-if", "concurrent-queue", - "hermit-abi", + "hermit-abi 0.5.2", "pin-project-lite", "rustix 1.1.4", "windows-sys 0.61.2", @@ -2784,6 +2994,12 @@ dependencies = [ "termtree", ] +[[package]] +name = "pretty-hex" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa0831dd7cc608c38a5e323422a0077678fa5744aa2be4ad91c4ece8eec8d5" + [[package]] name = "proc-macro-crate" version = "3.2.0" @@ -2865,6 +3081,19 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + [[package]] name = "rand" version = "0.8.5" @@ -2872,8 +3101,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", ] [[package]] @@ -2883,7 +3122,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", ] [[package]] @@ -2895,13 +3143,22 @@ dependencies = [ "getrandom 0.2.17", ] +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f97cdb2a36ed4183de61b2f824cc45c9f1037f28afe0a322e9fff4c108b5aaa" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -2945,6 +3202,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -3052,7 +3318,7 @@ dependencies = [ "num-traits", "pkcs1", "pkcs8", - "rand_core", + "rand_core 0.6.4", "signature", "subtle", "zeroize", @@ -3068,7 +3334,7 @@ dependencies = [ "borsh", "bytes", "num-traits", - "rand", + "rand 0.8.5", "rkyv", "serde", "serde_json", @@ -3395,7 +3661,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b75a19a7a740b25bc7944bdee6172368f988763b744e3d4dfe753f6b4ece40cc" dependencies = [ "libc", - "mio", + "mio 0.8.11", "signal-hook", ] @@ -3416,7 +3682,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -3459,12 +3725,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.4.10" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -3494,17 +3760,18 @@ dependencies = [ "async-std", "criterion", "dotenvy", - "env_logger", + "env_logger 0.11.0", "futures-util", "hex", "libsqlite3-sys", "paste", - "rand", + "rand 0.8.5", "rand_xoshiro", "serde", "serde_json", "sqlx-core", "sqlx-macros", + "sqlx-mssql", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -3525,7 +3792,7 @@ dependencies = [ "backoff", "cargo_metadata", "chrono", - "clap", + "clap 4.4.7", "clap_complete", "console", "dialoguer", @@ -3550,7 +3817,7 @@ dependencies = [ "async-std", "async-task", "base64 0.22.0", - "bigdecimal", + "bigdecimal 0.4.0", "bit-vec", "bstr", "bytes", @@ -3594,12 +3861,23 @@ dependencies = [ "webpki-roots", ] +[[package]] +name = "sqlx-example-mssql-todos" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap 4.4.7", + "dotenvy", + "sqlx", + "tokio", +] + [[package]] name = "sqlx-example-mysql-todos" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.4.7", "sqlx", "tokio", ] @@ -3613,7 +3891,7 @@ dependencies = [ "axum", "dotenvy", "http-body-util", - "rand", + "rand 0.8.5", "regex", "serde", "serde_json", @@ -3654,7 +3932,7 @@ name = "sqlx-example-postgres-json" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.4.7", "dotenvy", "serde", "serde_json", @@ -3677,7 +3955,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", - "clap", + "clap 4.4.7", "dotenvy", "mockall", "sqlx", @@ -3690,7 +3968,7 @@ version = "0.9.0-alpha.1" dependencies = [ "color-eyre", "dotenvy", - "rand", + "rand 0.8.5", "rust_decimal", "sqlx", "sqlx-example-postgres-multi-database-accounts", @@ -3705,7 +3983,7 @@ version = "0.1.0" dependencies = [ "argon2 0.5.3", "password-hash 0.5.0", - "rand", + "rand 0.8.5", "serde", "sqlx", "thiserror 1.0.40", @@ -3731,7 +4009,7 @@ version = "0.9.0-alpha.1" dependencies = [ "color-eyre", "dotenvy", - "rand", + "rand 0.8.5", "rust_decimal", "sqlx", "sqlx-example-postgres-multi-tenant-accounts", @@ -3746,7 +4024,7 @@ version = "0.1.0" dependencies = [ "argon2 0.5.3", "password-hash 0.5.0", - "rand", + "rand 0.8.5", "serde", "sqlx", "thiserror 1.0.40", @@ -3806,7 +4084,7 @@ name = "sqlx-example-postgres-todos" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.4.7", "dotenvy", "sqlx", "tokio", @@ -3843,7 +4121,7 @@ name = "sqlx-example-sqlite-todos" version = "0.1.0" dependencies = [ "anyhow", - "clap", + "clap 4.4.7", "sqlx", "tokio", ] @@ -3877,6 +4155,7 @@ dependencies = [ "sha2", "smol", "sqlx-core", + "sqlx-mssql", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", @@ -3886,11 +4165,39 @@ dependencies = [ "url", ] +[[package]] +name = "sqlx-mssql" +version = "0.9.0-alpha.1" +dependencies = [ + "async-std", + "bigdecimal 0.4.0", + "bytes", + "chrono", + "dotenvy", + "either", + "futures-core", + "futures-io", + "futures-util", + "log", + "percent-encoding", + "rust_decimal", + "serde", + "sqlx", + "sqlx-core", + "thiserror 2.0.17", + "tiberius", + "time", + "tokio", + "tokio-util", + "tracing", + "uuid", +] + [[package]] name = "sqlx-mysql" version = "0.9.0-alpha.1" dependencies = [ - "bigdecimal", + "bigdecimal 0.4.0", "bitflags 2.4.0", "byteorder", "bytes", @@ -3904,7 +4211,7 @@ dependencies = [ "generic-array", "log", "percent-encoding", - "rand", + "rand 0.8.5", "rsa", "rust_decimal", "serde", @@ -3924,7 +4231,7 @@ version = "0.9.0-alpha.1" dependencies = [ "atoi", "base64 0.22.0", - "bigdecimal", + "bigdecimal 0.4.0", "bit-vec", "bitflags 2.4.0", "byteorder", @@ -3946,7 +4253,7 @@ dependencies = [ "md-5", "memchr", "num-bigint", - "rand", + "rand 0.8.5", "rust_decimal", "serde", "serde_json", @@ -3995,7 +4302,7 @@ version = "0.1.0" dependencies = [ "anyhow", "dotenvy", - "env_logger", + "env_logger 0.11.0", "sqlx", ] @@ -4031,6 +4338,12 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + [[package]] name = "strsim" version = "0.10.0" @@ -4163,6 +4476,15 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + [[package]] name = "thiserror" version = "1.0.40" @@ -4212,6 +4534,35 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "tiberius" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1446cb4198848d1562301a3340424b4f425ef79f35ef9ee034769a9dd92c10d" +dependencies = [ + "async-trait", + "asynchronous-codec", + "bigdecimal 0.3.1", + "byteorder", + "bytes", + "chrono", + "connection-string", + "encoding_rs", + "enumflags2", + "futures-util", + "libgssapi", + "num-traits", + "once_cell", + "pin-project-lite", + "pretty-hex", + "rust_decimal", + "thiserror 1.0.40", + "time", + "tracing", + "uuid", + "winauth", +] + [[package]] name = "time" version = "0.3.37" @@ -4280,33 +4631,31 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.25.0" +version = "1.43.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8e00990ebabbe4c14c08aca901caed183ecd5c09562a12c824bb53d3c3fd3af" +checksum = "333f1ce734dbc263af1106964dba1f8c993a91d1857910fb542d45179c3d3da5" dependencies = [ - "autocfg", + "backtrace", "bytes", "libc", - "memchr", - "mio", - "num_cpus", - "parking_lot", + "mio 1.2.0", + "parking_lot 0.12.5", "pin-project-lite", "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys 0.42.0", + "windows-sys 0.52.0", ] [[package]] name = "tokio-macros" -version = "1.8.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d266c00fde287f55d3f1c3e96c500c362a2b8c695076ec180f27918820bc6df8" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.87", ] [[package]] @@ -4320,6 +4669,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.5.11" @@ -4626,6 +4989,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + [[package]] name = "version_check" version = "0.9.5" @@ -4651,6 +5020,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -4740,6 +5115,18 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "whoami" version = "2.0.2" @@ -4780,6 +5167,19 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "winauth" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f820cd208ce9c6b050812dc2d724ba98c6c1e9db5ce9b3f58d925ae5723a5e6" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "md5", + "rand 0.7.3", + "winapi", +] + [[package]] name = "windows-core" version = "0.62.2" @@ -4839,21 +5239,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-sys" -version = "0.42.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" -dependencies = [ - "windows_aarch64_gnullvm 0.42.2", - "windows_aarch64_msvc 0.42.2", - "windows_i686_gnu 0.42.2", - "windows_i686_msvc 0.42.2", - "windows_x86_64_gnu 0.42.2", - "windows_x86_64_gnullvm 0.42.2", - "windows_x86_64_msvc 0.42.2", -] - [[package]] name = "windows-sys" version = "0.48.0" @@ -4921,12 +5306,6 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" - [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -4939,12 +5318,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" -[[package]] -name = "windows_aarch64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" - [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -4957,12 +5330,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" -[[package]] -name = "windows_i686_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" - [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -4981,12 +5348,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" -[[package]] -name = "windows_i686_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" - [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -4999,12 +5360,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" -[[package]] -name = "windows_x86_64_gnu" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" - [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -5017,12 +5372,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" - [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -5035,12 +5384,6 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" -[[package]] -name = "windows_x86_64_msvc" -version = "0.42.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" - [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/Cargo.toml b/Cargo.toml index 961dc3cc84..50e553ca05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,9 +7,11 @@ members = [ "sqlx-test", "sqlx-cli", # "sqlx-bench", + "sqlx-mssql", "sqlx-mysql", "sqlx-postgres", "sqlx-sqlite", + "examples/mssql/todos", "examples/mysql/todos", "examples/postgres/axum-social-with-tests", "examples/postgres/chat", @@ -64,14 +66,14 @@ rustdoc-args = ["--cfg", "docsrs"] default = ["any", "macros", "migrate", "json"] derive = ["sqlx-macros/derive"] -macros = ["derive", "sqlx-macros/macros", "sqlx-core/offline", "sqlx-mysql?/offline", "sqlx-postgres?/offline", "sqlx-sqlite?/offline"] -migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] +macros = ["derive", "sqlx-macros/macros", "sqlx-core/offline", "sqlx-mssql?/offline", "sqlx-mysql?/offline", "sqlx-postgres?/offline", "sqlx-sqlite?/offline"] +migrate = ["sqlx-core/migrate", "sqlx-macros?/migrate", "sqlx-mssql?/migrate", "sqlx-mysql?/migrate", "sqlx-postgres?/migrate", "sqlx-sqlite?/migrate"] # Enable parsing of `sqlx.toml` for configuring macros and migrations. sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-macros?/sqlx-toml", "sqlx-sqlite?/sqlx-toml"] # intended mainly for CI and docs -all-databases = ["mysql", "sqlite", "postgres", "any"] +all-databases = ["mssql", "mysql", "sqlite", "postgres", "any"] _unstable-all-types = [ "bigdecimal", "rust_decimal", @@ -118,7 +120,8 @@ _rt-tokio = [] _sqlite = [] # database -any = ["sqlx-core/any", "sqlx-mysql?/any", "sqlx-postgres?/any", "sqlx-sqlite?/any"] +any = ["sqlx-core/any", "sqlx-mssql?/any", "sqlx-mysql?/any", "sqlx-postgres?/any", "sqlx-sqlite?/any"] +mssql = ["sqlx-mssql", "sqlx-macros?/mssql"] postgres = ["sqlx-postgres", "sqlx-macros?/postgres"] mysql = ["sqlx-mysql", "sqlx-macros?/mysql"] mysql-rsa = ["mysql", "sqlx-mysql/rsa", "sqlx-macros?/mysql-rsa"] @@ -149,17 +152,17 @@ sqlite-preupdate-hook = ["sqlx-sqlite/preupdate-hook"] sqlite-unlock-notify = ["sqlx-sqlite/unlock-notify"] # types -json = ["sqlx-core/json", "sqlx-macros?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] +json = ["sqlx-core/json", "sqlx-macros?/json", "sqlx-mssql?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] -bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] +bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros?/bigdecimal", "sqlx-mssql?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-macros?/bit-vec", "sqlx-postgres?/bit-vec"] -chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +chrono = ["sqlx-core/chrono", "sqlx-macros?/chrono", "sqlx-mssql?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] ipnet = ["sqlx-core/ipnet", "sqlx-macros?/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros?/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-macros?/mac_address", "sqlx-postgres?/mac_address"] -rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] -time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] -uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] +rust_decimal = ["sqlx-core/rust_decimal", "sqlx-macros?/rust_decimal", "sqlx-mssql?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] +time = ["sqlx-core/time", "sqlx-macros?/time", "sqlx-mssql?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] +uuid = ["sqlx-core/uuid", "sqlx-macros?/uuid", "sqlx-mssql?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] regexp = ["sqlx-sqlite?/regexp"] bstr = ["sqlx-core/bstr"] @@ -170,6 +173,7 @@ sqlx-macros-core = { version = "=0.9.0-alpha.1", path = "sqlx-macros-core" } sqlx-macros = { version = "=0.9.0-alpha.1", path = "sqlx-macros" } # Driver crates +sqlx-mssql = { version = "=0.9.0-alpha.1", path = "sqlx-mssql" } sqlx-mysql = { version = "=0.9.0-alpha.1", path = "sqlx-mysql", default-features = false } sqlx-postgres = { version = "=0.9.0-alpha.1", path = "sqlx-postgres" } sqlx-sqlite = { version = "=0.9.0-alpha.1", path = "sqlx-sqlite" } @@ -216,6 +220,7 @@ default-features = false sqlx-core = { workspace = true, features = ["migrate"] } sqlx-macros = { workspace = true, optional = true } +sqlx-mssql = { workspace = true, optional = true } sqlx-mysql = { workspace = true, optional = true, default-features = false } sqlx-postgres = { workspace = true, optional = true } sqlx-sqlite = { workspace = true, optional = true } @@ -456,3 +461,67 @@ required-features = ["postgres"] name = "postgres-rustsec" path = "tests/postgres/rustsec.rs" required-features = ["postgres", "macros", "migrate"] + +# +# MSSQL +# + +[[test]] +name = "mssql" +path = "tests/mssql/mssql.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-types" +path = "tests/mssql/types.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-describe" +path = "tests/mssql/describe.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-macros" +path = "tests/mssql/macros.rs" +required-features = ["mssql", "macros"] + +[[test]] +name = "mssql-derives" +path = "tests/mssql/derives.rs" +required-features = ["mssql", "derive"] + +[[test]] +name = "mssql-error" +path = "tests/mssql/error.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-test-attr" +path = "tests/mssql/test-attr.rs" +required-features = ["mssql", "macros", "migrate"] + +[[test]] +name = "mssql-advisory-lock" +path = "tests/mssql/advisory-lock.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-isolation-level" +path = "tests/mssql/isolation-level.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-query-builder" +path = "tests/mssql/query_builder.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-bulk-insert" +path = "tests/mssql/bulk-insert.rs" +required-features = ["mssql"] + +[[test]] +name = "mssql-migrate" +path = "tests/mssql/migrate.rs" +required-features = ["mssql", "macros", "migrate"] diff --git a/examples/mssql/todos/Cargo.toml b/examples/mssql/todos/Cargo.toml new file mode 100644 index 0000000000..f7298d42b7 --- /dev/null +++ b/examples/mssql/todos/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sqlx-example-mssql-todos" +version = "0.1.0" +edition = "2021" +workspace = "../../../" + +[dependencies] +anyhow = "1.0" +sqlx = { path = "../../../", features = [ "mssql", "runtime-tokio", "tls-native-tls" ] } +clap = { version = "4", features = ["derive"] } +tokio = { version = "1.20.0", features = ["rt", "macros"] } +dotenvy = "0.15.0" diff --git a/examples/mssql/todos/migrations/20250101000000_create_todos.sql b/examples/mssql/todos/migrations/20250101000000_create_todos.sql new file mode 100644 index 0000000000..aca157b7da --- /dev/null +++ b/examples/mssql/todos/migrations/20250101000000_create_todos.sql @@ -0,0 +1,5 @@ +CREATE TABLE todos ( + id BIGINT IDENTITY(1,1) PRIMARY KEY, + description NVARCHAR(MAX) NOT NULL, + done BIT NOT NULL DEFAULT 0 +); diff --git a/examples/mssql/todos/src/main.rs b/examples/mssql/todos/src/main.rs new file mode 100644 index 0000000000..27bf395c04 --- /dev/null +++ b/examples/mssql/todos/src/main.rs @@ -0,0 +1,81 @@ +use clap::{Parser, Subcommand}; +use sqlx::mssql::MssqlPool; +use sqlx::Row; +use std::env; + +#[derive(Parser)] +struct Args { + #[command(subcommand)] + cmd: Option, +} + +#[derive(Subcommand)] +enum Command { + Add { description: String }, + Done { id: i64 }, +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let pool = MssqlPool::connect(&env::var("DATABASE_URL")?).await?; + + match args.cmd { + Some(Command::Add { description }) => { + println!("Adding new todo with description '{description}'"); + let todo_id = add_todo(&pool, description).await?; + println!("Added new todo with id {todo_id}"); + } + Some(Command::Done { id }) => { + println!("Marking todo {id} as done"); + if complete_todo(&pool, id).await? { + println!("Todo {id} is marked as done"); + } else { + println!("Invalid id {id}"); + } + } + None => { + println!("Printing list of all todos"); + list_todos(&pool).await?; + } + } + + Ok(()) +} + +async fn add_todo(pool: &MssqlPool, description: String) -> anyhow::Result { + // MSSQL uses OUTPUT INSERTED instead of RETURNING + let rec = sqlx::query("INSERT INTO todos (description) OUTPUT INSERTED.id VALUES (@p1)") + .bind(&description) + .fetch_one(pool) + .await?; + + Ok(rec.get::("id")) +} + +async fn complete_todo(pool: &MssqlPool, id: i64) -> anyhow::Result { + let rows_affected = sqlx::query("UPDATE todos SET done = 1 WHERE id = @p1") + .bind(id) + .execute(pool) + .await? + .rows_affected(); + + Ok(rows_affected > 0) +} + +async fn list_todos(pool: &MssqlPool) -> anyhow::Result<()> { + let recs = sqlx::query("SELECT id, description, done FROM todos ORDER BY id") + .fetch_all(pool) + .await?; + + for rec in recs { + println!( + "- [{}] {}: {}", + if rec.get::("done") { "x" } else { " " }, + rec.get::("id"), + rec.get::("description"), + ); + } + + Ok(()) +} diff --git a/sqlx-core/src/any/arguments.rs b/sqlx-core/src/any/arguments.rs index 59d6f4d6e0..abb7098072 100644 --- a/sqlx-core/src/any/arguments.rs +++ b/sqlx-core/src/any/arguments.rs @@ -67,8 +67,8 @@ impl AnyArguments { AnyValueKind::Null(AnyTypeInfoKind::SmallInt) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::Integer) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::BigInt) => out.add(Option::::None), - AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::::None), - AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Real) => out.add(Option::::None), + AnyValueKind::Null(AnyTypeInfoKind::Double) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::Text) => out.add(Option::::None), AnyValueKind::Null(AnyTypeInfoKind::Blob) => out.add(Option::>::None), AnyValueKind::Bool(b) => out.add(b), diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 534b92764d..75502484a8 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -32,6 +32,7 @@ migrate = ["sqlx-core/migrate"] sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-sqlite?/sqlx-toml"] # database +mssql = ["sqlx-mssql"] mysql = ["sqlx-mysql"] mysql-rsa = ["mysql", "sqlx-mysql/rsa"] postgres = ["sqlx-postgres"] @@ -42,20 +43,21 @@ sqlite-unbundled = ["_sqlite", "sqlx-sqlite/unbundled"] sqlite-load-extension = ["sqlx-sqlite/load-extension"] # type integrations -json = ["sqlx-core/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] +json = ["sqlx-core/json", "sqlx-mssql?/json", "sqlx-mysql?/json", "sqlx-postgres?/json", "sqlx-sqlite?/json"] -bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] +bigdecimal = ["sqlx-core/bigdecimal", "sqlx-mssql?/bigdecimal", "sqlx-mysql?/bigdecimal", "sqlx-postgres?/bigdecimal"] bit-vec = ["sqlx-core/bit-vec", "sqlx-postgres?/bit-vec"] -chrono = ["sqlx-core/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] +chrono = ["sqlx-core/chrono", "sqlx-mssql?/chrono", "sqlx-mysql?/chrono", "sqlx-postgres?/chrono", "sqlx-sqlite?/chrono"] ipnet = ["sqlx-core/ipnet", "sqlx-postgres?/ipnet"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-postgres?/ipnetwork"] mac_address = ["sqlx-core/mac_address", "sqlx-postgres?/mac_address"] -rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] -time = ["sqlx-core/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] -uuid = ["sqlx-core/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] +rust_decimal = ["sqlx-core/rust_decimal", "sqlx-mssql?/rust_decimal", "sqlx-mysql?/rust_decimal", "sqlx-postgres?/rust_decimal"] +time = ["sqlx-core/time", "sqlx-mssql?/time", "sqlx-mysql?/time", "sqlx-postgres?/time", "sqlx-sqlite?/time"] +uuid = ["sqlx-core/uuid", "sqlx-mssql?/uuid", "sqlx-mysql?/uuid", "sqlx-postgres?/uuid", "sqlx-sqlite?/uuid"] [dependencies] sqlx-core = { workspace = true, features = ["offline"] } +sqlx-mssql = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-mysql = { workspace = true, features = ["offline", "migrate"], optional = true, default-features = false } sqlx-postgres = { workspace = true, features = ["offline", "migrate"], optional = true } sqlx-sqlite = { workspace = true, features = ["offline", "migrate"], optional = true } diff --git a/sqlx-macros-core/src/database/impls.rs b/sqlx-macros-core/src/database/impls.rs index 523b85cc14..d51f9745cf 100644 --- a/sqlx-macros-core/src/database/impls.rs +++ b/sqlx-macros-core/src/database/impls.rs @@ -46,6 +46,9 @@ mod sqlx { #[cfg(feature = "postgres")] pub use sqlx_postgres as postgres; + #[cfg(feature = "mssql")] + pub use sqlx_mssql as mssql; + #[cfg(feature = "_sqlite")] pub use sqlx_sqlite as sqlite; } @@ -63,6 +66,12 @@ impl_database_ext! { row: sqlx::postgres::PgRow, } +#[cfg(feature = "mssql")] +impl_database_ext! { + sqlx::mssql::Mssql, + row: sqlx::mssql::MssqlRow, +} + #[cfg(feature = "_sqlite")] impl_database_ext! { sqlx::sqlite::Sqlite, diff --git a/sqlx-macros-core/src/database/mod.rs b/sqlx-macros-core/src/database/mod.rs index 0885b3cca8..f6f864d95b 100644 --- a/sqlx-macros-core/src/database/mod.rs +++ b/sqlx-macros-core/src/database/mod.rs @@ -10,7 +10,12 @@ use std::collections::hash_map; use std::collections::HashMap; use std::sync::{LazyLock, Mutex}; -#[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))] +#[cfg(any( + feature = "postgres", + feature = "mysql", + feature = "mssql", + feature = "_sqlite" +))] mod impls; pub trait DatabaseExt: Database + TypeChecking { diff --git a/sqlx-macros-core/src/lib.rs b/sqlx-macros-core/src/lib.rs index db6586200f..61e0fc0729 100644 --- a/sqlx-macros-core/src/lib.rs +++ b/sqlx-macros-core/src/lib.rs @@ -53,6 +53,8 @@ pub const FOSS_DRIVERS: &[QueryDriver] = &[ QueryDriver::new::(), #[cfg(feature = "postgres")] QueryDriver::new::(), + #[cfg(feature = "mssql")] + QueryDriver::new::(), #[cfg(feature = "_sqlite")] QueryDriver::new::(), ]; diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index bd90da9608..fabe0f2468 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -33,6 +33,7 @@ migrate = ["sqlx-macros-core/migrate"] sqlx-toml = ["sqlx-macros-core/sqlx-toml"] # database +mssql = ["sqlx-macros-core/mssql"] mysql = ["sqlx-macros-core/mysql"] mysql-rsa = ["sqlx-macros-core/mysql-rsa"] postgres = ["sqlx-macros-core/postgres"] diff --git a/sqlx-mssql/Cargo.toml b/sqlx-mssql/Cargo.toml new file mode 100644 index 0000000000..40f039df05 --- /dev/null +++ b/sqlx-mssql/Cargo.toml @@ -0,0 +1,69 @@ +[package] +name = "sqlx-mssql" +documentation = "https://docs.rs/sqlx" +description = "MSSQL driver implementation for SQLx. Not for direct use; see the `sqlx` crate for details." +version.workspace = true +license.workspace = true +edition.workspace = true +authors.workspace = true +repository.workspace = true +rust-version.workspace = true + +[features] +json = ["sqlx-core/json", "serde"] +any = ["sqlx-core/any"] +offline = ["sqlx-core/offline", "serde"] +migrate = ["sqlx-core/migrate"] + +# Authentication features +winauth = ["tiberius/winauth"] +integrated-auth-gssapi = ["tiberius/integrated-auth-gssapi"] + +# Type Integration features +bigdecimal = ["dep:bigdecimal", "sqlx-core/bigdecimal", "tiberius/bigdecimal"] +chrono = ["dep:chrono", "sqlx-core/chrono", "tiberius/chrono"] +rust_decimal = ["dep:rust_decimal", "sqlx-core/rust_decimal", "tiberius/rust_decimal"] +time = ["dep:time", "sqlx-core/time", "tiberius/time"] +uuid = ["dep:uuid", "sqlx-core/uuid"] + +[dependencies] +sqlx-core = { workspace = true } + +# TDS protocol driver +tiberius = { version = "0.12", default-features = false, features = ["tds73"] } + +# Futures crates +futures-core = { version = "0.3.19", default-features = false } +futures-io = "0.3.24" +futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } + +# Runtime bridging +tokio = { workspace = true, optional = true } +tokio-util = { version = "0.7", features = ["compat"], optional = true } +async-std = { workspace = true, optional = true } + +# Type Integrations (versions inherited from `[workspace.dependencies]`) +bigdecimal = { workspace = true, optional = true } +chrono = { workspace = true, optional = true } +rust_decimal = { workspace = true, optional = true } +time = { workspace = true, optional = true } +uuid = { workspace = true, optional = true } + +# Misc +bytes = "1.1.0" +either = "1.6.1" +log = "0.4.18" +tracing = { version = "0.1.37", features = ["log"] } +percent-encoding = "2.1.0" + +dotenvy.workspace = true +thiserror.workspace = true + +serde = { version = "1.0.144", optional = true } + +[dev-dependencies] +# FIXME: https://github.com/rust-lang/cargo/issues/15622 +sqlx = { path = "..", default-features = false, features = ["mssql"] } + +[lints] +workspace = true diff --git a/sqlx-mssql/MSSQL_SUPPORT.md b/sqlx-mssql/MSSQL_SUPPORT.md new file mode 100644 index 0000000000..e58c1cc66f --- /dev/null +++ b/sqlx-mssql/MSSQL_SUPPORT.md @@ -0,0 +1,1189 @@ +# MSSQL (SQL Server) Support for SQLx + +A complete developer guide for using SQLx with Microsoft SQL Server, built on the [Tiberius](https://github.com/prisma/tiberius) TDS driver. + +--- + +## Table of Contents + +- [Overview](#overview) +- [Getting Started](#getting-started) +- [Feature Flags](#feature-flags) +- [Connection & Authentication](#connection--authentication) +- [Connection Pooling](#connection-pooling) +- [SSL/TLS](#ssltls) +- [Type Mappings](#type-mappings) +- [Querying](#querying) +- [Compile-Time Query Macros](#compile-time-query-macros) +- [FromRow & Derive Macros](#fromrow--derive-macros) +- [QueryBuilder](#querybuilder) +- [Transactions & Isolation Levels](#transactions--isolation-levels) +- [Migrations](#migrations) +- [Advisory Locks](#advisory-locks) +- [Bulk Insert](#bulk-insert) +- [XML Type](#xml-type) +- [Error Handling](#error-handling) +- [Any Driver Support](#any-driver-support) +- [Examples](#examples) +- [Docker & CI](#docker--ci) +- [Test Coverage](#test-coverage) + +--- + +## Overview + +Full SQL Server support has been added to SQLx, bringing feature parity with PostgreSQL, MySQL, and SQLite where applicable. The implementation provides: + +- Complete type system mapping between Rust and SQL Server types +- Four authentication methods (SQL Server, Windows/NTLM, Integrated/GSSAPI, Azure AD) +- SSL/TLS with configurable modes +- Compile-time checked queries via macros +- Connection pooling with callbacks +- Runtime-polymorphic `Any` driver support +- Database migrations with `sqlx migrate` +- RAII advisory locks via `sp_getapplock`/`sp_releaseapplock` +- Bulk insert via the TDS `INSERT BULK` protocol +- Transaction isolation levels including `SNAPSHOT` +- Nested transactions via savepoints +- Testing infrastructure with Docker Compose (MSSQL 2019 & 2022) + +**URL schemes:** `mssql://` and `sqlserver://` + +--- + +## Getting Started + +### Add SQLx to Your Project + +SQLx requires three choices in your feature flags: + +1. **Database driver** — `mssql` +2. **Async runtime** — one of `runtime-tokio` or `runtime-async-std` +3. **TLS backend** — one of `tls-native-tls`, `tls-rustls-aws-lc-rs`, `tls-rustls-ring`, or `tls-none` + +```toml +[dependencies] +sqlx = { version = "0.9", features = [ + "mssql", # SQL Server driver + "runtime-tokio", # async runtime (or runtime-async-std) + "tls-native-tls", # TLS backend (see Feature Flags for options) +] } +tokio = { version = "1", features = ["full"] } +``` + +> **Tip:** If you're unsure which TLS backend to pick, `tls-native-tls` is the safest default for SQL Server — it uses the platform's native TLS stack (SChannel on Windows, OpenSSL on Linux) and has the best compatibility with SQL Server's TLS implementation. + +### Minimal Example + +```rust +use sqlx::mssql::MssqlPool; + +#[tokio::main] +async fn main() -> Result<(), sqlx::Error> { + let pool = MssqlPool::connect("mssql://sa:YourStrong!Passw0rd@localhost/master").await?; + + let row: (i32,) = sqlx::query_as("SELECT @p1") + .bind(42i32) + .fetch_one(&pool) + .await?; + + println!("Got: {}", row.0); + Ok(()) +} +``` + +--- + +## Feature Flags + +### Required + +| Feature | Description | +|---------|-------------| +| `mssql` | Enable the MSSQL driver | + +### Async Runtime (pick one) + +| Feature | Description | +|---------|-------------| +| `runtime-tokio` | Use Tokio | +| `runtime-async-std` | Use async-std (via async-global-executor / smol) | + +### TLS Backend (pick one) + +| Feature | Description | +|---------|-------------| +| `tls-native-tls` | Platform-native TLS (recommended for SQL Server) | +| `tls-rustls-aws-lc-rs` | Rustls with AWS LC crypto | +| `tls-rustls-ring` | Rustls with ring crypto | +| `tls-none` | No TLS support | + +### Type Integrations + +| Feature | Description | +|---------|-------------| +| `json` | JSON type support via `serde_json` (stored as `NVARCHAR`) | +| `uuid` | `uuid::Uuid` ↔ `UNIQUEIDENTIFIER` | +| `chrono` | `chrono` datetime types | +| `time` | `time` crate datetime types | +| `rust_decimal` | `rust_decimal::Decimal` ↔ `DECIMAL`/`NUMERIC`/`MONEY` | +| `bigdecimal` | `bigdecimal::BigDecimal` ↔ `DECIMAL`/`NUMERIC`/`MONEY` | + +### Authentication + +| Feature | Description | +|---------|-------------| +| `winauth` | Windows/NTLM authentication | +| `integrated-auth-gssapi` | Integrated auth (Kerberos on Unix, SSPI on Windows) | + +### Functionality + +| Feature | Description | +|---------|-------------| +| `any` | Runtime-polymorphic `Any` driver | +| `migrate` | Database migrations | +| `offline` | Offline mode for compile-time macros (no live database needed in CI) | + +### Recommended Starter Set + +For most applications: + +```toml +sqlx = { version = "0.9", features = [ + "mssql", + "runtime-tokio", + "tls-native-tls", + "migrate", + "json", + "chrono", # or "time" + "uuid", + "rust_decimal", +] } +``` + +--- + +## Connection & Authentication + +### Connection String Format + +``` +mssql://[user[:password]@]host[:port][/database][?properties] +``` + +### Connection Options + +| Option | Default | Description | +|--------|---------|-------------| +| `host` | `localhost` | Database server hostname | +| `port` | `1433` | Port number | +| `username` | `sa` | Username | +| `password` | — | Password | +| `database` | — | Database name | +| `instance` | — | SQL Server named instance | +| `app_name` | `sqlx` | Application name sent to server | +| `statement-cache-capacity` | `100` | Max cached prepared statements | +| `application_intent` | `read_write` | `read_write` or `read_only` (Always On replicas) | + +### Programmatic Configuration + +Use `MssqlConnectOptions` for full control over connection settings: + +```rust +use sqlx::mssql::MssqlConnectOptions; + +let opts = MssqlConnectOptions::new() + .host("db.example.com") + .port(1433) + .username("app_user") + .password("s3cret") + .database("myapp") + .app_name("my-service") + .statement_cache_capacity(200) + .application_intent_read_only(false); + +let pool = MssqlPool::connect_with(opts).await?; +``` + +### URL-Based Configuration + +```rust +use sqlx::mssql::MssqlPool; + +let pool = MssqlPool::connect( + "mssql://app_user:s3cret@db.example.com:1433/myapp?app_name=my-service" +).await?; +``` + +Both approaches are equivalent. Use `MssqlConnectOptions` when you need to build connection parameters dynamically (e.g., from environment variables or a config file). + +### Authentication Methods + +**1. SQL Server Auth (default)** + +Standard username/password authentication. + +```rust +let pool = MssqlPool::connect("mssql://sa:password@localhost/mydb").await?; +``` + +**2. Windows/NTLM Auth** (feature: `winauth`) + +Supports `domain\user` syntax. + +```rust +let opts = MssqlConnectOptions::new() + .host("localhost") + .windows_auth(true); +``` + +**3. Integrated Auth / GSSAPI** (feature: `integrated-auth-gssapi`) + +Uses SSPI on Windows and Kerberos on Unix. + +```rust +let opts = MssqlConnectOptions::new() + .host("localhost") + .integrated_auth(true); +``` + +**4. Azure AD Token Auth** + +Pass a bearer token for Azure Active Directory authentication. This takes precedence over all other auth methods. + +```rust +let opts = MssqlConnectOptions::new() + .host("your-server.database.windows.net") + .aad_token("eyJ0eX..."); +``` + +--- + +## Connection Pooling + +For production applications, always use a connection pool rather than individual connections. + +### Basic Pool + +```rust +use sqlx::mssql::MssqlPool; + +// Simple — uses default pool settings +let pool = MssqlPool::connect("mssql://sa:password@localhost/mydb").await?; +``` + +### Configuring the Pool + +```rust +use sqlx::mssql::{MssqlPool, MssqlPoolOptions}; +use std::time::Duration; + +let pool = MssqlPoolOptions::new() + .max_connections(20) + .min_connections(5) + .acquire_timeout(Duration::from_secs(10)) + .idle_timeout(Duration::from_secs(600)) + .max_lifetime(Duration::from_secs(1800)) + .test_before_acquire(true) + .connect("mssql://sa:password@localhost/mydb") + .await?; +``` + +### Pool Configuration Reference + +| Option | Default | Description | +|--------|---------|-------------| +| `max_connections` | `10` | Maximum number of connections in the pool | +| `min_connections` | `0` | Minimum idle connections maintained (best-effort) | +| `acquire_timeout` | `30s` | Max time to wait for a connection (includes all phases) | +| `idle_timeout` | `10min` | Close connections idle longer than this | +| `max_lifetime` | `30min` | Close connections older than this | +| `test_before_acquire` | `true` | Ping idle connections before returning them | +| `acquire_slow_threshold` | `2s` | Log a warning for acquires slower than this | + +### Eager vs Lazy Connection + +```rust +// connect() — opens at least one connection immediately, fails fast on bad credentials +let pool = MssqlPoolOptions::new() + .connect("mssql://sa:password@localhost/mydb") + .await?; + +// connect_lazy() — no connections opened until first use +// Useful in tests or when the database may not be available at startup +let pool = MssqlPoolOptions::new() + .connect_lazy("mssql://sa:password@localhost/mydb")?; +``` + +### Pool Callbacks + +Callbacks let you run logic at key points in a connection's lifecycle: + +```rust +let pool = MssqlPoolOptions::new() + .max_connections(10) + // Called after a new connection is established + .after_connect(|conn, _metadata| { + Box::pin(async move { + // e.g., SET session options + sqlx::query("SET ANSI_NULLS ON") + .execute(&mut *conn) + .await?; + Ok(()) + }) + }) + // Called before returning an idle connection from the pool + .before_acquire(|conn, _metadata| { + Box::pin(async move { + // Return Ok(true) to use this connection + // Return Ok(false) to close it and try another + Ok(true) + }) + }) + // Called when a connection is returned to the pool + .after_release(|conn, _metadata| { + Box::pin(async move { + // Return Ok(true) to keep in the pool + // Return Ok(false) to close it + Ok(true) + }) + }) + .connect("mssql://sa:password@localhost/mydb") + .await?; +``` + +Each callback receives a `PoolConnectionMetadata` with: +- `age` — time since the connection was first opened +- `idle_for` — time the connection has been idle (only meaningful in `before_acquire`) + +### Production Tuning Tips + +- Set `max_connections` based on your workload and SQL Server's `max worker threads` setting. A good starting point is 2× the number of CPU cores. +- Set `min_connections` to keep a warm pool and avoid cold-start latency. +- Keep `max_lifetime` at 30 minutes or less to cycle connections and pick up DNS changes. +- Use `after_connect` to set session-level options (e.g., `SET ANSI_NULLS ON`). +- Use `test_before_acquire(true)` (the default) in production. Disable only if latency is critical and you handle stale connections at the application level. + +--- + +## SSL/TLS + +Configurable encryption modes for the TDS connection. + +| Mode | Description | +|------|-------------| +| `Disabled` | No encryption | +| `LoginOnly` | Encrypt login packet only | +| `Preferred` (default) | Encrypt if server supports it | +| `Required` | Always encrypt, fail otherwise | + +**Connection string parameters:** + +| Parameter | Description | +|-----------|-------------| +| `sslmode` / `ssl_mode` | `disabled`, `login_only`, `preferred`, `required` | +| `encrypt` | Legacy alias: `true` = required, `false` = disabled | +| `trust_server_certificate` | Trust without validation (default: `false`) | +| `trust_server_certificate_ca` | Path to CA certificate file (`.pem`, `.crt`, `.der`) | + +> **Note:** `trust_server_certificate` and `trust_server_certificate_ca` are mutually exclusive. If both are set, the CA path takes precedence. + +``` +mssql://sa:password@localhost/mydb?sslmode=required&trust_server_certificate=true +``` + +**Programmatic configuration:** + +```rust +use sqlx::mssql::{MssqlConnectOptions, MssqlSslMode}; + +let opts = MssqlConnectOptions::new() + .host("db.example.com") + .ssl_mode(MssqlSslMode::Required) + .trust_server_certificate(false) + .trust_server_certificate_ca("/path/to/ca.pem"); +``` + +--- + +## Type Mappings + +### Primitive Types + +| Rust Type | SQL Server Type(s) | Notes | +|-----------|-------------------|-------| +| `bool` | `BIT` | | +| `u8` | `TINYINT` | Unsigned, full range 0–255 | +| `i8` | `TINYINT` | **Only 0–127** (SQL Server TINYINT is unsigned; values 128–255 don't fit in `i8`) | +| `i16` | `SMALLINT` | | +| `i32` | `INT` | | +| `i64` | `BIGINT` | | +| `f32` | `REAL`, `FLOAT` | | +| `f64` | `REAL`, `FLOAT`, `MONEY`, `SMALLMONEY` | | +| `&str` / `String` | `NVARCHAR` | | +| `&[u8]` / `Vec` | `VARBINARY` | | + +### Feature-Gated Types + +#### `uuid` + +| Rust Type | SQL Server Type | +|-----------|----------------| +| `uuid::Uuid` | `UNIQUEIDENTIFIER` | + +#### `rust_decimal` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `rust_decimal::Decimal` | `DECIMAL`, `NUMERIC`, `MONEY`, `SMALLMONEY` | + +#### `bigdecimal` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `bigdecimal::BigDecimal` | `DECIMAL`, `NUMERIC`, `MONEY` | + +#### `chrono` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `chrono::NaiveDate` | `DATE` | +| `chrono::NaiveTime` | `TIME` | +| `chrono::NaiveDateTime` | `DATETIME2`, `DATETIME`, `SMALLDATETIME` | +| `chrono::DateTime` | `DATETIME2`, `DATETIMEOFFSET` | +| `chrono::DateTime` | `DATETIMEOFFSET`, `DATETIME2` | + +#### `time` + +| Rust Type | SQL Server Type(s) | +|-----------|-------------------| +| `time::Date` | `DATE` | +| `time::Time` | `TIME` | +| `time::PrimitiveDateTime` | `DATETIME2`, `DATETIME`, `SMALLDATETIME` | +| `time::OffsetDateTime` | `DATETIMEOFFSET`, `DATETIME2` | + +#### `json` + +| Rust Type | SQL Server Type | +|-----------|----------------| +| `serde_json::Value` / `Json` | `NVARCHAR` | + +> **Note:** SQL Server has no native JSON column type. JSON is stored as `NVARCHAR` text. You can still use SQL Server's built-in JSON functions (`JSON_VALUE`, `OPENJSON`, etc.) in your queries. + +#### XML + +| Rust Type | SQL Server Type | +|-----------|----------------| +| `MssqlXml` | `XML` | + +### Nullable Types + +All types above support `Option` for nullable columns. + +### Runtime Type Inspection + +Use `MssqlTypeInfo` to inspect column types at runtime: + +```rust +use sqlx::TypeInfo; + +let statement = conn.prepare("SELECT id, name FROM users".into_sql_str()).await?; +assert_eq!(statement.column(0).type_info().name(), "BIGINT"); +assert_eq!(statement.column(1).type_info().name(), "NVARCHAR"); +``` + +--- + +## Querying + +MSSQL uses `@p1`, `@p2`, `@p3`, ... as parameter placeholders (not `$1` or `?`). + +### Basic Queries + +```rust +use sqlx::Row; + +// Execute a statement (INSERT, UPDATE, DELETE) +let result = sqlx::query("UPDATE users SET active = 1 WHERE id = @p1") + .bind(42i32) + .execute(&pool) + .await?; +println!("Rows affected: {}", result.rows_affected()); + +// Fetch a single row +let row = sqlx::query("SELECT id, name FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +let name: String = row.get("name"); + +// Fetch all rows +let rows = sqlx::query("SELECT id, name FROM users") + .fetch_all(&pool) + .await?; + +// Fetch optional (returns None if no rows) +let maybe_row = sqlx::query("SELECT id FROM users WHERE email = @p1") + .bind("alice@example.com") + .fetch_optional(&pool) + .await?; +``` + +### Typed Queries with `query_as` + +```rust +let user: (i32, String) = sqlx::query_as("SELECT id, name FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; + +// Or with a named struct (see FromRow section) +let user: User = sqlx::query_as("SELECT id, name FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +``` + +### Scalar Queries + +```rust +let count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM users") + .fetch_one(&pool) + .await?; +``` + +### Streaming with `fetch` + +For large result sets, use `fetch` to stream rows without loading them all into memory: + +```rust +use futures::TryStreamExt; + +let mut stream = sqlx::query("SELECT id, name FROM users") + .fetch(&pool); + +while let Some(row) = stream.try_next().await? { + let id: i32 = row.get("id"); + // process row... +} +``` + +### Row Access + +```rust +use sqlx::Row; + +let row = sqlx::query("SELECT id, name FROM users") + .fetch_one(&pool) + .await?; + +// By column name +let name: String = row.get("name"); + +// By column index (0-based) +let id: i32 = row.get(0); + +// Fallible access (returns Result) +let name: String = row.try_get("name")?; +``` + +### Custom Row Mapping + +```rust +let value = sqlx::query("SELECT 1 + @p1") + .bind(5_i32) + .try_map(|row: MssqlRow| row.try_get::(0)) + .fetch_one(&pool) + .await?; +``` + +### OUTPUT INSERTED (MSSQL's RETURNING) + +SQL Server does not support the `RETURNING` clause. Use `OUTPUT INSERTED` instead to get values from inserted/updated rows: + +```rust +// Get the auto-generated ID after INSERT +let id: i64 = sqlx::query_scalar( + "INSERT INTO users (name) OUTPUT INSERTED.id VALUES (@p1)" +) + .bind("Alice") + .fetch_one(&pool) + .await?; + +// Get multiple columns +let row = sqlx::query( + "INSERT INTO users (name, email) OUTPUT INSERTED.id, INSERTED.created_at VALUES (@p1, @p2)" +) + .bind("Alice") + .bind("alice@example.com") + .fetch_one(&pool) + .await?; +``` + +### Calling Stored Procedures + +Use `EXEC` to call stored procedures: + +```rust +let rows = sqlx::query("EXEC GetUsersByRole @p1") + .bind("admin") + .fetch_all(&pool) + .await?; + +// With output parameters, use a query that captures results +let result: (i32,) = sqlx::query_as("EXEC CountUsers") + .fetch_one(&pool) + .await?; +``` + +--- + +## Compile-Time Query Macros + +The standard SQLx macros work with MSSQL when `DATABASE_URL` is set to an `mssql://` connection string: + +```rust +// Compile-time checked query +let row = sqlx::query!("SELECT @p1 AS value", 42i32) + .fetch_one(&pool) + .await?; + +// With custom return type +#[derive(sqlx::FromRow)] +struct User { + id: i32, + name: String, +} + +let user = sqlx::query_as!(User, "SELECT id, name FROM users WHERE id = @p1", 1i32) + .fetch_one(&pool) + .await?; + +// Scalar queries +let count = sqlx::query_scalar!("SELECT COUNT(*) FROM users") + .fetch_one(&pool) + .await?; +``` + +### Offline Mode + +For CI builds without a live database, use offline mode: + +```bash +# Generate query metadata (run with DATABASE_URL set) +cargo sqlx prepare + +# This creates a .sqlx/ directory with cached query metadata. +# Commit this directory to version control. +``` + +Then build without a database: + +```bash +SQLX_OFFLINE=true cargo build +``` + +Enable the `offline` feature flag to use this capability. + +--- + +## FromRow & Derive Macros + +### Basic FromRow + +Map query results directly to a struct: + +```rust +#[derive(sqlx::FromRow)] +struct User { + id: i32, + name: String, + email: Option, +} + +let user: User = sqlx::query_as("SELECT id, name, email FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +``` + +### Enum Types with `#[derive(Type)]` + +**Integer-repr enums** map to SQL integer columns: + +```rust +#[derive(sqlx::Type, Debug, PartialEq)] +#[repr(i32)] +enum Status { + Active = 1, + Inactive = 0, + Banned = -1, +} + +// Works with INT columns +let status: Status = sqlx::query_scalar("SELECT status FROM users WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +``` + +**Transparent wrappers** create newtypes over existing SQL types: + +```rust +#[derive(sqlx::Type, Debug, PartialEq)] +#[sqlx(transparent)] +struct UserId(i64); + +let id: UserId = sqlx::query_scalar("SELECT id FROM users WHERE id = @p1") + .bind(1i64) + .fetch_one(&pool) + .await?; +``` + +### Combining FromRow and Type + +```rust +#[derive(sqlx::Type, Debug, PartialEq)] +#[repr(i16)] +enum Priority { + Low = 0, + Medium = 1, + High = 2, +} + +#[derive(sqlx::FromRow, Debug)] +struct Task { + id: i32, + title: String, + priority: Priority, +} + +let task: Task = sqlx::query_as("SELECT id, title, priority FROM tasks WHERE id = @p1") + .bind(1i32) + .fetch_one(&pool) + .await?; +``` + +--- + +## QueryBuilder + +`QueryBuilder` generates MSSQL-style parameter placeholders (`@p1`, `@p2`, ...) automatically: + +```rust +use sqlx::QueryBuilder; +use sqlx::mssql::Mssql; + +let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users WHERE "); +qb.push("name = ").push_bind("Alice"); +qb.push(" AND age > ").push_bind(21i32); +// Produces: SELECT * FROM users WHERE name = @p1 AND age > @p2 + +let users = qb.build_query_as::() + .fetch_all(&pool) + .await?; +``` + +### Dynamic WHERE Clauses + +```rust +let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users WHERE 1=1"); + +if let Some(name) = filter_name { + qb.push(" AND name = ").push_bind(name); +} +if let Some(min_age) = filter_min_age { + qb.push(" AND age >= ").push_bind(min_age); +} + +let results = qb.build_query_as::().fetch_all(&pool).await?; +``` + +### Reset and Rebuild + +```rust +let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); +let query = qb.build(); +// ... use query ... + +// Reset to build a new query with the same builder +qb.reset(); +qb.push("SELECT COUNT(*) FROM users"); +let count_query = qb.build(); +``` + +--- + +## Transactions & Isolation Levels + +### Basic Transactions + +```rust +let mut tx = pool.begin().await?; + +sqlx::query("INSERT INTO users (name) VALUES (@p1)") + .bind("Alice") + .execute(&mut *tx) + .await?; + +sqlx::query("INSERT INTO audit_log (action) VALUES (@p1)") + .bind("user_created") + .execute(&mut *tx) + .await?; + +tx.commit().await?; +// Or: tx.rollback().await?; +``` + +### Nested Transactions (Savepoints) + +Calling `begin()` on an existing transaction creates a savepoint: + +```rust +let mut tx = pool.begin().await?; + +sqlx::query("INSERT INTO users (id, name) VALUES (@p1, @p2)") + .bind(1i32) + .bind("Alice") + .execute(&mut *tx) + .await?; + +// Nested transaction — creates a savepoint +let mut savepoint = tx.begin().await?; + +sqlx::query("INSERT INTO users (id, name) VALUES (@p1, @p2)") + .bind(2i32) + .bind("Bob") + .execute(&mut *savepoint) + .await?; + +// Roll back only the inner transaction +savepoint.rollback().await?; +// Bob's insert is undone, but Alice's remains + +tx.commit().await?; +// Alice is committed, Bob is not +``` + +### Isolation Levels + +| Level | Description | +|-------|-------------| +| `ReadUncommitted` | Dirty reads allowed | +| `ReadCommitted` | Default SQL Server isolation | +| `RepeatableRead` | Prevents non-repeatable reads | +| `Snapshot` | Row versioning-based isolation | +| `Serializable` | Strictest isolation | + +> **Important:** `begin_with_isolation` is a method on `MssqlConnection`, not on `Pool`. You must acquire a connection first: + +```rust +use sqlx::mssql::MssqlIsolationLevel; + +let mut conn = pool.acquire().await?; +let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Snapshot) + .await?; + +sqlx::query("SELECT * FROM accounts WHERE id = @p1") + .bind(1i32) + .fetch_one(&mut *tx) + .await?; + +tx.commit().await?; +``` + +> **Note:** `Snapshot` isolation requires the database to have `ALLOW_SNAPSHOT_ISOLATION` enabled: +> ```sql +> ALTER DATABASE [mydb] SET ALLOW_SNAPSHOT_ISOLATION ON; +> ``` + +--- + +## Migrations + +MSSQL supports the full `sqlx migrate` workflow. + +```bash +# Create a new migration +sqlx migrate add create_users_table + +# Run pending migrations +sqlx migrate run + +# Revert the last migration +sqlx migrate revert +``` + +**Programmatic usage:** + +```rust +sqlx::migrate!("./migrations") + .run(&pool) + .await?; +``` + +**Database lifecycle:** + +- `create_database(url)` — Creates a database via `CREATE DATABASE [name]` +- `database_exists(url)` — Checks existence via `DB_ID()` +- `drop_database(url)` — Drops with `ALTER DATABASE SET SINGLE_USER WITH ROLLBACK IMMEDIATE` for cleanup + +**No-transaction migrations** are supported for DDL operations that cannot run inside a transaction. + +Migration files use standard SQL Server syntax. Use bracket-quoted identifiers (`[schema].[table]`) for schema-qualified objects. + +--- + +## Advisory Locks + +Application-level named locks using SQL Server's `sp_getapplock` and `sp_releaseapplock`, with an RAII guard pattern. + +### Lock Modes + +| Mode | Compatible With | +|------|----------------| +| `Shared` | Shared, Update | +| `Update` | Shared only | +| `Exclusive` (default) | None | + +### Usage + +```rust +use sqlx::mssql::{MssqlAdvisoryLock, MssqlAdvisoryLockMode}; + +// Create an exclusive lock +let lock = MssqlAdvisoryLock::new("my_resource"); + +// Or with a specific mode +let lock = MssqlAdvisoryLock::with_mode("my_resource", MssqlAdvisoryLockMode::Shared); + +// RAII guard — acquire and release +let guard = lock.acquire_guard(&mut conn).await?; +// ... do work while lock is held ... +let conn = guard.release_now().await?; // explicit release + +// Non-blocking attempt +match lock.try_acquire_guard(&mut conn).await? { + either::Either::Left(guard) => { + // Lock acquired + let conn = guard.release_now().await?; + } + either::Either::Right(conn) => { + // Lock not available + } +} + +// Manual acquire/release (without guard) +lock.acquire(&mut conn).await?; +// ... do work ... +lock.release(&mut conn).await?; +``` + +> **Warning:** Unlike PostgreSQL advisory locks, MSSQL advisory lock guards do **NOT** auto-release on drop. If you drop the guard without calling `release_now()` or `leak()`, a warning is logged and the lock remains held until the connection is closed or returned to the pool. Always call `release_now()` explicitly. + +--- + +## Bulk Insert + +High-performance data loading via the TDS `INSERT BULK` protocol. The target table must already exist. + +```rust +use sqlx::mssql::IntoRow; + +let mut bulk = conn.bulk_insert("my_table").await?; + +bulk.send(("Alice", 30_i32).into_row()).await?; +bulk.send(("Bob", 25_i32).into_row()).await?; +bulk.send(("Carol", 28_i32).into_row()).await?; + +let rows_affected = bulk.finalize().await?; +assert_eq!(rows_affected, 3); +``` + +> **Important:** You **must** call `finalize()` to flush buffered data. If the `MssqlBulkInsert` is dropped without calling `finalize()`, buffered rows are lost. + +Tuple elements map to table columns in order. Tuples up to **10 elements** are supported via `tiberius::IntoRow`. + +--- + +## XML Type + +A dedicated `MssqlXml` wrapper type distinguishes XML columns from regular strings. + +```rust +use sqlx::mssql::MssqlXml; + +let xml = MssqlXml::from("hello".to_string()); + +sqlx::query("INSERT INTO docs (content) VALUES (@p1)") + .bind(&xml) + .execute(&pool) + .await?; + +let result: MssqlXml = sqlx::query_scalar("SELECT content FROM docs") + .fetch_one(&pool) + .await?; +``` + +--- + +## Error Handling + +### Error Types + +All SQLx operations return `sqlx::Error`. For database-specific errors, downcast to `MssqlDatabaseError`: + +```rust +use sqlx::error::ErrorKind; + +let result = sqlx::query("INSERT INTO users (id, name) VALUES (@p1, @p2)") + .bind(1i32) + .bind("Alice") + .execute(&pool) + .await; + +match result { + Ok(r) => println!("Inserted {} rows", r.rows_affected()), + Err(sqlx::Error::Database(db_err)) => { + // Classify the error + match db_err.kind() { + ErrorKind::UniqueViolation => { + println!("Duplicate key: {}", db_err.message()); + } + ErrorKind::ForeignKeyViolation => { + println!("Foreign key constraint failed"); + } + ErrorKind::NotNullViolation => { + println!("Required field is null"); + } + ErrorKind::CheckViolation => { + println!("Check constraint failed"); + } + _ => { + println!("Database error: {}", db_err.message()); + } + } + } + Err(e) => println!("Other error: {}", e), +} +``` + +### MssqlDatabaseError Fields + +When you need SQL Server-specific error details, downcast further: + +```rust +use sqlx::mssql::MssqlDatabaseError; + +if let sqlx::Error::Database(db_err) = &err { + if let Some(mssql_err) = db_err.try_downcast_ref::() { + println!("Error number: {}", mssql_err.number()); // SQL Server error number + println!("State: {}", mssql_err.state()); // Error state + println!("Class: {}", mssql_err.class()); // Severity class + println!("Message: {}", mssql_err.message()); // Error message + println!("Server: {:?}", mssql_err.server()); // Server name (Option) + println!("Procedure: {:?}", mssql_err.procedure()); // Stored procedure name (Option) + } +} +``` + +### ErrorKind Mapping + +| SQL Server Error Number | ErrorKind | +|------------------------|-----------| +| 2601, 2627 | `UniqueViolation` | +| 547 | `ForeignKeyViolation` | +| 515 | `NotNullViolation` | +| 2628 | `CheckViolation` | +| All others | `Other` | + +### Connection Recovery + +Connections remain usable after query errors: + +```rust +// This query fails +let result = sqlx::query("SELECT * FROM nonexistent_table") + .execute(&mut conn) + .await; +assert!(result.is_err()); + +// Connection is still valid +let val: (i32,) = sqlx::query_as("SELECT 42") + .fetch_one(&mut conn) + .await?; +``` + +--- + +## Any Driver Support + +MSSQL is fully integrated with the `Any` runtime-polymorphic driver, enabled via the `any` feature flag. + +```rust +use sqlx::any::AnyPool; + +// Connects to whichever database the URL points to +let pool = AnyPool::connect("mssql://sa:password@localhost/mydb").await?; + +let rows = sqlx::query("SELECT 1 + 1 AS result") + .fetch_all(&pool) + .await?; +``` + +All standard operations work through `Any`: queries, transactions, ping, close, and prepared statements. + +--- + +## Examples + +A full CRUD Todo application is available at `examples/mssql/todos/`, demonstrating connection pooling, migrations, query execution, and error handling. + +--- + +## Docker & CI + +### Docker Compose + +The test suite includes Docker Compose configurations for MSSQL 2019 and 2022: + +```bash +docker compose -f tests/docker-compose.yml up mssql_2022 -d +``` + +**Services:** + +| Service | Image | Port | +|---------|-------|------| +| `mssql_2022` | `mcr.microsoft.com/mssql/server:2022-latest` | 1433 | +| `mssql_2019` | `mcr.microsoft.com/mssql/server:2019-latest` | 1433 | + +### CI Matrix + +The GitHub Actions workflow tests across: + +- **MSSQL versions:** 2019, 2022 +- **Async runtimes:** tokio, async-global-executor, smol +- **TLS backends:** native-tls, rustls-aws-lc-rs, rustls-ring, none + +--- + +## Test Coverage + +Comprehensive test suite in `tests/mssql/`: + +| Area | File | What's Tested | +|------|------|---------------| +| Core queries | `mssql.rs` | Connections, SELECT, INSERT, parameters, large result sets, error handling | +| Type round-trips | `types.rs` | All primitive and feature-gated types with boundary values, NULLs, Unicode, large data | +| Test attribute | `test-attr.rs` | `#[sqlx_macros::test]` macro with automatic test DB setup | +| Isolation levels | `isolation-level.rs` | All five isolation level configurations | +| Advisory locks | `advisory-lock.rs` | Acquire, release, guard pattern, all lock modes | +| Bulk insert | `bulk-insert.rs` | High-performance loading, multi-row operations | +| Derives | `derives.rs` | `#[derive(FromRow)]`, custom field mappings | +| Query builder | `query_builder.rs` | Dynamic query construction, parameter handling | +| Error handling | `error.rs` | Database error inspection, error details | +| Compile-time macros | `macros.rs` | Online and offline macro verification | +| Describe | `describe.rs` | `sp_describe` column metadata and type inference | +| Migrations | `migrate.rs` | Migration lifecycle: create, run, revert | diff --git a/sqlx-mssql/issues/mssql-sp-return-value.md b/sqlx-mssql/issues/mssql-sp-return-value.md new file mode 100644 index 0000000000..3392247b8e --- /dev/null +++ b/sqlx-mssql/issues/mssql-sp-return-value.md @@ -0,0 +1,47 @@ +# MSSQL: surface stored procedure return values through the executor + +## Context + +`sp_getapplock` communicates success/failure through its **return value**, not through SQL errors: + +| Return code | Meaning | +|---|---| +| `0` | Lock granted immediately | +| `1` | Lock granted after waiting | +| `-1` | Timed out | +| `-2` | Cancelled | +| `-3` | Deadlock victim | +| `-999` | Parameter validation error | + +## Current workaround + +We wrap the call in a `DECLARE @r / IF @r < 0 THROW` pattern so that a failed lock becomes a SQL error that `execute` can catch: + +```sql +DECLARE @r INT; +EXEC @r = sp_getapplock @Resource = 'sqlx_migrations', + @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1; +IF @r < 0 THROW 50000, 'Failed to acquire migration lock', 1; +``` + +This is sufficient for production use — the lock works correctly in all realistic scenarios, and failures are now surfaced as errors instead of being silently ignored. + +## Ideal long-term solution + +The proper fix is for the MSSQL executor to capture the TDS `RETURNSTATUS` token that SQL Server sends after stored procedure execution, and expose it through the driver's result types. + +### What would need to change + +1. **`collect_results` in `executor.rs`** — currently only handles `QueryItem::Metadata` and `QueryItem::Row`. The TDS return status token is not surfaced by tiberius's `QueryStream`. Investigate whether tiberius exposes this via `ExecuteResult` (from `.execute()`) or if it requires upstream changes. + +2. **`MssqlQueryResult`** — currently only holds `rows_affected: u64`. Would need an additional field like `return_status: Option` to carry the stored procedure return value. + +3. **`Migrate::lock` trait** — the signature is `Result<(), MigrateError>`, which is fine (we'd just check the return status and map negatives to `Err`). No trait changes needed. + +### Why this matters beyond migrations + +Any user calling stored procedures via `execute` today cannot inspect return values. This is a general limitation of the MSSQL driver, not specific to migrations. The `THROW` workaround only works when you control the SQL — it doesn't help when calling third-party procedures that use return codes for flow control. + +## Priority + +**Low** — the THROW workaround fully covers the migration lock case, and stored procedure return values are a niche use case. This is a correctness/completeness improvement, not a bug fix. diff --git a/sqlx-mssql/src/advisory_lock.rs b/sqlx-mssql/src/advisory_lock.rs new file mode 100644 index 0000000000..1899565f07 --- /dev/null +++ b/sqlx-mssql/src/advisory_lock.rs @@ -0,0 +1,366 @@ +use std::ops::{Deref, DerefMut}; + +use crate::error::Error; +use crate::query_scalar::query_scalar; +use crate::Either; +use crate::MssqlConnection; + +/// The lock mode for a MSSQL advisory lock. +/// +/// Maps to the `@LockMode` parameter of `sp_getapplock`. +#[derive(Debug, Clone, Copy, Default)] +pub enum MssqlAdvisoryLockMode { + /// A shared lock, compatible with other `Shared` and `Update` locks. + Shared, + + /// An update lock, compatible with `Shared` but not with other `Update` or `Exclusive`. + Update, + + /// An exclusive lock, incompatible with all other lock modes. + #[default] + Exclusive, +} + +impl MssqlAdvisoryLockMode { + fn as_str(&self) -> &'static str { + match self { + MssqlAdvisoryLockMode::Shared => "Shared", + MssqlAdvisoryLockMode::Update => "Update", + MssqlAdvisoryLockMode::Exclusive => "Exclusive", + } + } +} + +/// A session-scoped advisory lock backed by SQL Server's `sp_getapplock` / +/// `sp_releaseapplock`. +/// +/// Advisory locks are cooperative: they don't block access to any database +/// object; instead, all participants must explicitly acquire the same named +/// lock. The lock is scoped to the database session (connection). +/// +/// # RAII Guard +/// +/// Use [`acquire_guard`][Self::acquire_guard] or +/// [`try_acquire_guard`][Self::try_acquire_guard] to get an +/// [`MssqlAdvisoryLockGuard`] that provides access to the underlying connection +/// and can release the lock via [`release_now()`][MssqlAdvisoryLockGuard::release_now]. +/// +/// Unlike PostgreSQL, MSSQL connections cannot queue commands for deferred +/// execution, so the lock **cannot** be released automatically on drop. +/// If the guard is dropped without calling `release_now()` or `leak()`, a +/// warning is logged. The lock will still be released when the connection +/// is closed or returned to the pool. +/// +/// For manual lock management without a guard, use [`acquire`][Self::acquire], +/// [`try_acquire`][Self::try_acquire], and [`release`][Self::release]. +/// +/// # Resource Name +/// +/// SQL Server limits resource names to 255 characters. The name is passed as a +/// query parameter, so SQL injection is not possible. +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { +/// use sqlx::mssql::MssqlAdvisoryLock; +/// +/// let lock = MssqlAdvisoryLock::new("my_app_lock"); +/// +/// // Using the RAII guard (preferred): +/// let guard = lock.acquire_guard(&mut *conn).await?; +/// // ... do work under the lock, using `&mut *guard` as a connection ... +/// guard.release_now().await?; +/// +/// // Or manual management: +/// lock.acquire(&mut *conn).await?; +/// // ... do work ... +/// lock.release(conn).await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct MssqlAdvisoryLock { + resource: String, + mode: MssqlAdvisoryLockMode, +} + +/// A wrapper for a connection that represents a held MSSQL advisory lock. +/// +/// Can be acquired by [`MssqlAdvisoryLock::acquire_guard()`] or +/// [`MssqlAdvisoryLock::try_acquire_guard()`]. +/// +/// ### Note: Release is NOT automatic on drop! +/// +/// Unlike PostgreSQL, MSSQL connections cannot queue commands for deferred +/// execution. If this guard is dropped without calling +/// [`release_now()`][Self::release_now], a warning is logged and the lock +/// remains held until the connection is closed or returned to the pool. +/// +/// Always prefer calling `.release_now().await` when you are done with the lock. +pub struct MssqlAdvisoryLockGuard> { + lock: MssqlAdvisoryLock, + conn: Option, +} + +impl MssqlAdvisoryLock { + /// Create a new advisory lock with the given resource name and the default + /// [`Exclusive`][MssqlAdvisoryLockMode::Exclusive] mode. + pub fn new(resource: impl Into) -> Self { + Self { + resource: resource.into(), + mode: MssqlAdvisoryLockMode::default(), + } + } + + /// Create a new advisory lock with the given resource name and lock mode. + pub fn with_mode(resource: impl Into, mode: MssqlAdvisoryLockMode) -> Self { + Self { + resource: resource.into(), + mode, + } + } + + /// Returns the resource name of this lock. + pub fn resource(&self) -> &str { + &self.resource + } + + /// Returns the lock mode. + pub fn mode(&self) -> &MssqlAdvisoryLockMode { + &self.mode + } + + /// Acquire the lock, waiting indefinitely until it is available. + /// + /// # Errors + /// + /// Returns an error if `sp_getapplock` returns a negative status code + /// (e.g. lock request was cancelled or a deadlock was detected). + pub async fn acquire(&self, conn: &mut MssqlConnection) -> Result<(), Error> { + let status: i32 = query_scalar( + "DECLARE @r INT; \ + EXEC @r = sp_getapplock @Resource = @p1, @LockMode = @p2, \ + @LockOwner = 'Session', @LockTimeout = -1; \ + SELECT @r;", + ) + .bind(&self.resource) + .bind(self.mode.as_str()) + .fetch_one(&mut *conn) + .await?; + + if status < 0 { + return Err(Error::Protocol(format!( + "sp_getapplock failed for resource '{}': status {status}{}", + self.resource, + applock_error_message(status), + ))); + } + + Ok(()) + } + + /// Try to acquire the lock without waiting. + /// + /// Returns `Ok(true)` if the lock was acquired, `Ok(false)` if it was not + /// available (timeout). + pub async fn try_acquire(&self, conn: &mut MssqlConnection) -> Result { + let status: i32 = query_scalar( + "DECLARE @r INT; \ + EXEC @r = sp_getapplock @Resource = @p1, @LockMode = @p2, \ + @LockOwner = 'Session', @LockTimeout = 0; \ + SELECT @r;", + ) + .bind(&self.resource) + .bind(self.mode.as_str()) + .fetch_one(&mut *conn) + .await?; + + if status >= 0 { + // 0 = granted synchronously, 1 = granted after wait + Ok(true) + } else if status == -1 { + // -1 = timed out + Ok(false) + } else { + Err(Error::Protocol(format!( + "sp_getapplock failed for resource '{}': status {status}{}", + self.resource, + applock_error_message(status), + ))) + } + } + + /// Release the lock. + /// + /// Returns `Ok(true)` if the lock was successfully released, `Ok(false)` + /// if the lock was not held by this session. + pub async fn release(&self, conn: &mut MssqlConnection) -> Result { + let sql = "DECLARE @r INT; \ + EXEC @r = sp_releaseapplock @Resource = @p1, @LockOwner = 'Session'; \ + SELECT @r;"; + + let status: i32 = query_scalar(sql) + .bind(&self.resource) + .fetch_one(&mut *conn) + .await?; + + match status { + 0 => Ok(true), + -999 => Ok(false), + _ => Err(Error::Protocol(format!( + "sp_releaseapplock failed for resource '{}': status {status}", + self.resource, + ))), + } + } + + /// Acquire the lock and return an RAII guard that provides access to the + /// underlying connection. + /// + /// The guard does **not** release the lock on drop (see + /// [`MssqlAdvisoryLockGuard`] for details). Call + /// [`release_now()`][MssqlAdvisoryLockGuard::release_now] to release the + /// lock and recover the connection. + /// + /// A connection-like type is required to execute the call. Allowed types + /// include `MssqlConnection`, `PoolConnection`, and mutable + /// references to either. + pub async fn acquire_guard>( + &self, + mut conn: C, + ) -> Result, Error> { + self.acquire(conn.as_mut()).await?; + Ok(MssqlAdvisoryLockGuard::new(self.clone(), conn)) + } + + /// Try to acquire the lock without waiting, returning an RAII guard on + /// success. + /// + /// Returns `Ok(Left(guard))` if the lock was acquired, or + /// `Ok(Right(conn))` if it was not available. + pub async fn try_acquire_guard>( + &self, + mut conn: C, + ) -> Result, C>, Error> { + if self.try_acquire(conn.as_mut()).await? { + Ok(Either::Left(MssqlAdvisoryLockGuard::new( + self.clone(), + conn, + ))) + } else { + Ok(Either::Right(conn)) + } + } + + /// Execute `sp_releaseapplock` for this lock's resource on the given + /// connection. + /// + /// This is provided for manually releasing the lock from connections + /// returned by [`MssqlAdvisoryLockGuard::leak()`]. + /// + /// Returns `Ok((conn, true))` if released, `Ok((conn, false))` if the lock + /// was not held. + pub async fn force_release>( + &self, + mut conn: C, + ) -> Result<(C, bool), Error> { + let released = self.release(conn.as_mut()).await?; + Ok((conn, released)) + } +} + +const NONE_ERR: &str = "BUG: MssqlAdvisoryLockGuard.conn taken"; + +impl> MssqlAdvisoryLockGuard { + fn new(lock: MssqlAdvisoryLock, conn: C) -> Self { + MssqlAdvisoryLockGuard { + lock, + conn: Some(conn), + } + } + + /// Release the advisory lock immediately and return the connection. + /// + /// This is the preferred way to release the lock. An error should only be + /// returned if there is something wrong with the connection, in which case + /// the lock will be automatically released when the connection is closed. + pub async fn release_now(mut self) -> Result { + let (conn, released) = self + .lock + .force_release(self.conn.take().expect(NONE_ERR)) + .await?; + + if !released { + tracing::warn!( + resource = %self.lock.resource(), + "MssqlAdvisoryLockGuard: advisory lock was not held by the contained connection", + ); + } + + Ok(conn) + } + + /// Cancel the release of the advisory lock, keeping it held until the + /// connection is closed. + /// + /// To manually release the lock later, see + /// [`MssqlAdvisoryLock::force_release()`]. + pub fn leak(mut self) -> C { + self.conn.take().expect(NONE_ERR) + } +} + +impl + AsRef> Deref for MssqlAdvisoryLockGuard { + type Target = MssqlConnection; + + fn deref(&self) -> &Self::Target { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +impl + AsRef> DerefMut for MssqlAdvisoryLockGuard { + fn deref_mut(&mut self) -> &mut Self::Target { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} + +impl> AsRef for MssqlAdvisoryLockGuard +where + C: AsRef, +{ + fn as_ref(&self) -> &MssqlConnection { + self.conn.as_ref().expect(NONE_ERR).as_ref() + } +} + +impl> AsMut for MssqlAdvisoryLockGuard { + fn as_mut(&mut self) -> &mut MssqlConnection { + self.conn.as_mut().expect(NONE_ERR).as_mut() + } +} + +/// Logs a warning if dropped without calling `release_now()` or `leak()`. +/// +/// The lock remains held until the connection is closed or returned to the pool. +impl> Drop for MssqlAdvisoryLockGuard { + fn drop(&mut self) { + if self.conn.is_some() { + tracing::warn!( + resource = %self.lock.resource(), + "MssqlAdvisoryLockGuard dropped without calling release_now() or leak(). \ + The lock will be released when the connection is closed.", + ); + } + } +} + +fn applock_error_message(status: i32) -> &'static str { + match status { + -1 => " (timed out)", + -2 => " (lock request cancelled)", + -3 => " (deadlock victim)", + -999 => " (parameter validation or other call error)", + _ => "", + } +} diff --git a/sqlx-mssql/src/any.rs b/sqlx-mssql/src/any.rs new file mode 100644 index 0000000000..218bac33fc --- /dev/null +++ b/sqlx-mssql/src/any.rs @@ -0,0 +1,232 @@ +use crate::{ + Mssql, MssqlColumn, MssqlConnectOptions, MssqlConnection, MssqlQueryResult, MssqlRow, + MssqlTransactionManager, MssqlTypeInfo, +}; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::{stream, FutureExt, StreamExt, TryStreamExt}; +use sqlx_core::any::{ + AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, + AnyStatement, AnyTypeInfo, AnyTypeInfoKind, +}; +use sqlx_core::connection::Connection; +use sqlx_core::database::Database; +use sqlx_core::executor::Executor; +use sqlx_core::sql_str::SqlStr; +use sqlx_core::transaction::TransactionManager; +use std::future; + +sqlx_core::declare_driver_with_optional_migrate!(DRIVER = Mssql); + +impl AnyConnectionBackend for MssqlConnection { + fn name(&self) -> &str { + ::NAME + } + + fn close(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { + Connection::close(*self).boxed() + } + + fn close_hard(self: Box) -> BoxFuture<'static, sqlx_core::Result<()>> { + Connection::close_hard(*self).boxed() + } + + fn ping(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + Connection::ping(self).boxed() + } + + fn begin(&mut self, statement: Option) -> BoxFuture<'_, sqlx_core::Result<()>> { + MssqlTransactionManager::begin(self, statement).boxed() + } + + fn commit(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + MssqlTransactionManager::commit(self).boxed() + } + + fn rollback(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + MssqlTransactionManager::rollback(self).boxed() + } + + fn start_rollback(&mut self) { + MssqlTransactionManager::start_rollback(self) + } + + fn get_transaction_depth(&self) -> usize { + MssqlTransactionManager::get_transaction_depth(self) + } + + fn shrink_buffers(&mut self) { + Connection::shrink_buffers(self); + } + + fn flush(&mut self) -> BoxFuture<'_, sqlx_core::Result<()>> { + Connection::flush(self).boxed() + } + + fn should_flush(&self) -> bool { + Connection::should_flush(self) + } + + #[cfg(feature = "migrate")] + fn as_migrate( + &mut self, + ) -> sqlx_core::Result<&mut (dyn sqlx_core::migrate::Migrate + Send + 'static)> { + Ok(self) + } + + fn fetch_many( + &mut self, + query: SqlStr, + // MSSQL always sends parameterized queries via tiberius (no server-side + // prepared statement caching), so the persistent flag has no effect. + _persistent: bool, + arguments: Option, + ) -> BoxStream<'_, sqlx_core::Result>> { + let arguments = match arguments.map(AnyArguments::convert_into).transpose() { + Ok(arguments) => arguments, + Err(error) => { + return stream::once(future::ready(Err(sqlx_core::Error::Encode(error)))).boxed() + } + }; + + Box::pin( + stream::once(async move { + let results = self.run(query.as_str(), arguments).await?; + Ok::<_, sqlx_core::Error>(results) + }) + .map_ok(|results| { + futures_util::stream::iter(results.into_iter().map(|res| { + Ok(match res { + Either::Left(result) => Either::Left(map_result(result)), + Either::Right(row) => Either::Right(AnyRow::try_from(&row)?), + }) + })) + }) + .try_flatten(), + ) + } + + fn fetch_optional( + &mut self, + query: SqlStr, + // See fetch_many: MSSQL has no server-side prepared statement caching. + _persistent: bool, + arguments: Option, + ) -> BoxFuture<'_, sqlx_core::Result>> { + let arguments = arguments + .map(AnyArguments::convert_into) + .transpose() + .map_err(sqlx_core::Error::Encode); + + Box::pin(async move { + let arguments = arguments?; + let results = self.run(query.as_str(), arguments).await?; + + for result in results { + if let Either::Right(row) = result { + return Ok(Some(AnyRow::try_from(&row)?)); + } + } + + Ok(None) + }) + } + + fn prepare_with<'c, 'q: 'c>( + &'c mut self, + sql: SqlStr, + _parameters: &[AnyTypeInfo], + ) -> BoxFuture<'c, sqlx_core::Result> { + Box::pin(async move { + let statement = Executor::prepare_with(self, sql, &[]).await?; + let column_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) + }) + } + + #[cfg(feature = "offline")] + fn describe( + &mut self, + sql: SqlStr, + ) -> BoxFuture<'_, sqlx_core::Result>> { + Box::pin(async move { + let describe = Executor::describe(self, sql).await?; + describe.try_into_any() + }) + } +} + +impl<'a> TryFrom<&'a MssqlTypeInfo> for AnyTypeInfo { + type Error = sqlx_core::Error; + + fn try_from(type_info: &'a MssqlTypeInfo) -> Result { + Ok(AnyTypeInfo { + kind: match type_info.base_name() { + "TINYINT" => AnyTypeInfoKind::SmallInt, + "SMALLINT" => AnyTypeInfoKind::SmallInt, + "INT" => AnyTypeInfoKind::Integer, + "BIGINT" => AnyTypeInfoKind::BigInt, + "REAL" => AnyTypeInfoKind::Real, + "FLOAT" => AnyTypeInfoKind::Double, + "VARBINARY" | "BINARY" | "IMAGE" => AnyTypeInfoKind::Blob, + "NULL" => AnyTypeInfoKind::Null, + "BIT" => AnyTypeInfoKind::Bool, + "MONEY" => AnyTypeInfoKind::Double, + "SMALLMONEY" => AnyTypeInfoKind::Real, + "DECIMAL" | "NUMERIC" => AnyTypeInfoKind::Text, + "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" => { + AnyTypeInfoKind::Text + } + "UNIQUEIDENTIFIER" => AnyTypeInfoKind::Text, + "DATE" | "TIME" | "DATETIME" | "DATETIME2" | "SMALLDATETIME" | "DATETIMEOFFSET" => { + AnyTypeInfoKind::Text + } + _ => { + return Err(sqlx_core::Error::AnyDriverError( + format!("Any driver does not support MSSQL type {type_info:?}").into(), + )) + } + }, + }) + } +} + +impl<'a> TryFrom<&'a MssqlColumn> for AnyColumn { + type Error = sqlx_core::Error; + + fn try_from(column: &'a MssqlColumn) -> Result { + let type_info = AnyTypeInfo::try_from(&column.type_info)?; + + Ok(AnyColumn { + ordinal: column.ordinal, + name: column.name.clone(), + type_info, + }) + } +} + +impl<'a> TryFrom<&'a MssqlRow> for AnyRow { + type Error = sqlx_core::Error; + + fn try_from(row: &'a MssqlRow) -> Result { + AnyRow::map_from(row, row.column_names.clone()) + } +} + +impl<'a> TryFrom<&'a AnyConnectOptions> for MssqlConnectOptions { + type Error = sqlx_core::Error; + + fn try_from(any_opts: &'a AnyConnectOptions) -> Result { + let mut opts = Self::parse_from_url(&any_opts.database_url)?; + opts.log_settings = any_opts.log_settings.clone(); + Ok(opts) + } +} + +fn map_result(result: MssqlQueryResult) -> AnyQueryResult { + AnyQueryResult { + rows_affected: result.rows_affected, + last_insert_id: None, + } +} diff --git a/sqlx-mssql/src/arguments.rs b/sqlx-mssql/src/arguments.rs new file mode 100644 index 0000000000..9b71e8e20a --- /dev/null +++ b/sqlx-mssql/src/arguments.rs @@ -0,0 +1,59 @@ +use std::fmt::{self, Write}; + +use crate::database::MssqlArgumentValue; +use crate::encode::Encode; +use crate::types::Type; +use crate::Mssql; +pub(crate) use sqlx_core::arguments::*; +use sqlx_core::error::BoxDynError; + +/// Implementation of [`Arguments`] for MSSQL. +#[derive(Debug, Default, Clone)] +pub struct MssqlArguments { + pub(crate) values: Vec, +} + +impl MssqlArguments { + pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'q, Mssql> + Type, + { + let is_null = value.encode(&mut self.values)?; + if is_null.is_null() { + // If the encoder signaled null but didn't push a value, push a Null + if self + .values + .last() + .is_none_or(|v| !matches!(v, MssqlArgumentValue::Null)) + { + self.values.push(MssqlArgumentValue::Null); + } + } + Ok(()) + } +} + +impl Arguments for MssqlArguments { + type Database = Mssql; + + fn reserve(&mut self, len: usize, _size: usize) { + self.values.reserve(len); + } + + fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError> + where + T: Encode<'t, Self::Database> + Type, + { + self.add(value) + } + + fn len(&self) -> usize { + self.values.len() + } + + fn format_placeholder(&self, writer: &mut W) -> fmt::Result { + // MSSQL uses @p1, @p2, ... for parameterized queries. + // This is called after the bind is added, so len() is the correct 1-based index. + write!(writer, "@p{}", self.values.len()) + } +} diff --git a/sqlx-mssql/src/bulk_insert.rs b/sqlx-mssql/src/bulk_insert.rs new file mode 100644 index 0000000000..dac9a56943 --- /dev/null +++ b/sqlx-mssql/src/bulk_insert.rs @@ -0,0 +1,51 @@ +use crate::error::{tiberius_err, Error}; +use crate::io::SocketAdapter; +use sqlx_core::net::Socket; + +/// A bulk insert operation for high-performance data loading into SQL Server. +/// +/// Wraps the tiberius [`BulkLoadRequest`](tiberius::BulkLoadRequest) to provide +/// efficient bulk data insertion using the TDS `INSERT BULK` protocol. +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { +/// use sqlx::mssql::IntoRow; +/// +/// let mut bulk = conn.bulk_insert("#my_temp_table").await?; +/// bulk.send(("hello", 42i32).into_row()).await?; +/// bulk.send(("world", 99i32).into_row()).await?; +/// let total = bulk.finalize().await?; +/// assert_eq!(total, 2); +/// # Ok(()) +/// # } +/// ``` +pub struct MssqlBulkInsert<'c> { + inner: tiberius::BulkLoadRequest<'c, SocketAdapter>>, +} + +impl<'c> MssqlBulkInsert<'c> { + pub(crate) fn new( + inner: tiberius::BulkLoadRequest<'c, SocketAdapter>>, + ) -> Self { + Self { inner } + } + + /// Send a single row to the bulk insert operation. + /// + /// The row is a [`tiberius::TokenRow`] — use [`tiberius::IntoRow::into_row()`] + /// to convert tuples of up to 10 elements into a `TokenRow`. + pub async fn send(&mut self, row: tiberius::TokenRow<'c>) -> Result<(), Error> { + self.inner.send(row).await.map_err(tiberius_err) + } + + /// Finalize the bulk insert, flushing all buffered data to the server. + /// + /// Returns the total number of rows inserted. This **must** be called + /// after all rows have been sent — otherwise data will be lost. + pub async fn finalize(self) -> Result { + let result = self.inner.finalize().await.map_err(tiberius_err)?; + Ok(result.total()) + } +} diff --git a/sqlx-mssql/src/column.rs b/sqlx-mssql/src/column.rs new file mode 100644 index 0000000000..e721a78b6c --- /dev/null +++ b/sqlx-mssql/src/column.rs @@ -0,0 +1,32 @@ +use crate::ext::ustr::UStr; +use crate::{Mssql, MssqlTypeInfo}; +pub(crate) use sqlx_core::column::*; + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct MssqlColumn { + pub(crate) ordinal: usize, + pub(crate) name: UStr, + pub(crate) type_info: MssqlTypeInfo, + pub(crate) origin: ColumnOrigin, +} + +impl Column for MssqlColumn { + type Database = Mssql; + + fn ordinal(&self) -> usize { + self.ordinal + } + + fn name(&self) -> &str { + &self.name + } + + fn type_info(&self) -> &MssqlTypeInfo { + &self.type_info + } + + fn origin(&self) -> ColumnOrigin { + self.origin.clone() + } +} diff --git a/sqlx-mssql/src/connection/establish.rs b/sqlx-mssql/src/connection/establish.rs new file mode 100644 index 0000000000..5f1191c98b --- /dev/null +++ b/sqlx-mssql/src/connection/establish.rs @@ -0,0 +1,44 @@ +use crate::common::StatementCache; +use crate::connection::MssqlConnectionInner; +use crate::error::{tiberius_err, Error}; +use crate::io::SocketAdapter; +use crate::{MssqlConnectOptions, MssqlConnection}; +use sqlx_core::net::{Socket, WithSocket}; + +impl MssqlConnection { + pub(crate) async fn establish(options: &MssqlConnectOptions) -> Result { + let config = options.to_tiberius_config(); + let log_settings = options.log_settings.clone(); + let cache_capacity = options.statement_cache_capacity; + + let handler = EstablishHandler { config }; + + crate::net::connect_tcp(&options.host, options.port, handler) + .await? + .map(|client| MssqlConnection { + inner: Box::new(MssqlConnectionInner { + client, + transaction_depth: 0, + pending_rollback: false, + log_settings, + cache_statement: StatementCache::new(cache_capacity), + }), + }) + } +} + +struct EstablishHandler { + config: tiberius::Config, +} + +impl WithSocket for EstablishHandler { + type Output = Result>>, Error>; + + async fn with_socket(self, socket: S) -> Self::Output { + let boxed: Box = Box::new(socket); + let adapter = SocketAdapter::new(boxed); + tiberius::Client::connect(self.config, adapter) + .await + .map_err(tiberius_err) + } +} diff --git a/sqlx-mssql/src/connection/executor.rs b/sqlx-mssql/src/connection/executor.rs new file mode 100644 index 0000000000..254086f4da --- /dev/null +++ b/sqlx-mssql/src/connection/executor.rs @@ -0,0 +1,752 @@ +use crate::database::MssqlArgumentValue; +use crate::error::{tiberius_err, Error}; +use crate::executor::{Execute, Executor}; +use crate::ext::ustr::UStr; +use crate::logger::QueryLogger; +use crate::statement::{MssqlStatement, MssqlStatementMetadata}; +use crate::type_info::{type_name_for_tiberius, MssqlTypeInfo}; +use crate::value::{column_data_to_mssql_data, MssqlData}; +use crate::HashMap; +use crate::{Mssql, MssqlArguments, MssqlColumn, MssqlConnection, MssqlQueryResult, MssqlRow}; +use either::Either; +use futures_core::future::BoxFuture; +use futures_core::stream::BoxStream; +use futures_util::TryStreamExt; +use sqlx_core::column::{ColumnOrigin, TableColumn}; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr as _, SqlStr}; +use std::sync::Arc; + +/// Newtype wrapper to bridge `tiberius::ColumnData` into `tiberius::IntoSql`. +/// +/// tiberius implements `ToSql` but not `IntoSql` for some types (e.g. `time` +/// crate types, and `BigDecimal` due to version mismatch). `Query::bind()` +/// requires `IntoSql`, so this wrapper lets us construct `ColumnData` manually +/// and pass it to `bind()`. +#[cfg(any(feature = "chrono", feature = "time", feature = "bigdecimal"))] +struct ColumnDataWrapper<'a>(tiberius::ColumnData<'a>); + +#[cfg(any(feature = "chrono", feature = "time", feature = "bigdecimal"))] +impl<'a> tiberius::IntoSql<'a> for ColumnDataWrapper<'a> { + fn into_sql(self) -> tiberius::ColumnData<'a> { + self.0 + } +} + +/// Maximum days-since-epoch (0001-01-01) that fits in the 3-byte TDS date +/// encoding. `tiberius::time::Date::new()` panics if `days > 0x00FFFFFF`. +#[cfg(any(feature = "chrono", feature = "time"))] +const MAX_DAYS: u32 = 0x00FF_FFFF; + +/// Convert a signed days-since-epoch count to `u32`, returning +/// `Error::Encode` if negative or exceeding the TDS 3-byte limit. +#[cfg(any(feature = "chrono", feature = "time"))] +fn days_since_epoch_to_u32(days: i64) -> Result { + u32::try_from(days) + .ok() + .filter(|&d| d <= MAX_DAYS) + .ok_or_else(|| { + Error::Encode( + format!( + "date out of range for SQL Server: {days} days since epoch \ + (must be 0..={MAX_DAYS})" + ) + .into(), + ) + }) +} + +/// Convert a signed offset-in-minutes to `i16`, returning +/// `Error::Encode` if outside the SQL Server range (-840..=840). +#[cfg(any(feature = "chrono", feature = "time"))] +fn offset_minutes_to_i16(offset_minutes: i32) -> Result { + const MIN_OFFSET: i32 = -840; + const MAX_OFFSET: i32 = 840; + if (MIN_OFFSET..=MAX_OFFSET).contains(&offset_minutes) { + // SAFETY: range check above guarantees -840..=840, which fits in i16. + #[allow(clippy::cast_possible_truncation)] + Ok(offset_minutes as i16) + } else { + Err(Error::Encode( + format!( + "timezone offset out of range for SQL Server: {offset_minutes} minutes \ + (must be {MIN_OFFSET}..={MAX_OFFSET})" + ) + .into(), + )) + } +} + +/// Convert a `BigDecimal` into the `(i128, u8)` pair that +/// `tiberius::numeric::Numeric::new_with_scale` expects. +/// +/// Handles two edge cases: +/// - **Negative exponents** (e.g. `BigDecimal(9, -3)` = 9000): rescales to +/// exponent 0 so SQL Server receives the correct magnitude. +/// - **Scale > 37**: SQL Server NUMERIC max scale is 37, and tiberius +/// asserts `scale < 38`. Returns `Error::Encode` instead of panicking. +#[cfg(feature = "bigdecimal")] +fn bigdecimal_to_numeric(v: &bigdecimal::BigDecimal) -> Result<(i128, u8), Error> { + use bigdecimal::ToPrimitive; + + let (bigint, exponent) = v.as_bigint_and_exponent(); + let (bigint, exponent) = if exponent < 0 { + v.with_scale(0).into_bigint_and_exponent() + } else { + (bigint, exponent) + }; + + if exponent > 37 { + return Err(Error::Encode( + format!("BigDecimal scale {exponent} exceeds SQL Server maximum of 37").into(), + )); + } + // SAFETY: guarded by `exponent > 37` check above; 0..=37 fits in u8. + #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] + let scale = exponent as u8; + + let value: i128 = bigint.to_i128().ok_or_else(|| { + Error::Encode(format!("BigDecimal value too large for SQL NUMERIC: {v}").into()) + })?; + + Ok((value, scale)) +} + +impl MssqlConnection { + /// Execute a query, eagerly collecting all results. + /// + /// We collect eagerly because `tiberius::QueryStream` borrows `&mut Client`, + /// which prevents us from holding it across yield points alongside `&mut self`. + pub(crate) async fn run( + &mut self, + sql: &str, + arguments: Option, + ) -> Result>, Error> { + // Resolve any pending rollback first + crate::transaction::resolve_pending_rollback(self).await?; + + let mut logger = QueryLogger::new( + AssertSqlSafe(sql).into_sql_str(), + self.inner.log_settings.clone(), + ); + + let mut results = Vec::new(); + + if let Some(args) = arguments { + // Parameterized query using tiberius::Query + let mut query = tiberius::Query::new(sql); + + for arg in &args.values { + match arg { + MssqlArgumentValue::Null => { + query.bind(Option::<&str>::None); + } + MssqlArgumentValue::Bool(v) => { + query.bind(*v); + } + MssqlArgumentValue::U8(v) => { + query.bind(*v); + } + MssqlArgumentValue::I16(v) => { + query.bind(*v); + } + MssqlArgumentValue::I32(v) => { + query.bind(*v); + } + MssqlArgumentValue::I64(v) => { + query.bind(*v); + } + MssqlArgumentValue::F32(v) => { + query.bind(*v); + } + MssqlArgumentValue::F64(v) => { + query.bind(*v); + } + MssqlArgumentValue::String(v) => { + query.bind(v.as_str()); + } + MssqlArgumentValue::Binary(v) => { + query.bind(v.as_slice()); + } + #[cfg(feature = "chrono")] + MssqlArgumentValue::NaiveDateTime(v) => { + query.bind(*v); + } + #[cfg(feature = "chrono")] + MssqlArgumentValue::NaiveDate(v) => { + query.bind(*v); + } + #[cfg(feature = "chrono")] + MssqlArgumentValue::NaiveTime(v) => { + query.bind(*v); + } + #[cfg(feature = "chrono")] + MssqlArgumentValue::DateTimeFixedOffset(v) => { + use chrono::Timelike as _; + let epoch = chrono::NaiveDate::from_ymd_opt(1, 1, 1) + .expect("epoch 0001-01-01 is always valid"); + let naive = v.naive_local(); + let days = days_since_epoch_to_u32((naive.date() - epoch).num_days())?; + let time = naive.time(); + let total_ns = u64::from(time.num_seconds_from_midnight()) * 1_000_000_000 + + (u64::from(time.nanosecond()) % 1_000_000_000); + let increments = total_ns / 100; + let offset_minutes = v.offset().local_minus_utc() / 60; + let dt2 = tiberius::time::DateTime2::new( + tiberius::time::Date::new(days), + tiberius::time::Time::new(increments, 7), + ); + let cd = tiberius::ColumnData::DateTimeOffset(Some( + tiberius::time::DateTimeOffset::new( + dt2, + offset_minutes_to_i16(offset_minutes)?, + ), + )); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "uuid")] + MssqlArgumentValue::Uuid(v) => { + query.bind(v); + } + #[cfg(feature = "rust_decimal")] + MssqlArgumentValue::Decimal(v) => { + let unpacked = v.unpack(); + // SAFETY: rust_decimal mantissa is ≤96 bits (hi:mid:lo are u32s), fits in i128. + #[allow(clippy::cast_possible_wrap)] + let mut value = (((unpacked.hi as u128) << 64) + + ((unpacked.mid as u128) << 32) + + unpacked.lo as u128) as i128; + if v.is_sign_negative() { + value = -value; + } + let scale = v.scale(); + if scale > 37 { + return Err(Error::Encode( + format!( + "rust_decimal scale {scale} exceeds SQL Server maximum of 37" + ) + .into(), + )); + } + // SAFETY: guarded by `scale > 37` check above; 0..=37 fits in u8. + #[allow(clippy::cast_possible_truncation)] + let scale_u8 = scale as u8; + query.bind(tiberius::numeric::Numeric::new_with_scale(value, scale_u8)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimeDate(v) => { + let epoch = time::Date::from_ordinal_date(1, 1) + .expect("epoch 0001-01-01 is always valid"); + let days = days_since_epoch_to_u32((*v - epoch).whole_days())?; + let cd = tiberius::ColumnData::Date(Some(tiberius::time::Date::new(days))); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimeTime(v) => { + let (h, m, s, ns) = v.as_hms_nano(); + let total_ns = u64::from(h) * 3_600_000_000_000 + + u64::from(m) * 60_000_000_000 + + u64::from(s) * 1_000_000_000 + + u64::from(ns); + // Scale 7 = 100ns increments + let increments = total_ns / 100; + let cd = tiberius::ColumnData::Time(Some(tiberius::time::Time::new( + increments, 7, + ))); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimePrimitiveDateTime(v) => { + let date = v.date(); + let time = v.time(); + let epoch = time::Date::from_ordinal_date(1, 1) + .expect("epoch 0001-01-01 is always valid"); + let days = days_since_epoch_to_u32((date - epoch).whole_days())?; + let (h, m, s, ns) = time.as_hms_nano(); + let total_ns = u64::from(h) * 3_600_000_000_000 + + u64::from(m) * 60_000_000_000 + + u64::from(s) * 1_000_000_000 + + u64::from(ns); + let increments = total_ns / 100; + let cd = + tiberius::ColumnData::DateTime2(Some(tiberius::time::DateTime2::new( + tiberius::time::Date::new(days), + tiberius::time::Time::new(increments, 7), + ))); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "time")] + MssqlArgumentValue::TimeOffsetDateTime(v) => { + let epoch = time::Date::from_ordinal_date(1, 1) + .expect("epoch 0001-01-01 is always valid"); + let offset_minutes = v.offset().whole_seconds() / 60; + let date = v.date(); + let time = v.time(); + let days = days_since_epoch_to_u32((date - epoch).whole_days())?; + let (h, m, s, ns) = time.as_hms_nano(); + let total_ns = u64::from(h) * 3_600_000_000_000 + + u64::from(m) * 60_000_000_000 + + u64::from(s) * 1_000_000_000 + + u64::from(ns); + let increments = total_ns / 100; + let dt2 = tiberius::time::DateTime2::new( + tiberius::time::Date::new(days), + tiberius::time::Time::new(increments, 7), + ); + let cd = tiberius::ColumnData::DateTimeOffset(Some( + tiberius::time::DateTimeOffset::new( + dt2, + offset_minutes_to_i16(offset_minutes)?, + ), + )); + query.bind(ColumnDataWrapper(cd)); + } + #[cfg(feature = "bigdecimal")] + MssqlArgumentValue::BigDecimal(v) => { + let (value, scale) = bigdecimal_to_numeric(v)?; + let cd = tiberius::ColumnData::Numeric(Some( + tiberius::numeric::Numeric::new_with_scale(value, scale), + )); + query.bind(ColumnDataWrapper(cd)); + } + } + } + + let stream = query + .query(&mut self.inner.client) + .await + .map_err(tiberius_err)?; + collect_results(stream, &mut results, &mut logger).await?; + } else { + // Simple query (no parameters) + let stream = self + .inner + .client + .simple_query(sql) + .await + .map_err(tiberius_err)?; + collect_results(stream, &mut results, &mut logger).await?; + } + + Ok(results) + } +} + +/// Collect all results from a tiberius QueryStream into a Vec. +async fn collect_results( + mut stream: tiberius::QueryStream<'_>, + results: &mut Vec>, + logger: &mut QueryLogger, +) -> Result<(), Error> { + // Process all result sets + let mut columns: Option>> = None; + let mut column_names: Option>> = None; + let mut rows_affected: u64 = 0; + + while let Some(item) = stream.try_next().await.map_err(tiberius_err)? { + match item { + tiberius::QueryItem::Metadata(meta) => { + // Build column info from metadata + let cols: Vec = meta + .columns() + .iter() + .enumerate() + .map(|(ordinal, col)| { + let name = UStr::new(col.name()); + let type_info = + MssqlTypeInfo::new(type_name_for_tiberius(&col.column_type())); + MssqlColumn { + ordinal, + name, + type_info, + origin: ColumnOrigin::Unknown, + } + }) + .collect(); + + let names: HashMap = cols + .iter() + .enumerate() + .map(|(i, col)| (col.name.clone(), i)) + .collect(); + + columns = Some(Arc::new(cols)); + column_names = Some(Arc::new(names)); + } + tiberius::QueryItem::Row(row) => { + let cols = columns + .as_ref() + .ok_or_else(|| Error::Protocol("row received before metadata".into()))?; + let names = column_names + .as_ref() + .ok_or_else(|| Error::Protocol("row received before metadata".into()))?; + + // Convert tiberius row to MssqlRow by iterating over cells + let values: Vec = row + .into_iter() + .map(column_data_to_mssql_data) + .collect::, _>>()?; + + rows_affected += 1; + logger.increment_rows_returned(); + results.push(Either::Right(MssqlRow { + values, + columns: Arc::clone(cols), + column_names: Arc::clone(names), + })); + } + } + } + + // Report query result with total rows + logger.increase_rows_affected(rows_affected); + results.push(Either::Left(MssqlQueryResult { rows_affected })); + + Ok(()) +} + +/// Build column metadata from `sp_describe_first_result_set` result rows. +/// +/// Returns `(columns, column_names, nullable)` where `nullable` contains one +/// `Option` per column (extracted from the `is_nullable` field). +fn build_columns_from_describe_rows( + rows: &[tiberius::Row], +) -> (Vec, HashMap, Vec>) { + let mut columns = Vec::with_capacity(rows.len()); + let mut column_names = HashMap::with_capacity(rows.len()); + let mut nullable = Vec::with_capacity(rows.len()); + + for (ordinal, row) in rows.iter().enumerate() { + let name: &str = row.get("name").unwrap_or(""); + let type_name: &str = row.get("system_type_name").unwrap_or("UNKNOWN"); + let type_info = MssqlTypeInfo::new(type_name.to_uppercase()); + let is_nullable: Option = row.get("is_nullable"); + + let source_table: Option<&str> = row.get("source_table"); + let source_schema: Option<&str> = row.get("source_schema"); + let source_column: Option<&str> = row.get("source_column"); + + let origin = match (source_table, source_column) { + (Some(table), Some(col)) if !table.is_empty() && !col.is_empty() => { + let table_str = match source_schema { + Some(s) if !s.is_empty() => format!("{s}.{table}"), + _ => table.to_string(), + }; + ColumnOrigin::Table(TableColumn { + table: table_str.into(), + name: col.into(), + }) + } + _ => ColumnOrigin::Expression, + }; + + let ustr_name = UStr::new(name); + column_names.insert(ustr_name.clone(), ordinal); + columns.push(MssqlColumn { + ordinal, + name: ustr_name, + type_info, + origin, + }); + nullable.push(is_nullable); + } + + (columns, column_names, nullable) +} + +impl<'c> Executor<'c> for &'c mut MssqlConnection { + type Database = Mssql; + + fn fetch_many<'e, 'q, E>( + self, + mut query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + let arguments = query.take_arguments().map_err(Error::Encode); + // MSSQL always sends parameterized queries via tiberius — there is no + // server-side prepared statement caching like PostgreSQL's, so this + // flag is intentionally unused. + let _persistent = query.persistent(); + let sql = query.sql(); + + Box::pin( + futures_util::stream::once(async move { + let arguments = arguments?; + let results = self.run(sql.as_str(), arguments).await?; + Ok::<_, Error>(results) + }) + .map_ok(|results| futures_util::stream::iter(results.into_iter().map(Ok))) + .try_flatten(), + ) + } + + fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + 'q: 'e, + E: 'q, + { + let mut s = self.fetch_many(query); + + Box::pin(async move { + while let Some(v) = s.try_next().await? { + if let Either::Right(r) = v { + return Ok(Some(r)); + } + } + + Ok(None) + }) + } + + fn prepare_with<'e>( + self, + sql: SqlStr, + _parameters: &'e [MssqlTypeInfo], + ) -> BoxFuture<'e, Result> + where + 'c: 'e, + { + Box::pin(async move { + let mut describe_query = + tiberius::Query::new("EXEC sp_describe_first_result_set @tsql = @p1"); + describe_query.bind(sql.as_str()); + + let stream = describe_query + .query(&mut self.inner.client) + .await + .map_err(tiberius_err)?; + + let rows: Vec = + stream.into_first_result().await.map_err(tiberius_err)?; + let (columns, column_names, _nullable) = build_columns_from_describe_rows(&rows); + + Ok(MssqlStatement { + sql, + metadata: MssqlStatementMetadata { + columns: Arc::new(columns), + column_names: Arc::new(column_names), + parameters: 0, + }, + }) + }) + } + + #[doc(hidden)] + #[cfg(feature = "offline")] + fn describe<'e>( + self, + sql: SqlStr, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + Box::pin(async move { + let mut describe_query = + tiberius::Query::new("EXEC sp_describe_first_result_set @tsql = @p1"); + describe_query.bind(sql.as_str()); + + let stream = describe_query + .query(&mut self.inner.client) + .await + .map_err(tiberius_err)?; + + let rows: Vec = + stream.into_first_result().await.map_err(tiberius_err)?; + + let (columns, _column_names, nullable) = build_columns_from_describe_rows(&rows); + + // Count parameters using sp_describe_undeclared_parameters + let mut param_query = + tiberius::Query::new("EXEC sp_describe_undeclared_parameters @tsql = @p1"); + param_query.bind(sql.as_str()); + let param_count = match param_query.query(&mut self.inner.client).await { + Ok(stream) => stream + .into_first_result() + .await + .map_err(tiberius_err)? + .len(), + Err(e) => { + tracing::debug!("sp_describe_undeclared_parameters failed: {e}"); + 0 + } + }; + + Ok(crate::describe::Describe { + parameters: Some(Either::Right(param_count)), + columns, + nullable, + }) + }) + } +} + +#[cfg(test)] +#[cfg(any(feature = "chrono", feature = "time"))] +mod tests { + use super::*; + + #[test] + fn days_since_epoch_zero() { + assert_eq!(days_since_epoch_to_u32(0).unwrap(), 0); + } + + #[test] + fn days_since_epoch_max_date() { + // 9999-12-31 is 3_652_058 days from 0001-01-01 + assert_eq!(days_since_epoch_to_u32(3_652_058).unwrap(), 3_652_058); + } + + #[test] + fn days_since_epoch_negative() { + let err = days_since_epoch_to_u32(-1).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn days_since_epoch_overflow() { + let err = days_since_epoch_to_u32(i64::from(MAX_DAYS) + 1).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn days_since_epoch_at_max() { + assert_eq!( + days_since_epoch_to_u32(i64::from(MAX_DAYS)).unwrap(), + MAX_DAYS + ); + } + + #[test] + fn offset_minutes_zero() { + assert_eq!(offset_minutes_to_i16(0).unwrap(), 0); + } + + #[test] + fn offset_minutes_positive_max() { + assert_eq!(offset_minutes_to_i16(840).unwrap(), 840); + } + + #[test] + fn offset_minutes_negative_max() { + assert_eq!(offset_minutes_to_i16(-840).unwrap(), -840); + } + + #[test] + fn offset_minutes_out_of_sql_range() { + let err = offset_minutes_to_i16(841).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + let err = offset_minutes_to_i16(-841).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn offset_minutes_i16_overflow() { + let err = offset_minutes_to_i16(i32::MAX).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } +} + +#[cfg(test)] +#[cfg(feature = "bigdecimal")] +mod bigdecimal_tests { + use super::*; + use std::str::FromStr; + + #[test] + fn positive_scale_simple() { + // 123.45 → bigint=12345, exponent=2 → scale=2 + let bd = bigdecimal::BigDecimal::from_str("123.45").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 12345); + assert_eq!(scale, 2); + } + + #[test] + fn zero_scale() { + // 42 → bigint=42, exponent=0 → scale=0 + let bd = bigdecimal::BigDecimal::from_str("42").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 42); + assert_eq!(scale, 0); + } + + #[test] + fn negative_exponent_rescales() { + // Explicitly construct BigDecimal(123, -3) = 123 * 10^3 = 123000. + // This is the internal form that triggers the negative-exponent path. + let bd = bigdecimal::BigDecimal::new(123.into(), -3); + let (bigint_raw, exp_raw) = bd.as_bigint_and_exponent(); + assert_eq!(exp_raw, -3, "precondition: exponent must be negative"); + assert_eq!(bigint_raw, 123.into(), "precondition: raw bigint is 123"); + + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + // After rescaling: 123000 with scale 0 + assert_eq!(value, 123000); + assert_eq!(scale, 0); + } + + #[test] + fn negative_exponent_large_magnitude() { + // 5e10 = 50_000_000_000 → internally (5, -10) + let bd = bigdecimal::BigDecimal::from_str("5e10").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 50_000_000_000); + assert_eq!(scale, 0); + } + + #[test] + fn scale_at_max_37() { + // Scale exactly 37 is the maximum tiberius allows + let bd = bigdecimal::BigDecimal::new(1.into(), 37); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 1); + assert_eq!(scale, 37); + } + + #[test] + fn scale_38_rejected() { + // Scale 38 triggers tiberius assert!(scale < 38); must be rejected + let bd = bigdecimal::BigDecimal::new(1.into(), 38); + let err = bigdecimal_to_numeric(&bd).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn scale_39_rejected() { + let bd = bigdecimal::BigDecimal::new(1.into(), 39); + let err = bigdecimal_to_numeric(&bd).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn scale_256_rejected_not_truncated() { + // The original bug: `as u8` would silently truncate 256 → 0. + // Must return an error, not scale=0. + let bd = bigdecimal::BigDecimal::new(1.into(), 256); + let err = bigdecimal_to_numeric(&bd).unwrap_err(); + assert!(matches!(err, Error::Encode(_))); + } + + #[test] + fn negative_value() { + // -99.9 → bigint=-999, scale=1 + let bd = bigdecimal::BigDecimal::from_str("-99.9").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, -999); + assert_eq!(scale, 1); + } + + #[test] + fn zero_value() { + let bd = bigdecimal::BigDecimal::from_str("0").unwrap(); + let (value, scale) = bigdecimal_to_numeric(&bd).unwrap(); + assert_eq!(value, 0); + assert_eq!(scale, 0); + } +} diff --git a/sqlx-mssql/src/connection/mod.rs b/sqlx-mssql/src/connection/mod.rs new file mode 100644 index 0000000000..d5ea091aa1 --- /dev/null +++ b/sqlx-mssql/src/connection/mod.rs @@ -0,0 +1,176 @@ +use std::fmt::{self, Debug, Formatter}; + +pub(crate) use sqlx_core::connection::*; +use sqlx_core::net::Socket; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; + +use crate::bulk_insert::MssqlBulkInsert; +use crate::common::StatementCache; +use crate::error::{tiberius_err, Error}; +use crate::executor::Executor; +use crate::io::SocketAdapter; +use crate::isolation_level::MssqlIsolationLevel; +use crate::statement::MssqlStatementMetadata; +use crate::transaction::{resolve_pending_rollback, Transaction}; +use crate::{Mssql, MssqlConnectOptions}; + +mod establish; +mod executor; + +/// A connection to a MSSQL database. +pub struct MssqlConnection { + pub(crate) inner: Box, +} + +pub(crate) struct MssqlConnectionInner { + pub(crate) client: tiberius::Client>>, + pub(crate) transaction_depth: usize, + pub(crate) pending_rollback: bool, + pub(crate) log_settings: LogSettings, + pub(crate) cache_statement: StatementCache, +} + +impl Debug for MssqlConnection { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MssqlConnection").finish() + } +} + +impl Connection for MssqlConnection { + type Database = Mssql; + + type Options = MssqlConnectOptions; + + async fn close(self) -> Result<(), Error> { + // tiberius doesn't have an explicit close; dropping the client closes the connection. + drop(self); + Ok(()) + } + + async fn close_hard(self) -> Result<(), Error> { + drop(self); + Ok(()) + } + + async fn ping(&mut self) -> Result<(), Error> { + self.execute("SELECT 1").await?; + Ok(()) + } + + fn flush(&mut self) -> impl std::future::Future> + Send + '_ { + // No-op for MSSQL since tiberius handles buffering internally. + std::future::ready(Ok(())) + } + + fn cached_statements_size(&self) -> usize { + self.inner.cache_statement.len() + } + + async fn clear_cached_statements(&mut self) -> Result<(), Error> { + self.inner.cache_statement.clear(); + Ok(()) + } + + fn should_flush(&self) -> bool { + false + } + + fn begin( + &mut self, + ) -> impl std::future::Future, Error>> + Send + '_ + { + Transaction::begin(self, None) + } + + fn begin_with( + &mut self, + statement: impl SqlSafeStr, + ) -> impl std::future::Future, Error>> + Send + '_ + where + Self: Sized, + { + Transaction::begin(self, Some(statement.into_sql_str())) + } + + fn shrink_buffers(&mut self) { + // No-op for MSSQL + } +} + +// Implement `AsMut` so that `MssqlConnection` can be wrapped in +// a `MssqlAdvisoryLockGuard`. +impl AsMut for MssqlConnection { + fn as_mut(&mut self) -> &mut MssqlConnection { + self + } +} + +impl AsRef for MssqlConnection { + fn as_ref(&self) -> &MssqlConnection { + self + } +} + +impl MssqlConnection { + /// Begin a transaction with a specific isolation level. + /// + /// SQL Server requires `SET TRANSACTION ISOLATION LEVEL` to be issued + /// **before** `BEGIN TRANSACTION`. This method generates: + /// + /// ```sql + /// SET TRANSACTION ISOLATION LEVEL ; BEGIN TRANSACTION + /// ``` + /// + /// # Example + /// + /// ```rust,no_run + /// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { + /// use sqlx::mssql::MssqlIsolationLevel; + /// + /// let mut tx = conn.begin_with_isolation(MssqlIsolationLevel::Snapshot).await?; + /// // ... use tx ... + /// tx.commit().await?; + /// # Ok(()) + /// # } + /// ``` + pub fn begin_with_isolation( + &mut self, + level: MssqlIsolationLevel, + ) -> impl std::future::Future, Error>> + Send + '_ { + let sql = AssertSqlSafe(format!( + "SET TRANSACTION ISOLATION LEVEL {level}; BEGIN TRANSACTION" + )); + Transaction::begin(self, Some(sql.into_sql_str())) + } + + /// Start a bulk insert operation for high-performance data loading. + /// + /// The table must already exist. Tiberius executes `SELECT TOP 0 * FROM ` + /// to discover column metadata, then uses the TDS `INSERT BULK` protocol. + /// + /// # Example + /// + /// ```rust,no_run + /// # async fn example(conn: &mut sqlx::mssql::MssqlConnection) -> sqlx::Result<()> { + /// use sqlx::mssql::IntoRow; + /// + /// let mut bulk = conn.bulk_insert("#temp").await?; + /// bulk.send(("hello", 42i32).into_row()).await?; + /// let total = bulk.finalize().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn bulk_insert<'c>( + &'c mut self, + table: &'c str, + ) -> Result, Error> { + resolve_pending_rollback(self).await?; + let req = self + .inner + .client + .bulk_insert(table) + .await + .map_err(tiberius_err)?; + Ok(MssqlBulkInsert::new(req)) + } +} diff --git a/sqlx-mssql/src/database.rs b/sqlx-mssql/src/database.rs new file mode 100644 index 0000000000..69fa61a469 --- /dev/null +++ b/sqlx-mssql/src/database.rs @@ -0,0 +1,78 @@ +use crate::value::{MssqlValue, MssqlValueRef}; +use crate::{ + MssqlArguments, MssqlColumn, MssqlConnection, MssqlQueryResult, MssqlRow, MssqlStatement, + MssqlTransactionManager, MssqlTypeInfo, +}; +pub(crate) use sqlx_core::database::{Database, HasStatementCache}; + +/// MSSQL (SQL Server) database driver. +#[derive(Debug)] +pub struct Mssql; + +impl Database for Mssql { + type Connection = MssqlConnection; + + type TransactionManager = MssqlTransactionManager; + + type Row = MssqlRow; + + type QueryResult = MssqlQueryResult; + + type Column = MssqlColumn; + + type TypeInfo = MssqlTypeInfo; + + type Value = MssqlValue; + type ValueRef<'r> = MssqlValueRef<'r>; + + type Arguments = MssqlArguments; + type ArgumentBuffer = Vec; + + type Statement = MssqlStatement; + + const NAME: &'static str = "MSSQL"; + + const URL_SCHEMES: &'static [&'static str] = &["mssql", "sqlserver"]; +} + +impl HasStatementCache for Mssql {} + +/// A single argument value for MSSQL queries. +/// +/// Unlike MySQL/Postgres which use a byte buffer, MSSQL arguments are stored +/// as typed enum variants because tiberius requires typed `bind()` calls. +#[derive(Debug, Clone)] +pub enum MssqlArgumentValue { + Null, + Bool(bool), + U8(u8), + I16(i16), + I32(i32), + I64(i64), + F32(f32), + F64(f64), + String(String), + Binary(Vec), + #[cfg(feature = "chrono")] + NaiveDateTime(chrono::NaiveDateTime), + #[cfg(feature = "chrono")] + NaiveDate(chrono::NaiveDate), + #[cfg(feature = "chrono")] + NaiveTime(chrono::NaiveTime), + #[cfg(feature = "chrono")] + DateTimeFixedOffset(chrono::DateTime), + #[cfg(feature = "uuid")] + Uuid(uuid::Uuid), + #[cfg(feature = "rust_decimal")] + Decimal(rust_decimal::Decimal), + #[cfg(feature = "time")] + TimeDate(time::Date), + #[cfg(feature = "time")] + TimeTime(time::Time), + #[cfg(feature = "time")] + TimePrimitiveDateTime(time::PrimitiveDateTime), + #[cfg(feature = "time")] + TimeOffsetDateTime(time::OffsetDateTime), + #[cfg(feature = "bigdecimal")] + BigDecimal(bigdecimal::BigDecimal), +} diff --git a/sqlx-mssql/src/error.rs b/sqlx-mssql/src/error.rs new file mode 100644 index 0000000000..f61fd968a7 --- /dev/null +++ b/sqlx-mssql/src/error.rs @@ -0,0 +1,139 @@ +use std::borrow::Cow; +use std::error::Error as StdError; +use std::fmt::{self, Debug, Display, Formatter}; + +pub(crate) use sqlx_core::error::*; + +/// An error returned from the MSSQL database. +pub struct MssqlDatabaseError { + pub(crate) number: u32, + pub(crate) state: u8, + pub(crate) class: u8, + pub(crate) message: String, + pub(crate) server: Option, + pub(crate) procedure: Option, +} + +impl MssqlDatabaseError { + /// The error number returned by SQL Server. + pub fn number(&self) -> u32 { + self.number + } + + /// The error state. + pub fn state(&self) -> u8 { + self.state + } + + /// The severity class of the error. + pub fn class(&self) -> u8 { + self.class + } + + /// The server name that generated the error, if available. + pub fn server(&self) -> Option<&str> { + self.server.as_deref() + } + + /// The stored procedure name, if applicable. + pub fn procedure(&self) -> Option<&str> { + self.procedure.as_deref() + } +} + +impl Debug for MssqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("MssqlDatabaseError") + .field("number", &self.number) + .field("state", &self.state) + .field("class", &self.class) + .field("message", &self.message) + .finish() + } +} + +impl Display for MssqlDatabaseError { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!( + f, + "(number {}, state {}): {}", + self.number, self.state, self.message + ) + } +} + +impl StdError for MssqlDatabaseError {} + +impl DatabaseError for MssqlDatabaseError { + #[inline] + fn message(&self) -> &str { + &self.message + } + + fn code(&self) -> Option> { + Some(Cow::Owned(self.number.to_string())) + } + + #[doc(hidden)] + fn as_error(&self) -> &(dyn StdError + Send + Sync + 'static) { + self + } + + #[doc(hidden)] + fn as_error_mut(&mut self) -> &mut (dyn StdError + Send + Sync + 'static) { + self + } + + #[doc(hidden)] + fn into_error(self: Box) -> Box { + self + } + + fn kind(&self) -> ErrorKind { + match self.number { + // Cannot insert duplicate key + 2601 | 2627 => ErrorKind::UniqueViolation, + // Foreign key constraint violation + 547 => ErrorKind::ForeignKeyViolation, + // Cannot insert NULL + 515 => ErrorKind::NotNullViolation, + // Check constraint violation + 2628 => ErrorKind::CheckViolation, + _ => ErrorKind::Other, + } + } +} + +/// Convert a tiberius error into an sqlx Error. +pub(crate) fn tiberius_err(err: tiberius::error::Error) -> Error { + match err { + tiberius::error::Error::Server(token_error) => { + Error::Database(Box::new(MssqlDatabaseError { + number: token_error.code(), + state: token_error.state(), + class: token_error.class(), + message: token_error.message().to_string(), + server: { + let s = token_error.server(); + if s.is_empty() { + None + } else { + Some(s.to_string()) + } + }, + procedure: { + let s = token_error.procedure(); + if s.is_empty() { + None + } else { + Some(s.to_string()) + } + }, + })) + } + tiberius::error::Error::Io { kind, message } => { + Error::Io(std::io::Error::new(kind, message)) + } + other => Error::Protocol(other.to_string()), + } +} diff --git a/sqlx-mssql/src/io.rs b/sqlx-mssql/src/io.rs new file mode 100644 index 0000000000..88baca0c52 --- /dev/null +++ b/sqlx-mssql/src/io.rs @@ -0,0 +1,72 @@ +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use sqlx_core::net::Socket; + +/// Adapter that wraps an sqlx-core `Socket` to implement `futures_io::AsyncRead + AsyncWrite`, +/// which is what tiberius requires. +pub(crate) struct SocketAdapter { + inner: S, +} + +impl SocketAdapter { + pub fn new(socket: S) -> Self { + Self { inner: socket } + } +} + +impl futures_io::AsyncRead for SocketAdapter { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + loop { + match self.inner.try_read(&mut &mut *buf) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.inner.poll_read_ready(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } +} + +impl futures_io::AsyncWrite for SocketAdapter { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + loop { + match self.inner.try_write(buf) { + Ok(n) => return Poll::Ready(Ok(n)), + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + match self.inner.poll_write_ready(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Err(e) => return Poll::Ready(Err(e)), + } + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_shutdown(cx) + } +} + +// Implement Unpin since we only access the inner socket through &mut self +impl Unpin for SocketAdapter {} diff --git a/sqlx-mssql/src/isolation_level.rs b/sqlx-mssql/src/isolation_level.rs new file mode 100644 index 0000000000..1409f084de --- /dev/null +++ b/sqlx-mssql/src/isolation_level.rs @@ -0,0 +1,55 @@ +use std::fmt; + +/// SQL Server transaction isolation levels. +/// +/// SQL Server supports five isolation levels. The `SET TRANSACTION ISOLATION LEVEL` +/// statement must be issued **before** `BEGIN TRANSACTION`, unlike PostgreSQL which +/// accepts it inside the `BEGIN` block. +/// +/// See [SQL Server documentation](https://learn.microsoft.com/en-us/sql/t-sql/statements/set-transaction-isolation-level-transact-sql) +/// for details on each level. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub enum MssqlIsolationLevel { + /// Allows dirty reads. Statements can read rows modified by other + /// transactions that have not yet been committed. + ReadUncommitted, + + /// The default isolation level. Statements cannot read data modified + /// by other transactions that have not been committed. + #[default] + ReadCommitted, + + /// Statements cannot read data modified by other transactions that + /// have not been committed, and no other transactions can modify + /// data read by the current transaction until it completes. + RepeatableRead, + + /// Uses row versioning to provide transaction-level read consistency. + /// Requires the `ALLOW_SNAPSHOT_ISOLATION` database option to be `ON`. + Snapshot, + + /// Statements cannot read data modified by other transactions that + /// have not been committed. No other transactions can modify data + /// read by the current transaction, and no other transactions can + /// insert new rows matching the current transaction's search conditions. + Serializable, +} + +impl MssqlIsolationLevel { + /// Returns the SQL Server syntax for this isolation level. + pub fn as_str(&self) -> &'static str { + match self { + Self::ReadUncommitted => "READ UNCOMMITTED", + Self::ReadCommitted => "READ COMMITTED", + Self::RepeatableRead => "REPEATABLE READ", + Self::Snapshot => "SNAPSHOT", + Self::Serializable => "SERIALIZABLE", + } + } +} + +impl fmt::Display for MssqlIsolationLevel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} diff --git a/sqlx-mssql/src/lib.rs b/sqlx-mssql/src/lib.rs new file mode 100644 index 0000000000..6cb0bf905a --- /dev/null +++ b/sqlx-mssql/src/lib.rs @@ -0,0 +1,83 @@ +//! **MSSQL** (SQL Server) database driver. +#![deny(clippy::cast_possible_truncation)] +#![deny(clippy::cast_possible_wrap)] +#![deny(clippy::cast_sign_loss)] + +#[macro_use] +extern crate sqlx_core; + +use crate::executor::Executor; + +pub(crate) use sqlx_core::driver_prelude::*; + +pub mod advisory_lock; +mod bulk_insert; +mod isolation_level; + +#[cfg(feature = "any")] +pub mod any; + +mod arguments; +mod column; +mod connection; +mod database; +mod error; +mod io; +mod options; +mod query_result; +mod row; +mod statement; +mod transaction; +mod type_checking; +mod type_info; +pub mod types; +mod value; + +#[cfg(feature = "migrate")] +mod migrate; + +#[cfg(feature = "migrate")] +mod testing; + +pub use advisory_lock::{MssqlAdvisoryLock, MssqlAdvisoryLockGuard, MssqlAdvisoryLockMode}; +pub use arguments::MssqlArguments; +pub use bulk_insert::MssqlBulkInsert; +pub use column::MssqlColumn; +pub use connection::MssqlConnection; +pub use database::Mssql; +pub use error::MssqlDatabaseError; +pub use isolation_level::MssqlIsolationLevel; +pub use options::ssl_mode::MssqlSslMode; +pub use options::MssqlConnectOptions; +pub use query_result::MssqlQueryResult; +pub use row::MssqlRow; +pub use statement::MssqlStatement; +pub use transaction::MssqlTransactionManager; +pub use type_info::MssqlTypeInfo; +pub use types::xml::MssqlXml; +pub use value::{MssqlValue, MssqlValueRef}; + +// Re-export tiberius types needed for bulk insert row construction. +pub use tiberius::{IntoRow, IntoSql, TokenRow}; + +/// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL. +pub type MssqlPool = crate::pool::Pool; + +/// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MSSQL. +pub type MssqlPoolOptions = crate::pool::PoolOptions; + +/// An alias for [`Executor<'_, Database = Mssql>`][Executor]. +pub trait MssqlExecutor<'c>: Executor<'c, Database = Mssql> {} +impl<'c, T: Executor<'c, Database = Mssql>> MssqlExecutor<'c> for T {} + +/// An alias for [`Transaction`][crate::transaction::Transaction], specialized for MSSQL. +pub type MssqlTransaction<'c> = crate::transaction::Transaction<'c, Mssql>; + +// NOTE: required due to the lack of lazy normalization +impl_into_arguments_for_arguments!(MssqlArguments); +impl_acquire!(Mssql, MssqlConnection); +impl_column_index_for_row!(MssqlRow); +impl_column_index_for_statement!(MssqlStatement); + +// required because some databases have a different handling of NULL +impl_encode_for_option!(Mssql); diff --git a/sqlx-mssql/src/migrate.rs b/sqlx-mssql/src/migrate.rs new file mode 100644 index 0000000000..631f8cc061 --- /dev/null +++ b/sqlx-mssql/src/migrate.rs @@ -0,0 +1,344 @@ +use std::str::FromStr; +use std::time::Duration; +use std::time::Instant; + +use futures_core::future::BoxFuture; +pub(crate) use sqlx_core::migrate::*; +use sqlx_core::sql_str::AssertSqlSafe; + +use crate::connection::{ConnectOptions, Connection}; +use crate::error::Error; +use crate::executor::Executor; +use crate::query::query; +use crate::query_as::query_as; +use crate::query_scalar::query_scalar; +use crate::{Mssql, MssqlConnectOptions, MssqlConnection}; + +/// Escape a table name for safe use as an MSSQL bracket-quoted identifier (`[...]`). +fn escape_table_name(table_name: &str) -> String { + format!("[{}]", table_name.replace(']', "]]")) +} + +fn parse_for_maintenance(url: &str) -> Result<(MssqlConnectOptions, String), Error> { + let mut options = MssqlConnectOptions::from_str(url)?; + + let database = if let Some(database) = &options.database { + database.to_owned() + } else { + return Err(Error::Configuration( + "DATABASE_URL does not specify a database".into(), + )); + }; + + // switch us to master database for create/drop commands + options.database = Some("master".to_owned()); + + Ok((options, database)) +} + +impl MigrateDatabase for Mssql { + async fn create_database(url: &str) -> Result<(), Error> { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + query( + "DECLARE @sql NVARCHAR(MAX) = N'CREATE DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql;", + ) + .bind(database) + .execute(&mut conn) + .await?; + + Ok(()) + } + + async fn database_exists(url: &str) -> Result { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + let exists: bool = + query_scalar("SELECT CASE WHEN DB_ID(@p1) IS NOT NULL THEN 1 ELSE 0 END") + .bind(database) + .fetch_one(&mut conn) + .await?; + + Ok(exists) + } + + async fn drop_database(url: &str) -> Result<(), Error> { + let (options, database) = parse_for_maintenance(url)?; + let mut conn = options.connect().await?; + + query( + "IF DB_ID(@p1) IS NOT NULL \ + BEGIN \ + DECLARE @sql NVARCHAR(MAX); \ + SET @sql = N'ALTER DATABASE ' + QUOTENAME(@p1) + N' SET SINGLE_USER WITH ROLLBACK IMMEDIATE'; \ + EXEC sp_executesql @sql; \ + SET @sql = N'DROP DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END" + ) + .bind(database) + .execute(&mut conn) + .await?; + + Ok(()) + } +} + +impl Migrate for MssqlConnection { + fn create_schema_if_not_exists<'e>( + &'e mut self, + schema_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + query( + "IF NOT EXISTS (SELECT * FROM sys.schemas WHERE name = @p1) \ + BEGIN \ + DECLARE @sql NVARCHAR(MAX) = N'CREATE SCHEMA ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END", + ) + .bind(schema_name) + .execute(&mut *self) + .await?; + + Ok(()) + }) + } + + fn ensure_migrations_table<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + let ident = escape_table_name(table_name); + // Atomic check-and-create: the IF NOT EXISTS and CREATE TABLE run + // in a single batch so concurrent migrators cannot race. + // The WHERE clause is parameterized; the identifier must use + // bracket-escaping because DDL identifiers can't be parameterized. + query(AssertSqlSafe(format!( + "IF NOT EXISTS (SELECT 1 FROM sys.tables WHERE name = @p1) \ + CREATE TABLE {ident} ( \ + version BIGINT PRIMARY KEY, \ + description NVARCHAR(MAX) NOT NULL, \ + installed_on DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME(), \ + success BIT NOT NULL, \ + checksum VARBINARY(MAX) NOT NULL, \ + execution_time BIGINT NOT NULL \ + );" + ))) + .bind(table_name) + .execute(&mut *self) + .await?; + + Ok(()) + }) + } + + fn dirty_version<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async move { + let ident = escape_table_name(table_name); + let row: Option<(i64,)> = query_as(AssertSqlSafe(format!( + "SELECT TOP 1 version FROM {ident} WHERE success = 0 ORDER BY version" + ))) + .fetch_optional(self) + .await?; + + Ok(row.map(|r| r.0)) + }) + } + + fn list_applied_migrations<'e>( + &'e mut self, + table_name: &'e str, + ) -> BoxFuture<'e, Result, MigrateError>> { + Box::pin(async move { + let ident = escape_table_name(table_name); + let rows: Vec<(i64, Vec)> = query_as(AssertSqlSafe(format!( + "SELECT version, checksum FROM {ident} ORDER BY version" + ))) + .fetch_all(self) + .await?; + + let migrations = rows + .into_iter() + .map(|(version, checksum)| AppliedMigration { + version, + checksum: checksum.into(), + }) + .collect(); + + Ok(migrations) + }) + } + + fn lock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + // sp_getapplock returns a status code (0/1 = success, negative = failure) + // but `execute` only surfaces SQL errors, not return values. + // We use THROW to convert a failed lock acquisition into a SQL error. + let _ = self + .execute( + "DECLARE @r INT; \ + EXEC @r = sp_getapplock @Resource = 'sqlx_migrations', @LockMode = 'Exclusive', @LockOwner = 'Session', @LockTimeout = -1; \ + IF @r < 0 THROW 50000, 'Failed to acquire migration lock', 1;" + ) + .await?; + + Ok(()) + }) + } + + fn unlock(&mut self) -> BoxFuture<'_, Result<(), MigrateError>> { + Box::pin(async move { + let _ = self + .execute( + "EXEC sp_releaseapplock @Resource = 'sqlx_migrations', @LockOwner = 'Session'", + ) + .await?; + + Ok(()) + }) + } + + fn apply<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async move { + let start = Instant::now(); + + if migration.no_tx { + execute_migration(self, table_name, migration).await?; + } else { + // Use a single transaction for the actual migration script and the essential + // bookkeeping so we never execute migrations twice. + // See https://github.com/launchbadge/sqlx/issues/1966. + let mut tx = self.begin().await?; + execute_migration(&mut tx, table_name, migration).await?; + tx.commit().await?; + } + + // Update `execution_time`. + // NOTE: The process may disconnect/die at this point, so the elapsed time value + // might be lost. We accept this small risk since this value is not super important. + let elapsed = start.elapsed(); + + let ident = escape_table_name(table_name); + + #[allow(clippy::cast_possible_truncation)] + let _ = query(AssertSqlSafe(format!( + r#" + UPDATE {ident} + SET execution_time = @p1 + WHERE version = @p2 + "# + ))) + .bind(elapsed.as_nanos() as i64) + .bind(migration.version) + .execute(self) + .await?; + + Ok(elapsed) + }) + } + + fn revert<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result> { + Box::pin(async move { + let start = Instant::now(); + + if migration.no_tx { + revert_migration(self, table_name, migration).await?; + } else { + let mut tx = self.begin().await?; + revert_migration(&mut tx, table_name, migration).await?; + tx.commit().await?; + } + + let elapsed = start.elapsed(); + + Ok(elapsed) + }) + } + + fn skip<'e>( + &'e mut self, + table_name: &'e str, + migration: &'e Migration, + ) -> BoxFuture<'e, Result<(), MigrateError>> { + Box::pin(async move { + let ident = escape_table_name(table_name); + // language=TSQL + let _ = query(AssertSqlSafe(format!( + r#" + INSERT INTO {ident} ( version, description, success, checksum, execution_time ) + VALUES ( @p1, @p2, 1, @p3, -1 ) + "# + ))) + .bind(migration.version) + .bind(&*migration.description) + .bind(&*migration.checksum) + .execute(self) + .await?; + + Ok(()) + }) + } +} + +async fn execute_migration( + conn: &mut MssqlConnection, + table_name: &str, + migration: &Migration, +) -> Result<(), MigrateError> { + let _ = conn + .execute(migration.sql.clone()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + let ident = escape_table_name(table_name); + let _ = query(AssertSqlSafe(format!( + r#" + INSERT INTO {ident} ( version, description, success, checksum, execution_time ) + VALUES ( @p1, @p2, 1, @p3, -1 ) + "# + ))) + .bind(migration.version) + .bind(&*migration.description) + .bind(&*migration.checksum) + .execute(conn) + .await?; + + Ok(()) +} + +async fn revert_migration( + conn: &mut MssqlConnection, + table_name: &str, + migration: &Migration, +) -> Result<(), MigrateError> { + let _ = conn + .execute(migration.sql.clone()) + .await + .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; + + let ident = escape_table_name(table_name); + let _ = query(AssertSqlSafe(format!( + r#"DELETE FROM {ident} WHERE version = @p1"# + ))) + .bind(migration.version) + .execute(conn) + .await?; + + Ok(()) +} diff --git a/sqlx-mssql/src/options/connect.rs b/sqlx-mssql/src/options/connect.rs new file mode 100644 index 0000000000..0a4cb94809 --- /dev/null +++ b/sqlx-mssql/src/options/connect.rs @@ -0,0 +1,36 @@ +use crate::connection::ConnectOptions; +use crate::error::Error; +use crate::{MssqlConnectOptions, MssqlConnection}; +use log::LevelFilter; +use sqlx_core::Url; +use std::time::Duration; + +impl ConnectOptions for MssqlConnectOptions { + type Connection = MssqlConnection; + + fn from_url(url: &Url) -> Result { + Self::parse_from_url(url) + } + + fn to_url_lossy(&self) -> Url { + self.build_url() + .expect("BUG: MssqlConnectOptions generated an un-parseable URL") + } + + async fn connect(&self) -> Result + where + Self::Connection: Sized, + { + MssqlConnection::establish(self).await + } + + fn log_statements(mut self, level: LevelFilter) -> Self { + self.log_settings.log_statements(level); + self + } + + fn log_slow_statements(mut self, level: LevelFilter, duration: Duration) -> Self { + self.log_settings.log_slow_statements(level, duration); + self + } +} diff --git a/sqlx-mssql/src/options/mod.rs b/sqlx-mssql/src/options/mod.rs new file mode 100644 index 0000000000..e1126dbdcc --- /dev/null +++ b/sqlx-mssql/src/options/mod.rs @@ -0,0 +1,333 @@ +mod connect; +mod parse; +pub mod ssl_mode; + +use crate::connection::LogSettings; +use ssl_mode::MssqlSslMode; + +/// Options and flags which can be used to configure a MSSQL connection. +/// +/// A value of `MssqlConnectOptions` can be parsed from a connection URL, +/// as described below. +/// +/// The generic format of the connection URL: +/// +/// ```text +/// mssql://[user[:password]@]host[:port][/database][?properties] +/// ``` +/// +/// ## Properties +/// +/// |Parameter|Default|Description| +/// |---------|-------|-----------| +/// | `sslmode` / `ssl_mode` | `preferred` | SSL encryption mode: `disabled`, `login_only`, `preferred`, `required`. | +/// | `encrypt` | (none) | Legacy alias: `true` maps to `required`, `false` to `disabled`. | +/// | `trust_server_certificate` | `false` | Whether to trust the server certificate without validation. | +/// | `trust_server_certificate_ca` | (none) | Path to a CA certificate file to validate the server certificate against. Mutually exclusive with `trust_server_certificate`. | +/// | `application_intent` | `read_write` | Application intent: `read_write` or `read_only`. `read_only` routes to Always On read replicas. | +/// | `statement-cache-capacity` | `100` | The maximum number of prepared statements stored in the cache. | +/// | `app_name` | `sqlx` | The application name sent to the server. | +/// | `instance` | `None` | The SQL Server instance name. | +/// | `auth` | `sql_server` | Authentication method: `sql_server`, `windows` (cfg-gated), `integrated` (cfg-gated), `aad_token`. | +/// | `token` | (none) | Azure AD bearer token (used when `auth=aad_token`). | +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example() -> sqlx::Result<()> { +/// use sqlx::{Connection, ConnectOptions}; +/// use sqlx::mssql::{MssqlConnectOptions, MssqlConnection}; +/// +/// // URL connection string +/// let conn = MssqlConnection::connect("mssql://sa:password@localhost/master").await?; +/// +/// // Manually-constructed options +/// let conn = MssqlConnectOptions::new() +/// .host("localhost") +/// .username("sa") +/// .password("password") +/// .database("master") +/// .connect().await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct MssqlConnectOptions { + pub(crate) host: String, + pub(crate) port: u16, + pub(crate) username: String, + pub(crate) password: Option, + pub(crate) database: Option, + pub(crate) instance: Option, + pub(crate) ssl_mode: MssqlSslMode, + pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) application_intent_read_only: bool, + pub(crate) statement_cache_capacity: usize, + pub(crate) app_name: String, + pub(crate) log_settings: LogSettings, + /// When `true`, use Windows (NTLM) authentication instead of SQL Server auth. + /// The username can use `domain\user` syntax which tiberius parses internally. + #[cfg(all(windows, feature = "winauth"))] + pub(crate) windows_auth: bool, + /// When `true`, use integrated authentication (SSPI on Windows / Kerberos on Unix). + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] + pub(crate) integrated_auth: bool, + /// Azure AD bearer token for AAD authentication. + pub(crate) aad_token: Option, +} + +impl Default for MssqlConnectOptions { + fn default() -> Self { + Self::new() + } +} + +impl MssqlConnectOptions { + /// Creates a new, default set of options ready for configuration. + pub fn new() -> Self { + Self { + port: 1433, + host: String::from("localhost"), + username: String::from("sa"), + password: None, + database: None, + instance: None, + ssl_mode: MssqlSslMode::default(), + trust_server_certificate: false, + trust_server_certificate_ca: None, + application_intent_read_only: false, + statement_cache_capacity: 100, + app_name: String::from("sqlx"), + log_settings: Default::default(), + #[cfg(all(windows, feature = "winauth"))] + windows_auth: false, + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] + integrated_auth: false, + aad_token: None, + } + } + + /// Sets the name of the host to connect to. + pub fn host(mut self, host: &str) -> Self { + host.clone_into(&mut self.host); + self + } + + /// Sets the port to connect to at the server host. + /// + /// The default port for MSSQL is `1433`. + pub fn port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Sets the username to connect as. + pub fn username(mut self, username: &str) -> Self { + username.clone_into(&mut self.username); + self + } + + /// Sets the password to connect with. + pub fn password(mut self, password: &str) -> Self { + self.password = Some(password.to_owned()); + self + } + + /// Sets the database name. + pub fn database(mut self, database: &str) -> Self { + self.database = Some(database.to_owned()); + self + } + + /// Sets the SQL Server instance name. + pub fn instance(mut self, instance: &str) -> Self { + self.instance = Some(instance.to_owned()); + self + } + + /// Sets the SSL encryption mode. + pub fn ssl_mode(mut self, mode: MssqlSslMode) -> Self { + self.ssl_mode = mode; + self + } + + /// Sets whether to use TLS encryption. + /// + /// This is a legacy convenience method. + /// `true` maps to [`MssqlSslMode::Required`], `false` to [`MssqlSslMode::Disabled`]. + pub fn encrypt(mut self, encrypt: bool) -> Self { + self.ssl_mode = if encrypt { + MssqlSslMode::Required + } else { + MssqlSslMode::Disabled + }; + self + } + + /// Sets whether to trust the server certificate without validation. + pub fn trust_server_certificate(mut self, trust: bool) -> Self { + self.trust_server_certificate = trust; + self + } + + /// Sets a CA certificate file path to validate the server certificate against. + /// + /// Accepts `.pem`, `.crt`, or `.der` certificate files. + /// + /// This is mutually exclusive with [`trust_server_certificate`](Self::trust_server_certificate). + /// When a CA path is set, `trust_server_certificate` is ignored. + pub fn trust_server_certificate_ca(mut self, path: &str) -> Self { + self.trust_server_certificate_ca = Some(path.to_owned()); + self + } + + /// Sets the application intent to read-only. + /// + /// When `true`, sets `ApplicationIntent=ReadOnly` in the TDS login packet, + /// which routes connections to Always On Availability Group read replicas. + pub fn application_intent_read_only(mut self, read_only: bool) -> Self { + self.application_intent_read_only = read_only; + self + } + + /// Get whether the application intent is set to read-only. + pub fn get_application_intent_read_only(&self) -> bool { + self.application_intent_read_only + } + + /// Sets the capacity of the connection's statement cache. + pub fn statement_cache_capacity(mut self, capacity: usize) -> Self { + self.statement_cache_capacity = capacity; + self + } + + /// Sets the application name sent to the server. + pub fn app_name(mut self, app_name: &str) -> Self { + app_name.clone_into(&mut self.app_name); + self + } + + /// Sets whether to use Windows (NTLM) authentication. + /// + /// When enabled, the username can use `domain\user` syntax + /// which tiberius parses internally. + #[cfg(all(windows, feature = "winauth"))] + pub fn windows_auth(mut self, enabled: bool) -> Self { + self.windows_auth = enabled; + self + } + + /// Sets whether to use integrated authentication (SSPI on Windows / Kerberos on Unix). + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] + pub fn integrated_auth(mut self, enabled: bool) -> Self { + self.integrated_auth = enabled; + self + } + + /// Sets an Azure AD bearer token for authentication. + /// + /// When set, AAD token authentication takes precedence over other auth methods. + pub fn aad_token(mut self, token: &str) -> Self { + self.aad_token = Some(token.to_owned()); + self + } + + /// Get the current host. + pub fn get_host(&self) -> &str { + &self.host + } + + /// Get the server's port. + pub fn get_port(&self) -> u16 { + self.port + } + + /// Get the current username. + pub fn get_username(&self) -> &str { + &self.username + } + + /// Get the current database name. + pub fn get_database(&self) -> Option<&str> { + self.database.as_deref() + } + + /// Build a `tiberius::Config` from these options. + pub(crate) fn to_tiberius_config(&self) -> tiberius::Config { + let mut config = tiberius::Config::new(); + + config.host(&self.host); + config.port(self.port); + config.application_name(&self.app_name); + + if let Some(database) = &self.database { + config.database(database); + } + + if let Some(instance) = &self.instance { + config.instance_name(instance); + } + + if let Some(token) = &self.aad_token { + config.authentication(tiberius::AuthMethod::aad_token(token)); + } else { + #[allow(unused_mut)] + let mut handled = false; + + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] + if !handled && self.integrated_auth { + config.authentication(tiberius::AuthMethod::Integrated); + handled = true; + } + + #[cfg(all(windows, feature = "winauth"))] + if !handled && self.windows_auth { + config.authentication(tiberius::AuthMethod::windows( + &self.username, + self.password.as_deref().unwrap_or(""), + )); + handled = true; + } + + if !handled { + config.authentication(tiberius::AuthMethod::sql_server( + &self.username, + self.password.as_deref().unwrap_or(""), + )); + } + } + + if let Some(ca_path) = &self.trust_server_certificate_ca { + // trust_cert_ca and trust_cert are mutually exclusive in tiberius + config.trust_cert_ca(ca_path); + } else if self.trust_server_certificate { + config.trust_cert(); + } + + if self.application_intent_read_only { + config.readonly(true); + } + + config.encryption(match self.ssl_mode { + MssqlSslMode::Disabled => tiberius::EncryptionLevel::NotSupported, + MssqlSslMode::LoginOnly => tiberius::EncryptionLevel::Off, + MssqlSslMode::Preferred => tiberius::EncryptionLevel::On, + MssqlSslMode::Required => tiberius::EncryptionLevel::Required, + }); + + config + } +} diff --git a/sqlx-mssql/src/options/parse.rs b/sqlx-mssql/src/options/parse.rs new file mode 100644 index 0000000000..8eedb359b2 --- /dev/null +++ b/sqlx-mssql/src/options/parse.rs @@ -0,0 +1,376 @@ +use std::str::FromStr; + +use percent_encoding::percent_decode_str; +use sqlx_core::Url; + +use crate::error::Error; + +use super::ssl_mode::MssqlSslMode; +use super::MssqlConnectOptions; + +impl MssqlConnectOptions { + pub(crate) fn parse_from_url(url: &Url) -> Result { + let mut options = Self::new(); + + if let Some(host) = url.host_str() { + options = options.host(host); + } + + if let Some(port) = url.port() { + options = options.port(port); + } + + let username = url.username(); + if !username.is_empty() { + options = options.username( + &percent_decode_str(username) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + if let Some(password) = url.password() { + options = options.password( + &percent_decode_str(password) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + let path = url.path().trim_start_matches('/'); + if !path.is_empty() { + options = options.database( + &percent_decode_str(path) + .decode_utf8() + .map_err(Error::config)?, + ); + } + + for (key, value) in url.query_pairs() { + match &*key { + "sslmode" | "ssl_mode" => { + options = options.ssl_mode(match &*value { + "disabled" => MssqlSslMode::Disabled, + "login_only" => MssqlSslMode::LoginOnly, + "preferred" => MssqlSslMode::Preferred, + "required" => MssqlSslMode::Required, + _ => { + return Err(Error::Configuration( + format!("unknown sslmode value: {value}").into(), + )) + } + }); + } + + "encrypt" => { + options = options.encrypt(value.parse().map_err(Error::config)?); + } + + "trust_server_certificate" | "trustServerCertificate" => { + options = + options.trust_server_certificate(value.parse().map_err(Error::config)?); + } + + "instance" => { + options = options.instance(&value); + } + + "app_name" | "application-name" => { + options = options.app_name(&value); + } + + "statement-cache-capacity" => { + options = + options.statement_cache_capacity(value.parse().map_err(Error::config)?); + } + + "application_intent" | "applicationIntent" => match &*value { + "read_only" | "ReadOnly" => { + options = options.application_intent_read_only(true); + } + "read_write" | "ReadWrite" => { + options = options.application_intent_read_only(false); + } + _ => { + return Err(Error::Configuration( + format!("unknown application_intent value: {value}").into(), + )) + } + }, + + "trust_server_certificate_ca" | "trustServerCertificateCa" => { + options = options.trust_server_certificate_ca(&value); + } + + "auth" => { + match &*value { + "sql_server" => {} + #[cfg(all(windows, feature = "winauth"))] + "windows" => { + options.windows_auth = true; + } + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] + "integrated" => { + options.integrated_auth = true; + } + "aad_token" => { + // token value is set via the separate `token` parameter + } + _ => { + return Err(Error::Configuration( + format!("unknown auth value: {value}").into(), + )) + } + } + } + + "token" => { + options.aad_token = Some(value.into_owned()); + } + + _ => {} + } + } + + Ok(options) + } + + pub(crate) fn build_url(&self) -> Result { + let mut url = Url::parse(&format!( + "mssql://{}@{}:{}", + self.username, self.host, self.port + )) + .map_err(|e| Error::Configuration(e.to_string().into()))?; + + if let Some(password) = &self.password { + let _ = url.set_password(Some(password)); + } + + if let Some(database) = &self.database { + url.set_path(database); + } + + let sslmode = match self.ssl_mode { + MssqlSslMode::Disabled => "disabled", + MssqlSslMode::LoginOnly => "login_only", + MssqlSslMode::Preferred => "preferred", + MssqlSslMode::Required => "required", + }; + url.query_pairs_mut().append_pair("sslmode", sslmode); + + if self.application_intent_read_only { + url.query_pairs_mut() + .append_pair("application_intent", "read_only"); + } + + if let Some(ca_path) = &self.trust_server_certificate_ca { + url.query_pairs_mut() + .append_pair("trust_server_certificate_ca", ca_path); + } + + if let Some(token) = &self.aad_token { + url.query_pairs_mut() + .append_pair("auth", "aad_token") + .append_pair("token", token); + } else { + #[cfg(any( + all(windows, feature = "winauth"), + all(unix, feature = "integrated-auth-gssapi") + ))] + if self.integrated_auth { + url.query_pairs_mut().append_pair("auth", "integrated"); + } + + #[cfg(all(windows, feature = "winauth"))] + if self.windows_auth && !self.integrated_auth { + url.query_pairs_mut().append_pair("auth", "windows"); + } + } + + Ok(url) + } +} + +impl FromStr for MssqlConnectOptions { + type Err = Error; + + fn from_str(s: &str) -> Result { + let url: Url = s.parse().map_err(Error::config)?; + Self::parse_from_url(&url) + } +} + +#[test] +fn it_parses_basic_mssql_url() { + let url = "mssql://sa:password@localhost:1433/master"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.host, "localhost"); + assert_eq!(opts.port, 1433); + assert_eq!(opts.username, "sa"); + assert_eq!(opts.password, Some("password".into())); + assert_eq!(opts.database, Some("master".into())); +} + +#[test] +fn it_parses_url_with_instance() { + let url = "mssql://sa:password@localhost/master?instance=SQLEXPRESS"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + + assert_eq!(opts.instance, Some("SQLEXPRESS".into())); +} + +#[test] +fn it_parses_sslmode_disabled() { + let url = "mssql://sa:password@localhost/master?sslmode=disabled"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Disabled)); +} + +#[test] +fn it_parses_sslmode_login_only() { + let url = "mssql://sa:password@localhost/master?ssl_mode=login_only"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::LoginOnly)); +} + +#[test] +fn it_parses_sslmode_preferred() { + let url = "mssql://sa:password@localhost/master?sslmode=preferred"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Preferred)); +} + +#[test] +fn it_parses_sslmode_required() { + let url = "mssql://sa:password@localhost/master?sslmode=required"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Required)); +} + +#[test] +fn it_parses_encrypt_true_as_required() { + let url = "mssql://sa:password@localhost/master?encrypt=true"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Required)); +} + +#[test] +fn it_parses_encrypt_false_as_disabled() { + let url = "mssql://sa:password@localhost/master?encrypt=false"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(matches!(opts.ssl_mode, MssqlSslMode::Disabled)); +} + +#[test] +fn it_rejects_invalid_sslmode() { + let url = "mssql://sa:password@localhost/master?sslmode=bogus"; + assert!(MssqlConnectOptions::from_str(url).is_err()); +} + +#[test] +fn it_roundtrips_sslmode_in_url() { + let url = "mssql://sa:password@localhost/master?sslmode=login_only"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + let built = opts.build_url().unwrap(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert!(matches!(opts2.ssl_mode, MssqlSslMode::LoginOnly)); +} + +#[test] +fn it_parses_application_intent_read_only() { + let url = "mssql://sa:password@localhost/master?application_intent=read_only"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(opts.application_intent_read_only); +} + +#[test] +fn it_parses_application_intent_read_write() { + let url = "mssql://sa:password@localhost/master?application_intent=read_write"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(!opts.application_intent_read_only); +} + +#[test] +fn it_parses_application_intent_camel_case() { + let url = "mssql://sa:password@localhost/master?applicationIntent=ReadOnly"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert!(opts.application_intent_read_only); +} + +#[test] +fn it_rejects_invalid_application_intent() { + let url = "mssql://sa:password@localhost/master?application_intent=bogus"; + assert!(MssqlConnectOptions::from_str(url).is_err()); +} + +#[test] +fn it_parses_trust_server_certificate_ca() { + let url = "mssql://sa:password@localhost/master?trust_server_certificate_ca=/path/to/ca.pem"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert_eq!( + opts.trust_server_certificate_ca, + Some("/path/to/ca.pem".into()) + ); +} + +#[test] +fn it_roundtrips_application_intent_in_url() { + let opts = MssqlConnectOptions::new() + .host("localhost") + .username("sa") + .password("password") + .application_intent_read_only(true); + let built = opts.build_url().unwrap(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert!(opts2.application_intent_read_only); +} + +#[test] +fn it_roundtrips_trust_cert_ca_in_url() { + let opts = MssqlConnectOptions::new() + .host("localhost") + .username("sa") + .password("password") + .trust_server_certificate_ca("/etc/ssl/ca.pem"); + let built = opts.build_url().unwrap(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert_eq!( + opts2.trust_server_certificate_ca, + Some("/etc/ssl/ca.pem".into()) + ); +} + +#[test] +fn it_parses_aad_token_auth() { + let url = "mssql://sa@localhost/master?auth=aad_token&token=my-bearer-token"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert_eq!(opts.aad_token, Some("my-bearer-token".into())); +} + +#[test] +fn it_roundtrips_aad_token_in_url() { + let opts = MssqlConnectOptions::new() + .host("localhost") + .username("sa") + .aad_token("my-bearer-token"); + let built = opts.build_url().unwrap(); + let opts2 = MssqlConnectOptions::parse_from_url(&built).unwrap(); + assert_eq!(opts2.aad_token, Some("my-bearer-token".into())); +} + +#[test] +fn it_parses_sql_server_auth_explicitly() { + let url = "mssql://sa:password@localhost/master?auth=sql_server"; + let opts = MssqlConnectOptions::from_str(url).unwrap(); + assert_eq!(opts.aad_token, None); +} + +#[test] +fn it_rejects_invalid_auth() { + let url = "mssql://sa:password@localhost/master?auth=bogus"; + assert!(MssqlConnectOptions::from_str(url).is_err()); +} diff --git a/sqlx-mssql/src/options/ssl_mode.rs b/sqlx-mssql/src/options/ssl_mode.rs new file mode 100644 index 0000000000..09519bdcb3 --- /dev/null +++ b/sqlx-mssql/src/options/ssl_mode.rs @@ -0,0 +1,18 @@ +/// The SSL mode to use when connecting to MSSQL. +/// +/// Maps to the tiberius `EncryptionLevel` variants. +#[derive(Debug, Clone, Copy, Default)] +pub enum MssqlSslMode { + /// No encryption at all (`EncryptionLevel::NotSupported`). + Disabled, + + /// Only encrypt the login packet (`EncryptionLevel::Off`). + LoginOnly, + + /// Encrypt if the server supports it (`EncryptionLevel::On`). + #[default] + Preferred, + + /// Always encrypt; fail if the server doesn't support it (`EncryptionLevel::Required`). + Required, +} diff --git a/sqlx-mssql/src/query_result.rs b/sqlx-mssql/src/query_result.rs new file mode 100644 index 0000000000..de00dda5ca --- /dev/null +++ b/sqlx-mssql/src/query_result.rs @@ -0,0 +1,30 @@ +use std::iter::{Extend, IntoIterator}; + +#[derive(Debug, Default)] +pub struct MssqlQueryResult { + pub(super) rows_affected: u64, +} + +impl MssqlQueryResult { + pub fn rows_affected(&self) -> u64 { + self.rows_affected + } +} + +impl Extend for MssqlQueryResult { + fn extend>(&mut self, iter: T) { + for elem in iter { + self.rows_affected += elem.rows_affected; + } + } +} + +#[cfg(feature = "any")] +impl From for sqlx_core::any::AnyQueryResult { + fn from(done: MssqlQueryResult) -> Self { + sqlx_core::any::AnyQueryResult { + rows_affected: done.rows_affected(), + last_insert_id: None, + } + } +} diff --git a/sqlx-mssql/src/row.rs b/sqlx-mssql/src/row.rs new file mode 100644 index 0000000000..261e2ea016 --- /dev/null +++ b/sqlx-mssql/src/row.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +pub(crate) use sqlx_core::row::*; + +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::value::MssqlData; +use crate::HashMap; +use crate::{Mssql, MssqlColumn, MssqlValueRef}; + +/// Implementation of [`Row`] for MSSQL. +pub struct MssqlRow { + pub(crate) values: Vec, + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, +} + +impl Row for MssqlRow { + type Database = Mssql; + + fn columns(&self) -> &[MssqlColumn] { + &self.columns + } + + fn try_get_raw(&self, index: I) -> Result, Error> + where + I: ColumnIndex, + { + let index = index.index(self)?; + let column = &self.columns[index]; + let data = &self.values[index]; + + Ok(MssqlValueRef { + data, + type_info: column.type_info.clone(), + }) + } +} + +impl ColumnIndex for &'_ str { + fn index(&self, row: &MssqlRow) -> Result { + row.column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .copied() + } +} + +impl std::fmt::Debug for MssqlRow { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + debug_row(self, f) + } +} diff --git a/sqlx-mssql/src/statement.rs b/sqlx-mssql/src/statement.rs new file mode 100644 index 0000000000..ad414dfd9c --- /dev/null +++ b/sqlx-mssql/src/statement.rs @@ -0,0 +1,57 @@ +use super::MssqlColumn; +use crate::column::ColumnIndex; +use crate::error::Error; +use crate::ext::ustr::UStr; +use crate::HashMap; +use crate::{Mssql, MssqlArguments, MssqlTypeInfo}; +use either::Either; +use sqlx_core::sql_str::SqlStr; +use std::sync::Arc; + +pub(crate) use sqlx_core::statement::*; + +#[derive(Debug, Clone)] +pub struct MssqlStatement { + pub(crate) sql: SqlStr, + pub(crate) metadata: MssqlStatementMetadata, +} + +#[derive(Debug, Default, Clone)] +pub(crate) struct MssqlStatementMetadata { + pub(crate) columns: Arc>, + pub(crate) column_names: Arc>, + pub(crate) parameters: usize, +} + +impl Statement for MssqlStatement { + type Database = Mssql; + + fn into_sql(self) -> SqlStr { + self.sql + } + + fn sql(&self) -> &SqlStr { + &self.sql + } + + fn parameters(&self) -> Option> { + Some(Either::Right(self.metadata.parameters)) + } + + fn columns(&self) -> &[MssqlColumn] { + &self.metadata.columns + } + + impl_statement_query!(MssqlArguments); +} + +impl ColumnIndex for &'_ str { + fn index(&self, statement: &MssqlStatement) -> Result { + statement + .metadata + .column_names + .get(*self) + .ok_or_else(|| Error::ColumnNotFound((*self).into())) + .copied() + } +} diff --git a/sqlx-mssql/src/testing/mod.rs b/sqlx-mssql/src/testing/mod.rs new file mode 100644 index 0000000000..54032da0ad --- /dev/null +++ b/sqlx-mssql/src/testing/mod.rs @@ -0,0 +1,209 @@ +use std::future::Future; +use std::ops::Deref; +use std::str::FromStr; +use std::sync::OnceLock; +use std::time::Duration; + +use crate::error::Error; +use crate::executor::Executor; +use crate::pool::{Pool, PoolOptions}; +use crate::query::query; +use crate::{Mssql, MssqlConnectOptions, MssqlConnection}; +use sqlx_core::connection::Connection; +use sqlx_core::query_scalar::query_scalar; + +pub(crate) use sqlx_core::testing::*; + +// Using a blocking `OnceLock` here because the critical sections are short. +static MASTER_POOL: OnceLock> = OnceLock::new(); + +impl TestSupport for Mssql { + fn test_context( + args: &TestArgs, + ) -> impl Future, Error>> + Send + '_ { + test_context(args) + } + + async fn cleanup_test(db_name: &str) -> Result<(), Error> { + let mut conn = MASTER_POOL + .get() + .expect("cleanup_test() invoked outside `#[sqlx::test]`") + .acquire() + .await?; + + do_cleanup(&mut conn, db_name).await + } + + async fn cleanup_test_dbs() -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let mut conn = MssqlConnection::connect(&url).await?; + + let delete_db_names: Vec = query_scalar("SELECT db_name FROM _sqlx_test_databases") + .fetch_all(&mut conn) + .await?; + + if delete_db_names.is_empty() { + return Ok(None); + } + + let mut deleted_count = 0usize; + + for db_name in &delete_db_names { + match query( + "IF DB_ID(@p1) IS NOT NULL \ + BEGIN \ + DECLARE @sql NVARCHAR(MAX); \ + SET @sql = N'ALTER DATABASE ' + QUOTENAME(@p1) + N' SET SINGLE_USER WITH ROLLBACK IMMEDIATE'; \ + EXEC sp_executesql @sql; \ + SET @sql = N'DROP DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END", + ) + .bind(db_name) + .execute(&mut conn) + .await + { + Ok(_deleted) => { + deleted_count += 1; + } + // Assume a database error just means the DB is still in use. + Err(Error::Database(dbe)) => { + eprintln!("could not clean test database {db_name:?}: {dbe}") + } + // Bubble up other errors + Err(e) => return Err(e), + } + } + + if deleted_count == 0 { + return Ok(None); + } + + // Clean up the tracking table + for db_name in &delete_db_names { + let _ = query("DELETE FROM _sqlx_test_databases WHERE db_name = @p1") + .bind(db_name) + .execute(&mut conn) + .await; + } + + let _ = conn.close().await; + Ok(Some(deleted_count)) + } + + async fn snapshot(_conn: &mut Self::Connection) -> Result, Error> { + Err(Error::Configuration( + "snapshots are not yet supported for MSSQL".into(), + )) + } +} + +async fn test_context(args: &TestArgs) -> Result, Error> { + let url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let master_opts = MssqlConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL"); + + let pool = PoolOptions::new() + .max_connections(20) + // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. + .after_release(|_conn, _| Box::pin(async move { Ok(false) })) + .connect_lazy_with(master_opts); + + let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) { + Ok(inserted) => inserted, + Err((existing, pool)) => { + assert_eq!( + existing.connect_options().host, + pool.connect_options().host, + "DATABASE_URL changed at runtime, host differs" + ); + + assert_eq!( + existing.connect_options().database, + pool.connect_options().database, + "DATABASE_URL changed at runtime, database differs" + ); + + existing + } + }; + + let mut conn = master_pool.acquire().await?; + + // Create tracking table if it doesn't exist + conn.execute( + r#" + IF NOT EXISTS (SELECT * FROM sys.tables WHERE name = '_sqlx_test_databases') + CREATE TABLE _sqlx_test_databases ( + db_name NVARCHAR(200) NOT NULL PRIMARY KEY, + test_path NVARCHAR(MAX) NOT NULL, + created_at DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME() + ); + "#, + ) + .await?; + + let db_name = Mssql::db_name(args); + do_cleanup(&mut conn, &db_name).await?; + + query("INSERT INTO _sqlx_test_databases(db_name, test_path) VALUES (@p1, @p2)") + .bind(&db_name) + .bind(args.test_path) + .execute(&mut *conn) + .await?; + + query( + "DECLARE @sql NVARCHAR(MAX) = N'CREATE DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql;", + ) + .bind(&db_name) + .execute(&mut *conn) + .await?; + + eprintln!("created database {db_name}"); + + Ok(TestContext { + pool_opts: PoolOptions::new() + .max_connections(5) + .idle_timeout(Some(Duration::from_secs(1))) + .parent(master_pool.clone()), + connect_opts: master_pool + .connect_options() + .deref() + .clone() + .database(&db_name), + db_name, + }) +} + +async fn do_cleanup(conn: &mut MssqlConnection, db_name: &str) -> Result<(), Error> { + query( + "IF DB_ID(@p1) IS NOT NULL \ + BEGIN \ + DECLARE @sql NVARCHAR(MAX); \ + SET @sql = N'ALTER DATABASE ' + QUOTENAME(@p1) + N' SET SINGLE_USER WITH ROLLBACK IMMEDIATE'; \ + EXEC sp_executesql @sql; \ + SET @sql = N'DROP DATABASE ' + QUOTENAME(@p1); \ + EXEC sp_executesql @sql; \ + END", + ) + .bind(db_name) + .execute(&mut *conn) + .await?; + query("DELETE FROM _sqlx_test_databases WHERE db_name = @p1") + .bind(db_name) + .execute(&mut *conn) + .await?; + + Ok(()) +} + +fn once_lock_try_insert_polyfill(this: &OnceLock, value: T) -> Result<&T, (&T, T)> { + let mut value = Some(value); + let res = this.get_or_init(|| value.take().unwrap()); + match value { + None => Ok(res), + Some(value) => Err((res, value)), + } +} diff --git a/sqlx-mssql/src/transaction.rs b/sqlx-mssql/src/transaction.rs new file mode 100644 index 0000000000..b74862235c --- /dev/null +++ b/sqlx-mssql/src/transaction.rs @@ -0,0 +1,122 @@ +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; + +use crate::error::{tiberius_err, Error}; +use crate::executor::Executor; +use crate::{Mssql, MssqlConnection}; + +pub(crate) use sqlx_core::transaction::*; + +/// Implementation of [`TransactionManager`] for MSSQL. +/// +/// MSSQL uses non-ANSI syntax for savepoints: +/// - depth 0 -> `BEGIN TRANSACTION` +/// - depth N -> `SAVE TRANSACTION _sqlx_savepoint_N` +/// - commit depth 1 -> `COMMIT` +/// - commit depth N -> no-op (savepoints auto-commit with parent) +/// - rollback depth 1 -> `ROLLBACK` +/// - rollback depth N -> `ROLLBACK TRANSACTION _sqlx_savepoint_N` +pub struct MssqlTransactionManager; + +impl TransactionManager for MssqlTransactionManager { + type Database = Mssql; + + async fn begin(conn: &mut MssqlConnection, statement: Option) -> Result<(), Error> { + let depth = conn.inner.transaction_depth; + + // Execute any pending rollback first + resolve_pending_rollback(conn).await?; + + let statement = match statement { + Some(_) if depth > 0 => return Err(Error::InvalidSavePointStatement), + Some(statement) => statement, + None => { + if depth == 0 { + SqlStr::from_static("BEGIN TRANSACTION") + } else { + AssertSqlSafe(format!("SAVE TRANSACTION _sqlx_savepoint_{}", depth)) + .into_sql_str() + } + } + }; + + conn.execute(statement).await?; + conn.inner.transaction_depth += 1; + + Ok(()) + } + + async fn commit(conn: &mut MssqlConnection) -> Result<(), Error> { + let depth = conn.inner.transaction_depth; + + if depth > 0 { + if depth == 1 { + // Only the outermost transaction actually commits + conn.execute("COMMIT").await?; + } + // Savepoints auto-commit with their parent transaction, so no-op for depth > 1 + conn.inner.transaction_depth = depth - 1; + } + + Ok(()) + } + + async fn rollback(conn: &mut MssqlConnection) -> Result<(), Error> { + let depth = conn.inner.transaction_depth; + + if depth > 0 { + if depth == 1 { + conn.execute("ROLLBACK").await?; + } else { + let savepoint = format!("ROLLBACK TRANSACTION _sqlx_savepoint_{}", depth - 1); + conn.execute(AssertSqlSafe(savepoint)).await?; + } + conn.inner.transaction_depth = depth - 1; + } + + Ok(()) + } + + fn start_rollback(conn: &mut MssqlConnection) { + let depth = conn.inner.transaction_depth; + if depth > 0 { + // We can't execute async SQL from a synchronous context (Drop), + // so we set a flag and execute the rollback on the next operation. + conn.inner.pending_rollback = true; + conn.inner.transaction_depth = depth - 1; + } + } + + fn get_transaction_depth(conn: &MssqlConnection) -> usize { + conn.inner.transaction_depth + } +} + +/// Execute pending rollback if one was triggered by `start_rollback`. +pub(crate) async fn resolve_pending_rollback(conn: &mut MssqlConnection) -> Result<(), Error> { + if conn.inner.pending_rollback { + conn.inner.pending_rollback = false; + let depth = conn.inner.transaction_depth; + if depth == 0 { + // Rollback the entire transaction + conn.inner + .client + .simple_query("ROLLBACK") + .await + .map_err(tiberius_err)? + .into_results() + .await + .map_err(tiberius_err)?; + } else { + let savepoint = format!("ROLLBACK TRANSACTION _sqlx_savepoint_{}", depth); + conn.inner + .client + .simple_query(savepoint) + .await + .map_err(tiberius_err)? + .into_results() + .await + .map_err(tiberius_err)?; + } + } + Ok(()) +} diff --git a/sqlx-mssql/src/type_checking.rs b/sqlx-mssql/src/type_checking.rs new file mode 100644 index 0000000000..aa4cbcffe7 --- /dev/null +++ b/sqlx-mssql/src/type_checking.rs @@ -0,0 +1,56 @@ +// Type mappings used by the macros and `Debug` impls. + +#[allow(unused_imports)] +use sqlx_core as sqlx; + +use crate::Mssql; + +impl_type_checking!( + Mssql { + u8, + i8, + i16, + i32, + i64, + f32, + f64, + + // ordering is important here as otherwise we might infer strings to be binary + // NVARCHAR, VARCHAR, NCHAR, CHAR, NTEXT, TEXT + String, + + // VARBINARY, BINARY, IMAGE + Vec, + + #[cfg(feature = "uuid")] + sqlx::types::Uuid, + + #[cfg(feature = "json")] + sqlx::types::JsonValue, + }, + ParamChecking::Weak, + feature-types: _info => None, + datetime-types: { + chrono: { + sqlx::types::chrono::NaiveTime, + sqlx::types::chrono::NaiveDate, + sqlx::types::chrono::NaiveDateTime, + sqlx::types::chrono::DateTime, + sqlx::types::chrono::DateTime, + }, + time: { + sqlx::types::time::Time, + sqlx::types::time::Date, + sqlx::types::time::PrimitiveDateTime, + sqlx::types::time::OffsetDateTime, + }, + }, + numeric-types: { + bigdecimal: { + sqlx::types::BigDecimal, + }, + rust_decimal: { + sqlx::types::Decimal, + }, + }, +); diff --git a/sqlx-mssql/src/type_info.rs b/sqlx-mssql/src/type_info.rs new file mode 100644 index 0000000000..76907aacbb --- /dev/null +++ b/sqlx-mssql/src/type_info.rs @@ -0,0 +1,75 @@ +use std::fmt::{self, Display, Formatter}; + +pub(crate) use sqlx_core::type_info::*; + +/// Type information for a MSSQL type. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "offline", derive(serde::Serialize, serde::Deserialize))] +pub struct MssqlTypeInfo { + pub(crate) name: String, +} + +impl MssqlTypeInfo { + pub(crate) fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + /// Return the base type name without any parenthesized precision/scale. + /// + /// e.g. `"DECIMAL(10,2)"` → `"DECIMAL"`, `"NVARCHAR(4000)"` → `"NVARCHAR"` + pub(crate) fn base_name(&self) -> &str { + self.name.split('(').next().unwrap_or(&self.name).trim() + } +} + +impl Display for MssqlTypeInfo { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.pad(&self.name) + } +} + +impl TypeInfo for MssqlTypeInfo { + fn is_null(&self) -> bool { + false + } + + fn name(&self) -> &str { + &self.name + } +} + +/// Map a tiberius column type to a MSSQL type name string. +pub(crate) fn type_name_for_tiberius(col_type: &tiberius::ColumnType) -> &'static str { + match col_type { + tiberius::ColumnType::Null => "NULL", + tiberius::ColumnType::Bit => "BIT", + tiberius::ColumnType::Int1 => "TINYINT", + tiberius::ColumnType::Int2 => "SMALLINT", + tiberius::ColumnType::Int4 => "INT", + tiberius::ColumnType::Int8 => "BIGINT", + tiberius::ColumnType::Float4 => "REAL", + tiberius::ColumnType::Float8 => "FLOAT", + tiberius::ColumnType::Datetime | tiberius::ColumnType::Datetimen => "DATETIME", + tiberius::ColumnType::Datetime2 => "DATETIME2", + tiberius::ColumnType::Datetime4 => "SMALLDATETIME", + tiberius::ColumnType::DatetimeOffsetn => "DATETIMEOFFSET", + tiberius::ColumnType::Daten => "DATE", + tiberius::ColumnType::Timen => "TIME", + tiberius::ColumnType::Decimaln | tiberius::ColumnType::Numericn => "DECIMAL", + tiberius::ColumnType::Money => "MONEY", + tiberius::ColumnType::Money4 => "SMALLMONEY", + tiberius::ColumnType::BigVarChar | tiberius::ColumnType::NVarchar => "NVARCHAR", + tiberius::ColumnType::BigChar | tiberius::ColumnType::NChar => "NCHAR", + tiberius::ColumnType::BigVarBin => "VARBINARY", + tiberius::ColumnType::BigBinary => "BINARY", + tiberius::ColumnType::Text | tiberius::ColumnType::NText => "NTEXT", + tiberius::ColumnType::Image => "IMAGE", + tiberius::ColumnType::Xml => "XML", + tiberius::ColumnType::Guid => "UNIQUEIDENTIFIER", + tiberius::ColumnType::Intn => "INT", + tiberius::ColumnType::Bitn => "BIT", + tiberius::ColumnType::Floatn => "FLOAT", + tiberius::ColumnType::SSVariant => "SQL_VARIANT", + _ => "UNKNOWN", + } +} diff --git a/sqlx-mssql/src/types/bigdecimal.rs b/sqlx-mssql/src/types/bigdecimal.rs new file mode 100644 index 0000000000..c0a67851ef --- /dev/null +++ b/sqlx-mssql/src/types/bigdecimal.rs @@ -0,0 +1,46 @@ +use bigdecimal::BigDecimal; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for BigDecimal { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DECIMAL") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.base_name(), + "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY" + ) + } +} + +impl Encode<'_, Mssql> for BigDecimal { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::BigDecimal(self.clone())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for BigDecimal { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::BigDecimal(ref v) => Ok(v.clone()), + MssqlData::I32(v) => Ok(BigDecimal::from(*v)), + MssqlData::I64(v) => Ok(BigDecimal::from(*v)), + MssqlData::F64(v) => bigdecimal::FromPrimitive::from_f64(*v) + .ok_or_else(|| format!("failed to convert f64 {v} to BigDecimal").into()), + MssqlData::String(ref s) => s + .parse::() + .map_err(|e| format!("failed to parse BigDecimal from string: {e}").into()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected DECIMAL, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/bool.rs b/sqlx-mssql/src/types/bool.rs new file mode 100644 index 0000000000..a4eb3b0904 --- /dev/null +++ b/sqlx-mssql/src/types/bool.rs @@ -0,0 +1,41 @@ +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for bool { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("BIT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.base_name(), + "BIT" | "TINYINT" | "INT" | "SMALLINT" | "BIGINT" + ) + } +} + +impl Encode<'_, Mssql> for bool { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::Bool(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for bool { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::Bool(v) => Ok(*v), + MssqlData::U8(v) => Ok(*v != 0), + MssqlData::I16(v) => Ok(*v != 0), + MssqlData::I32(v) => Ok(*v != 0), + MssqlData::I64(v) => Ok(*v != 0), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected bool-compatible type, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/bytes.rs b/sqlx-mssql/src/types/bytes.rs new file mode 100644 index 0000000000..f1a2d40322 --- /dev/null +++ b/sqlx-mssql/src/types/bytes.rs @@ -0,0 +1,64 @@ +use std::borrow::Cow; +use std::rc::Rc; +use std::sync::Arc; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn bytes_compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "VARBINARY" | "BINARY" | "IMAGE") +} + +impl Type for [u8] { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("VARBINARY") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + bytes_compatible(ty) + } +} + +impl Encode<'_, Mssql> for &'_ [u8] { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::Binary(self.to_vec())); + Ok(IsNull::No) + } +} + +impl<'r> Decode<'r, Mssql> for &'r [u8] { + fn decode(value: MssqlValueRef<'r>) -> Result { + value.as_bytes() + } +} + +impl Type for Vec { + fn type_info() -> MssqlTypeInfo { + <[u8] as Type>::type_info() + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + <[u8] as Type>::compatible(ty) + } +} + +impl Encode<'_, Mssql> for Vec { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + <&[u8] as Encode>::encode(&**self, buf) + } +} + +impl Decode<'_, Mssql> for Vec { + fn decode(value: MssqlValueRef<'_>) -> Result { + <&[u8] as Decode>::decode(value).map(ToOwned::to_owned) + } +} + +forward_encode_impl!(Arc<[u8]>, &[u8], Mssql); +forward_encode_impl!(Rc<[u8]>, &[u8], Mssql); +forward_encode_impl!(Box<[u8]>, &[u8], Mssql); +forward_encode_impl!(Cow<'_, [u8]>, &[u8], Mssql); diff --git a/sqlx-mssql/src/types/chrono.rs b/sqlx-mssql/src/types/chrono.rs new file mode 100644 index 0000000000..d50a7a1587 --- /dev/null +++ b/sqlx-mssql/src/types/chrono.rs @@ -0,0 +1,166 @@ +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, Utc}; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +// ── NaiveDateTime ─────────────────────────────────────────────────────────── + +impl Type for NaiveDateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIME2") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "DATETIME2" | "DATETIME" | "SMALLDATETIME") + } +} + +impl Encode<'_, Mssql> for NaiveDateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::NaiveDateTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for NaiveDateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveDateTime(v) => Ok(*v), + MssqlData::DateTimeFixedOffset(v) => Ok(v.naive_utc()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetime, got {:?}", value.data).into()), + } + } +} + +// ── NaiveDate ─────────────────────────────────────────────────────────────── + +impl Type for NaiveDate { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATE") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "DATE" + } +} + +impl Encode<'_, Mssql> for NaiveDate { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::NaiveDate(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for NaiveDate { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveDate(v) => Ok(*v), + MssqlData::NaiveDateTime(v) => Ok(v.date()), + MssqlData::DateTimeFixedOffset(v) => Ok(v.naive_utc().date()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected date, got {:?}", value.data).into()), + } + } +} + +// ── NaiveTime ─────────────────────────────────────────────────────────────── + +impl Type for NaiveTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TIME") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "TIME" + } +} + +impl Encode<'_, Mssql> for NaiveTime { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::NaiveTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for NaiveTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveTime(v) => Ok(*v), + MssqlData::NaiveDateTime(v) => Ok(v.time()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected time, got {:?}", value.data).into()), + } + } +} + +// ── DateTime ─────────────────────────────────────────────────────────── + +impl Type for DateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIME2") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "DATETIME2" | "DATETIMEOFFSET") + } +} + +impl Encode<'_, Mssql> for DateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::NaiveDateTime(self.naive_utc())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for DateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::NaiveDateTime(v) => Ok(v.and_utc()), + MssqlData::DateTimeFixedOffset(v) => Ok(v.with_timezone(&Utc)), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetime, got {:?}", value.data).into()), + } + } +} + +// ── DateTime ─────────────────────────────────────────────────── + +impl Type for DateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIMEOFFSET") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "DATETIMEOFFSET" | "DATETIME2") + } +} + +impl Encode<'_, Mssql> for DateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::DateTimeFixedOffset(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for DateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::DateTimeFixedOffset(v) => Ok(*v), + MssqlData::NaiveDateTime(v) => { + // Assume UTC if no offset information + let utc = v.and_utc(); + Ok(utc.with_timezone( + &FixedOffset::east_opt(0).expect("UTC offset 0 is always valid"), + )) + } + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetimeoffset, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/float.rs b/sqlx-mssql/src/types/float.rs new file mode 100644 index 0000000000..98114e3836 --- /dev/null +++ b/sqlx-mssql/src/types/float.rs @@ -0,0 +1,68 @@ +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn real_compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "REAL" | "FLOAT" | "MONEY" | "SMALLMONEY") +} + +impl Type for f32 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("REAL") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + real_compatible(ty) + } +} + +impl Encode<'_, Mssql> for f32 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::F32(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for f32 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::F32(v) => Ok(*v), + #[allow(clippy::cast_possible_truncation)] + MssqlData::F64(v) => Ok(*v as f32), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected float, got {:?}", value.data).into()), + } + } +} + +impl Type for f64 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("FLOAT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + real_compatible(ty) + } +} + +impl Encode<'_, Mssql> for f64 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::F64(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for f64 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::F32(v) => Ok(f64::from(*v)), + MssqlData::F64(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected float, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/int.rs b/sqlx-mssql/src/types/int.rs new file mode 100644 index 0000000000..ffe4e9c2da --- /dev/null +++ b/sqlx-mssql/src/types/int.rs @@ -0,0 +1,170 @@ +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn int_compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "TINYINT" | "SMALLINT" | "INT" | "BIGINT") +} + +// u8 - MSSQL's TINYINT is unsigned (0-255) +impl Type for u8 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TINYINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for u8 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::U8(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for u8 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(*v), + MssqlData::I16(v) => Ok((*v).try_into()?), + MssqlData::I32(v) => Ok((*v).try_into()?), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i8 - maps to TINYINT but only 0-127 range +impl Type for i8 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TINYINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i8 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + if *self < 0 { + return Err("MSSQL TINYINT is unsigned; cannot encode negative i8".into()); + } + #[allow(clippy::cast_sign_loss)] + buf.push(MssqlArgumentValue::U8(*self as u8)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i8 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok((*v).try_into()?), + MssqlData::I16(v) => Ok((*v).try_into()?), + MssqlData::I32(v) => Ok((*v).try_into()?), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i16 +impl Type for i16 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("SMALLINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i16 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::I16(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i16 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(i16::from(*v)), + MssqlData::I16(v) => Ok(*v), + MssqlData::I32(v) => Ok((*v).try_into()?), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i32 +impl Type for i32 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("INT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i32 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::I32(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i32 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(i32::from(*v)), + MssqlData::I16(v) => Ok(i32::from(*v)), + MssqlData::I32(v) => Ok(*v), + MssqlData::I64(v) => Ok((*v).try_into()?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} + +// i64 +impl Type for i64 { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("BIGINT") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + int_compatible(ty) + } +} + +impl Encode<'_, Mssql> for i64 { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::I64(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for i64 { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::U8(v) => Ok(i64::from(*v)), + MssqlData::I16(v) => Ok(i64::from(*v)), + MssqlData::I32(v) => Ok(i64::from(*v)), + MssqlData::I64(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected integer, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/json.rs b/sqlx-mssql/src/types/json.rs new file mode 100644 index 0000000000..4557de0b06 --- /dev/null +++ b/sqlx-mssql/src/types/json.rs @@ -0,0 +1,39 @@ +use serde::{Deserialize, Serialize}; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::{Json, Type}; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for Json { + fn type_info() -> MssqlTypeInfo { + // SQL Server has no native JSON type; JSON is stored as NVARCHAR + MssqlTypeInfo::new("NVARCHAR") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + +impl Encode<'_, Mssql> for Json +where + T: Serialize, +{ + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + let json_string = self.encode_to_string()?; + buf.push(MssqlArgumentValue::String(json_string)); + Ok(IsNull::No) + } +} + +impl<'r, T> Decode<'r, Mssql> for Json +where + T: Deserialize<'r> + 'r, +{ + fn decode(value: MssqlValueRef<'r>) -> Result { + Json::decode_from_string(value.as_str()?) + } +} diff --git a/sqlx-mssql/src/types/mod.rs b/sqlx-mssql/src/types/mod.rs new file mode 100644 index 0000000000..cb82c2f267 --- /dev/null +++ b/sqlx-mssql/src/types/mod.rs @@ -0,0 +1,55 @@ +//! Conversions between Rust and **MSSQL** types. +//! +//! # Types +//! +//! | Rust type | MSSQL type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `bool` | BIT | +//! | `u8` | TINYINT (unsigned, 0-255) | +//! | `i8` | TINYINT (0-127 only) | +//! | `i16` | SMALLINT | +//! | `i32` | INT | +//! | `i64` | BIGINT | +//! | `f32` | REAL, FLOAT | +//! | `f64` | REAL, FLOAT, MONEY, SMALLMONEY | +//! | `&str`, [`String`] | NVARCHAR | +//! | `&[u8]`, `Vec` | VARBINARY | +//! +//! ### Feature-gated +//! +//! | Rust type | MSSQL type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `uuid::Uuid` | UNIQUEIDENTIFIER | +//! | `rust_decimal::Decimal` | DECIMAL, NUMERIC, MONEY | +//! | `bigdecimal::BigDecimal` | DECIMAL, NUMERIC, MONEY | +//! | `time::Date` | DATE | +//! | `time::Time` | TIME | +//! | `time::PrimitiveDateTime` | DATETIME2, DATETIME, SMALLDATETIME | +//! | `time::OffsetDateTime` | DATETIMEOFFSET, DATETIME2 | +//! | `serde_json::Value` (`Json`) | NVARCHAR (JSON stored as string) | +//! +//! # Nullable +//! +//! In addition, `Option` is supported where `T` implements `Type`. An `Option` represents +//! a potentially `NULL` value from MSSQL. + +pub(crate) use sqlx_core::types::*; + +#[cfg(feature = "bigdecimal")] +mod bigdecimal; +mod bool; +mod bytes; +#[cfg(feature = "chrono")] +mod chrono; +mod float; +mod int; +#[cfg(feature = "json")] +mod json; +#[cfg(feature = "rust_decimal")] +mod rust_decimal; +mod str; +#[cfg(feature = "time")] +mod time; +#[cfg(feature = "uuid")] +mod uuid; +pub mod xml; diff --git a/sqlx-mssql/src/types/rust_decimal.rs b/sqlx-mssql/src/types/rust_decimal.rs new file mode 100644 index 0000000000..c942f549f7 --- /dev/null +++ b/sqlx-mssql/src/types/rust_decimal.rs @@ -0,0 +1,46 @@ +use rust_decimal::Decimal; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for Decimal { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DECIMAL") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.base_name(), + "DECIMAL" | "NUMERIC" | "MONEY" | "SMALLMONEY" + ) + } +} + +impl Encode<'_, Mssql> for Decimal { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::Decimal(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Decimal { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::Decimal(v) => Ok(*v), + MssqlData::I32(v) => Ok(Decimal::from(*v)), + MssqlData::I64(v) => Ok(Decimal::from(*v)), + MssqlData::F64(v) => Decimal::try_from(*v) + .map_err(|e| format!("failed to convert f64 to Decimal: {e}").into()), + MssqlData::String(ref s) => s + .parse::() + .map_err(|e| format!("failed to parse Decimal from string: {e}").into()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected DECIMAL, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/str.rs b/sqlx-mssql/src/types/str.rs new file mode 100644 index 0000000000..816a516167 --- /dev/null +++ b/sqlx-mssql/src/types/str.rs @@ -0,0 +1,67 @@ +use std::borrow::Cow; +use std::rc::Rc; +use std::sync::Arc; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +fn str_compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.base_name(), + "NVARCHAR" | "VARCHAR" | "NCHAR" | "CHAR" | "NTEXT" | "TEXT" | "XML" + ) +} + +impl Type for str { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("NVARCHAR") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + str_compatible(ty) + } +} + +impl Encode<'_, Mssql> for &'_ str { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::String((*self).to_owned())); + Ok(IsNull::No) + } +} + +impl<'r> Decode<'r, Mssql> for &'r str { + fn decode(value: MssqlValueRef<'r>) -> Result { + value.as_str() + } +} + +impl Type for String { + fn type_info() -> MssqlTypeInfo { + >::type_info() + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + >::compatible(ty) + } +} + +impl Encode<'_, Mssql> for String { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + <&str as Encode>::encode(self.as_str(), buf) + } +} + +impl Decode<'_, Mssql> for String { + fn decode(value: MssqlValueRef<'_>) -> Result { + <&str as Decode>::decode(value).map(ToOwned::to_owned) + } +} + +forward_encode_impl!(Arc, &str, Mssql); +forward_encode_impl!(Rc, &str, Mssql); +forward_encode_impl!(Cow<'_, str>, &str, Mssql); +forward_encode_impl!(Box, &str, Mssql); diff --git a/sqlx-mssql/src/types/time.rs b/sqlx-mssql/src/types/time.rs new file mode 100644 index 0000000000..e86f6a3061 --- /dev/null +++ b/sqlx-mssql/src/types/time.rs @@ -0,0 +1,128 @@ +use time::{Date, OffsetDateTime, PrimitiveDateTime, Time}; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +// ── Date ─────────────────────────────────────────────────────────────────── + +impl Type for Date { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATE") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "DATE" + } +} + +impl Encode<'_, Mssql> for Date { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::TimeDate(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Date { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimeDate(v) => Ok(*v), + MssqlData::TimePrimitiveDateTime(v) => Ok(v.date()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected date, got {:?}", value.data).into()), + } + } +} + +// ── Time ─────────────────────────────────────────────────────────────────── + +impl Type for Time { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("TIME") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "TIME" + } +} + +impl Encode<'_, Mssql> for Time { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::TimeTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Time { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimeTime(v) => Ok(*v), + MssqlData::TimePrimitiveDateTime(v) => Ok(v.time()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected time, got {:?}", value.data).into()), + } + } +} + +// ── PrimitiveDateTime ────────────────────────────────────────────────────── + +impl Type for PrimitiveDateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIME2") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "DATETIME2" | "DATETIME" | "SMALLDATETIME") + } +} + +impl Encode<'_, Mssql> for PrimitiveDateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::TimePrimitiveDateTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for PrimitiveDateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimePrimitiveDateTime(v) => Ok(*v), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetime, got {:?}", value.data).into()), + } + } +} + +// ── OffsetDateTime ───────────────────────────────────────────────────────── + +impl Type for OffsetDateTime { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("DATETIMEOFFSET") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!(ty.base_name(), "DATETIMEOFFSET" | "DATETIME2") + } +} + +impl Encode<'_, Mssql> for OffsetDateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::TimeOffsetDateTime(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for OffsetDateTime { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::TimeOffsetDateTime(v) => Ok(*v), + MssqlData::TimePrimitiveDateTime(v) => Ok(v.assume_utc()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected datetimeoffset, got {:?}", value.data).into()), + } + } +} diff --git a/sqlx-mssql/src/types/uuid.rs b/sqlx-mssql/src/types/uuid.rs new file mode 100644 index 0000000000..e382fe6a3f --- /dev/null +++ b/sqlx-mssql/src/types/uuid.rs @@ -0,0 +1,61 @@ +use uuid::Uuid; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::value::MssqlData; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +impl Type for Uuid { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("UNIQUEIDENTIFIER") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "UNIQUEIDENTIFIER" + } +} + +impl Encode<'_, Mssql> for Uuid { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::Uuid(*self)); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for Uuid { + fn decode(value: MssqlValueRef<'_>) -> Result { + match value.data { + MssqlData::Uuid(v) => Ok(*v), + MssqlData::String(ref s) => Ok(Uuid::parse_str(s)?), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected UNIQUEIDENTIFIER, got {:?}", value.data).into()), + } + } +} + +impl Type for uuid::fmt::Hyphenated { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("UNIQUEIDENTIFIER") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + ty.base_name() == "UNIQUEIDENTIFIER" + } +} + +impl Encode<'_, Mssql> for uuid::fmt::Hyphenated { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::Uuid(*self.as_uuid())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for uuid::fmt::Hyphenated { + fn decode(value: MssqlValueRef<'_>) -> Result { + let uuid = Uuid::decode(value)?; + Ok(uuid.hyphenated()) + } +} diff --git a/sqlx-mssql/src/types/xml.rs b/sqlx-mssql/src/types/xml.rs new file mode 100644 index 0000000000..ace1b70940 --- /dev/null +++ b/sqlx-mssql/src/types/xml.rs @@ -0,0 +1,78 @@ +use std::fmt; + +use crate::database::MssqlArgumentValue; +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::types::Type; +use crate::{Mssql, MssqlTypeInfo, MssqlValueRef}; + +/// SQL Server `XML` column type. +/// +/// A newtype wrapper around [`String`] that maps to the MSSQL `XML` type. +/// This allows sqlx macros to distinguish `XML` columns from `NVARCHAR`. +/// +/// # Example +/// +/// ```rust,no_run +/// # async fn example() -> sqlx::Result<()> { +/// use sqlx::mssql::MssqlXml; +/// +/// let xml = MssqlXml::from("hello".to_owned()); +/// assert_eq!(xml.as_ref(), "hello"); +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct MssqlXml(pub String); + +impl Type for MssqlXml { + fn type_info() -> MssqlTypeInfo { + MssqlTypeInfo::new("XML") + } + + fn compatible(ty: &MssqlTypeInfo) -> bool { + matches!( + ty.base_name(), + "XML" | "NVARCHAR" | "VARCHAR" | "NTEXT" | "TEXT" + ) + } +} + +impl Encode<'_, Mssql> for MssqlXml { + fn encode_by_ref(&self, buf: &mut Vec) -> Result { + buf.push(MssqlArgumentValue::String(self.0.clone())); + Ok(IsNull::No) + } +} + +impl Decode<'_, Mssql> for MssqlXml { + fn decode(value: MssqlValueRef<'_>) -> Result { + let s = value.as_str()?; + Ok(MssqlXml(s.to_owned())) + } +} + +impl From for MssqlXml { + fn from(s: String) -> Self { + MssqlXml(s) + } +} + +impl From for String { + fn from(xml: MssqlXml) -> Self { + xml.0 + } +} + +impl AsRef for MssqlXml { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for MssqlXml { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} diff --git a/sqlx-mssql/src/value.rs b/sqlx-mssql/src/value.rs new file mode 100644 index 0000000000..ee60ad1414 --- /dev/null +++ b/sqlx-mssql/src/value.rs @@ -0,0 +1,379 @@ +use std::borrow::Cow; + +pub(crate) use sqlx_core::value::*; + +use crate::error::{BoxDynError, Error}; +use crate::{Mssql, MssqlTypeInfo}; + +/// Internal storage for an MSSQL value, decoupled from tiberius lifetimes. +#[derive(Debug, Clone)] +pub(crate) enum MssqlData { + Null, + Bool(bool), + U8(u8), + I16(i16), + I32(i32), + I64(i64), + F32(f32), + F64(f64), + String(String), + Binary(Vec), + #[cfg(feature = "chrono")] + NaiveDateTime(chrono::NaiveDateTime), + #[cfg(feature = "chrono")] + NaiveDate(chrono::NaiveDate), + #[cfg(feature = "chrono")] + NaiveTime(chrono::NaiveTime), + #[cfg(feature = "chrono")] + DateTimeFixedOffset(chrono::DateTime), + #[cfg(feature = "uuid")] + Uuid(uuid::Uuid), + #[cfg(feature = "rust_decimal")] + Decimal(rust_decimal::Decimal), + #[cfg(all(feature = "time", not(feature = "chrono")))] + TimeDate(time::Date), + #[cfg(all(feature = "time", not(feature = "chrono")))] + TimeTime(time::Time), + #[cfg(all(feature = "time", not(feature = "chrono")))] + TimePrimitiveDateTime(time::PrimitiveDateTime), + #[cfg(all(feature = "time", not(feature = "chrono")))] + TimeOffsetDateTime(time::OffsetDateTime), + #[cfg(all(feature = "bigdecimal", not(feature = "rust_decimal")))] + BigDecimal(bigdecimal::BigDecimal), +} + +/// Implementation of [`Value`] for MSSQL. +#[derive(Debug, Clone)] +pub struct MssqlValue { + pub(crate) data: MssqlData, + pub(crate) type_info: MssqlTypeInfo, +} + +/// Implementation of [`ValueRef`] for MSSQL. +#[derive(Debug, Clone)] +pub struct MssqlValueRef<'r> { + pub(crate) data: &'r MssqlData, + pub(crate) type_info: MssqlTypeInfo, +} + +impl<'r> MssqlValueRef<'r> { + pub(crate) fn as_str(&self) -> Result<&'r str, BoxDynError> { + match self.data { + MssqlData::String(ref s) => Ok(s.as_str()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected string, got {:?}", self.data).into()), + } + } + + pub(crate) fn as_bytes(&self) -> Result<&'r [u8], BoxDynError> { + match self.data { + MssqlData::Binary(ref b) => Ok(b.as_slice()), + MssqlData::String(ref s) => Ok(s.as_bytes()), + MssqlData::Null => Err("unexpected NULL".into()), + _ => Err(format!("expected binary, got {:?}", self.data).into()), + } + } +} + +impl Value for MssqlValue { + type Database = Mssql; + + fn as_ref(&self) -> MssqlValueRef<'_> { + MssqlValueRef { + data: &self.data, + type_info: self.type_info.clone(), + } + } + + fn type_info(&self) -> Cow<'_, MssqlTypeInfo> { + Cow::Borrowed(&self.type_info) + } + + fn is_null(&self) -> bool { + matches!(self.data, MssqlData::Null) + } +} + +impl<'r> ValueRef<'r> for MssqlValueRef<'r> { + type Database = Mssql; + + fn to_owned(&self) -> MssqlValue { + MssqlValue { + data: self.data.clone(), + type_info: self.type_info.clone(), + } + } + + fn type_info(&self) -> Cow<'_, MssqlTypeInfo> { + Cow::Borrowed(&self.type_info) + } + + fn is_null(&self) -> bool { + matches!(self.data, MssqlData::Null) + } +} + +/// Convert a `tiberius::ColumnData` into our owned `MssqlData`. +pub(crate) fn column_data_to_mssql_data( + data: tiberius::ColumnData<'_>, +) -> Result { + match data { + tiberius::ColumnData::U8(Some(v)) => Ok(MssqlData::U8(v)), + tiberius::ColumnData::I16(Some(v)) => Ok(MssqlData::I16(v)), + tiberius::ColumnData::I32(Some(v)) => Ok(MssqlData::I32(v)), + tiberius::ColumnData::I64(Some(v)) => Ok(MssqlData::I64(v)), + tiberius::ColumnData::F32(Some(v)) => Ok(MssqlData::F32(v)), + tiberius::ColumnData::F64(Some(v)) => Ok(MssqlData::F64(v)), + tiberius::ColumnData::Bit(Some(v)) => Ok(MssqlData::Bool(v)), + tiberius::ColumnData::String(Some(v)) => Ok(MssqlData::String(v.into_owned())), + tiberius::ColumnData::Binary(Some(v)) => Ok(MssqlData::Binary(v.into_owned())), + tiberius::ColumnData::Xml(Some(xml)) => { + Ok(MssqlData::String(xml.into_owned().into_string())) + } + + #[cfg(feature = "chrono")] + tiberius::ColumnData::DateTime2(Some(dt2)) => { + let date = chrono_date_from_days(dt2.date().days() as i64, 1)?; + // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + let t = dt2.time(); + #[allow(clippy::cast_possible_wrap)] + let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + // infallible: (0,0,0) is always valid + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new( + date, time, + ))) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::DateTime(Some(dt)) => { + let date = chrono_date_from_days(dt.days() as i64, 1900)?; + let ns = dt.seconds_fragments() as i64 * 1_000_000_000i64 / 300; + // infallible: (0,0,0) is always valid + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new( + date, time, + ))) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::SmallDateTime(Some(dt)) => { + let date = chrono_date_from_days(dt.days() as i64, 1900)?; + let seconds = dt.seconds_fragments() as u32 * 60; + let time = chrono::NaiveTime::from_num_seconds_from_midnight_opt(seconds, 0) + .ok_or_else(|| { + Error::Protocol(format!( + "invalid SmallDateTime seconds: {seconds} exceeds seconds-in-a-day" + )) + })?; + Ok(MssqlData::NaiveDateTime(chrono::NaiveDateTime::new( + date, time, + ))) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::Date(Some(d)) => Ok(MssqlData::NaiveDate(chrono_date_from_days( + d.days() as i64, + 1, + )?)), + #[cfg(feature = "chrono")] + tiberius::ColumnData::Time(Some(t)) => { + // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + #[allow(clippy::cast_possible_wrap)] + let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + // infallible: (0,0,0) is always valid + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + Ok(MssqlData::NaiveTime(time)) + } + #[cfg(feature = "chrono")] + tiberius::ColumnData::DateTimeOffset(Some(dto)) => { + let date = chrono_date_from_days(dto.datetime2().date().days() as i64, 1)?; + // SAFETY: TDS time increments at scale 7 max out at 863_999_999_999, well within i64. + let t = dto.datetime2().time(); + #[allow(clippy::cast_possible_wrap)] + let ns = t.increments() as i64 * 10i64.pow(9u32.saturating_sub(t.scale() as u32)); + // infallible: (0,0,0) is always valid + let time = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap() + + chrono::Duration::nanoseconds(ns); + let naive = chrono::NaiveDateTime::new(date, time); + let offset_secs = dto.offset() as i32 * 60; + let fixed_offset = chrono::FixedOffset::east_opt(offset_secs).ok_or_else(|| { + Error::Protocol(format!("invalid timezone offset: {offset_secs} seconds")) + })?; + let dt = naive + .and_local_timezone(fixed_offset) + .single() + .ok_or_else(|| { + Error::Protocol(format!( + "ambiguous or invalid local time for offset {offset_secs}s" + )) + })?; + Ok(MssqlData::DateTimeFixedOffset(dt)) + } + + #[cfg(feature = "uuid")] + tiberius::ColumnData::Guid(Some(v)) => Ok(MssqlData::Uuid(v)), + + #[cfg(feature = "rust_decimal")] + tiberius::ColumnData::Numeric(Some(n)) => Ok(MssqlData::Decimal( + rust_decimal::Decimal::from_i128_with_scale(n.value(), n.scale() as u32), + )), + + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::Date(Some(d)) => Ok(MssqlData::TimeDate(time_date_from_days( + i64::from(d.days()), + 1, + )?)), + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::Time(Some(t)) => { + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + Ok(MssqlData::TimeTime(time_from_sec_fragments(ns)?)) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::DateTime2(Some(dt2)) => { + let date = time_date_from_days(i64::from(dt2.date().days()), 1)?; + let t = dt2.time(); + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + let time = time_from_sec_fragments(ns)?; + Ok(MssqlData::TimePrimitiveDateTime( + time::PrimitiveDateTime::new(date, time), + )) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::DateTime(Some(dt)) => { + let date = time_date_from_days(i64::from(dt.days()), 1900)?; + let ns = dt.seconds_fragments() as u64 * 1_000_000_000u64 / 300; + let time = time_from_sec_fragments(ns)?; + Ok(MssqlData::TimePrimitiveDateTime( + time::PrimitiveDateTime::new(date, time), + )) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::SmallDateTime(Some(dt)) => { + let date = time_date_from_days(i64::from(dt.days()), 1900)?; + let seconds = dt.seconds_fragments() as u64 * 60; + let time = time_from_sec_fragments(seconds * 1_000_000_000)?; + Ok(MssqlData::TimePrimitiveDateTime( + time::PrimitiveDateTime::new(date, time), + )) + } + #[cfg(all(feature = "time", not(feature = "chrono")))] + tiberius::ColumnData::DateTimeOffset(Some(dto)) => { + let date = time_date_from_days(i64::from(dto.datetime2().date().days()), 1)?; + let t = dto.datetime2().time(); + let ns = t.increments() * 10u64.pow(9u32.saturating_sub(t.scale() as u32)); + let time = time_from_sec_fragments(ns)?; + let naive = time::PrimitiveDateTime::new(date, time); + let offset_secs = dto.offset() as i32 * 60; + let offset = time::UtcOffset::from_whole_seconds(offset_secs).map_err(|_| { + Error::Protocol(format!("invalid UTC offset: {offset_secs} seconds")) + })?; + Ok(MssqlData::TimeOffsetDateTime(naive.assume_offset(offset))) + } + + #[cfg(all(feature = "bigdecimal", not(feature = "rust_decimal")))] + tiberius::ColumnData::Numeric(Some(n)) => { + use bigdecimal::num_bigint::BigInt; + Ok(MssqlData::BigDecimal(bigdecimal::BigDecimal::new( + BigInt::from(n.value()), + n.scale() as i64, + ))) + } + + // All None variants represent SQL NULL + tiberius::ColumnData::U8(None) + | tiberius::ColumnData::I16(None) + | tiberius::ColumnData::I32(None) + | tiberius::ColumnData::I64(None) + | tiberius::ColumnData::F32(None) + | tiberius::ColumnData::F64(None) + | tiberius::ColumnData::Bit(None) + | tiberius::ColumnData::String(None) + | tiberius::ColumnData::Guid(None) + | tiberius::ColumnData::Binary(None) + | tiberius::ColumnData::Numeric(None) + | tiberius::ColumnData::Xml(None) + | tiberius::ColumnData::DateTime(None) + | tiberius::ColumnData::SmallDateTime(None) + | tiberius::ColumnData::DateTime2(None) + | tiberius::ColumnData::DateTimeOffset(None) + | tiberius::ColumnData::Date(None) + | tiberius::ColumnData::Time(None) => Ok(MssqlData::Null), + + // Unhandled Some(...) variant — real data the driver can't convert. + // Currently unreachable with all features enabled, but kept for forward + // compatibility when tiberius adds new variants. + #[allow(unreachable_patterns)] + other => { + let debug = format!("{other:?}"); + let truncated = if debug.len() > 200 { + let mut end = 200; + while !debug.is_char_boundary(end) { + end -= 1; + } + &debug[..end] + } else { + &debug + }; + Err(Error::Protocol(format!( + "unsupported tiberius ColumnData variant: {truncated}" + ))) + } + } +} + +/// Convert days since `start_year`-01-01 to a `time::Date`. +#[cfg(all(feature = "time", not(feature = "chrono")))] +fn time_date_from_days(days: i64, start_year: i32) -> Result { + let start = time::Date::from_ordinal_date(start_year, 1) + .map_err(|_| Error::Protocol(format!("invalid start year for date: {start_year}")))?; + start + .checked_add(time::Duration::days(days)) + .ok_or_else(|| { + Error::Protocol(format!( + "date overflow: {days} days from {start_year}-01-01" + )) + }) +} + +/// Convert nanoseconds-since-midnight to a `time::Time`. +#[cfg(all(feature = "time", not(feature = "chrono")))] +fn time_from_sec_fragments(nanoseconds: u64) -> Result { + const NANOS_PER_DAY: u64 = 86_400_000_000_000; + if nanoseconds >= NANOS_PER_DAY { + return Err(Error::Protocol(format!( + "time nanoseconds out of range: {nanoseconds} (must be < {NANOS_PER_DAY})" + ))); + } + // SAFETY: bounds check above guarantees nanoseconds < 86_400_000_000_000, + // so hours ≤ 23, minutes ≤ 59, seconds ≤ 59 — all fit in u8. + #[allow(clippy::cast_possible_truncation)] + let hours = (nanoseconds / 3_600_000_000_000) as u8; + let remaining = nanoseconds % 3_600_000_000_000; + #[allow(clippy::cast_possible_truncation)] + let minutes = (remaining / 60_000_000_000) as u8; + let remaining = remaining % 60_000_000_000; + #[allow(clippy::cast_possible_truncation)] + let seconds = (remaining / 1_000_000_000) as u8; + #[allow(clippy::cast_possible_truncation)] + let nanos = (remaining % 1_000_000_000) as u32; + time::Time::from_hms_nano(hours, minutes, seconds, nanos).map_err(|_| { + Error::Protocol(format!( + "invalid time: {hours:02}:{minutes:02}:{seconds:02}.{nanos:09}" + )) + }) +} + +/// Convert days since `start_year`-01-01 to a `chrono::NaiveDate`. +#[cfg(feature = "chrono")] +fn chrono_date_from_days(days: i64, start_year: i32) -> Result { + let start = chrono::NaiveDate::from_ymd_opt(start_year, 1, 1) + .ok_or_else(|| Error::Protocol(format!("invalid start year for date: {start_year}")))?; + start + .checked_add_signed(chrono::Duration::days(days)) + .ok_or_else(|| { + Error::Protocol(format!( + "date overflow: {days} days from {start_year}-01-01" + )) + }) +} diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 3744724c12..d7ffdaeb27 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -17,7 +17,7 @@ where let db_url = env::var("DATABASE_URL").map_err(|e| Error::Configuration(Box::new(e)))?; - Ok(DB::Connection::connect(&db_url).await?) + DB::Connection::connect(&db_url).await } // Make a new pool diff --git a/src/any/mod.rs b/src/any/mod.rs index 434d255573..fa6a8f5498 100644 --- a/src/any/mod.rs +++ b/src/any/mod.rs @@ -37,6 +37,8 @@ pub fn install_default_drivers() { ONCE.call_once(|| { install_drivers(&[ + #[cfg(feature = "mssql")] + sqlx_mssql::any::DRIVER, #[cfg(feature = "mysql")] sqlx_mysql::any::DRIVER, #[cfg(feature = "postgres")] diff --git a/src/lib.rs b/src/lib.rs index 438463210d..e0cd7dd164 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,6 +44,13 @@ pub use sqlx_core::error::{self, Error, Result}; #[cfg(feature = "migrate")] pub use sqlx_core::migrate; +#[cfg(feature = "mssql")] +#[cfg_attr(docsrs, doc(cfg(feature = "mssql")))] +#[doc(inline)] +pub use sqlx_mssql::{ + self as mssql, Mssql, MssqlConnection, MssqlExecutor, MssqlPool, MssqlTransaction, +}; + #[cfg(feature = "mysql")] #[cfg_attr(docsrs, doc(cfg(feature = "mysql")))] #[doc(inline)] diff --git a/tests/any/any.rs b/tests/any/any.rs index 2c57a237ec..bc49804de2 100644 --- a/tests/any/any.rs +++ b/tests/any/any.rs @@ -152,7 +152,7 @@ async fn it_can_query_by_string_args() -> sqlx::Result<()> { let mut conn = new::().await?; let string = "Hello, world!".to_string(); - let ref tuple = ("Hello, world!".to_string(),); + let tuple = &("Hello, world!".to_string(),); #[cfg(feature = "postgres")] const SQL: &str = "SELECT 'Hello, world!' \ diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml index e03eeeea67..ce3eb05511 100644 --- a/tests/docker-compose.yml +++ b/tests/docker-compose.yml @@ -200,6 +200,34 @@ services: MARIADB_DATABASE: sqlx MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: 1 # + # Microsoft SQL Server 2022, 2019 + # + + mssql_2022: + build: + context: . + dockerfile: mssql/Dockerfile + args: + VERSION: 2022-latest + ports: + - 1433 + environment: + ACCEPT_EULA: "Y" + SA_PASSWORD: "YourStrong!Passw0rd" + + mssql_2019: + build: + context: . + dockerfile: mssql/Dockerfile + args: + VERSION: 2019-latest + ports: + - 1433 + environment: + ACCEPT_EULA: "Y" + SA_PASSWORD: "YourStrong!Passw0rd" + + # # PostgreSQL 17.x, 16.x, 15.x, 14.x, 13.x # https://www.postgresql.org/support/versioning/ # diff --git a/tests/mssql/advisory-lock.rs b/tests/mssql/advisory-lock.rs new file mode 100644 index 0000000000..c3dac4e924 --- /dev/null +++ b/tests/mssql/advisory-lock.rs @@ -0,0 +1,91 @@ +use sqlx::mssql::{Mssql, MssqlAdvisoryLock, MssqlAdvisoryLockMode}; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_acquires_and_releases() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_acquire_release"); + + lock.acquire(&mut conn).await?; + let released = lock.release(&mut conn).await?; + assert!(released, "lock should have been held and released"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_try_acquire_succeeds_when_free() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_try_free"); + + let acquired = lock.try_acquire(&mut conn).await?; + assert!(acquired, "lock should be free and acquired"); + + lock.release(&mut conn).await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_try_acquire_fails_when_held() -> anyhow::Result<()> { + let mut conn1 = new::().await?; + let mut conn2 = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_try_held"); + + // Conn1 holds the exclusive lock + lock.acquire(&mut conn1).await?; + + // Conn2 should fail to acquire it immediately + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(!acquired, "lock should not be available"); + + // Release from conn1 + lock.release(&mut conn1).await?; + + // Now conn2 should be able to acquire + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(acquired, "lock should now be free"); + + lock.release(&mut conn2).await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_supports_shared_locks() -> anyhow::Result<()> { + let mut conn1 = new::().await?; + let mut conn2 = new::().await?; + + let lock = MssqlAdvisoryLock::with_mode("sqlx_test_shared", MssqlAdvisoryLockMode::Shared); + + // Both connections should be able to acquire a shared lock + lock.acquire(&mut conn1).await?; + let acquired = lock.try_acquire(&mut conn2).await?; + assert!( + acquired, + "shared lock should be acquirable by second connection" + ); + + lock.release(&mut conn1).await?; + lock.release(&mut conn2).await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_release_returns_false_when_not_held() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_not_held"); + + let released = lock.release(&mut conn).await?; + assert!( + !released, + "release should return false when lock is not held" + ); + + Ok(()) +} diff --git a/tests/mssql/bulk-insert.rs b/tests/mssql/bulk-insert.rs new file mode 100644 index 0000000000..232f576752 --- /dev/null +++ b/tests/mssql/bulk-insert.rs @@ -0,0 +1,77 @@ +use sqlx::mssql::{IntoRow, Mssql}; +use sqlx::Row; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_bulk_inserts_rows() -> anyhow::Result<()> { + let mut conn = new::().await?; + + sqlx::query("CREATE TABLE #bulk_test (name NVARCHAR(50) NOT NULL, value INT NOT NULL)") + .execute(&mut conn) + .await?; + + let mut bulk = conn.bulk_insert("#bulk_test").await?; + bulk.send(("hello", 1i32).into_row()).await?; + bulk.send(("world", 2i32).into_row()).await?; + bulk.send(("foo", 3i32).into_row()).await?; + let total = bulk.finalize().await?; + assert_eq!(total, 3); + + let rows = sqlx::query("SELECT name, value FROM #bulk_test ORDER BY value") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 3); + assert_eq!(rows[0].get::("name"), "hello"); + assert_eq!(rows[0].get::("value"), 1); + assert_eq!(rows[1].get::("name"), "world"); + assert_eq!(rows[1].get::("value"), 2); + assert_eq!(rows[2].get::("name"), "foo"); + assert_eq!(rows[2].get::("value"), 3); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_bulk_inserts_empty() -> anyhow::Result<()> { + let mut conn = new::().await?; + + sqlx::query("CREATE TABLE #bulk_empty (id INT NOT NULL)") + .execute(&mut conn) + .await?; + + let bulk = conn.bulk_insert("#bulk_empty").await?; + let total = bulk.finalize().await?; + assert_eq!(total, 0); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_bulk_inserts_various_types() -> anyhow::Result<()> { + let mut conn = new::().await?; + + sqlx::query( + "CREATE TABLE #bulk_types (id INT NOT NULL, label NVARCHAR(100) NOT NULL, score FLOAT NOT NULL)" + ) + .execute(&mut conn) + .await?; + + let mut bulk = conn.bulk_insert("#bulk_types").await?; + bulk.send((1i32, "alpha", 1.5f64).into_row()).await?; + bulk.send((2i32, "beta", 2.7f64).into_row()).await?; + let total = bulk.finalize().await?; + assert_eq!(total, 2); + + let rows = sqlx::query("SELECT id, label, score FROM #bulk_types ORDER BY id") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].get::("id"), 1); + assert_eq!(rows[0].get::("label"), "alpha"); + assert_eq!(rows[1].get::("id"), 2); + assert_eq!(rows[1].get::("label"), "beta"); + + Ok(()) +} diff --git a/tests/mssql/derives.rs b/tests/mssql/derives.rs new file mode 100644 index 0000000000..a9e91d439c --- /dev/null +++ b/tests/mssql/derives.rs @@ -0,0 +1,110 @@ +use sqlx::mssql::Mssql; +use sqlx_test::{new, test_type}; + +#[sqlx::test] +async fn test_derive_weak_enum() -> anyhow::Result<()> { + #[derive(sqlx::Type, Debug, PartialEq, Eq)] + #[repr(i16)] + enum WeakEnumI16 { + Foo = i16::MIN, + Bar = 0, + Baz = i16::MAX, + } + + #[derive(sqlx::Type, Debug, PartialEq, Eq)] + #[repr(i32)] + enum WeakEnumI32 { + Foo = i32::MIN, + Bar = 0, + Baz = i32::MAX, + } + + #[derive(sqlx::Type, Debug, PartialEq, Eq)] + #[repr(i64)] + enum WeakEnumI64 { + Foo = i64::MIN, + Bar = 0, + Baz = i64::MAX, + } + + #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] + struct WeakEnumRow { + i16: WeakEnumI16, + i32: WeakEnumI32, + i64: WeakEnumI64, + } + + let mut conn = new::().await?; + + sqlx::raw_sql( + r#" + CREATE TABLE #weak_enum ( + i16 SMALLINT, + i32 INT, + i64 BIGINT + ) + "#, + ) + .execute(&mut conn) + .await?; + + let rows_in = vec![ + WeakEnumRow { + i16: WeakEnumI16::Foo, + i32: WeakEnumI32::Foo, + i64: WeakEnumI64::Foo, + }, + WeakEnumRow { + i16: WeakEnumI16::Bar, + i32: WeakEnumI32::Bar, + i64: WeakEnumI64::Bar, + }, + WeakEnumRow { + i16: WeakEnumI16::Baz, + i32: WeakEnumI32::Baz, + i64: WeakEnumI64::Baz, + }, + ]; + + for row in &rows_in { + sqlx::query( + r#" + INSERT INTO #weak_enum(i16, i32, i64) + VALUES (@p1, @p2, @p3) + "#, + ) + .bind(&row.i16) + .bind(&row.i32) + .bind(&row.i64) + .execute(&mut conn) + .await?; + } + + let rows_out: Vec = sqlx::query_as("SELECT * FROM #weak_enum") + .fetch_all(&mut conn) + .await?; + + assert_eq!(rows_in, rows_out); + + Ok(()) +} + +#[derive(PartialEq, Eq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct TransparentTuple(i64); + +#[derive(PartialEq, Eq, Debug, sqlx::Type)] +#[sqlx(transparent)] +struct TransparentNamed { + field: i64, +} + +test_type!(transparent_tuple(Mssql, + "CAST(0 AS BIGINT)" == TransparentTuple(0), + "CAST(23523 AS BIGINT)" == TransparentTuple(23523) +)); + +test_type!(transparent_named(Mssql, + "CAST(0 AS BIGINT)" == TransparentNamed { field: 0 }, + "CAST(23523 AS BIGINT)" == TransparentNamed { field: 23523 }, +)); diff --git a/tests/mssql/describe.rs b/tests/mssql/describe.rs index 3717829b41..64f97102d9 100644 --- a/tests/mssql/describe.rs +++ b/tests/mssql/describe.rs @@ -1,12 +1,12 @@ use sqlx::mssql::Mssql; -use sqlx::{Column, Executor, TypeInfo}; +use sqlx::{Column, Executor, SqlSafeStr, TypeInfo}; use sqlx_test::new; #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT * FROM tweet").await?; + let d = conn.describe("SELECT * FROM tweet".into_sql_str()).await?; assert_eq!(d.columns()[0].name(), "id"); assert_eq!(d.columns()[1].name(), "text"); @@ -31,7 +31,7 @@ async fn it_describes_with_params() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("SELECT text FROM tweet WHERE id = @p1") + .describe("SELECT text FROM tweet WHERE id = @p1".into_sql_str()) .await?; assert_eq!(d.columns()[0].name(), "text"); diff --git a/tests/mssql/error.rs b/tests/mssql/error.rs new file mode 100644 index 0000000000..e0dd1ed6fd --- /dev/null +++ b/tests/mssql/error.rs @@ -0,0 +1,79 @@ +use sqlx::error::ErrorKind; +use sqlx::mssql::Mssql; +use sqlx::Connection; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_fails_with_unique_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + sqlx::query("INSERT INTO tweet(id, text, owner_id) VALUES (1, 'Foo', 1)") + .execute(&mut *tx) + .await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO tweet(id, text, owner_id) VALUES (1, 'Bar', 1)") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::UniqueViolation); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_foreign_key_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO tweet_reply (tweet_id, text) VALUES (999, 'Reply!')") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::ForeignKeyViolation); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_not_null_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let res: Result<_, sqlx::Error> = sqlx::query("INSERT INTO tweet (id, text) VALUES (1, NULL)") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::NotNullViolation); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_fails_with_check_violation() -> anyhow::Result<()> { + let mut conn = new::().await?; + let mut tx = conn.begin().await?; + + let res: Result<_, sqlx::Error> = + sqlx::query("INSERT INTO products (product_no, name, price) VALUES (1, 'Product 1', 0)") + .execute(&mut *tx) + .await; + let err = res.unwrap_err(); + + let err = err.into_database_error().unwrap(); + + assert_eq!(err.kind(), ErrorKind::CheckViolation); + + Ok(()) +} diff --git a/tests/mssql/fixtures/comments.sql b/tests/mssql/fixtures/comments.sql new file mode 100644 index 0000000000..93ea9941f6 --- /dev/null +++ b/tests/mssql/fixtures/comments.sql @@ -0,0 +1,4 @@ +INSERT INTO comment(comment_id, post_id, user_id, content, created_at) +VALUES (1, 1, 2, 'lol bet ur still bad, 1v1 me', DATEADD(MINUTE, -50, SYSUTCDATETIME())), + (2, 1, 1, 'you''re on!', DATEADD(MINUTE, -45, SYSUTCDATETIME())), + (3, 2, 1, 'lol you''re just mad you lost :P', DATEADD(MINUTE, -15, SYSUTCDATETIME())); diff --git a/tests/mssql/fixtures/posts.sql b/tests/mssql/fixtures/posts.sql new file mode 100644 index 0000000000..e75d0d9381 --- /dev/null +++ b/tests/mssql/fixtures/posts.sql @@ -0,0 +1,5 @@ +SET IDENTITY_INSERT post ON; +INSERT INTO post(post_id, user_id, content, created_at) +VALUES (1, 1, 'This new computer is lightning-fast!', DATEADD(HOUR, -1, SYSUTCDATETIME())), + (2, 2, '@alice is a haxxor :(', DATEADD(MINUTE, -30, SYSUTCDATETIME())); +SET IDENTITY_INSERT post OFF; diff --git a/tests/mssql/fixtures/users.sql b/tests/mssql/fixtures/users.sql new file mode 100644 index 0000000000..0d4270c282 --- /dev/null +++ b/tests/mssql/fixtures/users.sql @@ -0,0 +1,3 @@ +SET IDENTITY_INSERT [user] ON; +INSERT INTO [user](user_id, username) VALUES (1, 'alice'), (2, 'bob'); +SET IDENTITY_INSERT [user] OFF; diff --git a/tests/mssql/isolation-level.rs b/tests/mssql/isolation-level.rs new file mode 100644 index 0000000000..de6ac5ecb2 --- /dev/null +++ b/tests/mssql/isolation-level.rs @@ -0,0 +1,56 @@ +use sqlx::mssql::{Mssql, MssqlIsolationLevel}; +use sqlx::Row; +use sqlx_test::new; + +#[sqlx_macros::test] +async fn it_begins_with_read_uncommitted() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::ReadUncommitted) + .await?; + + let row = sqlx::query("SELECT 1 AS val").fetch_one(&mut *tx).await?; + let val: i32 = row.get("val"); + assert_eq!(val, 1); + + tx.commit().await?; + Ok(()) +} + +#[sqlx_macros::test] +async fn it_begins_with_snapshot() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Enable snapshot isolation on the database first + sqlx::query("ALTER DATABASE CURRENT SET ALLOW_SNAPSHOT_ISOLATION ON") + .execute(&mut conn) + .await?; + + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Snapshot) + .await?; + + let row = sqlx::query("SELECT 1 AS val").fetch_one(&mut *tx).await?; + let val: i32 = row.get("val"); + assert_eq!(val, 1); + + tx.commit().await?; + Ok(()) +} + +#[sqlx_macros::test] +async fn it_begins_with_serializable() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Serializable) + .await?; + + let row = sqlx::query("SELECT 1 AS val").fetch_one(&mut *tx).await?; + let val: i32 = row.get("val"); + assert_eq!(val, 1); + + tx.commit().await?; + Ok(()) +} diff --git a/tests/mssql/migrate.rs b/tests/mssql/migrate.rs new file mode 100644 index 0000000000..b76a2a4eb1 --- /dev/null +++ b/tests/mssql/migrate.rs @@ -0,0 +1,84 @@ +use sqlx::migrate::Migrator; +use sqlx::mssql::{Mssql, MssqlConnection}; +use sqlx::pool::PoolConnection; +use sqlx::Executor; +use sqlx::Row; +use std::path::Path; + +#[sqlx::test(migrations = false)] +async fn simple(mut conn: PoolConnection) -> anyhow::Result<()> { + clean_up(&mut conn).await?; + + let migrator = Migrator::new(Path::new("tests/mssql/migrations_simple")).await?; + + // run migration + migrator.run(&mut conn).await?; + + // check outcome + let res: String = conn + .fetch_one("SELECT some_payload FROM migrations_simple_test") + .await? + .get(0); + assert_eq!(res, "110_suffix"); + + // running it a 2nd time should still work + migrator.run(&mut conn).await?; + + Ok(()) +} + +#[sqlx::test(migrations = false)] +async fn reversible(mut conn: PoolConnection) -> anyhow::Result<()> { + clean_up(&mut conn).await?; + + let migrator = Migrator::new(Path::new("tests/mssql/migrations_reversible")).await?; + + // run migration + migrator.run(&mut conn).await?; + + // check outcome + let res: i64 = conn + .fetch_one("SELECT some_payload FROM migrations_reversible_test") + .await? + .get(0); + assert_eq!(res, 101); + + // roll back nothing (last version) + migrator.undo(&mut conn, 20220721125033).await?; + + // check outcome + let res: i64 = conn + .fetch_one("SELECT some_payload FROM migrations_reversible_test") + .await? + .get(0); + assert_eq!(res, 101); + + // roll back one version + migrator.undo(&mut conn, 20220721124650).await?; + + // check outcome + let res: i64 = conn + .fetch_one("SELECT some_payload FROM migrations_reversible_test") + .await? + .get(0); + assert_eq!(res, 100); + + Ok(()) +} + +/// Ensure that we have a clean initial state. +async fn clean_up(conn: &mut MssqlConnection) -> anyhow::Result<()> { + conn.execute( + "IF OBJECT_ID('migrations_simple_test', 'U') IS NOT NULL DROP TABLE migrations_simple_test", + ) + .await + .ok(); + conn.execute("IF OBJECT_ID('migrations_reversible_test', 'U') IS NOT NULL DROP TABLE migrations_reversible_test") + .await + .ok(); + conn.execute("IF OBJECT_ID('_sqlx_migrations', 'U') IS NOT NULL DROP TABLE _sqlx_migrations") + .await + .ok(); + + Ok(()) +} diff --git a/tests/mssql/migrations/1_user.sql b/tests/mssql/migrations/1_user.sql new file mode 100644 index 0000000000..ed0bb63cdf --- /dev/null +++ b/tests/mssql/migrations/1_user.sql @@ -0,0 +1,4 @@ +CREATE TABLE [user] ( + user_id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, + username NVARCHAR(16) NOT NULL UNIQUE +); diff --git a/tests/mssql/migrations/2_post.sql b/tests/mssql/migrations/2_post.sql new file mode 100644 index 0000000000..cbdd07cdd4 --- /dev/null +++ b/tests/mssql/migrations/2_post.sql @@ -0,0 +1,7 @@ +CREATE TABLE post ( + post_id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, + user_id INT NOT NULL REFERENCES [user](user_id), + content NVARCHAR(MAX) NOT NULL, + created_at DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME() +); +CREATE INDEX post_created_at ON post (created_at DESC); diff --git a/tests/mssql/migrations/3_comment.sql b/tests/mssql/migrations/3_comment.sql new file mode 100644 index 0000000000..8f168a2e3d --- /dev/null +++ b/tests/mssql/migrations/3_comment.sql @@ -0,0 +1,8 @@ +CREATE TABLE comment ( + comment_id INT NOT NULL PRIMARY KEY, + post_id INT NOT NULL REFERENCES post(post_id), + user_id INT NOT NULL REFERENCES [user](user_id), + content NVARCHAR(MAX) NOT NULL, + created_at DATETIME2 NOT NULL DEFAULT SYSUTCDATETIME() +); +CREATE INDEX comment_created_at ON comment (created_at DESC); diff --git a/tests/mssql/migrations_reversible/20220721124650_add_table.down.sql b/tests/mssql/migrations_reversible/20220721124650_add_table.down.sql new file mode 100644 index 0000000000..5505859725 --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721124650_add_table.down.sql @@ -0,0 +1 @@ +DROP TABLE migrations_reversible_test; diff --git a/tests/mssql/migrations_reversible/20220721124650_add_table.up.sql b/tests/mssql/migrations_reversible/20220721124650_add_table.up.sql new file mode 100644 index 0000000000..9dfc757954 --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721124650_add_table.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE migrations_reversible_test ( + some_id BIGINT NOT NULL PRIMARY KEY, + some_payload BIGINT NOT NUll +); + +INSERT INTO migrations_reversible_test (some_id, some_payload) +VALUES (1, 100); diff --git a/tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql b/tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql new file mode 100644 index 0000000000..3f71737b8c --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721125033_modify_column.down.sql @@ -0,0 +1,2 @@ +UPDATE migrations_reversible_test +SET some_payload = some_payload - 1; diff --git a/tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql b/tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql new file mode 100644 index 0000000000..bbb176cf41 --- /dev/null +++ b/tests/mssql/migrations_reversible/20220721125033_modify_column.up.sql @@ -0,0 +1,2 @@ +UPDATE migrations_reversible_test +SET some_payload = some_payload + 1; diff --git a/tests/mssql/migrations_simple/20220721115250_add_test_table.sql b/tests/mssql/migrations_simple/20220721115250_add_test_table.sql new file mode 100644 index 0000000000..d5ba291914 --- /dev/null +++ b/tests/mssql/migrations_simple/20220721115250_add_test_table.sql @@ -0,0 +1,7 @@ +CREATE TABLE migrations_simple_test ( + some_id BIGINT NOT NULL PRIMARY KEY, + some_payload BIGINT NOT NUll +); + +INSERT INTO migrations_simple_test (some_id, some_payload) +VALUES (1, 100); diff --git a/tests/mssql/migrations_simple/20220721115524_convert_type.sql b/tests/mssql/migrations_simple/20220721115524_convert_type.sql new file mode 100644 index 0000000000..c437c39d02 --- /dev/null +++ b/tests/mssql/migrations_simple/20220721115524_convert_type.sql @@ -0,0 +1,34 @@ +-- Perform a tricky conversion of the payload. +-- +-- This script will only succeed once and will fail if executed twice. + +-- set up temporary target column +ALTER TABLE migrations_simple_test +ADD some_payload_tmp NVARCHAR(MAX); + +-- perform conversion +-- This will fail if `some_payload` is already a string column due to the addition. +-- We add a suffix after the addition to ensure that the SQL database does not silently cast the string back to an +-- integer. +UPDATE migrations_simple_test +SET some_payload_tmp = CONCAT(CAST((some_payload + 10) AS VARCHAR(3)), '_suffix'); + +-- remove original column including the content +ALTER TABLE migrations_simple_test +DROP COLUMN some_payload; + +-- prepare new payload column (nullable, so we can copy over the data) +ALTER TABLE migrations_simple_test +ADD some_payload NVARCHAR(MAX); + +-- copy new values +UPDATE migrations_simple_test +SET some_payload = some_payload_tmp; + +-- "freeze" column: MSSQL uses sp_rename + re-add or ALTER COLUMN for NOT NULL +ALTER TABLE migrations_simple_test +ALTER COLUMN some_payload NVARCHAR(MAX) NOT NULL; + +-- clean up +ALTER TABLE migrations_simple_test +DROP COLUMN some_payload_tmp; diff --git a/tests/mssql/mssql.rs b/tests/mssql/mssql.rs index 0986ef1bbd..f2e1627515 100644 --- a/tests/mssql/mssql.rs +++ b/tests/mssql/mssql.rs @@ -1,7 +1,8 @@ use futures_util::TryStreamExt; +use sqlx::mssql::MssqlRow; use sqlx::mssql::{Mssql, MssqlPoolOptions}; -use sqlx::{Column, Connection, Executor, MssqlConnection, Row, Statement, TypeInfo}; -use sqlx_core::mssql::MssqlRow; +use sqlx::mssql::{MssqlAdvisoryLock, MssqlIsolationLevel}; +use sqlx::{Column, Connection, Executor, MssqlConnection, Row, SqlSafeStr, Statement, TypeInfo}; use sqlx_test::new; use std::sync::atomic::{AtomicI32, Ordering}; use std::time::Duration; @@ -195,9 +196,9 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { let mut tx = conn.begin().await?; - sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES ($1)") + sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") .bind(10_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; tx.rollback().await?; @@ -214,7 +215,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") .bind(10_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; tx.commit().await?; @@ -232,7 +233,7 @@ async fn it_can_work_with_transactions() -> anyhow::Result<()> { sqlx::query("INSERT INTO _sqlx_users_1922 (id) VALUES (@p1)") .bind(20_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; } @@ -262,7 +263,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // insert a user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES (@p1)") .bind(50_i32) - .execute(&mut tx) + .execute(&mut *tx) .await?; // begin once more @@ -271,7 +272,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // insert another user sqlx::query("INSERT INTO _sqlx_users_2523 (id) VALUES (@p1)") .bind(10_i32) - .execute(&mut tx2) + .execute(&mut *tx2) .await?; // never mind, rollback @@ -279,7 +280,7 @@ async fn it_can_work_with_nested_transactions() -> anyhow::Result<()> { // did we really? let (count,): (i32,) = sqlx::query_as("SELECT COUNT(*) FROM _sqlx_users_2523") - .fetch_one(&mut tx) + .fetch_one(&mut *tx) .await?; assert_eq!(count, 1); @@ -305,10 +306,12 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { let tweet_id: i64 = sqlx::query_scalar( "INSERT INTO tweet ( id, text ) OUTPUT INSERTED.id VALUES ( 50, 'Hello, World' )", ) - .fetch_one(&mut tx) + .fetch_one(&mut *tx) .await?; - let statement = tx.prepare("SELECT * FROM tweet WHERE id = @p1").await?; + let statement = tx + .prepare("SELECT * FROM tweet WHERE id = @p1".into_sql_str()) + .await?; assert_eq!(statement.column(0).name(), "id"); assert_eq!(statement.column(1).name(), "text"); @@ -320,7 +323,7 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { assert_eq!(statement.column(2).type_info().name(), "TINYINT"); assert_eq!(statement.column(3).type_info().name(), "BIGINT"); - let row = statement.query().bind(tweet_id).fetch_one(&mut tx).await?; + let row = statement.query().bind(tweet_id).fetch_one(&mut *tx).await?; let tweet_text: String = row.try_get("text")?; assert_eq!(tweet_text, "Hello, World"); @@ -359,7 +362,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { CREATE TABLE #conn_stats( id int primary key, before_acquire_calls int default 0, - after_release_calls int default 0 + after_release_calls int default 0 ); INSERT INTO #conn_stats(id) VALUES ({}); "#, @@ -367,7 +370,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + conn.execute(sqlx::AssertSqlSafe(statement)).await?; Ok(()) }) }) @@ -380,7 +383,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { // MSSQL doesn't support UPDATE ... RETURNING either sqlx::query( r#" - UPDATE #conn_stats + UPDATE #conn_stats SET before_acquire_calls = before_acquire_calls + 1 "#, ) @@ -404,7 +407,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { Box::pin(async move { sqlx::query( r#" - UPDATE #conn_stats + UPDATE #conn_stats SET after_release_calls = after_release_calls + 1 "#, ) @@ -459,3 +462,217 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { Ok(()) } + +#[sqlx_macros::test] +async fn it_can_query_multiple_result_sets() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // A batch that produces two result sets + let results = conn + .run("SELECT 1 AS a; SELECT 2 AS b, 3 AS c;", None) + .await?; + + // First result set: one row with column "a" + let mut rows_first = Vec::new(); + let mut rows_second = Vec::new(); + let mut result_count = 0; + + for item in &results { + match item { + either::Either::Left(_) => { + result_count += 1; + } + either::Either::Right(row) => { + if result_count == 0 { + rows_first.push(row); + } else { + rows_second.push(row); + } + } + } + } + + assert_eq!(rows_first.len(), 1); + assert_eq!(rows_first[0].try_get::("a")?, 1); + + assert_eq!(rows_second.len(), 1); + assert_eq!(rows_second[0].try_get::("b")?, 2); + assert_eq!(rows_second[0].try_get::("c")?, 3); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_inspect_column_metadata() -> anyhow::Result<()> { + let mut conn = new::().await?; + + let statement = conn + .prepare("SELECT CAST(1 AS INT) AS int_col, CAST('hello' AS NVARCHAR(50)) AS str_col, CAST(NULL AS BIGINT) AS nullable_col".into_sql_str()) + .await?; + + assert_eq!(statement.column(0).name(), "int_col"); + assert_eq!(statement.column(1).name(), "str_col"); + assert_eq!(statement.column(2).name(), "nullable_col"); + + assert_eq!(statement.column(0).type_info().name(), "INT"); + // sp_describe_first_result_set returns "NVARCHAR(50)" for typed NVARCHAR + assert!(statement + .column(1) + .type_info() + .name() + .starts_with("NVARCHAR")); + assert_eq!(statement.column(2).type_info().name(), "BIGINT"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_reuse_connection_after_error() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Cause an error + let res: Result<_, sqlx::Error> = sqlx::query("SELECT * FROM this_table_does_not_exist_12345") + .execute(&mut conn) + .await; + assert!(res.is_err()); + + // Connection should still be usable + let val: (i32,) = sqlx::query_as("SELECT 42").fetch_one(&mut conn).await?; + assert_eq!(val.0, 42); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_bind_many_parameters() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Build a query with 100 parameters: SELECT @p1 + @p2 + ... + @p100 + let param_refs: Vec = (1..=100).map(|i| format!("@p{i}")).collect(); + let sql = format!("SELECT {}", param_refs.join(" + ")); + + let mut query = sqlx::query_scalar::<_, i32>(&sql); + for _ in 0..100 { + query = query.bind(1_i32); + } + + let result: i32 = query.fetch_one(&mut conn).await?; + assert_eq!(result, 100); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_handles_special_characters_in_strings() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Single quotes + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind("it's a test") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, "it's a test"); + + // Backslashes + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind(r"C:\Users\test") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, r"C:\Users\test"); + + // Unicode + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind("\u{1F600} hello \u{4E16}\u{754C}") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, "\u{1F600} hello \u{4E16}\u{754C}"); + + // Newlines and tabs + let val: (String,) = sqlx::query_as("SELECT @p1") + .bind("line1\nline2\ttab") + .fetch_one(&mut conn) + .await?; + assert_eq!(val.0, "line1\nline2\ttab"); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_use_transaction_isolation_levels() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Start a transaction with READ UNCOMMITTED isolation + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::ReadUncommitted) + .await?; + + // Verify we can do work inside the transaction + let val: (i32,) = sqlx::query_as("SELECT 1").fetch_one(&mut *tx).await?; + assert_eq!(val.0, 1); + + tx.commit().await?; + + // Start a transaction with SERIALIZABLE isolation + let mut tx = conn + .begin_with_isolation(MssqlIsolationLevel::Serializable) + .await?; + + let val: (i32,) = sqlx::query_as("SELECT 2").fetch_one(&mut *tx).await?; + assert_eq!(val.0, 2); + + tx.rollback().await?; + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_use_advisory_lock_guard() -> anyhow::Result<()> { + let mut conn = new::().await?; + + // Need a transaction context for sp_getapplock with Session owner + // Actually, Session-scoped locks work outside transactions too. + let lock = MssqlAdvisoryLock::new("sqlx_test_lock_guard"); + + // Acquire the lock via the RAII guard + let mut guard = lock.acquire_guard(&mut conn).await?; + + // Use the connection through the guard + let val: (i32,) = sqlx::query_as("SELECT 99").fetch_one(&mut *guard).await?; + assert_eq!(val.0, 99); + + // Release the lock and get the connection back + let conn = guard.release_now().await?; + + // Verify we can still use the connection + let val: (i32,) = sqlx::query_as("SELECT 100").fetch_one(conn).await?; + assert_eq!(val.0, 100); + + Ok(()) +} + +#[sqlx_macros::test] +async fn it_can_try_acquire_advisory_lock() -> anyhow::Result<()> { + let mut conn1 = new::().await?; + let mut conn2 = new::().await?; + + let lock = MssqlAdvisoryLock::new("sqlx_test_try_lock"); + + // Acquire on conn1 + lock.acquire(&mut conn1).await?; + + // Try to acquire on conn2 — should fail (return false) since it's exclusive + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(!acquired); + + // Release on conn1 + let released = lock.release(&mut conn1).await?; + assert!(released); + + // Now conn2 should be able to acquire + let acquired = lock.try_acquire(&mut conn2).await?; + assert!(acquired); + + lock.release(&mut conn2).await?; + + Ok(()) +} diff --git a/tests/mssql/query_builder.rs b/tests/mssql/query_builder.rs new file mode 100644 index 0000000000..2e938d4847 --- /dev/null +++ b/tests/mssql/query_builder.rs @@ -0,0 +1,108 @@ +use sqlx::mssql::Mssql; +use sqlx::query_builder::QueryBuilder; +use sqlx::Execute; + +#[test] +fn test_new() { + let qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); + assert_eq!(qb.sql(), "SELECT * FROM users"); +} + +#[test] +fn test_push() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); + let second_line = " WHERE last_name LIKE '[A-N]%';"; + qb.push(second_line); + + assert_eq!( + qb.sql(), + "SELECT * FROM users WHERE last_name LIKE '[A-N]%';".to_string(), + ); +} + +#[test] +#[should_panic] +fn test_push_panics_after_build_without_reset() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users;"); + + let _query = qb.build(); + + qb.push("SELECT * FROM users;"); +} + +#[test] +fn test_push_bind() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users WHERE id = "); + + qb.push_bind(42i32) + .push(" OR membership_level = ") + .push_bind(3i32); + + assert_eq!( + qb.sql(), + "SELECT * FROM users WHERE id = @p1 OR membership_level = @p2" + ); +} + +#[test] +fn test_build() { + let mut qb: QueryBuilder = QueryBuilder::new("SELECT * FROM users"); + + qb.push(" WHERE id = ").push_bind(42i32); + let query = qb.build(); + + assert!(Execute::persistent(&query)); + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = @p1"); +} + +#[test] +fn test_reset() { + let mut qb: QueryBuilder = QueryBuilder::new(""); + + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } + + qb.reset(); + + assert_eq!(qb.sql(), ""); +} + +#[test] +fn test_query_builder_reuse() { + let mut qb: QueryBuilder = QueryBuilder::new(""); + + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + + qb.reset(); + + let query = qb.push("SELECT * FROM users WHERE id = 99").build(); + + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = 99"); +} + +#[test] +fn test_query_builder_with_args() { + let mut qb: QueryBuilder = QueryBuilder::new(""); + + let mut query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + + let args = query.take_arguments().unwrap().unwrap(); + + let mut qb: QueryBuilder = QueryBuilder::with_arguments(query.sql().as_str(), args); + let query = qb.push(" OR membership_level = ").push_bind(3i32).build(); + + assert_eq!( + query.sql(), + "SELECT * FROM users WHERE id = @p1 OR membership_level = @p2" + ); +} diff --git a/tests/mssql/setup.sql b/tests/mssql/setup.sql index a033227b75..4a78cccfa7 100644 --- a/tests/mssql/setup.sql +++ b/tests/mssql/setup.sql @@ -18,3 +18,28 @@ IF OBJECT_ID('tweet') IS NULL ); END; GO + +IF OBJECT_ID('tweet_reply') IS NULL + BEGIN + CREATE TABLE tweet_reply + ( + id BIGINT NOT NULL IDENTITY(1,1) PRIMARY KEY, + tweet_id BIGINT NOT NULL, + text NVARCHAR(4000) NOT NULL, + owner_id BIGINT, + CONSTRAINT tweet_id_fk FOREIGN KEY (tweet_id) REFERENCES tweet(id) + ); + END; +GO + +IF OBJECT_ID('products') IS NULL + BEGIN + CREATE TABLE products + ( + product_no INT, + name NVARCHAR(200), + price DECIMAL(10,2), + CONSTRAINT chk_price CHECK (price > 0) + ); + END; +GO diff --git a/tests/mssql/test-attr.rs b/tests/mssql/test-attr.rs new file mode 100644 index 0000000000..81b3d62660 --- /dev/null +++ b/tests/mssql/test-attr.rs @@ -0,0 +1,92 @@ +// The no-arg variant is covered by other tests already. + +use sqlx::MssqlPool; + +const MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!("tests/mssql/migrations"); + +#[sqlx::test] +async fn it_gets_a_pool(pool: MssqlPool) -> sqlx::Result<()> { + let mut conn = pool.acquire().await?; + + let db_name: String = sqlx::query_scalar("SELECT DB_NAME()") + .fetch_one(&mut *conn) + .await?; + + assert!(db_name.starts_with("_sqlx_test_"), "db_name: {:?}", db_name); + + Ok(()) +} + +// This should apply migrations and then `fixtures/users.sql` +#[sqlx::test(migrations = "tests/mssql/migrations", fixtures("users"))] +async fn it_gets_users(pool: MssqlPool) -> sqlx::Result<()> { + let usernames: Vec = + sqlx::query_scalar(r#"SELECT username FROM [user] ORDER BY username"#) + .fetch_all(&pool) + .await?; + + assert_eq!(usernames, ["alice", "bob"]); + + let post_count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM post") + .fetch_one(&pool) + .await?; + + assert_eq!(post_count, 0); + + let comment_count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM comment") + .fetch_one(&pool) + .await?; + + assert_eq!(comment_count, 0); + + Ok(()) +} + +#[sqlx::test(migrations = "tests/mssql/migrations", fixtures("users", "posts"))] +async fn it_gets_posts(pool: MssqlPool) -> sqlx::Result<()> { + let post_contents: Vec = + sqlx::query_scalar("SELECT content FROM post ORDER BY created_at") + .fetch_all(&pool) + .await?; + + assert_eq!( + post_contents, + [ + "This new computer is lightning-fast!", + "@alice is a haxxor :(" + ] + ); + + let comment_count: i32 = sqlx::query_scalar("SELECT COUNT(*) FROM comment") + .fetch_one(&pool) + .await?; + + assert_eq!(comment_count, 0); + + Ok(()) +} + +// Try `migrator` +#[sqlx::test(migrator = "MIGRATOR", fixtures("users", "posts", "comments"))] +async fn it_gets_comments(pool: MssqlPool) -> sqlx::Result<()> { + let post_1_comments: Vec = + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at") + .bind(&1) + .fetch_all(&pool) + .await?; + + assert_eq!( + post_1_comments, + ["lol bet ur still bad, 1v1 me", "you're on!"] + ); + + let post_2_comments: Vec = + sqlx::query_scalar("SELECT content FROM comment WHERE post_id = @p1 ORDER BY created_at") + .bind(&2) + .fetch_all(&pool) + .await?; + + assert_eq!(post_2_comments, ["lol you're just mad you lost :P"]); + + Ok(()) +} diff --git a/tests/mssql/types.rs b/tests/mssql/types.rs index 3e9ae7395b..5818533f9e 100644 --- a/tests/mssql/types.rs +++ b/tests/mssql/types.rs @@ -1,3 +1,5 @@ +extern crate time_ as time; + use sqlx::mssql::Mssql; use sqlx_test::test_type; @@ -18,11 +20,26 @@ test_type!(i8( "CAST(0 AS TINYINT)" == 0_i8 )); -test_type!(i16(Mssql, "CAST(21415 AS SMALLINT)" == 21415_i16)); +test_type!(i16( + Mssql, + "CAST(21415 AS SMALLINT)" == 21415_i16, + "CAST(-32768 AS SMALLINT)" == i16::MIN, + "CAST(32767 AS SMALLINT)" == i16::MAX, +)); -test_type!(i32(Mssql, "CAST(2141512 AS INT)" == 2141512_i32)); +test_type!(i32( + Mssql, + "CAST(2141512 AS INT)" == 2141512_i32, + "CAST(-2147483648 AS INT)" == i32::MIN, + "CAST(2147483647 AS INT)" == i32::MAX, +)); -test_type!(i64(Mssql, "CAST(32324324432 AS BIGINT)" == 32324324432_i64)); +test_type!(i64( + Mssql, + "CAST(32324324432 AS BIGINT)" == 32324324432_i64, + "CAST(-9223372036854775808 AS BIGINT)" == i64::MIN, + "CAST(9223372036854775807 AS BIGINT)" == i64::MAX, +)); test_type!(f32( Mssql, @@ -34,6 +51,26 @@ test_type!(f64( "CAST(939399419.1225182 AS FLOAT)" == 939399419.1225182_f64 )); +test_type!(f64_money( + Mssql, + "CAST(922337203685477.5807 AS MONEY)" == 922337203685477.5807_f64, + "CAST(0 AS MONEY)" == 0.0_f64, + "CAST(-1234.5678 AS MONEY)" == -1234.5678_f64, +)); + +test_type!(f64_smallmoney( + Mssql, + "CAST(214748.3647 AS SMALLMONEY)" == 214748.3647_f64, + "CAST(0 AS SMALLMONEY)" == 0.0_f64, + "CAST(-1234.5678 AS SMALLMONEY)" == -1234.5678_f64, +)); + +#[cfg(feature = "rust_decimal")] +test_type!(rust_decimal_smallmoney(Mssql, + "CAST(214748.3647 AS SMALLMONEY)" == sqlx::types::Decimal::new(2147483647, 4), + "CAST(0 AS SMALLMONEY)" == sqlx::types::Decimal::ZERO, +)); + test_type!(str_nvarchar(Mssql, "CAST('this is foo' as NVARCHAR)" == "this is foo", )); @@ -48,3 +85,242 @@ test_type!(bool( "CAST(1 as BIT)" == true, "CAST(0 as BIT)" == false )); + +test_type!(bytes>(Mssql, + "CAST(0xDEADBEEF AS VARBINARY(MAX))" + == vec![0xDE_u8, 0xAD, 0xBE, 0xEF], + "CAST(0x AS VARBINARY(MAX))" + == Vec::::new(), + "CAST(0x0000000000000000 AS VARBINARY(MAX))" + == vec![0_u8; 8], +)); + +test_type!(bytes_single>(Mssql, + "CAST(0xFF AS VARBINARY(MAX))" == vec![0xFF_u8], +)); + +test_type!(bytes_large>(Mssql, + "CAST(REPLICATE(CAST(0xAB AS VARBINARY(MAX)), 10000) AS VARBINARY(MAX))" + == vec![0xAB_u8; 10000], +)); + +test_type!(str_nchar(Mssql, + "CAST('hello' AS NCHAR(5))" == "hello", +)); + +test_type!(str_varchar(Mssql, + "CAST('hello varchar' AS VARCHAR(50))" == "hello varchar", +)); + +test_type!(str_unicode(Mssql, + "CAST(N'\u{1F600}\u{1F680}\u{2764}' AS NVARCHAR(MAX))" == "\u{1F600}\u{1F680}\u{2764}", + "CAST(N'\u{4F60}\u{597D}\u{4E16}\u{754C}' AS NVARCHAR(MAX))" == "\u{4F60}\u{597D}\u{4E16}\u{754C}", +)); + +test_type!(str_nvarchar_max_large(Mssql, + "REPLICATE(CAST(N'x' AS NVARCHAR(MAX)), 10000)" + == "x".repeat(10000), +)); + +test_type!(null_bool>(Mssql, + "CAST(NULL AS BIT)" == None::, +)); + +test_type!(null_string>(Mssql, + "CAST(NULL AS NVARCHAR(100))" == None::, +)); + +test_type!(null_i64>(Mssql, + "CAST(NULL AS BIGINT)" == None::, +)); + +test_type!(null_f64>(Mssql, + "CAST(NULL AS FLOAT)" == None::, +)); + +test_type!(null_bytes>>(Mssql, + "CAST(NULL AS VARBINARY(MAX))" == None::>, +)); + +test_type!(xml(Mssql, + "CAST('hello' AS XML)" + == sqlx::mssql::MssqlXml::from("hello".to_owned()), +)); + +#[cfg(feature = "uuid")] +test_type!(uuid(Mssql, + "CAST('00000000-0000-0000-0000-000000000000' AS UNIQUEIDENTIFIER)" + == sqlx::types::Uuid::nil(), + "CAST('936da01f-9abd-4d9d-80c7-02af85c822a8' AS UNIQUEIDENTIFIER)" + == sqlx::types::Uuid::parse_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap(), +)); + +#[cfg(feature = "chrono")] +mod chrono { + use sqlx::mssql::Mssql; + use sqlx_test::test_type; + + type NaiveDate = sqlx::types::chrono::NaiveDate; + type NaiveTime = sqlx::types::chrono::NaiveTime; + type NaiveDateTime = sqlx::types::chrono::NaiveDateTime; + type DateTimeUtc = sqlx::types::chrono::DateTime; + type DateTimeFixed = sqlx::types::chrono::DateTime; + type FixedOffset = sqlx::types::chrono::FixedOffset; + + test_type!(chrono_naive_date(Mssql, + "CAST('2001-01-05' AS DATE)" + == NaiveDate::from_ymd_opt(2001, 1, 5).unwrap(), + "CAST('2050-11-23' AS DATE)" + == NaiveDate::from_ymd_opt(2050, 11, 23).unwrap(), + )); + + test_type!(chrono_naive_time(Mssql, + "CAST('05:10:20' AS TIME)" + == NaiveTime::from_hms_opt(5, 10, 20).unwrap(), + "CAST('00:00:00' AS TIME)" + == NaiveTime::from_hms_opt(0, 0, 0).unwrap(), + )); + + test_type!(chrono_naive_date_time(Mssql, + "CAST('2019-01-02 05:10:20' AS DATETIME2)" + == NaiveDateTime::new( + NaiveDate::from_ymd_opt(2019, 1, 2).unwrap(), + NaiveTime::from_hms_opt(5, 10, 20).unwrap(), + ), + )); + + test_type!(chrono_date_time_utc(Mssql, + "CAST('2019-01-02 05:10:20.000 +00:00' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2019, 1, 2) + .unwrap() + .and_hms_opt(5, 10, 20) + .unwrap() + .and_utc(), + )); + + test_type!(chrono_date_time_fixed_utc(Mssql, + "CAST('2019-01-02 05:10:20.000 +00:00' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2019, 1, 2) + .unwrap() + .and_hms_opt(5, 10, 20) + .unwrap() + .and_local_timezone(FixedOffset::east_opt(0).unwrap()) + .unwrap(), + )); + + test_type!(chrono_date_time_fixed_positive(Mssql, + "CAST('2024-06-15 14:30:00.000 +05:30' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2024, 6, 15) + .unwrap() + .and_hms_opt(14, 30, 0) + .unwrap() + .and_local_timezone(FixedOffset::east_opt(5 * 3600 + 30 * 60).unwrap()) + .unwrap(), + )); + + test_type!(chrono_date_time_fixed_negative(Mssql, + "CAST('2024-12-25 08:00:00.000 -08:00' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2024, 12, 25) + .unwrap() + .and_hms_opt(8, 0, 0) + .unwrap() + .and_local_timezone(FixedOffset::west_opt(8 * 3600).unwrap()) + .unwrap(), + )); + + // Verify DateTime can decode from DATETIMEOFFSET with non-zero offset + // (the value should be converted to UTC) + test_type!(chrono_date_time_utc_from_offset(Mssql, + "CAST('2024-06-15 14:30:00.000 +05:30' AS DATETIMEOFFSET)" + == NaiveDate::from_ymd_opt(2024, 6, 15) + .unwrap() + .and_hms_opt(9, 0, 0) + .unwrap() + .and_utc(), + )); +} + +#[cfg(feature = "time")] +mod time_tests { + use sqlx::mssql::Mssql; + use sqlx_test::test_type; + + type TimeDate = sqlx::types::time::Date; + type TimeTime = sqlx::types::time::Time; + type TimePrimitiveDateTime = sqlx::types::time::PrimitiveDateTime; + type TimeOffsetDateTime = sqlx::types::time::OffsetDateTime; + + use time::macros::{date, datetime, time as time_macro}; + + test_type!(time_date(Mssql, + "CAST('2001-01-05' AS DATE)" + == date!(2001-01-05), + "CAST('2050-11-23' AS DATE)" + == date!(2050-11-23), + )); + + test_type!(time_time(Mssql, + "CAST('05:10:20' AS TIME)" + == time_macro!(05:10:20), + "CAST('00:00:00' AS TIME)" + == time_macro!(00:00:00), + )); + + test_type!(time_primitive_date_time(Mssql, + "CAST('2019-01-02 05:10:20' AS DATETIME2)" + == datetime!(2019-01-02 05:10:20), + )); + + test_type!(time_offset_date_time(Mssql, + "CAST('2019-01-02 05:10:20.000 +00:00' AS DATETIMEOFFSET)" + == datetime!(2019-01-02 05:10:20 UTC), + )); +} + +#[cfg(feature = "rust_decimal")] +test_type!(rust_decimal(Mssql, + "CAST('0' AS DECIMAL(10,2))" == sqlx::types::Decimal::ZERO, + "CAST('1.23' AS DECIMAL(10,2))" == sqlx::types::Decimal::new(123, 2), + "CAST('-1.23' AS DECIMAL(10,2))" == sqlx::types::Decimal::new(-123, 2), +)); + +#[cfg(feature = "rust_decimal")] +test_type!(rust_decimal_money(Mssql, + "CAST(1234.5678 AS MONEY)" == sqlx::types::Decimal::new(12345678, 4), + "CAST(0 AS MONEY)" == sqlx::types::Decimal::ZERO, +)); + +#[cfg(feature = "bigdecimal")] +test_type!(bigdecimal(Mssql, + "CAST('0' AS DECIMAL(10,2))" == "0.00".parse::().unwrap(), + "CAST('1.23' AS DECIMAL(10,2))" == "1.23".parse::().unwrap(), + "CAST('-1.23' AS DECIMAL(10,2))" == "-1.23".parse::().unwrap(), +)); + +#[cfg(feature = "bigdecimal")] +test_type!(bigdecimal_money(Mssql, + "CAST(1234.5678 AS MONEY)" == "1234.5678".parse::().unwrap(), + "CAST(0 AS MONEY)" == "0".parse::().unwrap(), +)); + +#[cfg(feature = "json")] +mod json_tests { + use sqlx::mssql::Mssql; + use sqlx::types::Json; + use sqlx_test::test_type; + + #[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)] + struct Friend { + name: String, + age: u32, + } + + test_type!(json>(Mssql, + "CAST('{\"name\":\"Joe\",\"age\":33}' AS NVARCHAR(MAX))" + == Json(Friend { name: "Joe".to_string(), age: 33 }), + )); + + test_type!(json_value(Mssql, + "CAST('null' AS NVARCHAR(MAX))" == serde_json::Value::Null, + )); +}