diff --git a/Cargo.lock b/Cargo.lock index 7a2979fc80..7df78920e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -272,9 +272,9 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.4.0" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" dependencies = [ "event-listener 5.4.0", "event-listener-strategy", @@ -1021,26 +1021,22 @@ checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "criterion" -version = "0.5.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" dependencies = [ "anes", "cast", "ciborium", "clap", "criterion-plot", - "futures", - "is-terminal", - "itertools 0.10.5", + "itertools 0.13.0", "num-traits", - "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "tokio", @@ -1049,12 +1045,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", ] [[package]] @@ -1246,6 +1242,19 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "ease-off" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20e90ae5e739d99dc0406f9a4e2307a999625e2414d2ecc4fbb4ded8a3945f77" +dependencies = [ + "async-io", + "pin-project", + "rand", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "either" version = "1.15.0" @@ -1383,7 +1392,7 @@ checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ "futures-core", "futures-sink", - "spin", + "spin 0.9.8", ] [[package]] @@ -1446,20 +1455,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -1990,17 +1985,6 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" -[[package]] -name = "is-terminal" -version = "0.4.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.59.0", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -2099,7 +2083,7 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin", + "spin 0.9.8", ] [[package]] @@ -3481,6 +3465,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3517,6 +3510,7 @@ dependencies = [ "tempfile", "time", "tokio", + "tracing", "trybuild", "url", ] @@ -3552,6 +3546,7 @@ dependencies = [ "async-fs", "async-global-executor 3.1.0", "async-io", + "async-lock", "async-std", "async-task", "base64 0.22.1", @@ -3563,10 +3558,10 @@ dependencies = [ "chrono", "crc", "crossbeam-queue", + "ease-off", "either", "event-listener 5.4.0", "futures-core", - "futures-intrusive", "futures-io", "futures-util", "hashbrown 0.16.0", @@ -3574,11 +3569,14 @@ dependencies = [ "indexmap 2.10.0", "ipnet", "ipnetwork", + "lock_api", "log", "mac_address", "memchr", "native-tls", "percent-encoding", + "pin-project-lite", + "rand", "rust_decimal", "rustls", "rustls-native-certs", @@ -3587,6 +3585,7 @@ dependencies = [ "sha2", "smallvec", "smol", + "spin 0.10.0", "sqlx", "thiserror 2.0.17", "time", @@ -4411,9 +4410,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -4423,9 +4422,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -4434,9 +4433,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", diff --git a/Cargo.toml b/Cargo.toml index 00d5d656c1..3c3db27d5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -189,8 +189,10 @@ uuid = "1.1.2" # Common utility crates cfg-if = "1.0.0" -dotenvy = { version = "0.15.0", default-features = false } thiserror = { version = "2.0.17", default-features = false, features = ["std"] } +dotenvy = { version = "0.15.7", default-features = false } +ease-off = "0.1.6" +rand = "0.8.5" # Runtimes [workspace.dependencies.async-global-executor] @@ -222,7 +224,6 @@ sqlx-sqlite = { workspace = true, optional = true } anyhow = "1.0.52" time_ = { version = "0.3.2", package = "time" } futures-util = { version = "0.3.19", default-features = false, features = ["alloc"] } -env_logger = "0.11" async-std = { workspace = true, features = ["attributes"] } tokio = { version = "1.15.0", features = ["full"] } dotenvy = "0.15.0" @@ -236,9 +237,12 @@ rand = "0.8.4" rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" -criterion = { version = "0.5.1", features = ["async_tokio"] } +criterion = { version = "0.7.0", features = ["async_tokio"] } libsqlite3-sys = { version = "0.30.1" } +tracing = "0.1.41" +tracing-subscriber = "0.3.20" + # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. [target.'cfg(sqlite_test_sqlcipher)'.dev-dependencies] @@ -454,3 +458,15 @@ required-features = ["postgres"] name = "postgres-rustsec" path = "tests/postgres/rustsec.rs" required-features = ["postgres", "macros", "migrate"] + +# +# Benches +# +[[bench]] +name = "any-pool" +path = "benches/any/pool.rs" +required-features = ["runtime-tokio", "any"] +harness = false + +[profile.bench] +debug = true diff --git a/benches/any/pool.rs b/benches/any/pool.rs new file mode 100644 index 0000000000..423b2ce02b --- /dev/null +++ b/benches/any/pool.rs @@ -0,0 +1,139 @@ +use criterion::{criterion_group, criterion_main, Bencher, BenchmarkId, Criterion}; +use sqlx_core::any::AnyPoolOptions; +use std::fmt::{Display, Formatter}; +use std::thread; +use std::time::{Duration, Instant}; +use tracing::Instrument; + +#[derive(Debug)] +struct Input { + threads: usize, + tasks: usize, + pool_size: usize, +} + +impl Display for Input { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "threads: {}, tasks: {}, pool size: {}", + self.threads, self.tasks, self.pool_size + ) + } +} + +fn bench_pool(c: &mut Criterion) { + sqlx::any::install_default_drivers(); + tracing_subscriber::fmt::try_init().ok(); + + let database_url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let inputs = [ + Input { + threads: 1, + tasks: 2, + pool_size: 20, + }, + Input { + threads: 2, + tasks: 4, + pool_size: 20, + }, + Input { + threads: 4, + tasks: 8, + pool_size: 20, + }, + Input { + threads: 8, + tasks: 16, + pool_size: 20, + }, + Input { + threads: 16, + tasks: 32, + pool_size: 64, + }, + Input { + threads: 16, + tasks: 128, + pool_size: 64, + }, + ]; + + let mut group = c.benchmark_group("Bench Pool"); + + for input in inputs { + group.bench_with_input(BenchmarkId::from_parameter(&input), &input, |b, i| { + bench_pool_with(b, i, &database_url) + }); + } + + group.finish(); +} + +fn bench_pool_with(b: &mut Bencher, input: &Input, database_url: &str) { + let _span = tracing::info_span!( + "bench_pool_with", + threads = input.threads, + tasks = input.tasks, + pool_size = input.pool_size + ) + .entered(); + + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(input.threads) + .build() + .unwrap(); + + let pool = runtime.block_on(async { + AnyPoolOptions::new() + .min_connections(input.pool_size) + .max_connections(input.pool_size) + .test_before_acquire(false) + .connect(database_url) + .await + .expect("error connecting to pool") + }); + + for num in 1..=input.tasks { + let pool = pool.clone(); + + runtime.spawn( + async move { while pool.acquire().await.is_ok() {} } + .instrument(tracing::info_span!("task", num)), + ); + } + + // Spawn the benchmark loop into the runtime so we're not accidentally including the main thread + b.to_async(&runtime).iter_custom(|iters| { + let pool = pool.clone(); + + async move { + tokio::spawn( + async move { + let start = Instant::now(); + + for _ in 0..iters { + if let Err(e) = pool.acquire().await { + panic!("failed to acquire connection: {e:?}"); + } + } + + start.elapsed() + } + .instrument(tracing::info_span!("iter")), + ) + .await + .expect("panic in task") + } + }); + + runtime.block_on(pool.close()); + // Give the server a second to clean up + thread::sleep(Duration::from_millis(50)); +} + +criterion_group!(benches, bench_pool,); +criterion_main!(benches); diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index fff4ef3d24..dc03c192de 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -20,13 +20,13 @@ any = [] json = ["serde", "serde_json"] # for conditional compilation -_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"] -_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration -_rt-async-std = ["async-std", "_rt-async-io"] +_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-lock", "_rt-async-task"] +_rt-async-io = ["async-io", "async-fs", "ease-off/async-io-2"] # see note at async-fs declaration +_rt-async-lock = ["async-lock"] +_rt-async-std = ["async-std", "_rt-async-io", "_rt-async-lock"] _rt-async-task = ["async-task"] -_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"] -_rt-tokio = ["tokio", "tokio-stream"] - +_rt-smol = ["smol", "_rt-async-io", "_rt-async-lock", "_rt-async-task"] +_rt-tokio = ["tokio", "tokio-stream", "ease-off/tokio"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] _tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"] @@ -73,6 +73,7 @@ uuid = { workspace = true, optional = true } # work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281 async-fs = { version = "2.1", optional = true } async-io = { version = "2.4.1", optional = true } +async-lock = { version = "3.4.1", optional = true } async-task = { version = "4.7.1", optional = true } base64 = { version = "0.22.0", default-features = false, features = ["std"] } @@ -83,7 +84,6 @@ crossbeam-queue = "0.3.2" either = "1.6.1" futures-core = { version = "0.3.19", default-features = false } futures-io = "0.3.24" -futures-intrusive = "0.5.0" futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } log = { version = "0.4.18", default-features = false } memchr = { version = "2.4.1", default-features = false } @@ -103,10 +103,32 @@ indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.16.0" +rand.workspace = true thiserror.workspace = true +ease-off = { workspace = true, default-features = false } +pin-project-lite = "0.2.14" + +# N.B. we don't actually utilize spinlocks, we just need a `Mutex` type with a few requirements: +# * Guards that are `Send` (so `parking_lot` and `std::sync` are non-starters) +# * Guards that can use `Arc` and so don't borrow (which is provided by `lock_api`) +# +# Where we actually use this (in `sqlx-core/src/pool/shard.rs`), we don't rely on the mutex itself for anything but +# safe shared mutability. The `Shard` structure has its own synchronization, and only uses `Mutex::try_lock()`. +# +# We *could* use either `tokio::sync::Mutex` or `async_lock::Mutex` for this, but those have all the code for the +# async support, which we don't need. +[dependencies.spin] +version = "0.10.0" +default-features = false +features = ["mutex", "lock_api", "spin_mutex"] + +[dependencies.lock_api] +version = "0.4.13" +features = ["arc_lock"] + [dev-dependencies] -tokio = { version = "1", features = ["rt"] } +tokio = { version = "1", features = ["rt", "sync"] } [dev-dependencies.sqlx] # FIXME: https://github.com/rust-lang/cargo/issues/15622 diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 8c6f424cdf..8dfcc92a99 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -1,16 +1,19 @@ //! Types for working with errors produced by SQLx. +use crate::database::Database; use std::any::type_name; use std::borrow::Cow; use std::error::Error as StdError; use std::fmt::Display; use std::io; - -use crate::database::Database; +use std::sync::Arc; use crate::type_info::TypeInfo; use crate::types::Type; +#[cfg(doc)] +use crate::pool::{PoolConnector, PoolOptions}; + /// A specialized `Result` type for SQLx. pub type Result = ::std::result::Result; @@ -101,7 +104,10 @@ pub enum Error { /// /// [`Pool::acquire`]: crate::pool::Pool::acquire #[error("pool timed out while waiting for an open connection")] - PoolTimedOut, + PoolTimedOut { + #[source] + last_connect_error: Option>, + }, /// [`Pool::close`] was called while we were waiting in [`Pool::acquire`]. /// @@ -110,6 +116,19 @@ pub enum Error { #[error("attempted to acquire a connection on a closed pool")] PoolClosed, + /// A custom error that may be returned from a [`PoolConnector`] implementation. + #[error("error returned from pool connector")] + PoolConnector { + #[source] + source: BoxDynError, + + /// If `true`, `PoolConnector::connect()` is called again in an exponential backoff loop + /// up to [`PoolOptions::connect_timeout`]. + /// + /// See [`PoolConnector::connect()`] for details. + retryable: bool, + }, + /// A background worker has crashed. #[error("attempted to communicate with a crashed background worker")] WorkerCrashed, @@ -228,11 +247,6 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { #[doc(hidden)] fn into_error(self: Box) -> Box; - #[doc(hidden)] - fn is_transient_in_connect_phase(&self) -> bool { - false - } - /// Returns the name of the constraint that triggered the error, if applicable. /// If the error was caused by a conflict of a unique index, this will be the index name. /// @@ -270,6 +284,24 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { fn is_check_violation(&self) -> bool { matches!(self.kind(), ErrorKind::CheckViolation) } + + /// Returns `true` if this error can be retried when connecting to the database. + /// + /// Defaults to `false`. + /// + /// For example, the Postgres driver overrides this to return `true` for the following error codes: + /// + /// * `53300 too_many_connections`: returned when the maximum connections are exceeded + /// on the server. Assumed to be the result of a temporary overcommit + /// (e.g. an extra application replica being spun up to replace one that is going down). + /// * This error being consistently logged or returned is a likely indicator of a misconfiguration; + /// the sum of [`PoolOptions::max_connections`] for all replicas should not exceed + /// the maximum connections allowed by the server. + /// * `57P03 cannot_connect_now`: returned when the database server is still starting up + /// and the tcop component is not ready to accept connections yet. + fn is_retryable_connect_error(&self) -> bool { + false + } } impl dyn DatabaseError { diff --git a/sqlx-core/src/ext/future.rs b/sqlx-core/src/ext/future.rs new file mode 100644 index 0000000000..138f800118 --- /dev/null +++ b/sqlx-core/src/ext/future.rs @@ -0,0 +1,38 @@ +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[project = RaceProject] + pub struct Race { + #[pin] + left: L, + #[pin] + right: R, + } +} + +impl Future for Race +where + L: Future, + R: Future, +{ + type Output = Result; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(left) = this.left.as_mut().poll(cx) { + return Poll::Ready(Ok(left)); + } + + this.right.as_mut().poll(cx).map(Err) + } +} + +#[inline(always)] +pub fn race(left: L, right: R) -> Race { + Race { left, right } +} diff --git a/sqlx-core/src/ext/mod.rs b/sqlx-core/src/ext/mod.rs index 98059f8ca0..167c68560a 100644 --- a/sqlx-core/src/ext/mod.rs +++ b/sqlx-core/src/ext/mod.rs @@ -2,3 +2,5 @@ pub mod ustr; #[macro_use] pub mod async_stream; + +pub mod future; diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs new file mode 100644 index 0000000000..5adf82bf9a --- /dev/null +++ b/sqlx-core/src/pool/connect.rs @@ -0,0 +1,733 @@ +use crate::connection::{ConnectOptions, Connection}; +use crate::database::Database; +use crate::pool::connection::ConnectionInner; +use crate::pool::inner::PoolInner; +use crate::pool::{Pool, PoolConnection}; +use crate::rt::JoinHandle; +use crate::{rt, Error}; +use ease_off::EaseOff; +use event_listener::{listener, Event, EventListener}; +use std::fmt::{Display, Formatter}; +use std::future::Future; +use std::ptr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::Instant; + +use crate::pool::connection_set::DisconnectedSlot; +#[cfg(doc)] +use crate::pool::PoolOptions; +use crate::sync::{AsyncMutex, AsyncMutexGuard}; +use ease_off::core::EaseOffCore; +use std::io; +use std::ops::ControlFlow; +use std::pin::{pin, Pin}; +use std::task::{ready, Context, Poll}; + +const EASE_OFF: EaseOffCore = ease_off::Options::new().into_core(); + +/// Custom connect callback for [`Pool`][crate::pool::Pool]. +/// +/// Implemented for closures with the signature +/// `Fn(PoolConnectMetadata) -> impl Future>`. +/// +/// See [`Self::connect()`] for details and implementation advice. +/// +/// # Example: `after_connect` Replacement +/// The `after_connect` callback was removed in 0.9.0 as it was redundant to this API. +/// +/// This example uses Postgres but may be adapted to any driver. +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use sqlx::PgConnection; +/// use sqlx::postgres::PgPoolOptions; +/// use sqlx::Connection; +/// use sqlx::pool::PoolConnectMetadata; +/// +/// async fn _example() -> sqlx::Result<()> { +/// // `PoolConnector` is implemented for closures but this has restrictions on returning borrows +/// // due to current language limitations. Custom implementations are not subject to this. +/// // +/// // This example shows how to get around this using `Arc`. +/// let database_url: Arc = "postgres://...".into(); +/// +/// let pool = PgPoolOptions::new() +/// .min_connections(5) +/// .max_connections(30) +/// // Type annotation on the argument is required for the trait impl to reseolve. +/// .connect_with_connector(move |meta: PoolConnectMetadata| { +/// let database_url = database_url.clone(); +/// async move { +/// println!( +/// "opening connection {}, attempt {}; elapsed time: {:?}", +/// meta.pool_size, +/// meta.num_attempts + 1, +/// meta.start.elapsed() +/// ); +/// +/// let mut conn = PgConnection::connect(&database_url).await?; +/// +/// // Override the time zone of the connection. +/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'") +/// .execute(&mut conn) +/// .await?; +/// +/// Ok(conn) +/// } +/// }) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Example: `set_connect_options` Replacement +/// `set_connect_options` and `get_connect_options` were removed in 0.9.0 because they complicated +/// the pool internals. They can be reimplemented by capturing a mutex, or similar, in the callback. +/// +/// This example uses Postgres and [`tokio::sync::RwLock`] but may be adapted to any driver +/// or `async-std`, respectively. +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use tokio::sync::RwLock; +/// use sqlx::PgConnection; +/// use sqlx::postgres::PgConnectOptions; +/// use sqlx::postgres::PgPoolOptions; +/// use sqlx::ConnectOptions; +/// use sqlx::pool::PoolConnectMetadata; +/// +/// async fn _example() -> sqlx::Result<()> { +/// // If you do not wish to hold the lock during the connection attempt, +/// // you could use `Arc` instead. +/// let connect_opts: Arc> = Arc::new(RwLock::new("postgres://...".parse()?)); +/// // We need a copy that will be captured by the closure. +/// let connect_opts_ = connect_opts.clone(); +/// +/// let pool = PgPoolOptions::new() +/// .connect_with_connector(move |meta: PoolConnectMetadata| { +/// let connect_opts = connect_opts_.clone(); +/// async move { +/// println!( +/// "opening connection {}, attempt {}; elapsed time: {:?}", +/// meta.pool_size, +/// meta.num_attempts + 1, +/// meta.start.elapsed() +/// ); +/// +/// connect_opts.read().await.connect().await +/// } +/// }) +/// .await?; +/// +/// // Close the connection that was previously opened by `connect_with_connector()`. +/// pool.acquire().await?.close().await?; +/// +/// // Simulating a credential rotation +/// let mut write_connect_opts = connect_opts.write().await; +/// write_connect_opts +/// .set_username("new_username") +/// .set_password("new password"); +/// +/// // Should use the new credentials. +/// let mut conn = pool.acquire().await?; +/// +/// # Ok(()) +/// # } +/// ``` +/// +/// # Example: Custom Implementation +/// +/// Custom implementations of `PoolConnector` trade a little bit of boilerplate for much +/// more flexibility. Thanks to the signature of `connect()`, they can return a `Future` +/// type that borrows from `self`. +/// +/// This example uses Postgres but may be adapted to any driver. +/// +/// ```rust,no_run +/// use sqlx::{PgConnection, Postgres}; +/// use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +/// use sqlx_core::connection::ConnectOptions; +/// use sqlx_core::pool::{PoolConnectMetadata, PoolConnector}; +/// +/// struct MyConnector { +/// // A list of servers to connect to in a high-availability configuration. +/// host_ports: Vec<(String, u16)>, +/// username: String, +/// password: String, +/// } +/// +/// impl PoolConnector for MyConnector { +/// // The desugaring of `async fn` is compatible with the signature of `connect()`. +/// async fn connect(&self, meta: PoolConnectMetadata) -> sqlx::Result { +/// self.get_connect_options(meta.num_attempts) +/// .connect() +/// .await +/// } +/// } +/// +/// impl MyConnector { +/// fn get_connect_options(&self, attempt: usize) -> PgConnectOptions { +/// // Select servers in a round-robin. +/// let (ref host, port) = self.host_ports[attempt % self.host_ports.len()]; +/// +/// PgConnectOptions::new() +/// .host(host) +/// .port(port) +/// .username(&self.username) +/// .password(&self.password) +/// } +/// } +/// +/// # async fn _example() -> sqlx::Result<()> { +/// let pool = PgPoolOptions::new() +/// .max_connections(25) +/// .connect_with_connector(MyConnector { +/// host_ports: vec![ +/// ("db1.postgres.cluster.local".into(), 5432), +/// ("db2.postgres.cluster.local".into(), 5432), +/// ("db3.postgres.cluster.local".into(), 5432), +/// ("db4.postgres.cluster.local".into(), 5432), +/// ], +/// username: "my_username".into(), +/// password: "my password".into(), +/// }) +/// .await?; +/// +/// let conn = pool.acquire().await?; +/// +/// # Ok(()) +/// # } +/// ``` +pub trait PoolConnector: Send + Sync + 'static { + /// Open a connection for the pool. + /// + /// Any setup that must be done on the connection should be performed before it is returned. + /// + /// If this method returns an error that is known to be retryable, it is called again + /// in an exponential backoff loop. Retryable errors include, but are not limited to: + /// + /// * [`io::Error`] + /// * Database errors for which + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. + /// * [`Error::PoolConnector`] with `retryable: true`. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. + /// + /// This behavior may be customized by overriding [`Self::connect_with_control_flow()`]. + /// + /// Manual implementations of this method may also use the signature: + /// ```rust,ignore + /// async fn connect( + /// &self, + /// meta: PoolConnectMetadata + /// ) -> sqlx::Result<{PgConnection, MySqlConnection, SqliteConnection, etc.}> + /// ``` + /// + /// Note: the returned future must be `Send`. + fn connect( + &self, + meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_; + + /// Open a connection for the pool, or indicate what to do on an error. + /// + /// This method may return one of the following: + /// + /// * `ControlFlow::Break(Ok(_))` with a successfully established connection. + /// * `ControlFlow::Break(Err(_))` with an error to immediately return to the caller. + /// * `ControlFlow::Continue(_)` with a retryable error. + /// The pool will call this method again in an exponential backoff loop until it succeeds, + /// or the [connect timeout][PoolOptions::connect_timeout] + /// or [acquire timeout][PoolOptions::acquire_timeout] is reached. + /// + /// # Default Implementation + /// This method has a provided implementation by default which calls [`Self::connect()`] + /// and then returns `ControlFlow::Continue` if the error is any of the following: + /// + /// * [`io::Error`] + /// * Database errors for which + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. + /// * [`Error::PoolConnector`] with `retryable: true`. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. + /// + /// A custom backoff loop may be implemented by overriding this method and retrying internally, + /// only returning `ControlFlow::Break` if/when an error should be propagated out to the caller. + /// + /// If this method is overridden and does not call [`Self::connect()`], then the implementation + /// of the latter can be a stub. It is not called internally. + fn connect_with_control_flow( + &self, + meta: PoolConnectMetadata, + ) -> impl Future, Error>> + Send + '_ { + async { + match self.connect(meta).await { + Err(err @ Error::Io(_)) => ControlFlow::Continue(err), + Err(Error::Database(dbe)) if dbe.is_retryable_connect_error() => { + ControlFlow::Continue(Error::Database(dbe)) + } + Err( + err @ Error::PoolConnector { + retryable: true, .. + }, + ) => ControlFlow::Continue(err), + res => ControlFlow::Break(res), + } + } + } +} + +/// # Note: Future Changes (FIXME) +/// This could theoretically be replaced with an impl over `AsyncFn` to allow lending closures, +/// except we have no way to put the `Send` bound on the returned future. +/// +/// We need Return Type Notation for that: https://github.com/rust-lang/rust/pull/138424 +impl PoolConnector for F +where + DB: Database, + F: Fn(PoolConnectMetadata) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + fn connect( + &self, + meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_ { + self(meta) + } +} + +pub(crate) struct DefaultConnector( + pub <::Connection as Connection>::Options, +); + +impl PoolConnector for DefaultConnector { + fn connect( + &self, + _meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_ { + self.0.connect() + } +} + +/// Metadata passed to [`PoolConnector::connect()`] for every connection attempt. +#[derive(Debug)] +#[non_exhaustive] +pub struct PoolConnectMetadata { + /// The instant at which the current connection task was started, including all attempts. + /// + /// May be used for reporting purposes, or to implement a custom backoff. + pub start: Instant, + + /// The deadline (`start` plus the [connect timeout][PoolOptions::connect_timeout], if set). + pub deadline: Option, + + /// The number of attempts that have occurred so far. + pub num_attempts: u32, + /// The current size of the pool. + pub pool_size: usize, + /// The ID of the connection, unique for the pool. + pub connection_id: ConnectionId, +} + +pub struct DynConnector { + // We want to spawn the connection attempt as a task anyway + connect: Box< + dyn Fn( + Pool, + ConnectionId, + DisconnectedSlot>, + Arc, + ) -> ConnectTask + + Send + + Sync + + 'static, + >, +} + +impl DynConnector { + pub fn new(connector: impl PoolConnector) -> Self { + let connector = Arc::new(connector); + + Self { + connect: Box::new(move |pool, id, guard, shared| { + ConnectTask::spawn(pool, id, guard, connector.clone(), shared) + }), + } + } + + pub fn connect( + &self, + pool: Pool, + id: ConnectionId, + slot: DisconnectedSlot>, + shared: Arc, + ) -> ConnectTask { + (self.connect)(pool, id, slot, shared) + } +} + +pub struct ConnectTask { + handle: JoinHandle>>, + shared: Arc, +} + +pub struct ConnectTaskShared { + cancel_event: Event, + // Using the normal `std::sync::Mutex` because the critical sections are very short; + // we only hold the lock long enough to insert or take the value. + last_error: Mutex>, +} + +impl Future for ConnectTask { + type Output = crate::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.handle).poll(cx) + } +} + +impl ConnectTask { + fn spawn( + pool: Pool, + id: ConnectionId, + guard: DisconnectedSlot>, + connector: Arc>, + shared: Arc, + ) -> Self { + let handle = crate::rt::spawn(connect_with_backoff( + pool, + id, + connector, + guard, + shared.clone(), + )); + + Self { handle, shared } + } + + pub fn cancel(&self) -> Option { + self.shared.cancel_event.notify(1); + + self.shared + .last_error + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + } +} + +impl ConnectTaskShared { + pub fn new_arc() -> Arc { + Arc::new(Self { + cancel_event: Event::new(), + last_error: Mutex::new(None), + }) + } + + pub fn take_error(&self) -> Option { + self.last_error + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + } + + fn put_error(&self, error: Error) { + *self.last_error.lock().unwrap_or_else(|e| e.into_inner()) = Some(error); + } +} + +pub struct ConnectionCounter { + count: AtomicUsize, + next_id: AtomicUsize, + connect_available: Event, +} + +/// An opaque connection ID, unique for every connection attempt with the same pool. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ConnectionId(usize); + +impl ConnectionId { + pub(super) fn next() -> ConnectionId { + static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + + ConnectionId(NEXT_ID.fetch_add(1, Ordering::AcqRel)) + } +} + +impl ConnectionCounter { + pub fn new() -> Self { + Self { + count: AtomicUsize::new(0), + next_id: AtomicUsize::new(1), + connect_available: Event::new(), + } + } + + pub fn connections(&self) -> usize { + self.count.load(Ordering::Acquire) + } + + pub async fn drain(&self) { + while self.count.load(Ordering::Acquire) > 0 { + listener!(self.connect_available => permit_released); + permit_released.await; + } + } + + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. + /// + /// Returns the permit, and the ID of the new connection. + pub fn try_acquire_permit( + &self, + pool: &Arc>, + ) -> Option<(ConnectionId, ConnectPermit)> { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + // Don't skip the queue. + if pool.options.fair && self.connect_available.total_listeners() > 0 { + return None; + } + + let prev_size = self + .count + .fetch_update(Ordering::Release, Ordering::Acquire, |connections| { + (connections < pool.options.max_connections).then_some(connections + 1) + }) + .ok()?; + + let size = prev_size + 1; + + tracing::trace!(target: "sqlx::pool::connect", size, "increased size"); + + Some(( + ConnectionId(self.next_id.fetch_add(1, Ordering::SeqCst)), + ConnectPermit { + pool: Some(Arc::clone(pool)), + }, + )) + } + + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. + /// + /// Returns the permit, and the current size of the pool. + pub async fn acquire_permit( + &self, + pool: &Arc>, + ) -> (ConnectionId, ConnectPermit) { + // Check that `self` can increase size first before we check the parent. + let acquired = self.acquire_permit_self(pool).await; + + if let Some(parent) = pool.parent() { + let (_, permit) = parent.0.counter.acquire_permit_self(&parent.0).await; + + // consume the parent permit + permit.consume(); + } + + acquired + } + + // Separate method because `async fn`s cannot be recursive. + /// Attempt to acquire a [`ConnectPermit`] from this instance and this instance only. + async fn acquire_permit_self( + &self, + pool: &Arc>, + ) -> (ConnectionId, ConnectPermit) { + for attempt in 1usize.. { + if let Some(acquired) = self.try_acquire_permit(pool) { + return acquired; + } + + if attempt == 2 { + tracing::warn!( + "unable to acquire a connect permit after sleeping; this may indicate a bug" + ); + } + + listener!(self.connect_available => connect_available); + connect_available.await; + } + + panic!("BUG: was never able to acquire a connection despite waking many times") + } + + pub fn release_permit(&self, pool: &PoolInner) { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + self.count.fetch_sub(1, Ordering::Release); + self.connect_available.notify(1usize); + + if let Some(parent) = &pool.options.parent_pool { + parent.0.counter.release_permit(&parent.0); + } + } +} + +pub struct ConnectPermit { + pool: Option>>, +} + +impl ConnectPermit { + pub fn float_existing(pool: Arc>) -> Self { + Self { pool: Some(pool) } + } + + pub fn pool(&self) -> &Arc> { + self.pool.as_ref().unwrap() + } + + pub fn consume(mut self) { + self.pool = None; + } +} + +impl Drop for ConnectPermit { + fn drop(&mut self) { + if let Some(pool) = self.pool.take() { + pool.counter.release_permit(&pool); + } + } +} + +impl Display for ConnectionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + +#[tracing::instrument( + target = "sqlx::pool::connect", + skip_all, + fields(%connection_id), + err +)] +async fn connect_with_backoff( + pool: Pool, + connection_id: ConnectionId, + connector: Arc>, + slot: DisconnectedSlot>, + shared: Arc, +) -> crate::Result> { + listener!(pool.0.on_closed => closed); + listener!(shared.cancel_event => cancelled); + + let start = Instant::now(); + let deadline = pool + .0 + .options + .connect_timeout + .and_then(|timeout| start.checked_add(timeout)); + + for attempt in 1u32.. { + let meta = PoolConnectMetadata { + start, + deadline, + num_attempts: attempt, + pool_size: pool.size(), + connection_id, + }; + + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds=start.elapsed().as_secs_f64(), + "beginning connection attempt" + ); + + let res = connector.connect_with_control_flow(meta).await; + + let now = Instant::now(); + let elapsed = now.duration_since(start); + let elapsed_seconds = elapsed.as_secs_f64(); + + match res { + ControlFlow::Break(Ok(conn)) => { + tracing::debug!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + "connection established", + ); + + return Ok(PoolConnection::new(slot.put(ConnectionInner { + pool: Arc::downgrade(&pool.0), + raw: conn, + id: connection_id, + created_at: now, + last_released_at: now, + }))); + } + ControlFlow::Break(Err(e)) => { + tracing::error!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + error=?e, + "error connecting to database", + ); + + return Err(e); + } + ControlFlow::Continue(e) => { + tracing::warn!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + error=?e, + "error connecting to database; retrying", + ); + + shared.put_error(e); + } + } + + let wait = EASE_OFF + .nth_retry_at(attempt, now, deadline, &mut rand::thread_rng()) + .map_err(|_| { + Error::PoolTimedOut { + // This should be populated by the caller + last_connect_error: None, + } + })?; + + if let Some(wait) = wait { + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + "waiting for {:?}", + wait.duration_since(now), + ); + + let mut sleep = pin!(rt::sleep_until(wait)); + + std::future::poll_fn(|cx| { + if let Poll::Ready(()) = Pin::new(&mut closed).poll(cx) { + return Poll::Ready(Err(Error::PoolClosed)); + } + + if let Poll::Ready(()) = Pin::new(&mut cancelled).poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: None, + })); + } + + ready!(sleep.as_mut().poll(cx)); + Poll::Ready(Ok(())) + }) + .await?; + } + } + + Err(Error::PoolTimedOut { + last_connect_error: None, + }) +} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 7912b12aa1..1103374cca 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,51 +1,49 @@ use std::fmt::{self, Debug, Formatter}; use std::future::{self, Future}; +use std::io; use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; -use crate::sync::AsyncSemaphoreReleaser; - use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner}; +use super::inner::PoolInner; +use crate::pool::connect::{ConnectPermit, ConnectTaskShared, ConnectionId}; +use crate::pool::connection_set::{ConnectedSlot, DisconnectedSlot}; use crate::pool::options::PoolConnectionMetadata; +use crate::pool::{Pool, PoolOptions}; +use crate::rt; -const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); +const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); +const CLOSE_TIMEOUT: Duration = Duration::from_secs(5); /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. pub struct PoolConnection { - live: Option>, + conn: Option>>, close_on_drop: bool, - pub(crate) pool: Arc>, } -pub(super) struct Live { +pub(super) struct ConnectionInner { + // Note: must be `Weak` to prevent a reference cycle + pub(crate) pool: Weak>, pub(super) raw: DB::Connection, + pub(super) id: ConnectionId, pub(super) created_at: Instant, -} - -pub(super) struct Idle { - pub(super) live: Live, - pub(super) idle_since: Instant, -} - -/// RAII wrapper for connections being handled by functions that may drop them -pub(super) struct Floating { - pub(super) inner: C, - pub(super) guard: DecrementSizeGuard, + pub(super) last_released_at: Instant, } const EXPECT_MSG: &str = "BUG: inner connection already taken!"; impl Debug for PoolConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - // TODO: Show the type name of the connection ? - f.debug_struct("PoolConnection").finish() + f.debug_struct("PoolConnection") + .field("database", &DB::NAME) + .field("id", &self.conn.as_ref().map(|live| live.id)) + .finish() } } @@ -53,13 +51,13 @@ impl Deref for PoolConnection { type Target = DB::Connection; fn deref(&self) -> &Self::Target { - &self.live.as_ref().expect(EXPECT_MSG).raw + &self.conn.as_ref().expect(EXPECT_MSG).raw } } impl DerefMut for PoolConnection { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live.as_mut().expect(EXPECT_MSG).raw + &mut self.conn.as_mut().expect(EXPECT_MSG).raw } } @@ -76,6 +74,13 @@ impl AsMut for PoolConnection { } impl PoolConnection { + pub(super) fn new(live: ConnectedSlot>) -> Self { + Self { + conn: Some(live), + close_on_drop: false, + } + } + /// Close this connection, allowing the pool to open a replacement. /// /// Equivalent to calling [`.detach()`] then [`.close()`], but the connection permit is retained @@ -84,8 +89,8 @@ impl PoolConnection { /// [`.detach()`]: PoolConnection::detach /// [`.close()`]: Connection::close pub async fn close(mut self) -> Result<(), Error> { - let floating = self.take_live().float(self.pool.clone()); - floating.inner.raw.close().await + let (res, _slot) = close(self.take_conn()).await; + res } /// Close this connection on-drop, instead of returning it to the pool. @@ -111,7 +116,8 @@ impl PoolConnection { /// [`max_connections`]: crate::pool::PoolOptions::max_connections /// [`min_connections`]: crate::pool::PoolOptions::min_connections pub fn detach(mut self) -> DB::Connection { - self.take_live().float(self.pool.clone()).detach() + let (conn, _slot) = ConnectedSlot::take(self.take_conn()); + conn.raw } /// Detach this connection from the pool, treating it as permanently checked-out. @@ -120,11 +126,13 @@ impl PoolConnection { /// /// If you don't want to impact the pool's capacity, use [`.detach()`][Self::detach] instead. pub fn leak(mut self) -> DB::Connection { - self.take_live().raw + let (conn, slot) = ConnectedSlot::take(self.take_conn()); + DisconnectedSlot::leak(slot); + conn.raw } - fn take_live(&mut self) -> Live { - self.live.take().expect(EXPECT_MSG) + fn take_conn(&mut self) -> ConnectedSlot> { + self.conn.take().expect(EXPECT_MSG) } /// Test the connection to make sure it is still live before returning it to the pool. @@ -132,46 +140,33 @@ impl PoolConnection { /// This effectively runs the drop handler eagerly instead of spawning a task to do it. #[doc(hidden)] pub fn return_to_pool(&mut self) -> impl Future + Send + 'static { - // float the connection in the pool before we move into the task - // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime - // https://github.com/launchbadge/sqlx/issues/1396 - // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). - let floating: Option>> = - self.live.take().map(|live| live.float(self.pool.clone())); - - let pool = self.pool.clone(); + let conn = self.conn.take(); async move { - let returned_to_pool = if let Some(floating) = floating { - floating.return_to_pool().await - } else { - false + let Some(conn) = conn else { + return; }; - if !returned_to_pool { - pool.min_connections_maintenance(None).await; - } + let Some(pool) = Weak::upgrade(&conn.pool) else { + return; + }; + + rt::timeout(RETURN_TO_POOL_TIMEOUT, return_to_pool(conn, &pool)) + .await + // Dropping of the `slot` will check if the connection must be re-established + // but only after trying to pass it to a task that needs it. + .ok(); } } fn take_and_close(&mut self) -> impl Future + Send + 'static { - // float the connection in the pool before we move into the task - // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime - // https://github.com/launchbadge/sqlx/issues/1396 - // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). - let floating = self.live.take().map(|live| live.float(self.pool.clone())); - - let pool = self.pool.clone(); + let conn = self.conn.take(); async move { - if let Some(floating) = floating { + if let Some(conn) = conn { // Don't hold the connection forever if it hangs while trying to close - crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close()) - .await - .ok(); + rt::timeout(CLOSE_TIMEOUT, close(conn)).await.ok(); } - - pool.min_connections_maintenance(None).await; } } } @@ -204,218 +199,188 @@ impl Drop for PoolConnection { } // We still need to spawn a task to maintain `min_connections`. - if self.live.is_some() || self.pool.options.min_connections > 0 { + if self.conn.is_some() { crate::rt::spawn(self.return_to_pool()); } } } -impl Live { - pub fn float(self, pool: Arc>) -> Floating { - Floating { - inner: self, - // create a new guard from a previously leaked permit - guard: DecrementSizeGuard::new_permit(pool), - } - } - - pub fn into_idle(self) -> Idle { - Idle { - live: self, - idle_since: Instant::now(), - } - } -} - -impl Deref for Idle { - type Target = Live; - - fn deref(&self) -> &Self::Target { - &self.live - } -} - -impl DerefMut for Idle { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live - } -} - -impl Floating> { - pub fn new_live(conn: DB::Connection, guard: DecrementSizeGuard) -> Self { - Self { - inner: Live { - raw: conn, - created_at: Instant::now(), - }, - guard, +impl ConnectionInner { + pub fn metadata(&self) -> PoolConnectionMetadata { + PoolConnectionMetadata { + age: self.created_at.elapsed(), + idle_for: Duration::ZERO, } } - pub fn reattach(self) -> PoolConnection { - let Floating { inner, guard } = self; - - let pool = Arc::clone(&guard.pool); + pub fn idle_metadata(&self) -> PoolConnectionMetadata { + // Use a single `now` value for consistency. + let now = Instant::now(); - guard.cancel(); - PoolConnection { - live: Some(inner), - close_on_drop: false, - pool, + PoolConnectionMetadata { + // NOTE: the receiver is the later `Instant` and the arg is the earlier + // https://github.com/launchbadge/sqlx/issues/1912 + age: now.saturating_duration_since(self.created_at), + idle_for: now.saturating_duration_since(self.last_released_at), } } - pub fn release(self) { - self.guard.pool.clone().release(self); - } - - /// Return the connection to the pool. - /// - /// Returns `true` if the connection was successfully returned, `false` if it was closed. - async fn return_to_pool(mut self) -> bool { - // Immediately close the connection. - if self.guard.pool.is_closed() { - self.close().await; - return false; - } + pub fn is_beyond_max_lifetime(&self, options: &PoolOptions) -> bool { + if let Some(max_lifetime) = options.max_lifetime { + let age = self.created_at.elapsed(); - // If the connection is beyond max lifetime, close the connection and - // immediately create a new connection - if is_beyond_max_lifetime(&self.inner, &self.guard.pool.options) { - self.close().await; - return false; - } + if age > max_lifetime { + tracing::info!( + target: "sqlx::pool", + connection_id=%self.id, + ?age, + "connection is beyond `max_lifetime`, closing" + ); - if let Some(test) = &self.guard.pool.options.after_release { - let meta = self.metadata(); - match (test)(&mut self.inner.raw, meta).await { - Ok(true) => (), - Ok(false) => { - self.close().await; - return false; - } - Err(error) => { - tracing::warn!(%error, "error from `after_release`"); - // Connection is broken, don't try to gracefully close as - // something weird might happen. - self.close_hard().await; - return false; - } + return true; } } - // test the connection on-release to ensure it is still viable, - // and flush anything time-sensitive like transaction rollbacks - // if an Executor future/stream is dropped during an `.await` call, the connection - // is likely to be left in an inconsistent state, in which case it should not be - // returned to the pool; also of course, if it was dropped due to an error - // this is simply a band-aid as SQLx-next connections should be able - // to recover from cancellations - if let Err(error) = self.raw.ping().await { - tracing::warn!( - %error, - "error occurred while testing the connection on-release", - ); - - // Connection is broken, don't try to gracefully close. - self.close_hard().await; - false - } else { - // if the connection is still viable, release it to the pool - self.release(); - true - } + false } - pub async fn close(self) { - // This isn't used anywhere that we care about the return value - let _ = self.inner.raw.close().await; + pub fn is_beyond_idle_timeout(&self, options: &PoolOptions) -> bool { + if let Some(idle_timeout) = options.idle_timeout { + let now = Instant::now(); - // `guard` is dropped as intended - } + let age = now.duration_since(self.created_at); + let idle_duration = now.duration_since(self.last_released_at); - pub async fn close_hard(self) { - let _ = self.inner.raw.close_hard().await; - } + if idle_duration > idle_timeout { + tracing::info!( + target: "sqlx::pool", + connection_id=%self.id, + ?age, + ?idle_duration, + "connection is beyond `idle_timeout`, closing" + ); - pub fn detach(self) -> DB::Connection { - self.inner.raw - } - - pub fn into_idle(self) -> Floating> { - Floating { - inner: self.inner.into_idle(), - guard: self.guard, + return true; + } } - } - pub fn metadata(&self) -> PoolConnectionMetadata { - PoolConnectionMetadata { - age: self.created_at.elapsed(), - idle_for: Duration::ZERO, - } + false } } -impl Floating> { - pub fn from_idle( - idle: Idle, - pool: Arc>, - permit: AsyncSemaphoreReleaser<'_>, - ) -> Self { - Self { - inner: idle, - guard: DecrementSizeGuard::from_permit(pool, permit), - } - } - - pub async fn ping(&mut self) -> Result<(), Error> { - self.live.raw.ping().await - } +pub(crate) async fn close( + conn: ConnectedSlot>, +) -> (Result<(), Error>, DisconnectedSlot>) { + let connection_id = conn.id; - pub fn into_live(self) -> Floating> { - Floating { - inner: self.inner.live, - guard: self.guard, - } - } + tracing::debug!(target: "sqlx::pool", %connection_id, "closing connection (gracefully)"); - pub async fn close(self) -> DecrementSizeGuard { - if let Err(error) = self.inner.live.raw.close().await { - tracing::debug!(%error, "error occurred while closing the pool connection"); - } - self.guard - } + let (conn, slot) = ConnectedSlot::take(conn); - pub async fn close_hard(self) -> DecrementSizeGuard { - let _ = self.inner.live.raw.close_hard().await; - - self.guard - } - - pub fn metadata(&self) -> PoolConnectionMetadata { - // Use a single `now` value for consistency. - let now = Instant::now(); + let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close()) + .await + .unwrap_or_else(|_| { + Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into()) + }) + .inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); - PoolConnectionMetadata { - // NOTE: the receiver is the later `Instant` and the arg is the earlier - // https://github.com/launchbadge/sqlx/issues/1912 - age: now.saturating_duration_since(self.created_at), - idle_for: now.saturating_duration_since(self.idle_since), - } - } + (res, slot) } +pub(crate) async fn close_hard( + conn: ConnectedSlot>, +) -> (Result<(), Error>, DisconnectedSlot>) { + let connection_id = conn.id; + + tracing::debug!( + target: "sqlx::pool", + %connection_id, + "closing connection (forcefully)" + ); + + let (conn, slot) = ConnectedSlot::take(conn); + + let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close_hard()) + .await + .unwrap_or_else(|_| { + Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into()) + }) + .inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); -impl Deref for Floating { - type Target = C; + (res, slot) +} - fn deref(&self) -> &Self::Target { - &self.inner +/// Return the connection to the pool. +/// +/// Returns `true` if the connection was successfully returned, `false` if it was closed. +async fn return_to_pool( + mut conn: ConnectedSlot>, + pool: &PoolInner, +) -> Result<(), DisconnectedSlot>> { + // Immediately close the connection. + if pool.is_closed() { + let (_res, slot) = close(conn).await; + return Err(slot); + } + + // If the connection is beyond max lifetime, close the connection and + // immediately create a new connection + if conn.is_beyond_max_lifetime(&pool.options) { + let (_res, slot) = close(conn).await; + return Err(slot); + } + + if let Some(test) = &pool.options.after_release { + let meta = conn.metadata(); + match (test)(&mut conn.raw, meta).await { + Ok(true) => (), + Ok(false) => { + let (_res, slot) = close(conn).await; + return Err(slot); + } + Err(error) => { + tracing::warn!(%error, "error from `after_release`"); + // Connection is broken, don't try to gracefully close as + // something weird might happen. + let (_res, slot) = close_hard(conn).await; + return Err(slot); + } + } } -} -impl DerefMut for Floating { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner + // test the connection on-release to ensure it is still viable, + // and flush anything time-sensitive like transaction rollbacks + // if an Executor future/stream is dropped during an `.await` call, the connection + // is likely to be left in an inconsistent state, in which case it should not be + // returned to the pool; also of course, if it was dropped due to an error + // this is simply a band-aid as SQLx-next connections should be able + // to recover from cancellations + if let Err(error) = conn.raw.ping().await { + tracing::warn!( + target: "sqlx::pool", + %error, + "error occurred while testing the connection on-release", + ); + + // Connection is broken, don't try to gracefully close. + let (_res, slot) = close_hard(conn).await; + Err(slot) + } else { + // if the connection is still viable, release it to the pool + drop(conn); + Ok(()) } } diff --git a/sqlx-core/src/pool/connection_set.rs b/sqlx-core/src/pool/connection_set.rs new file mode 100644 index 0000000000..8683f8a902 --- /dev/null +++ b/sqlx-core/src/pool/connection_set.rs @@ -0,0 +1,543 @@ +use crate::ext::future::race; +use crate::rt; +use crate::sync::{AsyncMutex, AsyncMutexGuardArc}; +use event_listener::{listener, Event, EventListener, IntoNotification}; +use futures_core::Stream; +use futures_util::stream::FuturesUnordered; +use futures_util::{FutureExt, StreamExt}; +use std::cmp; +use std::future::Future; +use std::ops::{Deref, DerefMut, RangeInclusive, RangeToInclusive}; +use std::pin::{pin, Pin}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; + +pub struct ConnectionSet { + global: Arc, + slots: Box<[Arc>]>, +} + +pub struct ConnectedSlot(SlotGuard); + +pub struct DisconnectedSlot(SlotGuard); + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum AcquirePreference { + Connected, + Disconnected, + Either, +} + +struct Global { + unlock_event: Event, + disconnect_event: Event, + num_connected: AtomicUsize, + min_connections: usize, + min_connections_event: Event<()>, +} + +struct SlotGuard { + slot: Arc>, + // `Option` allows us to take the guard in the drop handler. + locked: Option>>, +} + +struct Slot { + // By having each `Slot` hold its own reference to `Global`, we can avoid extra contended clones + // which would sap performance + global: Arc, + index: usize, + // I'd love to eliminate this redundant `Arc` but it's likely not possible without `unsafe` + connection: Arc>>, + unlock_event: Event, + disconnect_event: Event, + connected: AtomicBool, + locked: AtomicBool, + leaked: AtomicBool, +} + +impl ConnectionSet { + pub fn new(size: RangeInclusive) -> Self { + let global = Arc::new(Global { + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + num_connected: AtomicUsize::new(0), + min_connections: *size.start(), + min_connections_event: Event::with_tag(), + }); + + ConnectionSet { + // `vec![; size].into()` clones `` instead of repeating it, + // which is *no bueno* when wrapping something in `Arc` + slots: (0..*size.end()) + .map(|index| { + Arc::new(Slot { + global: global.clone(), + index, + connection: Arc::new(AsyncMutex::new(None)), + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + connected: AtomicBool::new(false), + locked: AtomicBool::new(false), + leaked: AtomicBool::new(false), + }) + }) + .collect(), + global, + } + } + + #[inline(always)] + pub fn num_connected(&self) -> usize { + self.global.num_connected() + } + + pub fn count_idle(&self) -> usize { + self.slots.iter().filter(|slot| slot.is_locked()).count() + } + + pub async fn acquire_connected(&self) -> ConnectedSlot { + self.acquire_inner(AcquirePreference::Connected) + .await + .assert_connected() + } + + pub async fn acquire_disconnected(&self) -> DisconnectedSlot { + self.acquire_inner(AcquirePreference::Disconnected) + .await + .assert_disconnected() + } + + /// Attempt to acquire the connection associated with the current thread. + pub async fn acquire_any(&self) -> Result, DisconnectedSlot> { + self.acquire_inner(AcquirePreference::Either) + .await + .try_connected() + } + + async fn acquire_inner(&self, pref: AcquirePreference) -> SlotGuard { + /// Smallest time-step supported by [`tokio::time::sleep()`]. + /// + /// `async-io` doesn't document a minimum time-step, instead deferring to the platform. + const STEP_INTERVAL: Duration = Duration::from_millis(1); + + const SEARCH_LIMIT: usize = 5; + + let preferred_slot = current_thread_id() % self.slots.len(); + + tracing::trace!(preferred_slot, ?pref, "acquire_inner"); + + // Always try to lock the connection associated with our thread ID + let mut acquire_preferred = pin!(self.slots[preferred_slot].acquire(pref)); + + let mut step_interval = pin!(rt::interval_after(STEP_INTERVAL)); + + let mut intervals_elapsed = 0usize; + + let mut search_slots = FuturesUnordered::new(); + + let mut listen_global = pin!(self.global.listen(pref)); + + let mut search_slot = self.next_slot(preferred_slot); + + std::future::poll_fn(|cx| loop { + if let Poll::Ready(locked) = acquire_preferred.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + // Don't push redundant futures for small sets. + let search_limit = cmp::min(SEARCH_LIMIT, self.slots.len()); + + if search_slots.len() < search_limit && step_interval.as_mut().poll_tick(cx).is_ready() + { + intervals_elapsed = intervals_elapsed.saturating_add(1); + + if search_slot != preferred_slot && self.slots[search_slot].matches_pref(pref) { + search_slots.push(self.slots[search_slot].lock()); + } + + search_slot = self.next_slot(search_slot); + } + + if let Poll::Ready(Some(locked)) = Pin::new(&mut search_slots).poll_next(cx) { + if locked.matches_pref(pref) { + return Poll::Ready(locked); + } + + continue; + } + + if intervals_elapsed > search_limit && search_slots.len() < search_limit { + if let Poll::Ready(slot) = listen_global.as_mut().poll(cx) { + if self.slots[slot].matches_pref(pref) { + search_slots.push(self.slots[slot].lock()); + } + + listen_global.as_mut().set(self.global.listen(pref)); + continue; + } + } + + return Poll::Pending; + }) + .await + } + + pub fn try_acquire_connected(&self) -> Option> { + Some( + self.try_acquire(AcquirePreference::Connected)? + .assert_connected(), + ) + } + + pub fn try_acquire_disconnected(&self) -> Option> { + Some( + self.try_acquire(AcquirePreference::Disconnected)? + .assert_disconnected(), + ) + } + + fn try_acquire(&self, pref: AcquirePreference) -> Option> { + let mut search_slot = current_thread_id() % self.slots.len(); + + for _ in 0..self.slots.len() { + if let Some(locked) = self.slots[search_slot].try_acquire(pref) { + return Some(locked); + } + + search_slot = self.next_slot(search_slot); + } + + None + } + + pub fn min_connections_listener(&self) -> EventListener { + self.global.min_connections_event.listen() + } + + pub fn iter_idle(&self) -> impl Iterator> + '_ { + self.slots.iter().filter_map(|slot| { + Some( + slot.try_acquire(AcquirePreference::Connected)? + .assert_connected(), + ) + }) + } + + pub async fn drain(&self, ref close: impl AsyncFn(ConnectedSlot) -> DisconnectedSlot) { + let mut closing = FuturesUnordered::new(); + + // We could try to be more efficient by only populating the `FuturesUnordered` for + // connected slots, but then we'd have to handle a disconnected slot becoming connected, + // which could happen concurrently. + // + // However, we don't *need* to be efficient when shutting down the pool. + for slot in &self.slots { + closing.push(async { + let locked = slot.lock().await; + + let slot = match locked.try_connected() { + Ok(connected) => close(connected).await, + Err(disconnected) => disconnected, + }; + + // The pool is shutting down; don't wake any tasks that might have been interested + slot.leak(); + }); + } + + while closing.next().await.is_some() {} + } + + #[inline(always)] + fn next_slot(&self, slot: usize) -> usize { + // By adding a number that is coprime to `slots.len()` before taking the modulo, + // we can visit each slot in a pseudo-random order, spreading the demand evenly. + // + // Interestingly, this pattern returns to the original slot after `slots.len()` iterations, + // because of congruence: https://en.wikipedia.org/wiki/Modular_arithmetic#Congruence + (slot + 547) % self.slots.len() + } +} + +impl AcquirePreference { + #[inline(always)] + fn wants_connected(&self, is_connected: bool) -> bool { + match (self, is_connected) { + (Self::Connected, true) => true, + (Self::Disconnected, false) => true, + (Self::Either, _) => true, + _ => false, + } + } +} + +impl Slot { + #[inline(always)] + fn matches_pref(&self, pref: AcquirePreference) -> bool { + !self.is_leaked() && pref.wants_connected(self.is_connected()) + } + + #[inline(always)] + fn is_connected(&self) -> bool { + self.connected.load(Ordering::Relaxed) + } + + #[inline(always)] + fn is_locked(&self) -> bool { + self.locked.load(Ordering::Relaxed) + } + + #[inline(always)] + fn is_leaked(&self) -> bool { + self.leaked.load(Ordering::Relaxed) + } + + #[inline(always)] + fn set_is_connected(&self, connected: bool) { + let was_connected = self.connected.swap(connected, Ordering::Acquire); + + match (connected, was_connected) { + (false, true) => { + // Ensure this is synchronized with `connected` + self.global.num_connected.fetch_add(1, Ordering::Release); + } + (true, false) => { + self.global.num_connected.fetch_sub(1, Ordering::Release); + } + _ => (), + } + } + + async fn acquire(self: &Arc, pref: AcquirePreference) -> SlotGuard { + loop { + if self.matches_pref(pref) { + tracing::trace!(slot_index=%self.index, "waiting for lock"); + + let locked = self.lock().await; + + if locked.matches_pref(pref) { + return locked; + } + } + + match pref { + AcquirePreference::Connected => { + listener!(self.unlock_event => listener); + listener.await; + } + AcquirePreference::Disconnected => { + listener!(self.disconnect_event => listener); + listener.await + } + AcquirePreference::Either => { + listener!(self.unlock_event => unlock_listener); + listener!(self.disconnect_event => disconnect_listener); + race(unlock_listener, disconnect_listener).await.ok(); + } + } + } + } + + fn try_acquire(self: &Arc, pref: AcquirePreference) -> Option> { + if self.matches_pref(pref) { + let locked = self.try_lock()?; + + if locked.matches_pref(pref) { + return Some(locked); + } + } + + None + } + + async fn lock(self: &Arc) -> SlotGuard { + let locked = crate::sync::lock_arc(&self.connection).await; + + self.locked.store(true, Ordering::Relaxed); + + SlotGuard { + slot: self.clone(), + locked: Some(locked), + } + } + + fn try_lock(self: &Arc) -> Option> { + let locked = crate::sync::try_lock_arc(&self.connection)?; + + self.locked.store(true, Ordering::Relaxed); + + Some(SlotGuard { + slot: self.clone(), + locked: Some(locked), + }) + } +} + +impl SlotGuard { + #[inline(always)] + fn get(&self) -> &Option { + self.locked.as_ref().expect(EXPECT_LOCKED) + } + + #[inline(always)] + fn get_mut(&mut self) -> &mut Option { + self.locked.as_mut().expect(EXPECT_LOCKED) + } + + #[inline(always)] + fn matches_pref(&self, pref: AcquirePreference) -> bool { + !self.slot.is_leaked() && pref.wants_connected(self.is_connected()) + } + + #[inline(always)] + fn is_connected(&self) -> bool { + self.get().is_some() + } + + fn try_connected(self) -> Result, DisconnectedSlot> { + if self.is_connected() { + Ok(ConnectedSlot(self)) + } else { + Err(DisconnectedSlot(self)) + } + } + + fn assert_connected(self) -> ConnectedSlot { + assert!(self.is_connected()); + ConnectedSlot(self) + } + + fn assert_disconnected(self) -> DisconnectedSlot { + assert!(!self.is_connected()); + + DisconnectedSlot(self) + } + + /// Updates `Slot::connected` without notifying the `ConnectionSet`. + /// + /// Returns `Some(connected)` or `None` if this guard was already dropped. + fn drop_without_notify(&mut self) -> Option { + self.locked.take().map(|locked| { + let connected = locked.is_some(); + self.slot.set_is_connected(connected); + self.slot.locked.store(false, Ordering::Release); + connected + }) + } +} + +const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation"; +const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`"; + +impl ConnectedSlot { + pub fn take(mut self) -> (C, DisconnectedSlot) { + let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED); + (conn, self.0.assert_disconnected()) + } +} + +impl Deref for ConnectedSlot { + type Target = C; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.0.get().as_ref().expect(EXPECT_CONNECTED) + } +} + +impl DerefMut for ConnectedSlot { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.get_mut().as_mut().expect(EXPECT_CONNECTED) + } +} + +impl DisconnectedSlot { + pub fn put(mut self, conn: C) -> ConnectedSlot { + *self.0.get_mut() = Some(conn); + ConnectedSlot(self.0) + } + + pub fn leak(mut self) { + self.0.slot.leaked.store(true, Ordering::Release); + self.0.drop_without_notify(); + } +} + +impl Drop for SlotGuard { + fn drop(&mut self) { + let Some(connected) = self.drop_without_notify() else { + return; + }; + + let event = if connected { + &self.slot.global.unlock_event + } else { + &self.slot.global.disconnect_event + }; + + if event.notify(1.tag(self.slot.index).additional()) != 0 { + return; + } + + if connected { + self.slot.unlock_event.notify(1); + return; + } + + if self.slot.disconnect_event.notify(1) != 0 { + return; + } + + if self.slot.global.num_connected() < self.slot.global.min_connections { + self.slot.global.min_connections_event.notify(1); + } + } +} + +impl Global { + #[inline(always)] + fn num_connected(&self) -> usize { + self.num_connected.load(Ordering::Relaxed) + } + + async fn listen(&self, pref: AcquirePreference) -> usize { + match pref { + AcquirePreference::Either => race(self.listen_unlocked(), self.listen_disconnected()) + .await + .unwrap_or_else(|slot| slot), + AcquirePreference::Connected => self.listen_unlocked().await, + AcquirePreference::Disconnected => self.listen_disconnected().await, + } + } + + async fn listen_unlocked(&self) -> usize { + listener!(self.unlock_event => listener); + listener.await + } + + async fn listen_disconnected(&self) -> usize { + listener!(self.disconnect_event => listener); + listener.await + } +} + +fn current_thread_id() -> usize { + // FIXME: this can be replaced when this is stabilized: + // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + // `SeqCst` is possibly too strong since we don't need synchronization with + // any other variable. I'm not confident enough in my understanding of atomics to be certain, + // especially with regards to weakly ordered architectures. + // + // However, this is literally only done once on each thread, so it doesn't really matter. + static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + } + + CURRENT_THREAD_ID.with(|i| *i) +} diff --git a/sqlx-core/src/pool/idle.rs b/sqlx-core/src/pool/idle.rs new file mode 100644 index 0000000000..602ed3c5c8 --- /dev/null +++ b/sqlx-core/src/pool/idle.rs @@ -0,0 +1,100 @@ +use crate::connection::Connection; +use crate::database::Database; +use crate::pool::connection::{Floating, Idle, ConnectionInner}; +use crate::pool::inner::PoolInner; +use crossbeam_queue::ArrayQueue; +use event_listener::Event; +use futures_util::FutureExt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +use event_listener::listener; + +pub struct IdleQueue { + queue: ArrayQueue>, + // Keep a separate count because `ArrayQueue::len()` loops until the head and tail pointers + // stop changing, which may never happen at high contention. + len: AtomicUsize, + release_event: Event, + fair: bool, +} + +impl IdleQueue { + pub fn new(fair: bool, cap: usize) -> Self { + Self { + queue: ArrayQueue::new(cap), + len: AtomicUsize::new(0), + release_event: Event::new(), + fair, + } + } + + pub fn len(&self) -> usize { + self.len.load(Ordering::Acquire) + } + + pub async fn acquire(&self, pool: &Arc>) -> Floating> { + let mut should_wait = self.fair && self.release_event.total_listeners() > 0; + + for attempt in 1usize.. { + if should_wait { + listener!(self.release_event => release_event); + release_event.await; + } + + if let Some(conn) = self.try_acquire(pool) { + return conn; + } + + should_wait = true; + + if attempt == 2 { + tracing::warn!( + "unable to acquire a connection after sleeping; this may indicate a bug" + ); + } + } + + panic!("BUG: was never able to acquire a connection despite waking many times") + } + + pub fn try_acquire(&self, pool: &Arc>) -> Option>> { + self.len + .fetch_update(Ordering::Release, Ordering::Acquire, |len| { + len.checked_sub(1) + }) + .ok() + .and_then(|_| { + let conn = self.queue.pop()?; + + Some(Floating::from_idle(conn, Arc::clone(pool))) + }) + } + + pub fn release(&self, conn: Floating>) { + let Floating { + inner: conn, + permit, + } = conn.into_idle(); + + self.queue + .push(conn) + .unwrap_or_else(|_| panic!("BUG: idle queue capacity exceeded")); + + self.len.fetch_add(1, Ordering::Release); + + self.release_event.notify(1usize); + + // Don't decrease the size. + permit.consume(); + } + + pub fn drain(&self, pool: &PoolInner) { + while let Some(conn) = self.queue.pop() { + // Hopefully will send at least a TCP FIN packet. + conn.live.raw.close_hard().now_or_never(); + + pool.counter.release_permit(pool); + } + } +} diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index b698dc9df0..1ae687f1d1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -1,33 +1,34 @@ -use super::connection::{Floating, Idle, Live}; -use crate::connection::ConnectOptions; -use crate::connection::Connection; +use super::connection::ConnectionInner; use crate::database::Database; use crate::error::Error; -use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions}; -use crossbeam_queue::ArrayQueue; - -use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; +use crate::pool::{connection, CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; use std::cmp; -use std::future::{self, Future}; -use std::pin::pin; -use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; -use std::sync::{Arc, RwLock}; -use std::task::Poll; +use std::future::Future; +use std::ops::ControlFlow; +use std::pin::{pin, Pin}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; +use crate::connection::Connection; +use crate::ext::future::race; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::options::PoolConnectionMetadata; -use crate::private_tracing_dynamic_event; -use futures_util::FutureExt; +use crate::pool::connect::{ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector}; +use crate::pool::connection_set::{ConnectedSlot, ConnectionSet, DisconnectedSlot}; +use crate::{private_tracing_dynamic_event, rt}; +use event_listener::listener; +use futures_util::future::{self}; use std::time::{Duration, Instant}; use tracing::Level; +const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5); +const TEST_BEFORE_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(60); + pub(crate) struct PoolInner { - pub(super) connect_options: RwLock::Options>>, - pub(super) idle_conns: ArrayQueue>, - pub(super) semaphore: AsyncSemaphore, - pub(super) size: AtomicU32, - pub(super) num_idle: AtomicUsize, + pub(super) connector: DynConnector, + pub(super) counter: ConnectionCounter, + pub(super) connections: ConnectionSet>, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -38,50 +39,30 @@ pub(crate) struct PoolInner { impl PoolInner { pub(super) fn new_arc( options: PoolOptions, - connect_options: ::Options, + connector: impl PoolConnector, ) -> Arc { - let capacity = options.max_connections as usize; - - let semaphore_capacity = if let Some(parent) = &options.parent_pool { - assert!(options.max_connections <= parent.options().max_connections); - assert_eq!(options.fair, parent.options().fair); - // The child pool must steal permits from the parent - 0 - } else { - capacity - }; - - let pool = Self { - connect_options: RwLock::new(Arc::new(connect_options)), - idle_conns: ArrayQueue::new(capacity), - semaphore: AsyncSemaphore::new(options.fair, semaphore_capacity), - size: AtomicU32::new(0), - num_idle: AtomicUsize::new(0), + let pool = Arc::new(Self { + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + connections: ConnectionSet::new(options.min_connections..=options.max_connections), is_closed: AtomicBool::new(false), on_closed: event_listener::Event::new(), acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), options, - }; - - let pool = Arc::new(pool); + }); spawn_maintenance_tasks(&pool); pool } - pub(super) fn size(&self) -> u32 { - self.size.load(Ordering::Acquire) + pub(super) fn size(&self) -> usize { + self.connections.num_connected() } pub(super) fn num_idle(&self) -> usize { - // We don't use `self.idle_conns.len()` as it waits for the internal - // head and tail pointers to stop changing for a moment before calculating the length, - // which may take a long time at high levels of churn. - // - // By maintaining our own atomic count, we avoid that issue entirely. - self.num_idle.load(Ordering::Acquire) + self.connections.count_idle() } pub(super) fn is_closed(&self) -> bool { @@ -96,26 +77,11 @@ impl PoolInner { pub(super) fn close(self: &Arc) -> impl Future + '_ { self.mark_closed(); - async move { - // For child pools, we need to acquire permits we actually have rather than - // max_connections - let permits_to_acquire = if self.options.parent_pool.is_some() { - // Child pools start with 0 permits, so we acquire based on current size - self.size() - } else { - // Parent pools can acquire all max_connections permits - self.options.max_connections - }; - - let _permits = self.semaphore.acquire(permits_to_acquire).await; - - while let Some(idle) = self.idle_conns.pop() { - let _ = idle.live.raw.close().await; - } - - self.num_idle.store(0, Ordering::Release); - self.size.store(0, Ordering::Release); - } + // Keep clearing the idle queue as connections are released until the count reaches zero. + self.connections.drain(async |slot| { + let (_res, slot) = connection::close(slot).await; + slot + }) } pub(crate) fn close_event(&self) -> CloseEvent { @@ -124,177 +90,63 @@ impl PoolInner { } } - /// Attempt to pull a permit from `self.semaphore` or steal one from the parent. - /// - /// If we steal a permit from the parent but *don't* open a connection, - /// it should be returned to the parent. - async fn acquire_permit(self: &Arc) -> Result, Error> { - let parent = self - .parent() - // If we're already at the max size, we shouldn't try to steal from the parent. - // This is just going to cause unnecessary churn in `acquire()`. - .filter(|_| self.size() < self.options.max_connections); - - let mut acquire_self = pin!(self.semaphore.acquire(1).fuse()); - let mut close_event = pin!(self.close_event()); - - if let Some(parent) = parent { - let mut acquire_parent = pin!(parent.0.semaphore.acquire(1)); - let mut parent_close_event = pin!(parent.0.close_event()); - - let mut poll_parent = false; - - future::poll_fn(|cx| { - if close_event.as_mut().poll(cx).is_ready() { - return Poll::Ready(Err(Error::PoolClosed)); - } - - if parent_close_event.as_mut().poll(cx).is_ready() { - // Propagate the parent's close event to the child. - self.mark_closed(); - return Poll::Ready(Err(Error::PoolClosed)); - } - - if let Poll::Ready(permit) = acquire_self.as_mut().poll(cx) { - return Poll::Ready(Ok(permit)); - } - - // Don't try the parent right away. - if poll_parent { - acquire_parent.as_mut().poll(cx).map(Ok) - } else { - poll_parent = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - }) - .await - } else { - close_event.do_until(acquire_self).await - } - } - - fn parent(&self) -> Option<&Pool> { + pub(super) fn parent(&self) -> Option<&Pool> { self.options.parent_pool.as_ref() } #[inline] - pub(super) fn try_acquire(self: &Arc) -> Option>> { + pub(super) fn try_acquire(self: &Arc) -> Option>> { if self.is_closed() { return None; } - let permit = self.semaphore.try_acquire(1)?; - - self.pop_idle(permit).ok() + self.connections.try_acquire_connected() } - fn pop_idle<'a>( - self: &'a Arc, - permit: AsyncSemaphoreReleaser<'a>, - ) -> Result>, AsyncSemaphoreReleaser<'a>> { - if let Some(idle) = self.idle_conns.pop() { - self.num_idle.fetch_sub(1, Ordering::AcqRel); - Ok(Floating::from_idle(idle, (*self).clone(), permit)) - } else { - Err(permit) + pub(super) async fn acquire(self: &Arc) -> Result, Error> { + if self.is_closed() { + return Err(Error::PoolClosed); } - } - pub(super) fn release(&self, floating: Floating>) { - // `options.after_release` and other checks are in `PoolConnection::return_to_pool()`. - - let Floating { inner: idle, guard } = floating.into_idle(); + let acquire_started_at = Instant::now(); - if self.idle_conns.push(idle).is_err() { - panic!("BUG: connection queue overflow in release()"); - } + // Lazily allocated `Arc` + let mut connect_shared = None; - // NOTE: we need to make sure we drop the permit *after* we push to the idle queue - // don't decrease the size - guard.release_permit(); + let res = { + // Pinned to the stack without allocating + listener!(self.on_closed => close_listener); + let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); + let mut acquire_inner = pin!(self.acquire_inner(&mut connect_shared)); - self.num_idle.fetch_add(1, Ordering::AcqRel); - } - - /// Try to atomically increment the pool size for a new connection. - /// - /// Returns `Err` if the pool is at max capacity already or is closed. - pub(super) fn try_increment_size<'a>( - self: &'a Arc, - permit: AsyncSemaphoreReleaser<'a>, - ) -> Result, AsyncSemaphoreReleaser<'a>> { - let result = self - .size - .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { + std::future::poll_fn(|cx| { if self.is_closed() { - return None; + return Poll::Ready(Err(Error::PoolClosed)); } - size.checked_add(1) - .filter(|size| size <= &self.options.max_connections) - }); - - match result { - // we successfully incremented the size - Ok(_) => Ok(DecrementSizeGuard::from_permit((*self).clone(), permit)), - // the pool is at max capacity or is closed - Err(_) => Err(permit), - } - } + // The result doesn't matter so much as the wakeup + let _ = Pin::new(&mut close_listener).poll(cx); - pub(super) async fn acquire(self: &Arc) -> Result>, Error> { - if self.is_closed() { - return Err(Error::PoolClosed); - } - - let acquire_started_at = Instant::now(); - let deadline = acquire_started_at + self.options.acquire_timeout; - - let acquired = crate::rt::timeout( - self.options.acquire_timeout, - async { - loop { - // Handles the close-event internally - let permit = self.acquire_permit().await?; - - - // First attempt to pop a connection from the idle queue. - let guard = match self.pop_idle(permit) { - - // Then, check that we can use it... - Ok(conn) => match check_idle_conn(conn, &self.options).await { - - // All good! - Ok(live) => return Ok(live), - - // if the connection isn't usable for one reason or another, - // we get the `DecrementSizeGuard` back to open a new one - Err(guard) => guard, - }, - Err(permit) => if let Ok(guard) = self.try_increment_size(permit) { - // we can open a new connection - guard - } else { - // This can happen for a child pool that's at its connection limit, - // or if the pool was closed between `acquire_permit()` and - // `try_increment_size()`. - tracing::debug!("woke but was unable to acquire idle connection or open new one; retrying"); - // If so, we're likely in the current-thread runtime if it's Tokio, - // and so we should yield to let any spawned return_to_pool() tasks - // execute. - crate::rt::yield_now().await; - continue; - } - }; - - // Attempt to connect... - return self.connect(deadline, guard).await; + if let Poll::Ready(()) = deadline.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: None, + })); } - } - ) + + acquire_inner.as_mut().poll(cx) + }) .await - .map_err(|_| Error::PoolTimedOut)??; + }; + + let acquired = res.map_err(|e| match e { + Error::PoolTimedOut { + last_connect_error: None, + } => Error::PoolTimedOut { + last_connect_error: connect_shared + .and_then(|shared| Some(shared.take_error()?.into())), + }, + e => e, + })?; let acquired_after = acquire_started_at.elapsed(); @@ -322,302 +174,273 @@ impl PoolInner { Ok(acquired) } - pub(super) async fn connect( + async fn acquire_inner( self: &Arc, - deadline: Instant, - guard: DecrementSizeGuard, - ) -> Result>, Error> { - if self.is_closed() { - return Err(Error::PoolClosed); - } + connect_shared: &mut Option>, + ) -> Result, Error> { + tracing::trace!("waiting for any connection"); + + let disconnected = match self.connections.acquire_any().await { + Ok(conn) => match finish_acquire(self, conn).await { + Ok(conn) => return Ok(conn), + Err(slot) => slot, + }, + Err(slot) => slot, + }; - let mut backoff = Duration::from_millis(10); - let max_backoff = deadline_as_timeout(deadline)? / 5; + let mut connect_task = self.connector.connect( + Pool(self.clone()), + ConnectionId::next(), + disconnected, + connect_shared.insert(ConnectTaskShared::new_arc()).clone(), + ); loop { - let timeout = deadline_as_timeout(deadline)?; - - // clone the connect options arc so it can be used without holding the RwLockReadGuard - // across an async await point - let connect_options = self - .connect_options - .read() - .expect("write-lock holder panicked") - .clone(); - - // result here is `Result, TimeoutError>` - // if this block does not return, sleep for the backoff timeout and try again - match crate::rt::timeout(timeout, connect_options.connect()).await { - // successfully established connection - Ok(Ok(mut raw)) => { - // See comment on `PoolOptions::after_connect` - let meta = PoolConnectionMetadata { - age: Duration::ZERO, - idle_for: Duration::ZERO, - }; - - let res = if let Some(callback) = &self.options.after_connect { - callback(&mut raw, meta).await - } else { - Ok(()) - }; - - match res { - Ok(()) => return Ok(Floating::new_live(raw, guard)), - Err(error) => { - tracing::error!(%error, "error returned from after_connect"); - // The connection is broken, don't try to close nicely. - let _ = raw.close_hard().await; - - // Fall through to the backoff. - } - } - } - - // an IO error while connecting is assumed to be the system starting up - Ok(Err(Error::Io(e))) if e.kind() == std::io::ErrorKind::ConnectionRefused => (), - - // We got a transient database error, retry. - Ok(Err(Error::Database(error))) if error.is_transient_in_connect_phase() => (), - - // Any other error while connection should immediately - // terminate and bubble the error up + match race(&mut connect_task, self.connections.acquire_connected()).await { + Ok(Ok(conn)) => return Ok(conn), Ok(Err(e)) => return Err(e), - - // timed out - Err(_) => return Err(Error::PoolTimedOut), + Err(conn) => match finish_acquire(self, conn).await { + Ok(conn) => return Ok(conn), + Err(_) => continue, + }, } - - // If the connection is refused, wait in exponentially - // increasing steps for the server to come up, - // capped by a factor of the remaining time until the deadline - crate::rt::sleep(backoff).await; - backoff = cmp::min(backoff * 2, max_backoff); } } - /// Try to maintain `min_connections`, returning any errors (including `PoolTimedOut`). - pub async fn try_min_connections(self: &Arc, deadline: Instant) -> Result<(), Error> { - while self.size() < self.options.min_connections { - // Don't wait for a semaphore permit. - // - // If no extra permits are available then we shouldn't be trying to spin up - // connections anyway. - let Some(permit) = self.semaphore.try_acquire(1) else { - return Ok(()); - }; - - // We must always obey `max_connections`. - let Some(guard) = self.try_increment_size(permit).ok() else { - return Ok(()); - }; - - // We skip `after_release` since the connection was never provided to user code - // besides `after_connect`, if they set it. - self.release(self.connect(deadline, guard).await?); + pub(crate) async fn try_min_connections( + self: &Arc, + deadline: Option, + ) -> Result<(), Error> { + let shared = ConnectTaskShared::new_arc(); + + let connect_min_connections = future::try_join_all( + (self.connections.num_connected()..self.options.min_connections) + .filter_map(|_| self.connections.try_acquire_disconnected()) + .map(|slot| { + self.connector.connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + shared.clone(), + ) + }), + ); + + let conns = if let Some(deadline) = deadline { + match rt::timeout_at(deadline, connect_min_connections).await { + Ok(Ok(conns)) => conns, + Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => { + return Err(Error::PoolTimedOut { + last_connect_error: shared.take_error().map(Box::new), + }); + } + Ok(Err(e)) => return Err(e), + } + } else { + connect_min_connections.await? + }; + + for mut conn in conns { + // Bypass `after_release` + drop(conn.return_to_pool()); } Ok(()) } - - /// Attempt to maintain `min_connections`, logging if unable. - pub async fn min_connections_maintenance(self: &Arc, deadline: Option) { - let deadline = deadline.unwrap_or_else(|| { - // Arbitrary default deadline if the caller doesn't care. - Instant::now() + Duration::from_secs(300) - }); - - match self.try_min_connections(deadline).await { - Ok(()) => (), - Err(Error::PoolClosed) => (), - Err(Error::PoolTimedOut) => { - tracing::debug!("unable to complete `min_connections` maintenance before deadline") - } - Err(error) => tracing::debug!(%error, "error while maintaining min_connections"), - } - } } impl Drop for PoolInner { fn drop(&mut self) { self.mark_closed(); - - if let Some(parent) = &self.options.parent_pool { - // Release the stolen permits. - parent.0.semaphore.release(self.semaphore.permits()); - } } } -/// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise. -pub(super) fn is_beyond_max_lifetime( - live: &Live, - options: &PoolOptions, -) -> bool { - options - .max_lifetime - .is_some_and(|max| live.created_at.elapsed() > max) -} - -/// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. -fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions) -> bool { - options - .idle_timeout - .is_some_and(|timeout| idle.idle_since.elapsed() > timeout) -} +/// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. +/// +/// Otherwise, immediately returns the connection. +async fn finish_acquire( + pool: &Arc>, + mut conn: ConnectedSlot>, +) -> Result, DisconnectedSlot>> { + struct SpawnOnDrop(Option>>) + where + F::Output: Send + 'static; + + impl Future for SpawnOnDrop + where + F::Output: Send + 'static, + { + type Output = F::Output; + + #[inline(always)] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0 + .as_mut() + .expect("BUG: inner future taken") + .as_mut() + .poll(cx) + } + } -async fn check_idle_conn( - mut conn: Floating>, - options: &PoolOptions, -) -> Result>, DecrementSizeGuard> { - if options.test_before_acquire { - // Check that the connection is still live - if let Err(error) = conn.ping().await { - // an error here means the other end has hung up or we lost connectivity - // either way we're fine to just discard the connection - // the error itself here isn't necessarily unexpected so WARN is too strong - tracing::info!(%error, "ping on idle connection returned error"); - // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); + impl Drop for SpawnOnDrop + where + F::Output: Send + 'static, + { + fn drop(&mut self) { + rt::try_spawn(self.0.take().expect("BUG: inner future taken")); } } - if let Some(test) = &options.before_acquire { - let meta = conn.metadata(); - match test(&mut conn.live.raw, meta).await { - Ok(false) => { - // connection was rejected by user-defined hook, close nicely - return Err(conn.close().await); + async fn finish_inner( + conn: &mut ConnectedSlot>, + pool: &PoolInner, + ) -> ControlFlow<()> { + // Check that the connection is still live + if pool.options.test_before_acquire { + if let Err(error) = conn.raw.ping().await { + // an error here means the other end has hung up or we lost connectivity + // either way we're fine to just discard the connection + // the error itself here isn't necessarily unexpected so WARN is too strong + tracing::info!(%error, connection_id=%conn.id, "ping on idle connection returned error"); + return ControlFlow::Break(()); } + } - Err(error) => { - tracing::warn!(%error, "error from `before_acquire`"); - // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); - } + if let Some(test) = &pool.options.before_acquire { + let meta = conn.idle_metadata(); + match test(&mut conn.raw, meta).await { + Ok(false) => { + // connection was rejected by user-defined hook, close nicely + tracing::debug!(connection_id=%conn.id, "connection rejected by `before_acquire`"); + return ControlFlow::Break(()); + } + + Err(error) => { + tracing::warn!(%error, "error from `before_acquire`"); + return ControlFlow::Break(()); + } - Ok(true) => {} + Ok(true) => (), + } } + + // Checks passed + ControlFlow::Continue(()) + } + + if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + let pool = pool.clone(); + + // Spawn a task on-drop so the call may complete even if `acquire()` is cancelled. + conn = SpawnOnDrop(Some(Box::pin(async move { + match rt::timeout(TEST_BEFORE_ACQUIRE_TIMEOUT, finish_inner(&mut conn, &pool)).await { + Ok(ControlFlow::Continue(())) => { + Ok(conn) + } + Ok(ControlFlow::Break(())) => { + // Connection rejected by user-defined hook, attempt to close nicely + let (_res, slot) = connection::close(conn).await; + Err(slot) + } + Err(_) => { + tracing::info!(connection_id=%conn.id, "`before_acquire` checks timed out, closing connection"); + let (_res, slot) = connection::close_hard(conn).await; + Err(slot) + } + } + }))).await?; } - // No need to re-connect; connection is alive or we don't care - Ok(conn.into_live()) + tracing::debug!( + target: "sqlx::pool", + connection_id=%conn.id, + "acquired idle connection" + ); + + Ok(PoolConnection::new(conn)) } fn spawn_maintenance_tasks(pool: &Arc>) { - // NOTE: use `pool_weak` for the maintenance tasks - // so they don't keep `PoolInner` from being dropped. - let pool_weak = Arc::downgrade(pool); + if pool.options.min_connections > 0 { + // NOTE: use `pool_weak` for the maintenance tasks + // so they don't keep `PoolInner` from being dropped. + let pool_weak = Arc::downgrade(pool); + let mut close_event = pool.close_event(); + + rt::spawn(async move { + close_event + .do_until(check_min_connections(pool_weak)) + .await + .ok(); + }); + } - let period = match (pool.options.max_lifetime, pool.options.idle_timeout) { + let check_interval = match (pool.options.max_lifetime, pool.options.idle_timeout) { (Some(it), None) | (None, Some(it)) => it, - (Some(a), Some(b)) => cmp::min(a, b), - - (None, None) => { - if pool.options.min_connections > 0 { - crate::rt::spawn(async move { - if let Some(pool) = pool_weak.upgrade() { - pool.min_connections_maintenance(None).await; - } - }); - } - - return; - } + (None, None) => return, }; - // Immediately cancel this task if the pool is closed. + let pool_weak = Arc::downgrade(pool); let mut close_event = pool.close_event(); - crate::rt::spawn(async move { + rt::spawn(async move { let _ = close_event - .do_until(async { - // If the last handle to the pool was dropped while we were sleeping - while let Some(pool) = pool_weak.upgrade() { - if pool.is_closed() { - return; - } - - let next_run = Instant::now() + period; - - // Go over all idle connections, check for idleness and lifetime, - // and if we have fewer than min_connections after reaping a connection, - // open a new one immediately. Note that other connections may be popped from - // the queue in the meantime - that's fine, there is no harm in checking more - for _ in 0..pool.num_idle() { - if let Some(conn) = pool.try_acquire() { - if is_beyond_idle_timeout(&conn, &pool.options) - || is_beyond_max_lifetime(&conn, &pool.options) - { - let _ = conn.close().await; - pool.min_connections_maintenance(Some(next_run)).await; - } else { - pool.release(conn.into_live()); - } - } - } - - // Don't hold a reference to the pool while sleeping. - drop(pool); - - if let Some(duration) = next_run.checked_duration_since(Instant::now()) { - // `async-std` doesn't have a `sleep_until()` - crate::rt::sleep(duration).await; - } else { - // `next_run` is in the past, just yield. - crate::rt::yield_now().await; - } - } - }) + .do_until(check_idle_conns(pool_weak, check_interval)) .await; }); } -/// RAII guard returned by `Pool::try_increment_size()` and others. -/// -/// Will decrement the pool size if dropped, to avoid semantically "leaking" connections -/// (where the pool thinks it has more connections than it does). -pub(in crate::pool) struct DecrementSizeGuard { - pub(crate) pool: Arc>, - cancelled: bool, -} +async fn check_idle_conns(pool_weak: Weak>, check_interval: Duration) { + let mut interval = pin!(rt::interval_after(check_interval)); -impl DecrementSizeGuard { - /// Create a new guard that will release a semaphore permit on-drop. - pub fn new_permit(pool: Arc>) -> Self { - Self { - pool, - cancelled: false, + while let Some(pool) = pool_weak.upgrade() { + if pool.is_closed() { + return; } - } - pub fn from_permit(pool: Arc>, permit: AsyncSemaphoreReleaser<'_>) -> Self { - // here we effectively take ownership of the permit - permit.disarm(); - Self::new_permit(pool) - } + // Go over all idle connections, check for idleness and lifetime, + // and if we have fewer than min_connections after reaping a connection, + // open a new one immediately. + for conn in pool.connections.iter_idle() { + if conn.is_beyond_idle_timeout(&pool.options) + || conn.is_beyond_max_lifetime(&pool.options) + { + // Dropping the slot will check if the connection needs to be re-made. + let _ = connection::close(conn).await; + } + } - /// Release the semaphore permit without decreasing the pool size. - /// - /// If the permit was stolen from the pool's parent, it will be returned to the child's semaphore. - fn release_permit(self) { - self.pool.semaphore.release(1); - self.cancel(); - } + // Don't hold a reference to the pool while sleeping. + drop(pool); - pub fn cancel(mut self) { - self.cancelled = true; + interval.as_mut().tick().await; } } -impl Drop for DecrementSizeGuard { - fn drop(&mut self) { - if !self.cancelled { - self.pool.size.fetch_sub(1, Ordering::AcqRel); +async fn check_min_connections(pool_weak: Weak>) { + while let Some(pool) = pool_weak.upgrade() { + if pool.is_closed() { + return; + } - // and here we release the permit we got on construction - self.pool.semaphore.release(1); + match pool.try_min_connections(None).await { + Ok(()) => { + let listener = pool.connections.min_connections_listener(); + + // Important: don't hold a strong ref while sleeping + drop(pool); + + listener.await; + } + Err(e) => { + tracing::warn!( + target: "sqlx::pool::maintenance", + min_connections=pool.options.min_connections, + num_connected=pool.connections.num_connected(), + "unable to maintain `min_connections`: {e:?}", + ); + } } } } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index f11ff1d76a..224ee8ffb6 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -56,21 +56,20 @@ use std::fmt; use std::future::Future; -use std::pin::{pin, Pin}; +use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use std::time::{Duration, Instant}; - -use event_listener::EventListener; -use futures_core::FusedFuture; -use futures_util::FutureExt; use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::ext::future::race; use crate::sql_str::SqlSafeStr; use crate::transaction::Transaction; - +use event_listener::EventListener; +use futures_core::FusedFuture; +use tracing::Instrument; +pub use self::connect::{PoolConnectMetadata, PoolConnector}; pub use self::connection::PoolConnection; use self::inner::PoolInner; #[doc(hidden)] @@ -83,10 +82,17 @@ mod executor; #[macro_use] pub mod maybe; +mod connect; mod connection; mod inner; + +// mod idle; mod options; +// mod shard; + +mod connection_set; + /// An asynchronous pool of SQLx database connections. /// /// Create a pool with [Pool::connect] or [Pool::connect_with] and then call [Pool::acquire] @@ -356,15 +362,22 @@ impl Pool { /// returning it. pub fn acquire(&self) -> impl Future, Error>> + 'static { let shared = self.0.clone(); - async move { shared.acquire().await.map(|conn| conn.reattach()) } + async move { shared.acquire().await } + .instrument(tracing::error_span!(target: "sqlx::pool", "acquire")) } /// Attempts to retrieve a connection from the pool if there is one available. /// /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. + /// + /// # Note: Bypasses `before_acquire` + /// Since this function is not `async`, it cannot await the future returned by + /// [`before_acquire`][PoolOptions::before_acquire] without blocking. + /// + /// Instead, it simply returns the connection immediately. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| conn.into_live().reattach()) + self.0.try_acquire().map(|conn| PoolConnection::new(conn)) } /// Retrieves a connection and immediately begins a new transaction. @@ -532,7 +545,7 @@ impl Pool { } /// Returns the number of connections currently active. This includes idle connections. - pub fn size(&self) -> u32 { + pub fn size(&self) -> usize { self.0.size() } @@ -541,28 +554,6 @@ impl Pool { self.0.num_idle() } - /// Gets a clone of the connection options for this pool - pub fn connect_options(&self) -> Arc<::Options> { - self.0 - .connect_options - .read() - .expect("write-lock holder panicked") - .clone() - } - - /// Updates the connection options this pool will use when opening any future connections. Any - /// existing open connection in the pool will be left as-is. - pub fn set_connect_options(&self, connect_options: ::Options) { - // technically write() could also panic if the current thread already holds the lock, - // but because this method can't be re-entered by the same thread that shouldn't be a problem - let mut guard = self - .0 - .connect_options - .write() - .expect("write-lock holder panicked"); - *guard = Arc::new(connect_options); - } - /// Get the options for this pool pub fn options(&self) -> &PoolOptions { &self.0.options @@ -592,42 +583,19 @@ impl CloseEvent { /// /// Cancels the future and returns `Err(PoolClosed)` if/when the pool is closed. /// If the pool was already closed, the future is never run. + #[inline(always)] pub async fn do_until(&mut self, fut: Fut) -> Result { - // Check that the pool wasn't closed already. - // - // We use `poll_immediate()` as it will use the correct waker instead of - // a no-op one like `.now_or_never()`, but it won't actually suspend execution here. - futures_util::future::poll_immediate(&mut *self) - .await - .map_or(Ok(()), |_| Err(Error::PoolClosed))?; - - let mut fut = pin!(fut); - - // I find that this is clearer in intent than `futures_util::future::select()` - // or `futures_util::select_biased!{}` (which isn't enabled anyway). - std::future::poll_fn(|cx| { - // Poll `fut` first as the wakeup event is more likely for it than `self`. - if let Poll::Ready(ret) = fut.as_mut().poll(cx) { - return Poll::Ready(Ok(ret)); - } - - // Can't really factor out mapping to `Err(Error::PoolClosed)` though it seems like - // we should because that results in a different `Ok` type each time. - // - // Ideally we'd map to something like `Result` but using `!` as a type - // is not allowed on stable Rust yet. - self.poll_unpin(cx).map(|_| Err(Error::PoolClosed)) - }) - .await + race(fut, self).await.map_err(|_| Error::PoolClosed) } } impl Future for CloseEvent { type Output = (); + #[inline(always)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(listener) = &mut self.listener { - ready!(listener.poll_unpin(cx)); + ready!(Pin::new(listener).poll(cx)); } // `EventListener` doesn't like being polled after it yields, and even if it did it @@ -646,15 +614,6 @@ impl FusedFuture for CloseEvent { } } -/// get the time between the deadline and now and use that as our timeout -/// -/// returns `Error::PoolTimedOut` if the deadline is in the past -fn deadline_as_timeout(deadline: Instant) -> Result { - deadline - .checked_duration_since(Instant::now()) - .ok_or(Error::PoolTimedOut) -} - #[test] #[allow(dead_code)] fn assert_pool_traits() { diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 3d048f1795..975583e6f7 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -1,11 +1,14 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::pool::connect::{ConnectTaskShared, ConnectionId, DefaultConnector}; use crate::pool::inner::PoolInner; -use crate::pool::Pool; +use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; +use futures_util::{stream, TryStreamExt}; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; +use std::num::NonZero; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -44,14 +47,6 @@ use std::time::{Duration, Instant}; /// the perspectives of both API designer and consumer. pub struct PoolOptions { pub(crate) test_before_acquire: bool, - pub(crate) after_connect: Option< - Arc< - dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>> - + 'static - + Send - + Sync, - >, - >, pub(crate) before_acquire: Option< Arc< dyn Fn( @@ -74,12 +69,14 @@ pub struct PoolOptions { + Sync, >, >, - pub(crate) max_connections: u32, + pub(crate) max_connections: usize, + pub(crate) shards: NonZero, pub(crate) acquire_time_level: LevelFilter, pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, pub(crate) acquire_timeout: Duration, - pub(crate) min_connections: u32, + pub(crate) connect_timeout: Option, + pub(crate) min_connections: usize, pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, pub(crate) fair: bool, @@ -94,14 +91,15 @@ impl Clone for PoolOptions { fn clone(&self) -> Self { PoolOptions { test_before_acquire: self.test_before_acquire, - after_connect: self.after_connect.clone(), before_acquire: self.before_acquire.clone(), after_release: self.after_release.clone(), max_connections: self.max_connections, + shards: self.shards, acquire_time_level: self.acquire_time_level, acquire_slow_threshold: self.acquire_slow_threshold, acquire_slow_level: self.acquire_slow_level, acquire_timeout: self.acquire_timeout, + connect_timeout: self.connect_timeout, min_connections: self.min_connections, max_lifetime: self.max_lifetime, idle_timeout: self.idle_timeout, @@ -143,13 +141,13 @@ impl PoolOptions { pub fn new() -> Self { Self { // User-specifiable routines - after_connect: None, before_acquire: None, after_release: None, test_before_acquire: true, // A production application will want to set a higher limit than this. max_connections: 10, min_connections: 0, + shards: NonZero::::MIN, // Logging all acquires is opt-in acquire_time_level: LevelFilter::Off, // Default to warning, because an acquire timeout will be an error @@ -158,6 +156,7 @@ impl PoolOptions { // to not flag typical time to add a new connection to a pool. acquire_slow_threshold: Duration::from_secs(2), acquire_timeout: Duration::from_secs(30), + connect_timeout: None, idle_timeout: Some(Duration::from_secs(10 * 60)), max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, @@ -170,13 +169,13 @@ impl PoolOptions { /// Be mindful of the connection limits for your database as well as other applications /// which may want to connect to the same database (or even multiple instances of the same /// application in high-availability deployments). - pub fn max_connections(mut self, max: u32) -> Self { + pub fn max_connections(mut self, max: usize) -> Self { self.max_connections = max; self } /// Get the maximum number of connections that this pool should maintain - pub fn get_max_connections(&self) -> u32 { + pub fn get_max_connections(&self) -> usize { self.max_connections } @@ -202,16 +201,68 @@ impl PoolOptions { /// [`max_lifetime`]: Self::max_lifetime /// [`idle_timeout`]: Self::idle_timeout /// [`max_connections`]: Self::max_connections - pub fn min_connections(mut self, min: u32) -> Self { + pub fn min_connections(mut self, min: usize) -> Self { self.min_connections = min; self } /// Get the minimum number of connections to maintain at all times. - pub fn get_min_connections(&self) -> u32 { + pub fn get_min_connections(&self) -> usize { self.min_connections } + /// Set the number of shards to split the internal structures into. + /// + /// The default value is dynamically determined based on the configured number of worker threads + /// in the current runtime (if that information is available), + /// or [`std::thread::available_parallelism()`], + /// or 1 otherwise. + /// + /// Each shard is assigned an equal share of [`max_connections`][Self::max_connections] + /// and its own queue of tasks waiting to acquire a connection. + /// + /// Then, when accessing the pool, each thread selects a "local" shard based on its + /// [thread ID][std::thread::Thread::id]1. + /// + /// If the number of shards equals the number of threads (which they do by default), + /// and worker threads are spawned sequentially (which they generally are), + /// each thread should access a different shard, which should significantly reduce + /// cache coherence overhead on multicore systems. + /// + /// If the number of shards does not evenly divide `max_connections`, + /// the implementation makes a best-effort to distribute them as evenly as possible + /// (if `remainder = max_connections % shards` and `remainder != 0`, + /// then `remainder` shards will get one additional connection each). + /// + /// The implementation then clamps the number of connections in a shard to the range `[1, 64]`. + /// + /// ### Details + /// When a task calls [`Pool::acquire()`] (or any other method that calls `acquire()`), + /// it will first attempt to acquire a connection from its thread-local shard, or lock an empty + /// slot to open a new connection (acquiring an idle connection and opening a new connection + /// happen concurrently to minimize acquire time). + /// + /// Failing that, it joins the wait list on the shard. Released connections are passed to + /// waiting tasks in a first-come, first-serve order per shard. + /// + /// If the task cannot acquire a connection after a short delay, + /// it tries to acquire a connection from another shard. + /// + /// If the task _still_ cannot acquire a connection after a longer delay, + /// it joins a global wait list. Tasks in the global wait list are the highest priority + /// for released connections, implementing a kind of eventual fairness. + /// + /// 1 because, as of writing, [`std::thread::ThreadId::as_u64`] is unstable, + /// the current implementation assigns each thread its own sequential ID in a `thread_local!()`. + pub fn shards(mut self, shards: NonZero) -> Self { + self.shards = shards; + self + } + + pub fn get_shards(&self) -> usize { + self.shards.get() + } + /// Enable logging of time taken to acquire a connection from the connection pool via /// [`Pool::acquire()`]. /// @@ -268,6 +319,23 @@ impl PoolOptions { self.acquire_timeout } + /// Set the maximum amount of time to spend attempting to open a connection. + /// + /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. + /// + /// If shorter than `acquire_timeout`, this will cause the last connec + pub fn connect_timeout(mut self, timeout: impl Into>) -> Self { + self.connect_timeout = timeout.into(); + self + } + + /// Get the maximum amount of time to spend attempting to open a connection. + /// + /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. + pub fn get_connect_timeout(&self) -> Option { + self.connect_timeout + } + /// Set the maximum lifetime of individual connections. /// /// Any connection with a lifetime greater than this will be closed. @@ -339,57 +407,6 @@ impl PoolOptions { self } - /// Perform an asynchronous action after connecting to the database. - /// - /// If the operation returns with an error then the error is logged, the connection is closed - /// and a new one is opened in its place and the callback is invoked again. - /// - /// This occurs in a backoff loop to avoid high CPU usage and spamming logs during a transient - /// error condition. - /// - /// Note that this may be called for internally opened connections, such as when maintaining - /// [`min_connections`][Self::min_connections], that are then immediately returned to the pool - /// without invoking [`after_release`][Self::after_release]. - /// - /// # Example: Additional Parameters - /// This callback may be used to set additional configuration parameters - /// that are not exposed by the database's `ConnectOptions`. - /// - /// This example is written for PostgreSQL but can likely be adapted to other databases. - /// - /// ```no_run - /// # async fn f() -> Result<(), Box> { - /// use sqlx::Executor; - /// use sqlx::postgres::PgPoolOptions; - /// - /// let pool = PgPoolOptions::new() - /// .after_connect(|conn, _meta| Box::pin(async move { - /// // When directly invoking `Executor` methods, - /// // it is possible to execute multiple statements with one call. - /// conn.execute("SET application_name = 'your_app'; SET search_path = 'my_schema';") - /// .await?; - /// - /// Ok(()) - /// })) - /// .connect("postgres:// …").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self]. - pub fn after_connect(mut self, callback: F) -> Self - where - // We're passing the `PoolConnectionMetadata` here mostly for future-proofing. - // `age` and `idle_for` are obviously not useful for fresh connections. - for<'c> F: Fn(&'c mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'c, Result<(), Error>> - + 'static - + Send - + Sync, - { - self.after_connect = Some(Arc::new(callback)); - self - } - /// Perform an asynchronous action on a previously idle connection before giving it out. /// /// Alongside the connection, the closure gets [`PoolConnectionMetadata`] which contains @@ -537,23 +554,30 @@ impl PoolOptions { pub async fn connect_with( self, options: ::Options, + ) -> Result, Error> { + self.connect_with_connector(DefaultConnector(options)).await + } + + /// Create a new pool from this `PoolOptions` and immediately open at least one connection. + /// + /// This ensures the configuration is correct. + /// + /// The total number of connections opened is max(1, [min_connections][Self::min_connections]). + /// + /// See [PoolConnector] for examples. + pub async fn connect_with_connector( + self, + connector: impl PoolConnector, ) -> Result, Error> { // Don't take longer than `acquire_timeout` starting from when this is called. let deadline = Instant::now() + self.acquire_timeout; - let inner = PoolInner::new_arc(self, options); + let inner = PoolInner::new_arc(self, connector); if inner.options.min_connections > 0 { - // If the idle reaper is spawned then this will race with the call from that task - // and may not report any connection errors. - inner.try_min_connections(deadline).await?; + inner.try_min_connections(Some(deadline)).await?; } - // If `min_connections` is nonzero then we'll likely just pull a connection - // from the idle queue here, but it should at least get tested first. - let conn = inner.acquire().await?; - inner.release(conn); - Ok(Pool(inner)) } @@ -578,7 +602,11 @@ impl PoolOptions { /// optimistically establish that many connections for the pool. pub fn connect_lazy_with(self, options: ::Options) -> Pool { // `min_connections` is guaranteed by the idle reaper now. - Pool(PoolInner::new_arc(self, options)) + self.connect_lazy_with_connector(DefaultConnector(options)) + } + + pub fn connect_lazy_with_connector(self, connector: impl PoolConnector) -> Pool { + Pool(PoolInner::new_arc(self, connector)) } } @@ -594,3 +622,28 @@ impl Debug for PoolOptions { .finish() } } + +fn default_shards() -> NonZero { + #[cfg(feature = "_rt-tokio")] + if let Ok(rt) = tokio::runtime::Handle::try_current() { + return rt + .metrics() + .num_workers() + .try_into() + .unwrap_or(NonZero::::MIN); + } + + #[cfg(feature = "_rt-async-std")] + if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT") + .ok() + .and_then(|s| s.parse().ok()) + { + return val; + } + + if let Ok(val) = std::thread::available_parallelism() { + return val; + } + + NonZero::::MIN +} diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs new file mode 100644 index 0000000000..c1964c7c6a --- /dev/null +++ b/sqlx-core/src/pool/shard.rs @@ -0,0 +1,798 @@ +use crate::rt; +use event_listener::{listener, Event, IntoNotification}; +use futures_util::{future, stream, StreamExt}; +use spin::lock_api::Mutex; +use std::future::Future; +use std::num::NonZero; +use std::ops::{Deref, DerefMut}; +use std::pin::pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{atomic, Arc}; +use std::task::{ready, Poll}; +use std::time::Duration; +use std::{array, iter}; + +type ShardId = usize; +type ConnectionIndex = usize; + +/// Delay before a task waiting in a call to `acquire()` enters the global wait queue. +/// +/// We want tasks to acquire from their local shards where possible, so they don't enter +/// the global queue immediately. +const GLOBAL_ACQUIRE_DELAY: Duration = Duration::from_millis(10); + +/// Delay before attempting to acquire from a non-local shard, +/// as well as the backoff when iterating through shards. +const NON_LOCAL_ACQUIRE_DELAY: Duration = Duration::from_micros(100); + +pub struct Sharded { + shards: Box<[ArcShard]>, + global: Arc>, +} + +type ArcShard = Arc>>]>>; + +struct Global) + Send + Sync + 'static> { + unlock_event: Event>, + disconnect_event: Event>, + min_connections: usize, + num_shards: usize, + do_reconnect: F, +} + +type ArcMutexGuard = lock_api::ArcMutexGuard, Option>; + +struct SlotGuard { + // `Option` allows us to take the guard in the drop handler. + locked: Option>, + shard: ArcShard, + index: ConnectionIndex, + dropped: bool, +} + +pub struct ConnectedSlot(SlotGuard); + +pub struct DisconnectedSlot(SlotGuard); + +// Align to cache lines. +// Simplified from https://docs.rs/crossbeam-utils/0.8.21/src/crossbeam_utils/cache_padded.rs.html#80 +// +// Instead of listing every possible architecture, we just assume 64-bit architectures have 128-byte +// cache lines, which is at least true for newer versions of x86-64 and AArch64. +// A larger alignment isn't harmful as long as we make use of the space. +#[cfg_attr(target_pointer_width = "64", repr(align(128)))] +#[cfg_attr(not(target_pointer_width = "64"), repr(align(64)))] +struct Shard { + shard_id: ShardId, + /// Bitset for all connection indices that are currently in-use. + locked_set: AtomicUsize, + /// Bitset for all connection indices that are currently connected. + connected_set: AtomicUsize, + /// Bitset for all connection indices that have been explicitly leaked. + leaked_set: AtomicUsize, + unlock_event: Event>, + disconnect_event: Event>, + leak_event: Event, + global: Arc>, + connections: Ts, +} + +#[derive(Debug)] +struct Params { + shards: usize, + shard_size: usize, + remainder: usize, +} + +const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { + 64 +} else { + usize::BITS as usize +}; + +impl Sharded { + pub fn new( + connections: usize, + shards: NonZero, + min_connections: usize, + do_reconnect: impl Fn(DisconnectedSlot) + Send + Sync + 'static, + ) -> Sharded { + let params = Params::calc(connections, shards.get()); + + let global = Arc::new(Global { + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + num_shards: params.shards, + min_connections, + do_reconnect, + }); + + let shards = params + .shard_sizes() + .enumerate() + .map(|(shard_id, size)| Shard::new(shard_id, size, global.clone())) + .collect::>(); + + Sharded { shards, global } + } + + #[inline] + pub fn num_shards(&self) -> usize { + self.shards.len() + } + + #[allow(clippy::cast_possible_truncation)] // This is only informational + pub fn count_connected(&self) -> usize { + atomic::fence(Ordering::Acquire); + + self.shards + .iter() + .map(|shard| shard.connected_set.load(Ordering::Relaxed).count_ones() as usize) + .sum() + } + + #[allow(clippy::cast_possible_truncation)] // This is only informational + pub fn count_unlocked(&self, connected: bool) -> usize { + self.shards + .iter() + .map(|shard| shard.unlocked_mask(connected).count_ones()) + .sum() + } + + pub async fn acquire_connected(&self) -> ConnectedSlot { + let guard = self.acquire(true).await; + + assert!( + guard.get().is_some(), + "BUG: expected slot {}/{} to be connected but it wasn't", + guard.shard.shard_id, + guard.index + ); + + ConnectedSlot(guard) + } + + pub fn try_acquire_connected(&self) -> Option> { + todo!() + } + + pub async fn acquire_disconnected(&self) -> DisconnectedSlot { + let guard = self.acquire(false).await; + + assert!( + guard.get().is_none(), + "BUG: expected slot {}/{} NOT to be connected but it WAS", + guard.shard.shard_id, + guard.index + ); + + DisconnectedSlot(guard) + } + + async fn acquire(&self, connected: bool) -> SlotGuard { + if self.shards.len() == 1 { + return self.shards[0].acquire(connected).await; + } + + let thread_id = current_thread_id(); + + let mut acquire_local = pin!(self.shards[thread_id % self.shards.len()].acquire(connected)); + + let mut acquire_nonlocal = pin!(async { + let mut next_shard = thread_id; + + loop { + rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; + + // Choose shards pseudorandomly by multiplying with a (relatively) large prime. + next_shard = (next_shard.wrapping_mul(547)) % self.shards.len(); + + if let Some(locked) = self.shards[next_shard].try_acquire(connected) { + return locked; + } + } + }); + + let mut acquire_global = pin!(async { + rt::sleep(GLOBAL_ACQUIRE_DELAY).await; + + let event_to_listen = if connected { + &self.global.unlock_event + } else { + &self.global.disconnect_event + }; + + event_listener::listener!(event_to_listen => listener); + listener.await + }); + + // Hand-rolled `select!{}` because there isn't a great cross-runtime solution. + // + // `futures_util::select!{}` is a proc-macro. + std::future::poll_fn(|cx| { + if let Poll::Ready(locked) = acquire_local.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + if let Poll::Ready(locked) = acquire_nonlocal.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + Poll::Pending + }) + .await + } + + pub fn iter_min_connections(&self) -> impl Iterator> + '_ { + self.shards + .iter() + .flat_map(|shard| shard.iter_min_connections()) + } + + pub fn iter_idle(&self) -> impl Iterator> + '_ { + self.shards.iter().flat_map(|shard| shard.iter_idle()) + } + + pub async fn drain(&self, close: F) + where + F: Fn(ConnectedSlot) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + T: Send + 'static, + { + let close = Arc::new(close); + + stream::iter(self.shards.iter()) + .for_each_concurrent(None, |shard| { + let shard = shard.clone(); + let close = close.clone(); + + rt::spawn(async move { + shard.drain(&*close).await; + }) + }) + .await; + } +} + +impl Shard>>]> { + fn new(shard_id: ShardId, len: usize, global: Arc>) -> Arc { + // There's no way to create DSTs like this, in `std::sync::Arc`, on stable. + // + // Instead, we coerce from an array. + macro_rules! make_array { + ($($n:literal),+) => { + match len { + $($n => Arc::new(Shard { + shard_id, + locked_set: AtomicUsize::new(0), + connected_set: AtomicUsize::new(0), + leaked_set: AtomicUsize::new(0), + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + leak_event: Event::with_tag(), + global, + connections: array::from_fn::<_, $n, _>(|_| Arc::new(Mutex::new(None))) + }),)* + _ => unreachable!("BUG: length not supported: {len}"), + } + } + } + + make_array!( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ) + } + + #[inline] + fn unlocked_mask(&self, connected: bool) -> Mask { + let locked_set = self.locked_set.load(Ordering::Acquire); + let connected_set = self.connected_set.load(Ordering::Relaxed); + + let connected_mask = if connected { + connected_set + } else { + !connected_set + }; + + Mask(!locked_set & connected_mask) + } + + async fn acquire(self: &Arc, connected: bool) -> SlotGuard { + // Attempt an unfair acquire first, before we modify the waitlist. + if let Some(locked) = self.try_acquire(connected) { + return locked; + } + + let event_to_listen = if connected { + &self.unlock_event + } else { + &self.disconnect_event + }; + + event_listener::listener!(event_to_listen => listener); + + let mut listener = pin!(listener); + + loop { + // We need to check again after creating the event listener, + // because in the meantime, a concurrent task may have seen that there were no listeners + // and just unlocked its connection. + match rt::timeout(NON_LOCAL_ACQUIRE_DELAY, listener.as_mut()).await { + Ok(slot) => return slot, + Err(_) => { + if let Some(slot) = self.try_acquire(connected) { + return slot; + } + } + } + } + } + + fn try_acquire(self: &Arc, connected: bool) -> Option> { + // If `locked_set` is constantly changing, don't loop forever. + for index in self.unlocked_mask(connected) { + if let Some(slot) = self.try_lock(index) { + return Some(slot); + } + + std::hint::spin_loop(); + } + + None + } + + fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { + let locked = self.connections.get(index)?.try_lock_arc()?; + + // The locking of the connection itself must use an `Acquire` fence, + // so additional synchronization is unnecessary. + atomic_set(&self.locked_set, index, true, Ordering::Relaxed); + + Some(SlotGuard { + locked: Some(locked), + shard: self.clone(), + index, + dropped: false, + }) + } + + fn iter_min_connections(self: &Arc) -> impl Iterator> + '_ { + self.unlocked_mask(false) + .filter_map(|index| { + let slot = self.try_lock(index)?; + + // Guard against some weird bug causing this to already be connected + slot.get().is_none().then_some(DisconnectedSlot(slot)) + }) + .take(self.global.shard_min_connections(self.shard_id)) + } + + fn iter_idle(self: &Arc) -> impl Iterator> + '_ { + self.unlocked_mask(true).filter_map(|index| { + let slot = self.try_lock(index)?; + + // Guard against some weird bug causing this to already be connected + slot.get().is_some().then_some(ConnectedSlot(slot)) + }) + } + + fn all_leaked(&self) -> bool { + let all_leaked_mask = (1usize << self.connections.len()) - 1; + let leaked_set = self.leaked_set.load(Ordering::Acquire); + + leaked_set == all_leaked_mask + } + + async fn drain(self: &Arc, close: F) + where + F: Fn(ConnectedSlot) -> Fut, + Fut: Future>, + { + let mut drain_connected = pin!(async { + loop { + let connected = self.acquire(true).await; + DisconnectedSlot::leak(close(ConnectedSlot(connected)).await); + } + }); + + let mut drain_disconnected = pin!(async { + loop { + let disconnected = DisconnectedSlot(self.acquire(false).await); + DisconnectedSlot::leak(disconnected); + } + }); + + let mut drain_leaked = pin!(async { + loop { + listener!(self.leak_event => leaked); + leaked.await; + } + }); + + std::future::poll_fn(|cx| { + // The connection set is drained once all slots are leaked. + if self.all_leaked() { + return Poll::Ready(()); + } + + // These futures shouldn't return `Ready` + let _ = drain_connected.as_mut().poll(cx); + let _ = drain_disconnected.as_mut().poll(cx); + let _ = drain_leaked.as_mut().poll(cx); + + // Check again after driving the `drain` futures forward. + if self.all_leaked() { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + } +} + +impl Deref for ConnectedSlot { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0 + .get() + .as_ref() + .expect("BUG: expected slot to be populated, but it wasn't") + } +} + +impl DerefMut for ConnectedSlot { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + .get_mut() + .as_mut() + .expect("BUG: expected slot to be populated, but it wasn't") + } +} + +impl ConnectedSlot { + pub fn take(mut this: Self) -> (T, DisconnectedSlot) { + let conn = this + .0 + .get_mut() + .take() + .expect("BUG: expected slot to be populated, but it wasn't"); + + atomic_set( + &this.0.shard.connected_set, + this.0.index, + false, + Ordering::AcqRel, + ); + + (conn, DisconnectedSlot(this.0)) + } +} + +impl DisconnectedSlot { + pub fn put(mut self, connection: T) -> ConnectedSlot { + *self.0.get_mut() = Some(connection); + + atomic_set( + &self.0.shard.connected_set, + self.0.index, + true, + Ordering::AcqRel, + ); + + ConnectedSlot(self.0) + } + + pub fn leak(mut self: Self) { + self.0.locked = None; + + atomic_set( + &self.0.shard.connected_set, + self.0.index, + false, + Ordering::Relaxed, + ); + atomic_set( + &self.0.shard.leaked_set, + self.0.index, + true, + Ordering::AcqRel, + ); + + self.0.shard.leak_event.notify(usize::MAX.tag(self.0.index)); + } + + pub fn should_reconnect(&self) -> bool { + self.0.should_reconnect() + } +} + +impl SlotGuard { + fn get(&self) -> &Option { + self.locked + .as_deref() + .expect("BUG: `SlotGuard.locked` taken") + } + + fn get_mut(&mut self) -> &mut Option { + self.locked + .as_deref_mut() + .expect("BUG: `SlotGuard.locked` taken") + } + + fn should_reconnect(&self) -> bool { + let min_connections = self.shard.global.shard_min_connections(self.shard.shard_id); + + let num_connected = self + .shard + .connected_set + .load(Ordering::Acquire) + .count_ones() as usize; + + num_connected < min_connections + } +} + +impl Drop for SlotGuard { + fn drop(&mut self) { + let Some(locked) = self.locked.take() else { + return; + }; + + let connected = locked.is_some(); + + // Updating the connected flag shouldn't require a fence. + atomic_set( + &self.shard.connected_set, + self.index, + connected, + Ordering::Relaxed, + ); + + // We don't actually unlock the connection unless there's no receivers to accept it. + // If another receiver is waiting for a connection, we can directly pass them the lock. + // + // This prevents drive-by tasks from acquiring connections before waiting tasks + // at high contention, while requiring little synchronization otherwise. + // + // We *could* just pass them the shard ID and/or index, but then we have to handle + // the situation when a receiver was passed a connection that was still marked as locked, + // but was cancelled before it could complete the acquisition. Otherwise, the connection + // would be marked as locked forever, effectively being leaked. + let mut locked = Some(locked); + + // This is a code smell, but it's necessary because `event-listener` has no way to specify + // that a message should *only* be sent once. This means tags either need to be `Clone` + // or provided by a `FnMut()` closure. + // + // Note that there's no guarantee that this closure won't be called more than once by the + // implementation, but the code as of writing should not. + let mut self_as_tag = || { + let locked = locked + .take() + .expect("BUG: notification sent more than once"); + + SlotGuard { + locked: Some(locked), + shard: self.shard.clone(), + index: self.index, + // To avoid infinite recursion or deadlock, don't send another notification + // if this guard was already dropped once: just unlock it. + dropped: true, + } + }; + + if !self.dropped && connected { + // Check for global waiters first. + if self + .shard + .global + .unlock_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 { + return; + } + } else if !self.dropped { + if self + .shard + .global + .disconnect_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self + .shard + .disconnect_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self.should_reconnect() { + (self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag())); + return; + } + } + + // Be sure to drop the lock guard if it's still held, + // *before* we semantically release the lock in the bitset. + // + // Otherwise, another task could check and see the connection is free, + // but then fail to lock the mutex for it. + drop(locked); + + atomic_set(&self.shard.locked_set, self.index, false, Ordering::AcqRel); + } +} + +impl Global { + fn shard_min_connections(&self, shard_id: ShardId) -> usize { + let min_connections_per_shard = self.min_connections / self.num_shards; + + if (self.min_connections % self.num_shards) < shard_id { + min_connections_per_shard + 1 + } else { + min_connections_per_shard + } + } +} + +impl Params { + fn calc(connections: usize, mut shards: usize) -> Params { + assert_ne!(shards, 0); + + let mut shard_size = connections / shards; + let mut remainder = connections % shards; + + if shard_size == 0 { + tracing::debug!(connections, shards, "more shards than connections; clamping shard size to 1, shard count to connections"); + shards = connections; + shard_size = 1; + remainder = 0; + } else if shard_size >= MAX_SHARD_SIZE { + let new_shards = connections.div_ceil(MAX_SHARD_SIZE); + + tracing::debug!( + connections, + shards, + "shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}" + ); + + shards = new_shards; + shard_size = connections / shards; + remainder = connections % shards; + } + + Params { + shards, + shard_size, + remainder, + } + } + + fn shard_sizes(&self) -> impl Iterator { + iter::repeat_n(self.shard_size + 1, self.remainder).chain(iter::repeat_n( + self.shard_size, + self.shards - self.remainder, + )) + } +} + +fn atomic_set(atomic: &AtomicUsize, index: usize, value: bool, ordering: Ordering) { + if value { + let bit = 1 << index; + atomic.fetch_or(bit, ordering); + } else { + let bit = !(1 << index); + atomic.fetch_and(bit, ordering); + } +} + +fn current_thread_id() -> usize { + // FIXME: this can be replaced when this is stabilized: + // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + } + + CURRENT_THREAD_ID.with(|i| *i) +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct Mask(usize); + +impl Mask { + pub fn count_ones(&self) -> usize { + self.0.count_ones() as usize + } +} + +impl Iterator for Mask { + type Item = usize; + + fn next(&mut self) -> Option { + if self.0 == 0 { + return None; + } + + let index = self.0.trailing_zeros() as usize; + self.0 &= !(1 << index); + + Some(index) + } + + fn size_hint(&self) -> (usize, Option) { + let count = self.0.count_ones() as usize; + (count, Some(count)) + } +} + +#[cfg(test)] +mod tests { + use super::{Mask, Params, MAX_SHARD_SIZE}; + + #[test] + fn test_params() { + for connections in 0..100 { + for shards in 1..32 { + let params = Params::calc(connections, shards); + + let mut sum = 0; + + for (i, size) in params.shard_sizes().enumerate() { + assert!(size <= MAX_SHARD_SIZE, "Params::calc({connections}, {shards}) exceeded MAX_SHARD_SIZE at shard #{i}, size {size}"); + + sum += size; + + assert!(sum <= connections, "Params::calc({connections}, {shards}) exceeded connections at shard #{i}, size {size}"); + } + + assert_eq!( + sum, connections, + "Params::calc({connections}, {shards}) does not add up ({params:?}" + ); + } + } + } + + #[test] + fn test_mask() { + let inputs: &[(usize, &[usize])] = &[ + (0b0, &[]), + (0b1, &[0]), + (0b11, &[0, 1]), + (0b111, &[0, 1, 2]), + (0b1000, &[3]), + (0b1001, &[0, 3]), + (0b1001001, &[0, 3, 6]), + ]; + + for (mask, expected_indices) in inputs { + let actual_indices = Mask(*mask).collect::>(); + + assert_eq!( + actual_indices[..], + expected_indices[..], + "invalid mask: {mask:b}" + ); + } + } +} diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 273a1bfcd9..862cf6cedd 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -1,10 +1,13 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::task::{Context, Poll}; -use std::time::Duration; +use std::task::{ready, Context, Poll}; +use std::time::{Duration, Instant}; use cfg_if::cfg_if; +use futures_core::Stream; +use futures_util::StreamExt; +use pin_project_lite::pin_project; #[cfg(feature = "_rt-async-io")] pub mod rt_async_io; @@ -51,6 +54,23 @@ pub async fn timeout(duration: Duration, f: F) -> Result(deadline: Instant, f: F) -> Result { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::time::timeout_at(deadline.into(), f) + .await + .map_err(|_| TimeoutError); + } + + cfg_if! { + if #[cfg(feature = "_rt-async-io")] { + rt_async_io::timeout_at(deadline, f).await + } else { + missing_rt((deadline, f)) + } + } +} + pub async fn sleep(duration: Duration) { #[cfg(feature = "_rt-tokio")] if rt_tokio::available() { @@ -66,6 +86,150 @@ pub async fn sleep(duration: Duration) { } } +pub async fn sleep_until(instant: Instant) { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::time::sleep_until(instant.into()).await; + } + + cfg_if! { + if #[cfg(feature = "_rt-async-io")] { + rt_async_io::sleep_until(instant).await + } else { + missing_rt(instant) + } + } +} + +// https://github.com/taiki-e/pin-project-lite/issues/3 +#[cfg(all(feature = "_rt-tokio", feature = "_rt-async-io"))] +pin_project! { + #[project = IntervalProjected] + pub enum Interval { + Tokio { + // Bespoke impl because `tokio::time::Interval` allocates when we could just pin instead + #[pin] + sleep: tokio::time::Sleep, + period: Duration, + }, + AsyncIo { + #[pin] + timer: async_io::Timer, + }, + } +} + +#[cfg(all(feature = "_rt-tokio", not(feature = "_rt-async-io")))] +pin_project! { + #[project = IntervalProjected] + pub enum Interval { + Tokio { + #[pin] + sleep: tokio::time::Sleep, + period: Duration, + }, + } +} + +#[cfg(all(not(feature = "_rt-tokio"), feature = "_rt-async-io"))] +pin_project! { + #[project = IntervalProjected] + pub enum Interval { + AsyncIo { + #[pin] + timer: async_io::Timer, + }, + } +} + +#[cfg(not(any(feature = "_rt-tokio", feature = "_rt-async-io")))] +pub enum Interval {} + +pub fn interval_after(period: Duration) -> Interval { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return Interval::Tokio { + sleep: tokio::time::sleep(period), + period, + }; + } + + cfg_if! { + if #[cfg(feature = "_rt-async-io")] { + Interval::AsyncIo { timer: async_io::Timer::interval(period) } + } else { + missing_rt(period) + } + } +} + +impl Interval { + #[inline(always)] + pub fn tick(mut self: Pin<&mut Self>) -> impl Future + use<'_> { + std::future::poll_fn(move |cx| self.as_mut().poll_tick(cx)) + } + + #[inline(always)] + pub fn as_timeout(self: Pin<&mut Self>, fut: F) -> AsTimeout<'_, F> { + AsTimeout { + interval: self, + future: fut, + } + } + + #[inline(always)] + pub fn poll_tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + cfg_if! { + if #[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] { + match self.project() { + #[cfg(feature = "_rt-tokio")] + IntervalProjected::Tokio { mut sleep, period } => { + ready!(sleep.as_mut().poll(cx)); + let now = Instant::now(); + sleep.reset((now + *period).into()); + Poll::Ready(now) + } + #[cfg(feature = "_rt-async-io")] + IntervalProjected::AsyncIo { mut timer } => { + Poll::Ready(ready!(timer + .as_mut() + .poll_next(cx)) + .expect("BUG: `async_io::Timer::next()` should always yield")) + } + } + } else { + unreachable!() + } + } + } +} + +pin_project! { + pub struct AsTimeout<'i, F> { + interval: Pin<&'i mut Interval>, + #[pin] + future: F, + } +} + +impl Future for AsTimeout<'_, F> +where + F: Future, +{ + type Output = Option; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(out) = this.future.poll(cx) { + return Poll::Ready(Some(out)); + } + + this.interval.as_mut().poll_tick(cx).map(|_| None) + } +} + #[track_caller] pub fn spawn(fut: F) -> JoinHandle where @@ -90,6 +254,29 @@ where } } +pub fn try_spawn(fut: F) -> Option> +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + #[cfg(feature = "_rt-tokio")] + if let Ok(handle) = tokio::runtime::Handle::try_current() { + return Some(JoinHandle::Tokio(handle.spawn(fut))); + } + + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + Some(JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))) + } else if #[cfg(feature = "_rt-smol")] { + Some(JoinHandle::AsyncTask(Some(smol::spawn(fut)))) + } else if #[cfg(feature = "_rt-async-std")] { + Some(JoinHandle::AsyncStd(async_std::task::spawn(fut))) + } else { + None + } + } +} + #[track_caller] pub fn spawn_blocking(f: F) -> JoinHandle where @@ -163,7 +350,7 @@ pub fn test_block_on(f: F) -> F::Output { #[track_caller] pub const fn missing_rt(_unused: T) -> ! { if cfg!(feature = "_rt-tokio") { - panic!("this functionality requires a Tokio context") + panic!("this functionality requires an active Tokio runtime") } panic!("one of the `runtime` features of SQLx must be enabled") diff --git a/sqlx-core/src/rt/rt_async_io/mod.rs b/sqlx-core/src/rt/rt_async_io/mod.rs index 5e4d7074dc..70d01fbecb 100644 --- a/sqlx-core/src/rt/rt_async_io/mod.rs +++ b/sqlx-core/src/rt/rt_async_io/mod.rs @@ -1,4 +1,4 @@ mod socket; -mod timeout; -pub use timeout::*; +mod time; +pub use time::*; diff --git a/sqlx-core/src/rt/rt_async_io/time.rs b/sqlx-core/src/rt/rt_async_io/time.rs new file mode 100644 index 0000000000..dbe1d8f725 --- /dev/null +++ b/sqlx-core/src/rt/rt_async_io/time.rs @@ -0,0 +1,29 @@ +use crate::ext::future::race; +use crate::rt::TimeoutError; +use std::{ + future::Future, + time::{Duration, Instant}, +}; + +pub async fn sleep(duration: Duration) { + async_io::Timer::after(duration).await; +} + +pub async fn sleep_until(deadline: Instant) { + async_io::Timer::at(deadline).await; +} + +pub async fn timeout(duration: Duration, future: F) -> Result { + race(future, sleep(duration)) + .await + .map_err(|_| TimeoutError) +} + +pub async fn timeout_at( + deadline: Instant, + future: F, +) -> Result { + race(future, sleep_until(deadline)) + .await + .map_err(|_| TimeoutError) +} diff --git a/sqlx-core/src/rt/rt_async_io/timeout.rs b/sqlx-core/src/rt/rt_async_io/timeout.rs deleted file mode 100644 index b4a779074b..0000000000 --- a/sqlx-core/src/rt/rt_async_io/timeout.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{future::Future, pin::pin, time::Duration}; - -use futures_util::future::{select, Either}; - -use crate::rt::TimeoutError; - -pub async fn sleep(duration: Duration) { - timeout_future(duration).await; -} - -pub async fn timeout(duration: Duration, future: F) -> Result { - match select(pin!(future), timeout_future(duration)).await { - Either::Left((result, _)) => Ok(result), - Either::Right(_) => Err(TimeoutError), - } -} - -fn timeout_future(duration: Duration) -> impl Future { - async_io::Timer::after(duration) -} diff --git a/sqlx-core/src/rt/rt_tokio/mod.rs b/sqlx-core/src/rt/rt_tokio/mod.rs index ce699456db..364ce3bfee 100644 --- a/sqlx-core/src/rt/rt_tokio/mod.rs +++ b/sqlx-core/src/rt/rt_tokio/mod.rs @@ -1,5 +1,6 @@ mod socket; +#[inline(always)] pub fn available() -> bool { tokio::runtime::Handle::try_current().is_ok() } diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index ed082f752c..bce8d60c0d 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -1,206 +1,102 @@ -use cfg_if::cfg_if; - // For types with identical signatures that don't require runtime support, // we can just arbitrarily pick one to use based on what's enabled. // // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. -pub struct AsyncSemaphore { - // We use the semaphore from futures-intrusive as the one from async-lock - // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: - // * https://github.com/smol-rs/async-lock/issues/22 - // * https://github.com/smol-rs/async-lock/issues/23 - // - // We're on the look-out for a replacement, however, as futures-intrusive is not maintained - // and there are some soundness concerns (although it turns out any intrusive future is unsound - // in MIRI due to the necessitated mutable aliasing): - // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] - inner: futures_intrusive::sync::Semaphore, +use std::sync::Arc; +#[cfg(feature = "_rt-tokio")] +pub use tokio::sync::{ + Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, OwnedMutexGuard as AsyncMutexGuardArc, + RwLock as AsyncRwLock, +}; + +#[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] +pub use async_lock::{ + Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, MutexGuardArc as AsyncMutexGuardArc, + RwLock as AsyncRwLock, +}; +pub async fn lock_arc(mutex: &Arc>) -> AsyncMutexGuardArc { #[cfg(feature = "_rt-tokio")] - inner: tokio::sync::Semaphore, + return mutex.clone().lock_owned().await; + + #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] + return mutex.lock_arc().await; + + #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] + return crate::rt::missing_rt(mutex); } -impl AsyncSemaphore { - #[track_caller] - pub fn new(fair: bool, permits: usize) -> Self { - if cfg!(not(any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol", - feature = "_rt-tokio" - ))) { - crate::rt::missing_rt((fair, permits)); - } +pub fn try_lock_arc(mutex: &Arc>) -> Option> { + #[cfg(feature = "_rt-tokio")] + return mutex.clone().try_lock_owned().ok(); - AsyncSemaphore { - #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] - inner: futures_intrusive::sync::Semaphore::new(fair, permits), - #[cfg(feature = "_rt-tokio")] - inner: { - debug_assert!(fair, "Tokio only has fair permits"); - tokio::sync::Semaphore::new(permits) - }, - } + #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] + return mutex.try_lock_arc(); + + #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] + return crate::rt::missing_rt(mutex); +} + +#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] +pub use noop::*; + +#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] +mod noop { + use crate::rt::missing_rt; + use std::marker::PhantomData; + use std::ops::{Deref, DerefMut}; + use std::sync::Arc; + + pub struct AsyncMutex { + // `Sync` if `T: Send` + _marker: PhantomData>, } - pub fn permits(&self) -> usize { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - self.inner.permits() - } else if #[cfg(feature = "_rt-tokio")] { - self.inner.available_permits() - } else { - crate::rt::missing_rt(()) - } - } + pub struct AsyncMutexGuard<'a, T> { + inner: &'a AsyncMutex, + } + + pub struct AsyncMutexGuardArc { + inner: Arc>, } - pub async fn acquire(&self, permits: u32) -> AsyncSemaphoreReleaser<'_> { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - AsyncSemaphoreReleaser { - inner: self.inner.acquire(permits as usize).await, - } - } else if #[cfg(feature = "_rt-tokio")] { - AsyncSemaphoreReleaser { - inner: self - .inner - // Weird quirk: `tokio::sync::Semaphore` mostly uses `usize` for permit counts, - // but `u32` for this and `try_acquire_many()`. - .acquire_many(permits) - .await - .expect("BUG: we do not expose the `.close()` method"), - } - } else { - crate::rt::missing_rt(permits) - } + impl AsyncMutex { + pub fn new(val: T) -> Self { + missing_rt(val) + } + + pub fn lock(&self) -> AsyncMutexGuard { + missing_rt(self) } } - pub fn try_acquire(&self, permits: u32) -> Option> { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire(permits as usize)?, - }) - } else if #[cfg(feature = "_rt-tokio")] { - Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire_many(permits).ok()?, - }) - } else { - crate::rt::missing_rt(permits) - } + impl Deref for AsyncMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + missing_rt(self) } } - pub fn release(&self, permits: usize) { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - self.inner.release(permits); - } else if #[cfg(feature = "_rt-tokio")] { - self.inner.add_permits(permits); - } else { - crate::rt::missing_rt(permits); - } + impl DerefMut for AsyncMutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + missing_rt(self) } } -} -pub struct AsyncSemaphoreReleaser<'a> { - // We use the semaphore from futures-intrusive as the one from async-std - // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: - // * https://github.com/smol-rs/async-lock/issues/22 - // * https://github.com/smol-rs/async-lock/issues/23 - // - // We're on the look-out for a replacement, however, as futures-intrusive is not maintained - // and there are some soundness concerns (although it turns out any intrusive future is unsound - // in MIRI due to the necessitated mutable aliasing): - // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] - inner: futures_intrusive::sync::SemaphoreReleaser<'a>, + impl Deref for AsyncMutexGuardArc { + type Target = T; - #[cfg(feature = "_rt-tokio")] - inner: tokio::sync::SemaphorePermit<'a>, - - #[cfg(not(any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol", - feature = "_rt-tokio" - )))] - _phantom: std::marker::PhantomData<&'a ()>, -} + fn deref(&self) -> &Self::Target { + missing_rt(self) + } + } -impl AsyncSemaphoreReleaser<'_> { - pub fn disarm(self) { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - let mut this = self; - this.inner.disarm(); - } else if #[cfg(feature = "_rt-tokio")] { - self.inner.forget(); - } else { - crate::rt::missing_rt(()); - } + impl DerefMut for AsyncMutexGuardArc { + fn deref_mut(&mut self) -> &mut Self::Target { + missing_rt(self) } } } diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index f509f9da45..f532dcc5a9 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -1,5 +1,4 @@ use std::future::Future; -use std::ops::Deref; use std::str::FromStr; use std::sync::OnceLock; use std::time::Duration; @@ -108,27 +107,10 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) - .connect_lazy_with(master_opts); - - let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) { - Ok(inserted) => inserted, - Err((existing, pool)) => { - // Sanity checks. - assert_eq!( - existing.connect_options().host, - pool.connect_options().host, - "DATABASE_URL changed at runtime, host differs" - ); - - assert_eq!( - existing.connect_options().database, - pool.connect_options().database, - "DATABASE_URL changed at runtime, database differs" - ); - - existing - } - }; + .connect_lazy_with(master_opts.clone()); + + let master_pool = once_lock_try_insert_polyfill(&MASTER_POOL, pool) + .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; @@ -144,7 +126,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { -- BLOB/TEXT columns can only be used as index keys with a prefix length: -- https://dev.mysql.com/doc/refman/8.4/en/column-indexes.html#column-indexes-prefix primary key(db_name(63)) - ); + ); "#, ) .await?; @@ -172,11 +154,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), - connect_opts: master_pool - .connect_options() - .deref() - .clone() - .database(&db_name), + connect_opts: master_opts.database(&db_name), db_name, }) } diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs index 7b5a03f2b3..ac99bd13cf 100644 --- a/sqlx-postgres/src/error.rs +++ b/sqlx-postgres/src/error.rs @@ -186,7 +186,7 @@ impl DatabaseError for PgDatabaseError { self } - fn is_transient_in_connect_phase(&self) -> bool { + fn is_retryable_connect_error(&self) -> bool { // https://www.postgresql.org/docs/current/errcodes-appendix.html [ // too_many_connections diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index efbc43989b..00ae159759 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -127,6 +127,11 @@ impl PgConnectOptions { self } + /// Identical to [Self::host()], but through a mutable reference. + pub fn set_host(&mut self, host: &str) { + host.clone_into(&mut self.host); + } + /// Sets the port to connect to at the server host. /// /// The default port for PostgreSQL is `5432`. @@ -143,6 +148,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::port()`], but through a mutable reference. + pub fn set_port(&mut self, port: u16) -> &mut Self { + self.port = port; + self + } + /// Sets a custom path to a directory containing a unix domain socket, /// switching the connection method from TCP to the corresponding socket. /// @@ -169,6 +180,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::username()`], but through a mutable reference. + pub fn set_username(&mut self, username: &str) -> &mut Self { + username.clone_into(&mut self.username); + self + } + /// Sets the password to use if the server demands password authentication. /// /// # Example @@ -184,6 +201,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::password()`]. but through a mutable reference. + pub fn set_password(&mut self, password: &str) -> &mut Self { + self.password = Some(password.to_owned()); + self + } + /// Sets the database name. Defaults to be the same as the user name. /// /// # Example diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index 3e1cf0ddf7..70b00b6351 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,5 +1,4 @@ use std::future::Future; -use std::ops::Deref; use std::str::FromStr; use std::sync::OnceLock; use std::time::Duration; @@ -101,27 +100,10 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) - .connect_lazy_with(master_opts); - - let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) { - Ok(inserted) => inserted, - Err((existing, pool)) => { - // Sanity checks. - assert_eq!( - existing.connect_options().host, - pool.connect_options().host, - "DATABASE_URL changed at runtime, host differs" - ); - - assert_eq!( - existing.connect_options().database, - pool.connect_options().database, - "DATABASE_URL changed at runtime, database differs" - ); - - existing - } - }; + .connect_lazy_with(master_opts.clone()); + + let master_pool = once_lock_try_insert_polyfill(&MASTER_POOL, pool) + .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; @@ -177,11 +159,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), - connect_opts: master_pool - .connect_options() - .deref() - .clone() - .database(&db_name), + connect_opts: master_opts.database(&db_name), db_name, }) } diff --git a/sqlx-test/Cargo.toml b/sqlx-test/Cargo.toml index 32a341adcb..4fdcb3723e 100644 --- a/sqlx-test/Cargo.toml +++ b/sqlx-test/Cargo.toml @@ -10,6 +10,7 @@ sqlx = { default-features = false, path = ".." } env_logger = "0.11" dotenvy = "0.15.0" anyhow = "1.0.26" +tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } [lints] workspace = true diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 3744724c12..6a8b9d1120 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -1,10 +1,16 @@ use sqlx::pool::PoolOptions; use sqlx::{Connection, Database, Error, Pool}; use std::env; +use tracing_subscriber::EnvFilter; +use tracing_subscriber::fmt::format::FmtSpan; pub fn setup_if_needed() { let _ = dotenvy::dotenv(); - let _ = env_logger::builder().is_test(true).try_init(); + let _ = tracing_subscriber::fmt::Subscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + // .with_test_writer() + .try_init(); } // Make a new connection diff --git a/tests/any/pool.rs b/tests/any/pool.rs index a4849940b8..d5d47d161f 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,41 +1,33 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::connection::{ConnectOptions, Connection}; +use sqlx_core::pool::PoolConnectMetadata; use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ - atomic::{AtomicI32, AtomicUsize, Ordering}, + atomic::{AtomicI32, Ordering}, Arc, Mutex, }; use std::time::Duration; #[sqlx_macros::test] -async fn pool_should_invoke_after_connect() -> anyhow::Result<()> { +async fn pool_basic_functions() -> anyhow::Result<()> { sqlx::any::install_default_drivers(); - let counter = Arc::new(AtomicUsize::new(0)); - let pool = AnyPoolOptions::new() - .after_connect({ - let counter = counter.clone(); - move |_conn, _meta| { - let counter = counter.clone(); - Box::pin(async move { - counter.fetch_add(1, Ordering::SeqCst); - - Ok(()) - }) - } - }) + .max_connections(2) + .acquire_timeout(Duration::from_secs(3)) .connect(&dotenvy::var("DATABASE_URL")?) .await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; + let mut conn = pool.acquire().await?; + + conn.ping().await?; + + drop(conn); - // since connections are released asynchronously, - // `.after_connect()` may be called more than once - assert!(counter.load(Ordering::SeqCst) >= 1); + let b: bool = sqlx::query_scalar("SELECT true").fetch_one(&pool).await?; + + assert!(b); Ok(()) } @@ -74,6 +66,7 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx::any::install_default_drivers(); + tracing_subscriber::fmt::init(); #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] struct ConnStats { @@ -84,38 +77,13 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx_test::setup_if_needed(); - let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?; + let conn_options: Arc = Arc::new(std::env::var("DATABASE_URL")?.parse()?); let current_id = AtomicI32::new(0); let pool = AnyPoolOptions::new() .max_connections(1) .acquire_timeout(Duration::from_secs(5)) - .after_connect(move |conn, meta| { - assert_eq!(meta.age, Duration::ZERO); - assert_eq!(meta.idle_for, Duration::ZERO); - - let id = current_id.fetch_add(1, Ordering::AcqRel); - - Box::pin(async move { - let statement = format!( - // language=SQL - r#" - CREATE TEMPORARY TABLE conn_stats( - id int primary key, - before_acquire_calls int default 0, - after_release_calls int default 0 - ); - INSERT INTO conn_stats(id) VALUES ({}); - "#, - // Until we have generalized bind parameters - id - ); - - conn.execute(AssertSqlSafe(statement)).await?; - Ok(()) - }) - }) .before_acquire(|conn, meta| { // `age` and `idle_for` should both be nonzero assert_ne!(meta.age, Duration::ZERO); @@ -166,7 +134,33 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { }) }) // Don't establish a connection yet. - .connect_lazy_with(conn_options); + .connect_lazy_with_connector(move |_meta: PoolConnectMetadata| { + let connect_opts = Arc::clone(&conn_options); + let id = current_id.fetch_add(1, Ordering::AcqRel); + + async move { + let mut conn = connect_opts.connect().await?; + + let statement = format!( + // language=SQL + r#" + CREATE TEMPORARY TABLE conn_stats( + id int primary key, + before_acquire_calls int default 0, + after_release_calls int default 0 + ); + INSERT INTO conn_stats(id) VALUES ({}); + "#, + // Until we have generalized bind parameters + id + ); + + sqlx::raw_sql(AssertSqlSafe(statement)) + .execute(&mut conn) + .await?; + Ok(conn) + } + }); // Expected pattern of (id, before_acquire_calls, after_release_calls) let pattern = [ @@ -186,6 +180,8 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { ]; for (id, before_acquire_calls, after_release_calls) in pattern { + eprintln!("ID: {id}, before_acquire calls: {before_acquire_calls}, after_release calls: {after_release_calls}"); + let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats") .fetch_one(&pool) .await?; @@ -215,6 +211,7 @@ async fn test_connection_maintenance() -> anyhow::Result<()> { let last_meta = Arc::new(Mutex::new(None)); let last_meta_ = last_meta.clone(); let pool = AnyPoolOptions::new() + .acquire_timeout(Duration::from_secs(1)) .max_lifetime(Duration::from_millis(400)) .min_connections(3) .before_acquire(move |_conn, _meta| { diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 06adf0ca7f..d6bbebf2c5 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -255,6 +255,10 @@ async fn it_works_with_cache_disabled() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_executes_with_pool() -> anyhow::Result<()> { + setup_if_needed(); + + tracing::info!("starting test"); + let pool = sqlx_test::pool::().await?; let rows = pool.fetch_all("SELECT 1; SElECT 2").await?; @@ -1146,7 +1150,7 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { assert!(listener.next_buffered().is_none()); // Activate connection. - sqlx::query!("SELECT 1 AS one") + sqlx::query("SELECT 1 AS one") .fetch_all(&mut listener) .await?; @@ -2086,6 +2090,7 @@ async fn test_issue_3052() { } #[sqlx_macros::test] +#[cfg(feature = "chrono")] async fn test_bind_iter() -> anyhow::Result<()> { use sqlx::postgres::PgBindIterExt; use sqlx::types::chrono::{DateTime, Utc};