From 54564bead6612b217b855be56d9d41455e90f667 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 29 Oct 2025 03:43:48 -0700 Subject: [PATCH 01/24] feat: create `Pool::acquire()` benchmark --- Cargo.lock | 54 +++++-------------- Cargo.toml | 13 ++++- benches/any/pool.rs | 127 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 42 deletions(-) create mode 100644 benches/any/pool.rs diff --git a/Cargo.lock b/Cargo.lock index 7a2979fc80..ebd8daba84 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1021,26 +1021,22 @@ checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "criterion" -version = "0.5.1" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" dependencies = [ "anes", "cast", "ciborium", "clap", "criterion-plot", - "futures", - "is-terminal", - "itertools 0.10.5", + "itertools 0.13.0", "num-traits", - "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "tokio", @@ -1049,12 +1045,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", ] [[package]] @@ -1446,20 +1442,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" -[[package]] -name = "futures" -version = "0.3.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" -dependencies = [ - "futures-channel", - "futures-core", - "futures-io", - "futures-sink", - "futures-task", - "futures-util", -] - [[package]] name = "futures-channel" version = "0.3.31" @@ -1990,17 +1972,6 @@ version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" -[[package]] -name = "is-terminal" -version = "0.4.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.59.0", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -3517,6 +3488,7 @@ dependencies = [ "tempfile", "time", "tokio", + "tracing", "trybuild", "url", ] @@ -4411,9 +4383,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ "log", "pin-project-lite", @@ -4423,9 +4395,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -4434,9 +4406,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" dependencies = [ "once_cell", "valuable", diff --git a/Cargo.toml b/Cargo.toml index 00d5d656c1..35ba67013a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -236,9 +236,11 @@ rand = "0.8.4" rand_xoshiro = "0.6.0" hex = "0.4.3" tempfile = "3.10.1" -criterion = { version = "0.5.1", features = ["async_tokio"] } +criterion = { version = "0.7.0", features = ["async_tokio"] } libsqlite3-sys = { version = "0.30.1" } +tracing = { version = "0.1.44", features = ["attributes"] } + # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. [target.'cfg(sqlite_test_sqlcipher)'.dev-dependencies] @@ -454,3 +456,12 @@ required-features = ["postgres"] name = "postgres-rustsec" path = "tests/postgres/rustsec.rs" required-features = ["postgres", "macros", "migrate"] + +# +# Benches +# +[[bench]] +name = "any-pool" +path = "benches/any/pool.rs" +required-features = ["runtime-tokio", "any"] +harness = false diff --git a/benches/any/pool.rs b/benches/any/pool.rs new file mode 100644 index 0000000000..a689058055 --- /dev/null +++ b/benches/any/pool.rs @@ -0,0 +1,127 @@ +use criterion::{criterion_group, criterion_main, Bencher, BenchmarkId, Criterion}; +use sqlx_core::any::AnyPoolOptions; +use std::fmt::{Display, Formatter}; +use std::thread; +use std::time::{Duration, Instant}; +use tracing::Instrument; + +#[derive(Debug)] +struct Input { + threads: usize, + tasks: usize, + pool_size: u32, +} + +impl Display for Input { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "threads: {}, tasks: {}, pool size: {}", + self.threads, self.tasks, self.pool_size + ) + } +} + +fn bench_pool(c: &mut Criterion) { + sqlx::any::install_default_drivers(); + + let database_url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); + + let inputs = [ + Input { + threads: 1, + tasks: 2, + pool_size: 20, + }, + Input { + threads: 2, + tasks: 4, + pool_size: 20, + }, + Input { + threads: 4, + tasks: 8, + pool_size: 20, + }, + Input { + threads: 8, + tasks: 16, + pool_size: 20, + }, + Input { + threads: 16, + tasks: 32, + pool_size: 64, + }, + Input { + threads: 16, + tasks: 128, + pool_size: 64, + }, + ]; + + let mut group = c.benchmark_group("Bench Pool"); + + for input in inputs { + group.bench_with_input(BenchmarkId::from_parameter(&input), &input, |b, i| { + bench_pool_with(b, i, &database_url) + }); + } + + group.finish(); +} + +fn bench_pool_with(b: &mut Bencher, input: &Input, database_url: &str) { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .worker_threads(input.threads) + .build() + .unwrap(); + + let pool = runtime.block_on(async { + AnyPoolOptions::new() + .min_connections(input.pool_size) + .max_connections(input.pool_size) + .test_before_acquire(false) + .connect(database_url) + .await + .expect("error connecting to pool") + }); + + for _ in 1..=input.tasks { + let pool = pool.clone(); + + runtime.spawn(async move { while pool.acquire().await.is_ok() {} }); + } + + // Spawn the benchmark loop into the runtime so we're not accidentally including the main thread + b.to_async(&runtime).iter_custom(|iters| { + let pool = pool.clone(); + + async move { + tokio::spawn( + async move { + let start = Instant::now(); + + for _ in 0..iters { + if let Err(e) = pool.acquire().await { + panic!("failed to acquire connection: {e:?}"); + } + } + + start.elapsed() + } + .instrument(tracing::info_span!("iter")), + ) + .await + .expect("panic in task") + } + }); + + runtime.block_on(pool.close()); + // Give the server a second to clean up + thread::sleep(Duration::from_millis(50)); +} + +criterion_group!(benches, bench_pool,); +criterion_main!(benches); From 0601c3a75d809e8a4abc4e693c54a36d06396254 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 18 Oct 2024 19:48:44 -0700 Subject: [PATCH 02/24] breaking(pool): use `usize` for all connection counts --- sqlx-core/src/pool/inner.rs | 8 ++++---- sqlx-core/src/pool/mod.rs | 2 +- sqlx-core/src/pool/options.rs | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index b698dc9df0..7c6a84adfe 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -11,7 +11,7 @@ use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; use std::cmp; use std::future::{self, Future}; use std::pin::pin; -use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, RwLock}; use std::task::Poll; @@ -26,7 +26,7 @@ pub(crate) struct PoolInner { pub(super) connect_options: RwLock::Options>>, pub(super) idle_conns: ArrayQueue>, pub(super) semaphore: AsyncSemaphore, - pub(super) size: AtomicU32, + pub(super) size: AtomicUsize, pub(super) num_idle: AtomicUsize, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, @@ -55,7 +55,7 @@ impl PoolInner { connect_options: RwLock::new(Arc::new(connect_options)), idle_conns: ArrayQueue::new(capacity), semaphore: AsyncSemaphore::new(options.fair, semaphore_capacity), - size: AtomicU32::new(0), + size: AtomicUsize::new(0), num_idle: AtomicUsize::new(0), is_closed: AtomicBool::new(false), on_closed: event_listener::Event::new(), @@ -71,7 +71,7 @@ impl PoolInner { pool } - pub(super) fn size(&self) -> u32 { + pub(super) fn size(&self) -> usize { self.size.load(Ordering::Acquire) } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index f11ff1d76a..6e54089484 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -532,7 +532,7 @@ impl Pool { } /// Returns the number of connections currently active. This includes idle connections. - pub fn size(&self) -> u32 { + pub fn size(&self) -> usize { self.0.size() } diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 3d048f1795..2c2c2e1801 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -74,12 +74,12 @@ pub struct PoolOptions { + Sync, >, >, - pub(crate) max_connections: u32, + pub(crate) max_connections: usize, pub(crate) acquire_time_level: LevelFilter, pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, pub(crate) acquire_timeout: Duration, - pub(crate) min_connections: u32, + pub(crate) min_connections: usize, pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, pub(crate) fair: bool, @@ -170,13 +170,13 @@ impl PoolOptions { /// Be mindful of the connection limits for your database as well as other applications /// which may want to connect to the same database (or even multiple instances of the same /// application in high-availability deployments). - pub fn max_connections(mut self, max: u32) -> Self { + pub fn max_connections(mut self, max: usize) -> Self { self.max_connections = max; self } /// Get the maximum number of connections that this pool should maintain - pub fn get_max_connections(&self) -> u32 { + pub fn get_max_connections(&self) -> usize { self.max_connections } @@ -202,13 +202,13 @@ impl PoolOptions { /// [`max_lifetime`]: Self::max_lifetime /// [`idle_timeout`]: Self::idle_timeout /// [`max_connections`]: Self::max_connections - pub fn min_connections(mut self, min: u32) -> Self { + pub fn min_connections(mut self, min: usize) -> Self { self.min_connections = min; self } /// Get the minimum number of connections to maintain at all times. - pub fn get_min_connections(&self) -> u32 { + pub fn get_min_connections(&self) -> usize { self.min_connections } From 3ec35e06bde31cee43d3f87b0726a92283049fb0 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 18 Oct 2024 19:46:36 -0700 Subject: [PATCH 03/24] WIP pool changes --- sqlx-core/Cargo.toml | 9 +- sqlx-core/src/error.rs | 39 ++- sqlx-core/src/pool/connect.rs | 461 ++++++++++++++++++++++++++++++ sqlx-core/src/pool/connection.rs | 49 ++-- sqlx-core/src/pool/idle.rs | 97 +++++++ sqlx-core/src/pool/inner.rs | 472 ++++++++----------------------- sqlx-core/src/pool/mod.rs | 28 +- sqlx-core/src/pool/options.rs | 108 +++---- sqlx-core/src/rt/mod.rs | 25 +- sqlx-mysql/src/testing/mod.rs | 35 +-- sqlx-postgres/src/error.rs | 2 +- sqlx-postgres/src/options/mod.rs | 23 ++ sqlx-postgres/src/testing/mod.rs | 33 +-- 13 files changed, 855 insertions(+), 526 deletions(-) create mode 100644 sqlx-core/src/pool/connect.rs create mode 100644 sqlx-core/src/pool/idle.rs diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index fff4ef3d24..8512508cba 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -22,11 +22,10 @@ json = ["serde", "serde_json"] # for conditional compilation _rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"] _rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration -_rt-async-std = ["async-std", "_rt-async-io"] +_rt-async-std = ["async-std", "_rt-async-io", "ease-off/async-io-2"] _rt-async-task = ["async-task"] _rt-smol = ["smol", "_rt-async-io", "_rt-async-task"] -_rt-tokio = ["tokio", "tokio-stream"] - +_rt-tokio = ["tokio", "tokio-stream", "ease-off/tokio"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] _tls-rustls-ring-webpki = ["_tls-rustls", "rustls/ring", "webpki-roots"] @@ -83,7 +82,6 @@ crossbeam-queue = "0.3.2" either = "1.6.1" futures-core = { version = "0.3.19", default-features = false } futures-io = "0.3.24" -futures-intrusive = "0.5.0" futures-util = { version = "0.3.19", default-features = false, features = ["alloc", "sink", "io"] } log = { version = "0.4.18", default-features = false } memchr = { version = "2.4.1", default-features = false } @@ -105,6 +103,9 @@ hashbrown = "0.16.0" thiserror.workspace = true +ease-off = { workspace = true, features = ["futures"] } +pin-project-lite = "0.2.14" + [dev-dependencies] tokio = { version = "1", features = ["rt"] } diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 8c6f424cdf..ff416ed2f7 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -11,6 +11,9 @@ use crate::database::Database; use crate::type_info::TypeInfo; use crate::types::Type; +#[cfg(doc)] +use crate::pool::{PoolConnector, PoolOptions}; + /// A specialized `Result` type for SQLx. pub type Result = ::std::result::Result; @@ -110,6 +113,19 @@ pub enum Error { #[error("attempted to acquire a connection on a closed pool")] PoolClosed, + /// A custom error that may be returned from a [`PoolConnector`] implementation. + #[error("error returned from pool connector")] + PoolConnector { + #[source] + source: BoxDynError, + + /// If `true`, `PoolConnector::connect()` is called again in an exponential backoff loop + /// up to [`PoolOptions::connect_timeout`]. + /// + /// See [`PoolConnector::connect()`] for details. + retryable: bool, + }, + /// A background worker has crashed. #[error("attempted to communicate with a crashed background worker")] WorkerCrashed, @@ -228,11 +244,6 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { #[doc(hidden)] fn into_error(self: Box) -> Box; - #[doc(hidden)] - fn is_transient_in_connect_phase(&self) -> bool { - false - } - /// Returns the name of the constraint that triggered the error, if applicable. /// If the error was caused by a conflict of a unique index, this will be the index name. /// @@ -270,6 +281,24 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { fn is_check_violation(&self) -> bool { matches!(self.kind(), ErrorKind::CheckViolation) } + + /// Returns `true` if this error can be retried when connecting to the database. + /// + /// Defaults to `false`. + /// + /// For example, the Postgres driver overrides this to return `true` for the following error codes: + /// + /// * `53300 too_many_connections`: returned when the maximum connections are exceeded + /// on the server. Assumed to be the result of a temporary overcommit + /// (e.g. an extra application replica being spun up to replace one that is going down). + /// * This error being consistently logged or returned is a likely indicator of a misconfiguration; + /// the sum of [`PoolOptions::max_connections`] for all replicas should not exceed + /// the maximum connections allowed by the server. + /// * `57P03 cannot_connect_now`: returned when the database server is still starting up + /// and the tcop component is not ready to accept connections yet. + fn is_retryable_connect_error(&self) -> bool { + false + } } impl dyn DatabaseError { diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs new file mode 100644 index 0000000000..29e59cafee --- /dev/null +++ b/sqlx-core/src/pool/connect.rs @@ -0,0 +1,461 @@ +use crate::connection::{ConnectOptions, Connection}; +use crate::database::Database; +use crate::pool::connection::{Floating, Live}; +use crate::pool::inner::PoolInner; +use crate::pool::PoolConnection; +use crate::rt::JoinHandle; +use crate::Error; +use ease_off::EaseOff; +use event_listener::{Event, EventListener}; +use std::future::Future; +use std::pin::Pin; +use std::ptr; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, RwLock}; +use std::task::{Context, Poll}; +use std::time::{Duration, Instant}; +use tracing::Instrument; + +use std::io; + +/// Custom connect callback for [`Pool`][crate::pool::Pool]. +/// +/// Implemented for closures with the signature +/// `Fn(PoolConnectMetadata) -> impl Future>`. +/// +/// See [`Self::connect()`] for details and implementation advice. +/// +/// # Example: `after_connect` Replacement +/// The `after_connect` callback was removed in 0.9.0 as it was redundant to this API. +/// +/// This example uses Postgres but may be adapted to any driver. +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use sqlx::PgConnection; +/// use sqlx::postgres::PgPoolOptions; +/// use sqlx::Connection; +/// +/// # async fn _example() -> sqlx::Result<()> { +/// // `PoolConnector` is implemented for closures but has restrictions on returning borrows +/// // due to current language limitations. +/// // +/// // This example shows how to get around this using `Arc`. +/// let database_url: Arc = "postgres://...".into(); +/// +/// let pool = PgPoolOptions::new() +/// .min_connections(5) +/// .max_connections(30) +/// .connect_with_connector(move |meta| { +/// let database_url = database_url.clone(); +/// async move { +/// println!( +/// "opening connection {}, attempt {}; elapsed time: {}", +/// meta.pool_size, +/// meta.num_attempts + 1, +/// meta.start.elapsed() +/// ); +/// +/// let mut conn = PgConnection::connect(&database_url).await?; +/// +/// // Override the time zone of the connection. +/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'").await?; +/// +/// Ok(conn) +/// } +/// }) +/// .await?; +/// # Ok(()) +/// # } +/// ``` +/// +/// # Example: `set_connect_options` Replacement +/// `set_connect_options` and `get_connect_options` were removed in 0.9.0 because they complicated +/// the pool internals. They can be reimplemented by capturing a mutex, or similar, in the callback. +/// +/// This example uses Postgres and [`tokio::sync::Mutex`] but may be adapted to any driver +/// or `async-std`, respectively. +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use tokio::sync::{Mutex, RwLock}; +/// use sqlx::PgConnection; +/// use sqlx::postgres::PgConnectOptions; +/// use sqlx::postgres::PgPoolOptions; +/// use sqlx::ConnectOptions; +/// +/// # async fn _example() -> sqlx::Result<()> { +/// // If you do not wish to hold the lock during the connection attempt, +/// // you could use `Arc` instead. +/// let connect_opts: Arc> = Arc::new(RwLock::new("postgres://...".parse()?)); +/// // We need a copy that will be captured by the closure. +/// let connect_opts_ = connect_opts.clone(); +/// +/// let pool = PgPoolOptions::new() +/// .connect_with_connector(move |meta| { +/// let connect_opts_ = connect_opts.clone(); +/// async move { +/// println!( +/// "opening connection {}, attempt {}; elapsed time: {}", +/// meta.pool_size, +/// meta.num_attempts + 1, +/// meta.start.elapsed() +/// ); +/// +/// connect_opts.read().await.connect().await +/// } +/// }) +/// .await?; +/// +/// // Close the connection that was previously opened by `connect_with_connector()`. +/// pool.acquire().await?.close().await?; +/// +/// // Simulating a credential rotation +/// let mut write_connect_opts = connect_opts.write().await; +/// write_connect_opts +/// .set_username("new_username") +/// .set_password("new password"); +/// +/// // Should use the new credentials. +/// let mut conn = pool.acquire().await?; +/// +/// # Ok(()) +/// # } +/// ``` +/// +/// # Example: Custom Implementation +/// +/// Custom implementations of `PoolConnector` trade a little bit of boilerplate for much +/// more flexibility. Thanks to the signature of `connect()`, they can return a `Future` +/// type that borrows from `self`. +/// +/// This example uses Postgres but may be adapted to any driver. +/// +/// ```rust,no_run +/// use sqlx::{PgConnection, Postgres}; +/// use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +/// use sqlx_core::connection::ConnectOptions; +/// use sqlx_core::pool::{PoolConnectMetadata, PoolConnector}; +/// +/// struct MyConnector { +/// // A list of servers to connect to in a high-availability configuration. +/// host_ports: Vec<(String, u16)>, +/// username: String, +/// password: String, +/// } +/// +/// impl PoolConnector for MyConnector { +/// // The desugaring of `async fn` is compatible with the signature of `connect()`. +/// async fn connect(&self, meta: PoolConnectMetadata) -> sqlx::Result { +/// self.get_connect_options(meta.num_attempts) +/// .connect() +/// .await +/// } +/// } +/// +/// impl MyConnector { +/// fn get_connect_options(&self, attempt: usize) -> PgConnectOptions { +/// // Select servers in a round-robin. +/// let (ref host, port) = self.host_ports[attempt % self.host_ports.len()]; +/// +/// PgConnectOptions::new() +/// .host(host) +/// .port(port) +/// .username(&self.username) +/// .password(&self.password) +/// } +/// } +/// +/// # async fn _example() -> sqlx::Result<()> { +/// let pool = PgPoolOptions::new() +/// .max_connections(25) +/// .connect_with_connector(MyConnector { +/// host_ports: vec![ +/// ("db1.postgres.cluster.local".into(), 5432), +/// ("db2.postgres.cluster.local".into(), 5432), +/// ("db3.postgres.cluster.local".into(), 5432), +/// ("db4.postgres.cluster.local".into(), 5432), +/// ], +/// username: "my_username".into(), +/// password: "my password".into(), +/// }) +/// .await?; +/// +/// let conn = pool.acquire().await?; +/// +/// # Ok(()) +/// # } +/// ``` +pub trait PoolConnector: Send + Sync + 'static { + /// Open a connection for the pool. + /// + /// Any setup that must be done on the connection should be performed before it is returned. + /// + /// If this method returns an error that is known to be retryable, it is called again + /// in an exponential backoff loop. Retryable errors include, but are not limited to: + /// + /// * [`io::ErrorKind::ConnectionRefused`] + /// * Database errors for which + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. + /// * [`Error::PoolConnector`] with `retryable: true`. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. + /// + /// Manual implementations of this method may also use the signature: + /// ```rust,ignore + /// async fn connect( + /// &self, + /// meta: PoolConnectMetadata + /// ) -> sqlx::Result<{PgConnection, MySqlConnection, SqliteConnection, etc.}> + /// ``` + /// + /// Note: the returned future must be `Send`. + fn connect( + &self, + meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_; +} + +impl PoolConnector for F +where + DB: Database, + F: Fn(PoolConnectMetadata) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ + fn connect( + &self, + meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_ { + self(meta) + } +} + +pub(crate) struct DefaultConnector( + pub <::Connection as Connection>::Options, +); + +impl PoolConnector for DefaultConnector { + fn connect( + &self, + _meta: PoolConnectMetadata, + ) -> impl Future> + Send + '_ { + self.0.connect() + } +} + +/// Metadata passed to [`PoolConnector::connect()`] for every connection attempt. +#[derive(Debug)] +pub struct PoolConnectMetadata { + /// The instant at which the current connection task was started, including all attempts. + /// + /// May be used for reporting purposes, or to implement a custom backoff. + pub start: Instant, + /// The number of attempts that have occurred so far. + pub num_attempts: usize, + pub pool_size: usize, +} + +pub struct DynConnector { + // We want to spawn the connection attempt as a task anyway + connect: Box< + dyn Fn(ConnectPermit, usize) -> JoinHandle>> + + Send + + Sync + + 'static, + >, +} + +impl DynConnector { + pub fn new(connector: impl PoolConnector) -> Self { + let connector = Arc::new(connector); + + Self { + connect: Box::new(move |permit, size| { + crate::rt::spawn(connect_with_backoff(permit, connector.clone(), size)) + }), + } + } + + pub fn connect( + &self, + permit: ConnectPermit, + size: usize, + ) -> JoinHandle>> { + (self.connect)(permit, size) + } +} + +pub struct ConnectionCounter { + connections: AtomicUsize, + connect_available: Event, +} + +impl ConnectionCounter { + pub fn new() -> Self { + Self { + connections: AtomicUsize::new(0), + connect_available: Event::new(), + } + } + + pub fn connections(&self) -> usize { + self.connections.load(Ordering::Acquire) + } + + pub async fn drain(&self) { + while self.connections.load(Ordering::Acquire) > 0 { + self.connect_available.listen().await; + } + } + + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. + /// + /// Returns the permit, and the current size of the pool. + pub async fn acquire_permit( + &self, + pool: &Arc>, + ) -> (usize, ConnectPermit) { + // Check that `self` can increase size first before we check the parent. + let (size, permit) = self.acquire_permit_self(pool).await; + + if let Some(parent) = &pool.options.parent_pool { + let (_, permit) = parent.0.counter.acquire_permit_self(&parent.0).await; + + // consume the parent permit + permit.consume(); + } + + (size, permit) + } + + // Separate method because `async fn`s cannot be recursive. + /// Attempt to acquire a [`ConnectPermit`] from this instance and this instance only. + async fn acquire_permit_self( + &self, + pool: &Arc>, + ) -> (usize, ConnectPermit) { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + let mut should_wait = pool.options.fair && self.connect_available.total_listeners() > 0; + + for attempt in 1usize.. { + if should_wait { + self.connect_available.listen().await; + } + + let res = self.connections.fetch_update( + Ordering::Release, + Ordering::Acquire, + |connections| { + (connections < pool.options.max_connections).then_some(connections + 1) + }, + ); + + if let Ok(prev_size) = res { + let size = prev_size + 1; + + tracing::trace!(target: "sqlx::pool::connect", size, "increased size"); + + return ( + prev_size + 1, + ConnectPermit { + pool: Some(Arc::clone(pool)), + }, + ); + } + + should_wait = true; + + if attempt == 2 { + tracing::warn!( + "unable to acquire a connect permit after sleeping; this may indicate a bug" + ); + } + } + + panic!("BUG: was never able to acquire a connection despite waking many times") + } + + pub fn release_permit(&self, pool: &PoolInner) { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + self.connections.fetch_sub(1, Ordering::Release); + self.connect_available.notify(1usize); + + if let Some(parent) = &pool.options.parent_pool { + parent.0.counter.release_permit(&parent.0); + } + } +} + +pub struct ConnectPermit { + pool: Option>>, +} + +impl ConnectPermit { + pub fn float_existing(pool: Arc>) -> Self { + Self { pool: Some(pool) } + } + + pub fn pool(&self) -> &Arc> { + self.pool.as_ref().unwrap() + } + + pub fn consume(mut self) { + self.pool = None; + } +} + +impl Drop for ConnectPermit { + fn drop(&mut self) { + if let Some(pool) = self.pool.take() { + pool.counter.release_permit(&pool); + } + } +} + +#[tracing::instrument( + target = "sqlx::pool::connect", + skip_all, + fields(connection = size), + err +)] +async fn connect_with_backoff( + permit: ConnectPermit, + connector: Arc>, + size: usize, +) -> crate::Result> { + if permit.pool().is_closed() { + return Err(Error::PoolClosed); + } + + let mut ease_off = EaseOff::start_timeout(permit.pool().options.connect_timeout); + + for attempt in 1usize.. { + let meta = PoolConnectMetadata { + start: ease_off.started_at(), + num_attempts: attempt, + pool_size: size, + }; + + let conn = ease_off + .try_async(connector.connect(meta)) + .await + .or_retry_if(|e| can_retry_error(e.inner()))?; + + if let Some(conn) = conn { + return Ok(Floating::new_live(conn, permit).reattach()); + } + } + + Err(Error::PoolTimedOut) +} + +fn can_retry_error(e: &Error) -> bool { + match e { + Error::Io(e) if e.kind() == io::ErrorKind::ConnectionRefused => true, + Error::Database(e) => e.is_retryable_connect_error(), + _ => false, + } +} diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 7912b12aa1..48d124e38a 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -10,7 +10,8 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use super::inner::{is_beyond_max_lifetime, DecrementSizeGuard, PoolInner}; +use super::inner::{is_beyond_max_lifetime, PoolInner}; +use crate::pool::connect::ConnectPermit; use crate::pool::options::PoolConnectionMetadata; const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); @@ -37,7 +38,7 @@ pub(super) struct Idle { /// RAII wrapper for connections being handled by functions that may drop them pub(super) struct Floating { pub(super) inner: C, - pub(super) guard: DecrementSizeGuard, + pub(super) permit: ConnectPermit, } const EXPECT_MSG: &str = "BUG: inner connection already taken!"; @@ -127,6 +128,10 @@ impl PoolConnection { self.live.take().expect(EXPECT_MSG) } + pub(super) fn into_floating(mut self) -> Floating> { + self.take_live().float(self.pool.clone()) + } + /// Test the connection to make sure it is still live before returning it to the pool. /// /// This effectively runs the drop handler eagerly instead of spawning a task to do it. @@ -215,7 +220,7 @@ impl Live { Floating { inner: self, // create a new guard from a previously leaked permit - guard: DecrementSizeGuard::new_permit(pool), + permit: ConnectPermit::float_existing(pool), } } @@ -242,22 +247,22 @@ impl DerefMut for Idle { } impl Floating> { - pub fn new_live(conn: DB::Connection, guard: DecrementSizeGuard) -> Self { + pub fn new_live(conn: DB::Connection, permit: ConnectPermit) -> Self { Self { inner: Live { raw: conn, created_at: Instant::now(), }, - guard, + permit, } } pub fn reattach(self) -> PoolConnection { - let Floating { inner, guard } = self; + let Floating { inner, permit } = self; - let pool = Arc::clone(&guard.pool); + let pool = Arc::clone(permit.pool()); - guard.cancel(); + permit.consume(); PoolConnection { live: Some(inner), close_on_drop: false, @@ -266,7 +271,7 @@ impl Floating> { } pub fn release(self) { - self.guard.pool.clone().release(self); + self.permit.pool().clone().release(self); } /// Return the connection to the pool. @@ -274,19 +279,19 @@ impl Floating> { /// Returns `true` if the connection was successfully returned, `false` if it was closed. async fn return_to_pool(mut self) -> bool { // Immediately close the connection. - if self.guard.pool.is_closed() { + if self.permit.pool().is_closed() { self.close().await; return false; } // If the connection is beyond max lifetime, close the connection and // immediately create a new connection - if is_beyond_max_lifetime(&self.inner, &self.guard.pool.options) { + if is_beyond_max_lifetime(&self.inner, &self.permit.pool().options) { self.close().await; return false; } - if let Some(test) = &self.guard.pool.options.after_release { + if let Some(test) = &self.permit.pool().options.after_release { let meta = self.metadata(); match (test)(&mut self.inner.raw, meta).await { Ok(true) => (), @@ -345,7 +350,7 @@ impl Floating> { pub fn into_idle(self) -> Floating> { Floating { inner: self.inner.into_idle(), - guard: self.guard, + permit: self.permit, } } @@ -358,14 +363,10 @@ impl Floating> { } impl Floating> { - pub fn from_idle( - idle: Idle, - pool: Arc>, - permit: AsyncSemaphoreReleaser<'_>, - ) -> Self { + pub fn from_idle(idle: Idle, pool: Arc>) -> Self { Self { inner: idle, - guard: DecrementSizeGuard::from_permit(pool, permit), + permit: ConnectPermit::float_existing(pool), } } @@ -376,21 +377,21 @@ impl Floating> { pub fn into_live(self) -> Floating> { Floating { inner: self.inner.live, - guard: self.guard, + permit: self.permit, } } - pub async fn close(self) -> DecrementSizeGuard { + pub async fn close(self) -> ConnectPermit { if let Err(error) = self.inner.live.raw.close().await { tracing::debug!(%error, "error occurred while closing the pool connection"); } - self.guard + self.permit } - pub async fn close_hard(self) -> DecrementSizeGuard { + pub async fn close_hard(self) -> ConnectPermit { let _ = self.inner.live.raw.close_hard().await; - self.guard + self.permit } pub fn metadata(&self) -> PoolConnectionMetadata { diff --git a/sqlx-core/src/pool/idle.rs b/sqlx-core/src/pool/idle.rs new file mode 100644 index 0000000000..239313f7ed --- /dev/null +++ b/sqlx-core/src/pool/idle.rs @@ -0,0 +1,97 @@ +use crate::connection::Connection; +use crate::database::Database; +use crate::pool::connection::{Floating, Idle, Live}; +use crate::pool::inner::PoolInner; +use crossbeam_queue::ArrayQueue; +use event_listener::Event; +use futures_util::FutureExt; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +pub struct IdleQueue { + queue: ArrayQueue>, + // Keep a separate count because `ArrayQueue::len()` loops until the head and tail pointers + // stop changing, which may never happen at high contention. + len: AtomicUsize, + release_event: Event, + fair: bool, +} + +impl IdleQueue { + pub fn new(fair: bool, cap: usize) -> Self { + Self { + queue: ArrayQueue::new(cap), + len: AtomicUsize::new(0), + release_event: Event::new(), + fair, + } + } + + pub fn len(&self) -> usize { + self.len.load(Ordering::Acquire) + } + + pub async fn acquire(&self, pool: &Arc>) -> Floating> { + let mut should_wait = self.fair && self.release_event.total_listeners() > 0; + + for attempt in 1usize.. { + if should_wait { + self.release_event.listen().await; + } + + if let Some(conn) = self.try_acquire(pool) { + return conn; + } + + should_wait = true; + + if attempt == 2 { + tracing::warn!( + "unable to acquire a connection after sleeping; this may indicate a bug" + ); + } + } + + panic!("BUG: was never able to acquire a connection despite waking many times") + } + + pub fn try_acquire(&self, pool: &Arc>) -> Option>> { + self.len + .fetch_update(Ordering::Release, Ordering::Acquire, |len| { + len.checked_sub(1) + }) + .ok() + .and_then(|_| { + let conn = self.queue.pop()?; + + Some(Floating::from_idle(conn, Arc::clone(pool))) + }) + } + + pub fn release(&self, conn: Floating>) { + let Floating { + inner: conn, + permit, + } = conn.into_idle(); + + self.queue + .push(conn) + .unwrap_or_else(|_| panic!("BUG: idle queue capacity exceeded")); + + self.len.fetch_add(1, Ordering::Release); + + self.release_event.notify(1usize); + + // Don't decrease the size. + permit.consume(); + } + + pub fn drain(&self, pool: &PoolInner) { + while let Some(conn) = self.queue.pop() { + // Hopefully will send at least a TCP FIN packet. + conn.live.raw.close_hard().now_or_never(); + + pool.counter.release_permit(pool); + } + } +} diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 7c6a84adfe..c366bb07f1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -1,33 +1,29 @@ use super::connection::{Floating, Idle, Live}; -use crate::connection::ConnectOptions; -use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use crate::pool::{deadline_as_timeout, CloseEvent, Pool, PoolOptions}; -use crossbeam_queue::ArrayQueue; - -use crate::sync::{AsyncSemaphore, AsyncSemaphoreReleaser}; +use crate::pool::{CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; use std::cmp; use std::future::{self, Future}; use std::pin::pin; -use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; -use std::sync::{Arc, RwLock}; -use std::task::Poll; +use std::pin::pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::ready; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::options::PoolConnectionMetadata; +use crate::pool::connect::{ConnectPermit, ConnectionCounter, DynConnector}; +use crate::pool::idle::IdleQueue; use crate::private_tracing_dynamic_event; +use futures_util::future::{self, OptionFuture}; use futures_util::FutureExt; use std::time::{Duration, Instant}; use tracing::Level; pub(crate) struct PoolInner { - pub(super) connect_options: RwLock::Options>>, - pub(super) idle_conns: ArrayQueue>, - pub(super) semaphore: AsyncSemaphore, - pub(super) size: AtomicUsize, - pub(super) num_idle: AtomicUsize, + pub(super) connector: DynConnector, + pub(super) counter: ConnectionCounter, + pub(super) idle: IdleQueue, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -38,25 +34,12 @@ pub(crate) struct PoolInner { impl PoolInner { pub(super) fn new_arc( options: PoolOptions, - connect_options: ::Options, + connector: impl PoolConnector, ) -> Arc { - let capacity = options.max_connections as usize; - - let semaphore_capacity = if let Some(parent) = &options.parent_pool { - assert!(options.max_connections <= parent.options().max_connections); - assert_eq!(options.fair, parent.options().fair); - // The child pool must steal permits from the parent - 0 - } else { - capacity - }; - let pool = Self { - connect_options: RwLock::new(Arc::new(connect_options)), - idle_conns: ArrayQueue::new(capacity), - semaphore: AsyncSemaphore::new(options.fair, semaphore_capacity), - size: AtomicUsize::new(0), - num_idle: AtomicUsize::new(0), + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + idle: IdleQueue::new(options.fair, options.max_connections), is_closed: AtomicBool::new(false), on_closed: event_listener::Event::new(), acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), @@ -72,16 +55,11 @@ impl PoolInner { } pub(super) fn size(&self) -> usize { - self.size.load(Ordering::Acquire) + self.counter.connections() } pub(super) fn num_idle(&self) -> usize { - // We don't use `self.idle_conns.len()` as it waits for the internal - // head and tail pointers to stop changing for a moment before calculating the length, - // which may take a long time at high levels of churn. - // - // By maintaining our own atomic count, we avoid that issue entirely. - self.num_idle.load(Ordering::Acquire) + self.idle.len() } pub(super) fn is_closed(&self) -> bool { @@ -97,24 +75,11 @@ impl PoolInner { self.mark_closed(); async move { - // For child pools, we need to acquire permits we actually have rather than - // max_connections - let permits_to_acquire = if self.options.parent_pool.is_some() { - // Child pools start with 0 permits, so we acquire based on current size - self.size() - } else { - // Parent pools can acquire all max_connections permits - self.options.max_connections - }; - - let _permits = self.semaphore.acquire(permits_to_acquire).await; - - while let Some(idle) = self.idle_conns.pop() { - let _ = idle.live.raw.close().await; + while let Some(idle) = self.idle.try_acquire(self) { + idle.close().await; } - self.num_idle.store(0, Ordering::Release); - self.size.store(0, Ordering::Release); + self.counter.drain().await; } } @@ -124,56 +89,6 @@ impl PoolInner { } } - /// Attempt to pull a permit from `self.semaphore` or steal one from the parent. - /// - /// If we steal a permit from the parent but *don't* open a connection, - /// it should be returned to the parent. - async fn acquire_permit(self: &Arc) -> Result, Error> { - let parent = self - .parent() - // If we're already at the max size, we shouldn't try to steal from the parent. - // This is just going to cause unnecessary churn in `acquire()`. - .filter(|_| self.size() < self.options.max_connections); - - let mut acquire_self = pin!(self.semaphore.acquire(1).fuse()); - let mut close_event = pin!(self.close_event()); - - if let Some(parent) = parent { - let mut acquire_parent = pin!(parent.0.semaphore.acquire(1)); - let mut parent_close_event = pin!(parent.0.close_event()); - - let mut poll_parent = false; - - future::poll_fn(|cx| { - if close_event.as_mut().poll(cx).is_ready() { - return Poll::Ready(Err(Error::PoolClosed)); - } - - if parent_close_event.as_mut().poll(cx).is_ready() { - // Propagate the parent's close event to the child. - self.mark_closed(); - return Poll::Ready(Err(Error::PoolClosed)); - } - - if let Poll::Ready(permit) = acquire_self.as_mut().poll(cx) { - return Poll::Ready(Ok(permit)); - } - - // Don't try the parent right away. - if poll_parent { - acquire_parent.as_mut().poll(cx).map(Ok) - } else { - poll_parent = true; - cx.waker().wake_by_ref(); - Poll::Pending - } - }) - .await - } else { - close_event.do_until(acquire_self).await - } - } - fn parent(&self) -> Option<&Pool> { self.options.parent_pool.as_ref() } @@ -184,117 +99,103 @@ impl PoolInner { return None; } - let permit = self.semaphore.try_acquire(1)?; - - self.pop_idle(permit).ok() - } - - fn pop_idle<'a>( - self: &'a Arc, - permit: AsyncSemaphoreReleaser<'a>, - ) -> Result>, AsyncSemaphoreReleaser<'a>> { - if let Some(idle) = self.idle_conns.pop() { - self.num_idle.fetch_sub(1, Ordering::AcqRel); - Ok(Floating::from_idle(idle, (*self).clone(), permit)) - } else { - Err(permit) - } + self.idle.try_acquire(self) } pub(super) fn release(&self, floating: Floating>) { // `options.after_release` and other checks are in `PoolConnection::return_to_pool()`. - - let Floating { inner: idle, guard } = floating.into_idle(); - - if self.idle_conns.push(idle).is_err() { - panic!("BUG: connection queue overflow in release()"); - } - - // NOTE: we need to make sure we drop the permit *after* we push to the idle queue - // don't decrease the size - guard.release_permit(); - - self.num_idle.fetch_add(1, Ordering::AcqRel); - } - - /// Try to atomically increment the pool size for a new connection. - /// - /// Returns `Err` if the pool is at max capacity already or is closed. - pub(super) fn try_increment_size<'a>( - self: &'a Arc, - permit: AsyncSemaphoreReleaser<'a>, - ) -> Result, AsyncSemaphoreReleaser<'a>> { - let result = self - .size - .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { - if self.is_closed() { - return None; - } - - size.checked_add(1) - .filter(|size| size <= &self.options.max_connections) - }); - - match result { - // we successfully incremented the size - Ok(_) => Ok(DecrementSizeGuard::from_permit((*self).clone(), permit)), - // the pool is at max capacity or is closed - Err(_) => Err(permit), - } + self.idle.release(floating); } - pub(super) async fn acquire(self: &Arc) -> Result>, Error> { + pub(super) async fn acquire(self: &Arc) -> Result, Error> { if self.is_closed() { return Err(Error::PoolClosed); } let acquire_started_at = Instant::now(); - let deadline = acquire_started_at + self.options.acquire_timeout; - - let acquired = crate::rt::timeout( - self.options.acquire_timeout, - async { - loop { - // Handles the close-event internally - let permit = self.acquire_permit().await?; - - - // First attempt to pop a connection from the idle queue. - let guard = match self.pop_idle(permit) { - - // Then, check that we can use it... - Ok(conn) => match check_idle_conn(conn, &self.options).await { - - // All good! - Ok(live) => return Ok(live), - - // if the connection isn't usable for one reason or another, - // we get the `DecrementSizeGuard` back to open a new one - Err(guard) => guard, - }, - Err(permit) => if let Ok(guard) = self.try_increment_size(permit) { - // we can open a new connection - guard - } else { - // This can happen for a child pool that's at its connection limit, - // or if the pool was closed between `acquire_permit()` and - // `try_increment_size()`. - tracing::debug!("woke but was unable to acquire idle connection or open new one; retrying"); - // If so, we're likely in the current-thread runtime if it's Tokio, - // and so we should yield to let any spawned return_to_pool() tasks - // execute. - crate::rt::yield_now().await; - continue; + + let mut close_event = pin!(self.close_event()); + let mut deadline = pin!(crate::rt::sleep(self.options.acquire_timeout)); + let mut acquire_idle = pin!(self.idle.acquire(self).fuse()); + let mut check_idle = pin!(OptionFuture::from(None)); + let mut acquire_connect_permit = pin!(OptionFuture::from(Some( + self.counter.acquire_permit(self).fuse() + ))); + let mut connect = OptionFuture::from(None); + + // The internal state machine of `acquire()`. + // + // * The initial state is racing to acquire either an idle connection or a new `ConnectPermit`. + // * If we acquire a `ConnectPermit`, we begin the connection loop (with backoff) + // as implemented by `DynConnector`. + // * If we acquire an idle connection, we then start polling `check_idle_conn()`. + let acquired = future::poll_fn(|cx| { + use std::task::Poll::*; + + // First check if the pool is already closed, + // or register for a wakeup if it gets closed. + if let Ready(()) = close_event.poll_unpin(cx) { + return Ready(Err(Error::PoolClosed)); + } + + // Then check if our deadline has elapsed, or schedule a wakeup for when that happens. + if let Ready(()) = deadline.poll_unpin(cx) { + return Ready(Err(Error::PoolTimedOut)); + } + + // Attempt to acquire a connection from the idle queue. + if let Ready(idle) = acquire_idle.poll_unpin(cx) { + check_idle.set(Some(check_idle_conn(idle, &self.options)).into()); + } + + // If we acquired an idle connection, run any checks that need to be done. + // + // Includes `test_on_acquire` and the `before_acquire` callback, if set. + // + // We don't want to race this step if it's already running because canceling it + // will result in the potentially unnecessary closure of a connection. + // + // Instead, we just wait and see what happens. If we already started connecting, + // that'll happen concurrently. + match ready!(check_idle.poll_unpin(cx)) { + // The `.reattach()` call errors with "type annotations needed" if not qualified. + Some(Ok(live)) => return Ready(Ok(Floating::reattach(live))), + Some(Err(permit)) => { + // We don't strictly need to poll `connect` here; all we really want to do + // is to check if it is `None`. But since currently there's no getter for that, + // it doesn't really hurt to just poll it here. + match connect.poll_unpin(cx) { + Ready(None) => { + // If we're not already attempting to connect, + // take the permit returned from closing the connection and + // attempt to open a new one. + connect = Some(self.connector.connect(permit, self.size())).into(); } - }; + // `permit` is dropped in these branches, allowing another task to use it + Ready(Some(res)) => return Ready(res), + Pending => (), + } - // Attempt to connect... - return self.connect(deadline, guard).await; + // Attempt to acquire another idle connection concurrently to opening a new one. + acquire_idle.set(self.idle.acquire(self).fuse()); + // Annoyingly, `OptionFuture` doesn't fuse to `None` on its own + check_idle.set(None.into()); } + None => (), } - ) - .await - .map_err(|_| Error::PoolTimedOut)??; + + if let Ready(Some((size, permit))) = acquire_connect_permit.poll_unpin(cx) { + connect = Some(self.connector.connect(permit, size)).into(); + } + + if let Ready(Some(res)) = connect.poll_unpin(cx) { + // RFC: suppress errors here? + return Ready(res); + } + + Pending + }) + .await?; let acquired_after = acquire_started_at.elapsed(); @@ -322,102 +223,29 @@ impl PoolInner { Ok(acquired) } - pub(super) async fn connect( - self: &Arc, - deadline: Instant, - guard: DecrementSizeGuard, - ) -> Result>, Error> { - if self.is_closed() { - return Err(Error::PoolClosed); - } - - let mut backoff = Duration::from_millis(10); - let max_backoff = deadline_as_timeout(deadline)? / 5; - - loop { - let timeout = deadline_as_timeout(deadline)?; - - // clone the connect options arc so it can be used without holding the RwLockReadGuard - // across an async await point - let connect_options = self - .connect_options - .read() - .expect("write-lock holder panicked") - .clone(); - - // result here is `Result, TimeoutError>` - // if this block does not return, sleep for the backoff timeout and try again - match crate::rt::timeout(timeout, connect_options.connect()).await { - // successfully established connection - Ok(Ok(mut raw)) => { - // See comment on `PoolOptions::after_connect` - let meta = PoolConnectionMetadata { - age: Duration::ZERO, - idle_for: Duration::ZERO, - }; - - let res = if let Some(callback) = &self.options.after_connect { - callback(&mut raw, meta).await - } else { - Ok(()) - }; - - match res { - Ok(()) => return Ok(Floating::new_live(raw, guard)), - Err(error) => { - tracing::error!(%error, "error returned from after_connect"); - // The connection is broken, don't try to close nicely. - let _ = raw.close_hard().await; - - // Fall through to the backoff. - } - } - } - - // an IO error while connecting is assumed to be the system starting up - Ok(Err(Error::Io(e))) if e.kind() == std::io::ErrorKind::ConnectionRefused => (), - - // We got a transient database error, retry. - Ok(Err(Error::Database(error))) if error.is_transient_in_connect_phase() => (), - - // Any other error while connection should immediately - // terminate and bubble the error up - Ok(Err(e)) => return Err(e), - - // timed out - Err(_) => return Err(Error::PoolTimedOut), - } - - // If the connection is refused, wait in exponentially - // increasing steps for the server to come up, - // capped by a factor of the remaining time until the deadline - crate::rt::sleep(backoff).await; - backoff = cmp::min(backoff * 2, max_backoff); - } - } - /// Try to maintain `min_connections`, returning any errors (including `PoolTimedOut`). pub async fn try_min_connections(self: &Arc, deadline: Instant) -> Result<(), Error> { - while self.size() < self.options.min_connections { - // Don't wait for a semaphore permit. - // - // If no extra permits are available then we shouldn't be trying to spin up - // connections anyway. - let Some(permit) = self.semaphore.try_acquire(1) else { - return Ok(()); - }; - - // We must always obey `max_connections`. - let Some(guard) = self.try_increment_size(permit).ok() else { - return Ok(()); - }; - - // We skip `after_release` since the connection was never provided to user code - // besides `after_connect`, if they set it. - self.release(self.connect(deadline, guard).await?); - } + crate::rt::timeout_at(deadline, async { + while self.size() < self.options.min_connections { + // Don't wait for a connect permit. + // + // If no extra permits are available then we shouldn't be trying to spin up + // connections anyway. + let Some((size, permit)) = self.counter.acquire_permit(self).now_or_never() else { + return Ok(()); + }; + + let conn = self.connector.connect(permit, size).await?; + + // We skip `after_release` since the connection was never provided to user code + // besides inside `PollConnector::connect()`, if they override it. + self.release(conn.into_floating()); + } - Ok(()) + Ok(()) + }) + .await + .unwrap_or_else(|_| Err(Error::PoolTimedOut)) } /// Attempt to maintain `min_connections`, logging if unable. @@ -441,11 +269,7 @@ impl PoolInner { impl Drop for PoolInner { fn drop(&mut self) { self.mark_closed(); - - if let Some(parent) = &self.options.parent_pool { - // Release the stolen permits. - parent.0.semaphore.release(self.semaphore.permits()); - } + self.idle.drain(self); } } @@ -469,7 +293,7 @@ fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions( mut conn: Floating>, options: &PoolOptions, -) -> Result>, DecrementSizeGuard> { +) -> Result>, ConnectPermit> { if options.test_before_acquire { // Check that the connection is still live if let Err(error) = conn.ping().await { @@ -573,51 +397,3 @@ fn spawn_maintenance_tasks(pool: &Arc>) { .await; }); } - -/// RAII guard returned by `Pool::try_increment_size()` and others. -/// -/// Will decrement the pool size if dropped, to avoid semantically "leaking" connections -/// (where the pool thinks it has more connections than it does). -pub(in crate::pool) struct DecrementSizeGuard { - pub(crate) pool: Arc>, - cancelled: bool, -} - -impl DecrementSizeGuard { - /// Create a new guard that will release a semaphore permit on-drop. - pub fn new_permit(pool: Arc>) -> Self { - Self { - pool, - cancelled: false, - } - } - - pub fn from_permit(pool: Arc>, permit: AsyncSemaphoreReleaser<'_>) -> Self { - // here we effectively take ownership of the permit - permit.disarm(); - Self::new_permit(pool) - } - - /// Release the semaphore permit without decreasing the pool size. - /// - /// If the permit was stolen from the pool's parent, it will be returned to the child's semaphore. - fn release_permit(self) { - self.pool.semaphore.release(1); - self.cancel(); - } - - pub fn cancel(mut self) { - self.cancelled = true; - } -} - -impl Drop for DecrementSizeGuard { - fn drop(&mut self) { - if !self.cancelled { - self.pool.size.fetch_sub(1, Ordering::AcqRel); - - // and here we release the permit we got on construction - self.pool.semaphore.release(1); - } - } -} diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 6e54089484..0caa1161c1 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -71,6 +71,7 @@ use crate::error::Error; use crate::sql_str::SqlSafeStr; use crate::transaction::Transaction; +pub use self::connect::{PoolConnectMetadata, PoolConnector}; pub use self::connection::PoolConnection; use self::inner::PoolInner; #[doc(hidden)] @@ -83,8 +84,11 @@ mod executor; #[macro_use] pub mod maybe; +mod connect; mod connection; mod inner; + +mod idle; mod options; /// An asynchronous pool of SQLx database connections. @@ -356,7 +360,7 @@ impl Pool { /// returning it. pub fn acquire(&self) -> impl Future, Error>> + 'static { let shared = self.0.clone(); - async move { shared.acquire().await.map(|conn| conn.reattach()) } + async move { shared.acquire().await } } /// Attempts to retrieve a connection from the pool if there is one available. @@ -541,28 +545,6 @@ impl Pool { self.0.num_idle() } - /// Gets a clone of the connection options for this pool - pub fn connect_options(&self) -> Arc<::Options> { - self.0 - .connect_options - .read() - .expect("write-lock holder panicked") - .clone() - } - - /// Updates the connection options this pool will use when opening any future connections. Any - /// existing open connection in the pool will be left as-is. - pub fn set_connect_options(&self, connect_options: ::Options) { - // technically write() could also panic if the current thread already holds the lock, - // but because this method can't be re-entered by the same thread that shouldn't be a problem - let mut guard = self - .0 - .connect_options - .write() - .expect("write-lock holder panicked"); - *guard = Arc::new(connect_options); - } - /// Get the options for this pool pub fn options(&self) -> &PoolOptions { &self.0.options diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 2c2c2e1801..9775799fdf 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -1,8 +1,9 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::pool::connect::DefaultConnector; use crate::pool::inner::PoolInner; -use crate::pool::Pool; +use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; @@ -44,14 +45,6 @@ use std::time::{Duration, Instant}; /// the perspectives of both API designer and consumer. pub struct PoolOptions { pub(crate) test_before_acquire: bool, - pub(crate) after_connect: Option< - Arc< - dyn Fn(&mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'_, Result<(), Error>> - + 'static - + Send - + Sync, - >, - >, pub(crate) before_acquire: Option< Arc< dyn Fn( @@ -79,6 +72,7 @@ pub struct PoolOptions { pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, pub(crate) acquire_timeout: Duration, + pub(crate) connect_timeout: Duration, pub(crate) min_connections: usize, pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, @@ -94,7 +88,6 @@ impl Clone for PoolOptions { fn clone(&self) -> Self { PoolOptions { test_before_acquire: self.test_before_acquire, - after_connect: self.after_connect.clone(), before_acquire: self.before_acquire.clone(), after_release: self.after_release.clone(), max_connections: self.max_connections, @@ -102,6 +95,7 @@ impl Clone for PoolOptions { acquire_slow_threshold: self.acquire_slow_threshold, acquire_slow_level: self.acquire_slow_level, acquire_timeout: self.acquire_timeout, + connect_timeout: self.connect_timeout, min_connections: self.min_connections, max_lifetime: self.max_lifetime, idle_timeout: self.idle_timeout, @@ -143,7 +137,6 @@ impl PoolOptions { pub fn new() -> Self { Self { // User-specifiable routines - after_connect: None, before_acquire: None, after_release: None, test_before_acquire: true, @@ -158,6 +151,7 @@ impl PoolOptions { // to not flag typical time to add a new connection to a pool. acquire_slow_threshold: Duration::from_secs(2), acquire_timeout: Duration::from_secs(30), + connect_timeout: Duration::from_secs(2 * 60), idle_timeout: Some(Duration::from_secs(10 * 60)), max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, @@ -268,6 +262,23 @@ impl PoolOptions { self.acquire_timeout } + /// Set the maximum amount of time to spend attempting to open a connection. + /// + /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. + /// + /// If shorter than `acquire_timeout`, this will cause the last connec + pub fn connect_timeout(mut self, timeout: Duration) -> Self { + self.connect_timeout = timeout; + self + } + + /// Get the maximum amount of time to spend attempting to open a connection. + /// + /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. + pub fn get_connect_timeout(&self) -> Duration { + self.connect_timeout + } + /// Set the maximum lifetime of individual connections. /// /// Any connection with a lifetime greater than this will be closed. @@ -339,57 +350,6 @@ impl PoolOptions { self } - /// Perform an asynchronous action after connecting to the database. - /// - /// If the operation returns with an error then the error is logged, the connection is closed - /// and a new one is opened in its place and the callback is invoked again. - /// - /// This occurs in a backoff loop to avoid high CPU usage and spamming logs during a transient - /// error condition. - /// - /// Note that this may be called for internally opened connections, such as when maintaining - /// [`min_connections`][Self::min_connections], that are then immediately returned to the pool - /// without invoking [`after_release`][Self::after_release]. - /// - /// # Example: Additional Parameters - /// This callback may be used to set additional configuration parameters - /// that are not exposed by the database's `ConnectOptions`. - /// - /// This example is written for PostgreSQL but can likely be adapted to other databases. - /// - /// ```no_run - /// # async fn f() -> Result<(), Box> { - /// use sqlx::Executor; - /// use sqlx::postgres::PgPoolOptions; - /// - /// let pool = PgPoolOptions::new() - /// .after_connect(|conn, _meta| Box::pin(async move { - /// // When directly invoking `Executor` methods, - /// // it is possible to execute multiple statements with one call. - /// conn.execute("SET application_name = 'your_app'; SET search_path = 'my_schema';") - /// .await?; - /// - /// Ok(()) - /// })) - /// .connect("postgres:// …").await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// For a discussion on why `Box::pin()` is required, see [the type-level docs][Self]. - pub fn after_connect(mut self, callback: F) -> Self - where - // We're passing the `PoolConnectionMetadata` here mostly for future-proofing. - // `age` and `idle_for` are obviously not useful for fresh connections. - for<'c> F: Fn(&'c mut DB::Connection, PoolConnectionMetadata) -> BoxFuture<'c, Result<(), Error>> - + 'static - + Send - + Sync, - { - self.after_connect = Some(Arc::new(callback)); - self - } - /// Perform an asynchronous action on a previously idle connection before giving it out. /// /// Alongside the connection, the closure gets [`PoolConnectionMetadata`] which contains @@ -537,11 +497,25 @@ impl PoolOptions { pub async fn connect_with( self, options: ::Options, + ) -> Result, Error> { + self.connect_with_connector(DefaultConnector(options)).await + } + + /// Create a new pool from this `PoolOptions` and immediately open at least one connection. + /// + /// This ensures the configuration is correct. + /// + /// The total number of connections opened is max(1, [min_connections][Self::min_connections]). + /// + /// See [PoolConnector] for examples. + pub async fn connect_with_connector( + self, + connector: impl PoolConnector, ) -> Result, Error> { // Don't take longer than `acquire_timeout` starting from when this is called. let deadline = Instant::now() + self.acquire_timeout; - let inner = PoolInner::new_arc(self, options); + let inner = PoolInner::new_arc(self, connector); if inner.options.min_connections > 0 { // If the idle reaper is spawned then this will race with the call from that task @@ -552,7 +526,7 @@ impl PoolOptions { // If `min_connections` is nonzero then we'll likely just pull a connection // from the idle queue here, but it should at least get tested first. let conn = inner.acquire().await?; - inner.release(conn); + inner.release(conn.into_floating()); Ok(Pool(inner)) } @@ -578,7 +552,11 @@ impl PoolOptions { /// optimistically establish that many connections for the pool. pub fn connect_lazy_with(self, options: ::Options) -> Pool { // `min_connections` is guaranteed by the idle reaper now. - Pool(PoolInner::new_arc(self, options)) + self.connect_lazy_with_connector(DefaultConnector(options)) + } + + pub fn connect_lazy_with_connector(self, connector: impl PoolConnector) -> Pool { + Pool(PoolInner::new_arc(self, connector)) } } diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 273a1bfcd9..1da096d93b 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -2,7 +2,7 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; use std::task::{Context, Poll}; -use std::time::Duration; +use std::time::{Duration, Instant}; use cfg_if::cfg_if; @@ -51,6 +51,29 @@ pub async fn timeout(duration: Duration, f: F) -> Result(deadline: Instant, f: F) -> Result { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return tokio::time::timeout_at(deadline.into(), f) + .await + .map_err(|_| TimeoutError(())); + } + + #[cfg(feature = "_rt-async-std")] + { + let Some(duration) = deadline.checked_duration_since(Instant::now()) else { + return Err(TimeoutError(())); + }; + + async_std::future::timeout(duration, f) + .await + .map_err(|_| TimeoutError(())) + } + + #[cfg(not(feature = "_rt-async-std"))] + missing_rt((duration, f)) +} + pub async fn sleep(duration: Duration) { #[cfg(feature = "_rt-tokio")] if rt_tokio::available() { diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index f509f9da45..c27dda3ccd 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -1,5 +1,4 @@ use std::future::Future; -use std::ops::Deref; use std::str::FromStr; use std::sync::OnceLock; use std::time::Duration; @@ -108,27 +107,11 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) - .connect_lazy_with(master_opts); - - let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) { - Ok(inserted) => inserted, - Err((existing, pool)) => { - // Sanity checks. - assert_eq!( - existing.connect_options().host, - pool.connect_options().host, - "DATABASE_URL changed at runtime, host differs" - ); - - assert_eq!( - existing.connect_options().database, - pool.connect_options().database, - "DATABASE_URL changed at runtime, database differs" - ); - - existing - } - }; + .connect_lazy_with(master_opts.clone()); + + let master_pool = MASTER_POOL + .try_insert(pool) + .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; @@ -144,7 +127,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { -- BLOB/TEXT columns can only be used as index keys with a prefix length: -- https://dev.mysql.com/doc/refman/8.4/en/column-indexes.html#column-indexes-prefix primary key(db_name(63)) - ); + ); "#, ) .await?; @@ -172,11 +155,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), - connect_opts: master_pool - .connect_options() - .deref() - .clone() - .database(&db_name), + connect_opts: master_opts.database(&db_name), db_name, }) } diff --git a/sqlx-postgres/src/error.rs b/sqlx-postgres/src/error.rs index 7b5a03f2b3..ac99bd13cf 100644 --- a/sqlx-postgres/src/error.rs +++ b/sqlx-postgres/src/error.rs @@ -186,7 +186,7 @@ impl DatabaseError for PgDatabaseError { self } - fn is_transient_in_connect_phase(&self) -> bool { + fn is_retryable_connect_error(&self) -> bool { // https://www.postgresql.org/docs/current/errcodes-appendix.html [ // too_many_connections diff --git a/sqlx-postgres/src/options/mod.rs b/sqlx-postgres/src/options/mod.rs index efbc43989b..00ae159759 100644 --- a/sqlx-postgres/src/options/mod.rs +++ b/sqlx-postgres/src/options/mod.rs @@ -127,6 +127,11 @@ impl PgConnectOptions { self } + /// Identical to [Self::host()], but through a mutable reference. + pub fn set_host(&mut self, host: &str) { + host.clone_into(&mut self.host); + } + /// Sets the port to connect to at the server host. /// /// The default port for PostgreSQL is `5432`. @@ -143,6 +148,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::port()`], but through a mutable reference. + pub fn set_port(&mut self, port: u16) -> &mut Self { + self.port = port; + self + } + /// Sets a custom path to a directory containing a unix domain socket, /// switching the connection method from TCP to the corresponding socket. /// @@ -169,6 +180,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::username()`], but through a mutable reference. + pub fn set_username(&mut self, username: &str) -> &mut Self { + username.clone_into(&mut self.username); + self + } + /// Sets the password to use if the server demands password authentication. /// /// # Example @@ -184,6 +201,12 @@ impl PgConnectOptions { self } + /// Identical to [`Self::password()`]. but through a mutable reference. + pub fn set_password(&mut self, password: &str) -> &mut Self { + self.password = Some(password.to_owned()); + self + } + /// Sets the database name. Defaults to be the same as the user name. /// /// # Example diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index 3e1cf0ddf7..a7f6a54944 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,5 +1,4 @@ use std::future::Future; -use std::ops::Deref; use std::str::FromStr; use std::sync::OnceLock; use std::time::Duration; @@ -101,27 +100,11 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .max_connections(20) // Immediately close master connections. Tokio's I/O streams don't like hopping runtimes. .after_release(|_conn, _| Box::pin(async move { Ok(false) })) - .connect_lazy_with(master_opts); - - let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) { - Ok(inserted) => inserted, - Err((existing, pool)) => { - // Sanity checks. - assert_eq!( - existing.connect_options().host, - pool.connect_options().host, - "DATABASE_URL changed at runtime, host differs" - ); - - assert_eq!( - existing.connect_options().database, - pool.connect_options().database, - "DATABASE_URL changed at runtime, database differs" - ); - - existing - } - }; + .connect_lazy_with(master_opts.clone()); + + let master_pool = MASTER_POOL + .try_insert(pool) + .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; @@ -177,11 +160,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { // Close connections ASAP if left in the idle queue. .idle_timeout(Some(Duration::from_secs(1))) .parent(master_pool.clone()), - connect_opts: master_pool - .connect_options() - .deref() - .clone() - .database(&db_name), + connect_opts: master_opts.database(&db_name), db_name, }) } From dd17a3d98b8c21f8fbe1169be0184969e9973b0d Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 11:56:24 -0800 Subject: [PATCH 04/24] fix(pool): spawn task for `before_acquire` --- sqlx-core/src/pool/inner.rs | 117 ++++++++++++++++++++---------------- 1 file changed, 66 insertions(+), 51 deletions(-) diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index c366bb07f1..b2c9581dc3 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -14,7 +14,9 @@ use std::task::ready; use crate::logger::private_level_filter_to_trace_level; use crate::pool::connect::{ConnectPermit, ConnectionCounter, DynConnector}; use crate::pool::idle::IdleQueue; -use crate::private_tracing_dynamic_event; +use crate::rt::JoinHandle; +use crate::{private_tracing_dynamic_event, rt}; +use either::Either; use futures_util::future::{self, OptionFuture}; use futures_util::FutureExt; use std::time::{Duration, Instant}; @@ -117,7 +119,7 @@ impl PoolInner { let mut close_event = pin!(self.close_event()); let mut deadline = pin!(crate::rt::sleep(self.options.acquire_timeout)); let mut acquire_idle = pin!(self.idle.acquire(self).fuse()); - let mut check_idle = pin!(OptionFuture::from(None)); + let mut before_acquire = OptionFuture::from(None); let mut acquire_connect_permit = pin!(OptionFuture::from(Some( self.counter.acquire_permit(self).fuse() ))); @@ -145,21 +147,25 @@ impl PoolInner { // Attempt to acquire a connection from the idle queue. if let Ready(idle) = acquire_idle.poll_unpin(cx) { - check_idle.set(Some(check_idle_conn(idle, &self.options)).into()); + // If we acquired an idle connection, run any checks that need to be done. + // + // Includes `test_on_acquire` and the `before_acquire` callback, if set. + match finish_acquire(idle) { + // There are checks needed to be done, so they're spawned as a task + // to be cancellation-safe. + Either::Left(check_task) => { + before_acquire = Some(check_task).into(); + } + // The connection is ready to go. + Either::Right(conn) => { + return Ready(Ok(conn)); + } + } } - // If we acquired an idle connection, run any checks that need to be done. - // - // Includes `test_on_acquire` and the `before_acquire` callback, if set. - // - // We don't want to race this step if it's already running because canceling it - // will result in the potentially unnecessary closure of a connection. - // - // Instead, we just wait and see what happens. If we already started connecting, - // that'll happen concurrently. - match ready!(check_idle.poll_unpin(cx)) { - // The `.reattach()` call errors with "type annotations needed" if not qualified. - Some(Ok(live)) => return Ready(Ok(Floating::reattach(live))), + // Poll the task returned by `finish_acquire` + match ready!(before_acquire.poll_unpin(cx)) { + Some(Ok(conn)) => return Ready(Ok(conn)), Some(Err(permit)) => { // We don't strictly need to poll `connect` here; all we really want to do // is to check if it is `None`. But since currently there's no getter for that, @@ -179,7 +185,7 @@ impl PoolInner { // Attempt to acquire another idle connection concurrently to opening a new one. acquire_idle.set(self.idle.acquire(self).fuse()); // Annoyingly, `OptionFuture` doesn't fuse to `None` on its own - check_idle.set(None.into()); + before_acquire = None.into(); } None => (), } @@ -290,42 +296,51 @@ fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions timeout) } -async fn check_idle_conn( - mut conn: Floating>, - options: &PoolOptions, -) -> Result>, ConnectPermit> { - if options.test_before_acquire { - // Check that the connection is still live - if let Err(error) = conn.ping().await { - // an error here means the other end has hung up or we lost connectivity - // either way we're fine to just discard the connection - // the error itself here isn't necessarily unexpected so WARN is too strong - tracing::info!(%error, "ping on idle connection returned error"); - // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); - } - } - - if let Some(test) = &options.before_acquire { - let meta = conn.metadata(); - match test(&mut conn.live.raw, meta).await { - Ok(false) => { - // connection was rejected by user-defined hook, close nicely - return Err(conn.close().await); - } - - Err(error) => { - tracing::warn!(%error, "error from `before_acquire`"); +/// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. +/// +/// Otherwise, immediately returns the connection. +fn finish_acquire( + mut conn: Floating> +) -> Either, ConnectPermit>>, PoolConnection> { + let pool = conn.permit.pool(); + + if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + // Spawn a task so the call may complete even if `acquire()` is cancelled. + return Either::Left(rt::spawn(async move { + // Check that the connection is still live + if let Err(error) = conn.ping().await { + // an error here means the other end has hung up or we lost connectivity + // either way we're fine to just discard the connection + // the error itself here isn't necessarily unexpected so WARN is too strong + tracing::info!(%error, "ping on idle connection returned error"); // connection is broken so don't try to close nicely return Err(conn.close_hard().await); } - Ok(true) => {} - } - } + if let Some(test) = &conn.permit.pool().options.before_acquire { + let meta = conn.metadata(); + match test(&mut conn.inner.live.raw, meta).await { + Ok(false) => { + // connection was rejected by user-defined hook, close nicely + return Err(conn.close().await); + } - // No need to re-connect; connection is alive or we don't care - Ok(conn.into_live()) + Err(error) => { + tracing::warn!(%error, "error from `before_acquire`"); + // connection is broken so don't try to close nicely + return Err(conn.close_hard().await); + } + + Ok(true) => {} + } + } + + Ok(conn.into_live().reattach()) + })); + } + + // No checks are configured, return immediately. + Either::Right(conn.into_live().reattach()) } fn spawn_maintenance_tasks(pool: &Arc>) { @@ -340,7 +355,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { (None, None) => { if pool.options.min_connections > 0 { - crate::rt::spawn(async move { + rt::spawn(async move { if let Some(pool) = pool_weak.upgrade() { pool.min_connections_maintenance(None).await; } @@ -354,7 +369,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { // Immediately cancel this task if the pool is closed. let mut close_event = pool.close_event(); - crate::rt::spawn(async move { + rt::spawn(async move { let _ = close_event .do_until(async { // If the last handle to the pool was dropped while we were sleeping @@ -387,10 +402,10 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if let Some(duration) = next_run.checked_duration_since(Instant::now()) { // `async-std` doesn't have a `sleep_until()` - crate::rt::sleep(duration).await; + rt::sleep(duration).await; } else { // `next_run` is in the past, just yield. - crate::rt::yield_now().await; + rt::yield_now().await; } } }) From b9781c2f10858821637927fc2405ba76c3e1c43c Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 15:50:51 -0800 Subject: [PATCH 05/24] refactor(pool): use a unique ID per connection --- sqlx-core/src/pool/connect.rs | 120 +++++++++++++++++++------------ sqlx-core/src/pool/connection.rs | 28 ++++++-- sqlx-core/src/pool/inner.rs | 29 ++++---- 3 files changed, 110 insertions(+), 67 deletions(-) diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 29e59cafee..187dab9293 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -7,6 +7,7 @@ use crate::rt::JoinHandle; use crate::Error; use ease_off::EaseOff; use event_listener::{Event, EventListener}; +use std::fmt::{Display, Formatter}; use std::future::Future; use std::pin::Pin; use std::ptr; @@ -246,6 +247,7 @@ impl PoolConnector for DefaultConnector { /// Metadata passed to [`PoolConnector::connect()`] for every connection attempt. #[derive(Debug)] +#[non_exhaustive] pub struct PoolConnectMetadata { /// The instant at which the current connection task was started, including all attempts. /// @@ -253,13 +255,16 @@ pub struct PoolConnectMetadata { pub start: Instant, /// The number of attempts that have occurred so far. pub num_attempts: usize, + /// The current size of the pool. pub pool_size: usize, + /// The ID of the connection, unique for the pool. + pub connection_id: ConnectionId, } pub struct DynConnector { // We want to spawn the connection attempt as a task anyway connect: Box< - dyn Fn(ConnectPermit, usize) -> JoinHandle>> + dyn Fn(ConnectionId, ConnectPermit) -> JoinHandle>> + Send + Sync + 'static, @@ -271,53 +276,92 @@ impl DynConnector { let connector = Arc::new(connector); Self { - connect: Box::new(move |permit, size| { - crate::rt::spawn(connect_with_backoff(permit, connector.clone(), size)) + connect: Box::new(move |id, permit| { + crate::rt::spawn(connect_with_backoff(id, permit, connector.clone())) }), } } pub fn connect( &self, + id: ConnectionId, permit: ConnectPermit, - size: usize, ) -> JoinHandle>> { - (self.connect)(permit, size) + (self.connect)(id, permit) } } pub struct ConnectionCounter { - connections: AtomicUsize, + count: AtomicUsize, + next_id: AtomicUsize, connect_available: Event, } +/// An opaque connection ID, unique for every connection attempt with the same pool. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ConnectionId(usize); + impl ConnectionCounter { pub fn new() -> Self { Self { - connections: AtomicUsize::new(0), + count: AtomicUsize::new(0), + next_id: AtomicUsize::new(1), connect_available: Event::new(), } } pub fn connections(&self) -> usize { - self.connections.load(Ordering::Acquire) + self.count.load(Ordering::Acquire) } pub async fn drain(&self) { - while self.connections.load(Ordering::Acquire) > 0 { + while self.count.load(Ordering::Acquire) > 0 { self.connect_available.listen().await; } } + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. + /// + /// Returns the permit, and the ID of the new connection. + pub fn try_acquire_permit( + &self, + pool: &Arc>, + ) -> Option<(ConnectionId, ConnectPermit)> { + debug_assert!(ptr::addr_eq(self, &pool.counter)); + + // Don't skip the queue. + if pool.options.fair && self.connect_available.total_listeners() > 0 { + return None; + } + + let prev_size = self + .count + .fetch_update(Ordering::Release, Ordering::Acquire, |connections| { + (connections < pool.options.max_connections).then_some(connections + 1) + }) + .ok()?; + + let size = prev_size + 1; + + tracing::trace!(target: "sqlx::pool::connect", size, "increased size"); + + Some(( + ConnectionId(self.next_id.fetch_add(1, Ordering::SeqCst)), + ConnectPermit { + pool: Some(Arc::clone(pool)), + }, + )) + } + /// Attempt to acquire a permit from both this instance, and the parent pool, if applicable. /// /// Returns the permit, and the current size of the pool. pub async fn acquire_permit( &self, pool: &Arc>, - ) -> (usize, ConnectPermit) { + ) -> (ConnectionId, ConnectPermit) { // Check that `self` can increase size first before we check the parent. - let (size, permit) = self.acquire_permit_self(pool).await; + let acquired = self.acquire_permit_self(pool).await; if let Some(parent) = &pool.options.parent_pool { let (_, permit) = parent.0.counter.acquire_permit_self(&parent.0).await; @@ -326,7 +370,7 @@ impl ConnectionCounter { permit.consume(); } - (size, permit) + acquired } // Separate method because `async fn`s cannot be recursive. @@ -334,38 +378,13 @@ impl ConnectionCounter { async fn acquire_permit_self( &self, pool: &Arc>, - ) -> (usize, ConnectPermit) { - debug_assert!(ptr::addr_eq(self, &pool.counter)); - - let mut should_wait = pool.options.fair && self.connect_available.total_listeners() > 0; - + ) -> (ConnectionId, ConnectPermit) { for attempt in 1usize.. { - if should_wait { - self.connect_available.listen().await; + if let Some(acquired) = self.try_acquire_permit(pool) { + return acquired; } - let res = self.connections.fetch_update( - Ordering::Release, - Ordering::Acquire, - |connections| { - (connections < pool.options.max_connections).then_some(connections + 1) - }, - ); - - if let Ok(prev_size) = res { - let size = prev_size + 1; - - tracing::trace!(target: "sqlx::pool::connect", size, "increased size"); - - return ( - prev_size + 1, - ConnectPermit { - pool: Some(Arc::clone(pool)), - }, - ); - } - - should_wait = true; + self.connect_available.listen().await; if attempt == 2 { tracing::warn!( @@ -380,7 +399,7 @@ impl ConnectionCounter { pub fn release_permit(&self, pool: &PoolInner) { debug_assert!(ptr::addr_eq(self, &pool.counter)); - self.connections.fetch_sub(1, Ordering::Release); + self.count.fetch_sub(1, Ordering::Release); self.connect_available.notify(1usize); if let Some(parent) = &pool.options.parent_pool { @@ -415,16 +434,22 @@ impl Drop for ConnectPermit { } } +impl Display for ConnectionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self.0, f) + } +} + #[tracing::instrument( target = "sqlx::pool::connect", - skip_all, - fields(connection = size), + skip_all, + fields(%connection_id), err )] async fn connect_with_backoff( + connection_id: ConnectionId, permit: ConnectPermit, connector: Arc>, - size: usize, ) -> crate::Result> { if permit.pool().is_closed() { return Err(Error::PoolClosed); @@ -436,7 +461,8 @@ async fn connect_with_backoff( let meta = PoolConnectMetadata { start: ease_off.started_at(), num_attempts: attempt, - pool_size: size, + pool_size: permit.pool().size(), + connection_id, }; let conn = ease_off @@ -445,7 +471,7 @@ async fn connect_with_backoff( .or_retry_if(|e| can_retry_error(e.inner()))?; if let Some(conn) = conn { - return Ok(Floating::new_live(conn, permit).reattach()); + return Ok(Floating::new_live(conn, connection_id, permit).reattach()); } } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 48d124e38a..55133bf870 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -11,7 +11,7 @@ use crate::database::Database; use crate::error::Error; use super::inner::{is_beyond_max_lifetime, PoolInner}; -use crate::pool::connect::ConnectPermit; +use crate::pool::connect::{ConnectPermit, ConnectionId}; use crate::pool::options::PoolConnectionMetadata; const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); @@ -27,6 +27,7 @@ pub struct PoolConnection { pub(super) struct Live { pub(super) raw: DB::Connection, + pub(super) id: ConnectionId, pub(super) created_at: Instant, } @@ -247,10 +248,11 @@ impl DerefMut for Idle { } impl Floating> { - pub fn new_live(conn: DB::Connection, permit: ConnectPermit) -> Self { + pub fn new_live(conn: DB::Connection, id: ConnectionId, permit: ConnectPermit) -> Self { Self { inner: Live { raw: conn, + id, created_at: Instant::now(), }, permit, @@ -381,17 +383,29 @@ impl Floating> { } } - pub async fn close(self) -> ConnectPermit { + pub async fn close(self) -> (ConnectionId, ConnectPermit) { + let connection_id = self.inner.live.id; + + tracing::debug!(%connection_id, "closing connection (gracefully)"); + if let Err(error) = self.inner.live.raw.close().await { - tracing::debug!(%error, "error occurred while closing the pool connection"); + tracing::debug!( + %connection_id, + %error, + "error occurred while closing the pool connection" + ); } - self.permit + (connection_id, self.permit) } - pub async fn close_hard(self) -> ConnectPermit { + pub async fn close_hard(self) -> (ConnectionId, ConnectPermit) { + let connection_id = self.inner.live.id; + + tracing::debug!(%connection_id, "closing connection (hard)"); + let _ = self.inner.live.raw.close_hard().await; - self.permit + (connection_id, self.permit) } pub fn metadata(&self) -> PoolConnectionMetadata { diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index b2c9581dc3..a700f9fa42 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -12,7 +12,7 @@ use std::sync::Arc; use std::task::ready; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::connect::{ConnectPermit, ConnectionCounter, DynConnector}; +use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector}; use crate::pool::idle::IdleQueue; use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; @@ -166,7 +166,7 @@ impl PoolInner { // Poll the task returned by `finish_acquire` match ready!(before_acquire.poll_unpin(cx)) { Some(Ok(conn)) => return Ready(Ok(conn)), - Some(Err(permit)) => { + Some(Err((id, permit))) => { // We don't strictly need to poll `connect` here; all we really want to do // is to check if it is `None`. But since currently there's no getter for that, // it doesn't really hurt to just poll it here. @@ -175,7 +175,7 @@ impl PoolInner { // If we're not already attempting to connect, // take the permit returned from closing the connection and // attempt to open a new one. - connect = Some(self.connector.connect(permit, self.size())).into(); + connect = Some(self.connector.connect(id, permit)).into(); } // `permit` is dropped in these branches, allowing another task to use it Ready(Some(res)) => return Ready(res), @@ -190,8 +190,8 @@ impl PoolInner { None => (), } - if let Ready(Some((size, permit))) = acquire_connect_permit.poll_unpin(cx) { - connect = Some(self.connector.connect(permit, size)).into(); + if let Ready(Some((id, permit))) = acquire_connect_permit.poll_unpin(cx) { + connect = Some(self.connector.connect(id, permit)).into(); } if let Ready(Some(res)) = connect.poll_unpin(cx) { @@ -237,11 +237,11 @@ impl PoolInner { // // If no extra permits are available then we shouldn't be trying to spin up // connections anyway. - let Some((size, permit)) = self.counter.acquire_permit(self).now_or_never() else { + let Some((id, permit)) = self.counter.acquire_permit(self).now_or_never() else { return Ok(()); }; - let conn = self.connector.connect(permit, size).await?; + let conn = self.connector.connect(id, permit).await?; // We skip `after_release` since the connection was never provided to user code // besides inside `PollConnector::connect()`, if they override it. @@ -297,13 +297,16 @@ fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions( - mut conn: Floating> -) -> Either, ConnectPermit>>, PoolConnection> { + mut conn: Floating>, +) -> Either< + JoinHandle, (ConnectionId, ConnectPermit)>>, + PoolConnection, +> { let pool = conn.permit.pool(); - + if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { // Spawn a task so the call may complete even if `acquire()` is cancelled. return Either::Left(rt::spawn(async move { @@ -334,11 +337,11 @@ fn finish_acquire( Ok(true) => {} } } - + Ok(conn.into_live().reattach()) })); } - + // No checks are configured, return immediately. Either::Right(conn.into_live().reattach()) } From 8a0f1ab2306a1c19fd88d09d58a7354804084001 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 15:51:04 -0800 Subject: [PATCH 06/24] fix(pool): add timeout to `return_to_pool()` --- sqlx-core/src/pool/connection.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 55133bf870..00ffd792a1 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -13,7 +13,10 @@ use crate::error::Error; use super::inner::{is_beyond_max_lifetime, PoolInner}; use crate::pool::connect::{ConnectPermit, ConnectionId}; use crate::pool::options::PoolConnectionMetadata; +use crate::rt; +use std::future::Future; +const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); /// A connection managed by a [`Pool`][crate::pool::Pool]. @@ -149,7 +152,9 @@ impl PoolConnection { async move { let returned_to_pool = if let Some(floating) = floating { - floating.return_to_pool().await + rt::timeout(RETURN_TO_POOL_TIMEOUT, floating.return_to_pool()) + .await + .unwrap_or(false) } else { false }; From 98235e646efadaf3e7eba8488e45e40c0db43e18 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 15:52:53 -0800 Subject: [PATCH 07/24] feat(pool): add more info to `impl Debug for PoolConnection` --- sqlx-core/src/pool/connection.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 00ffd792a1..f4d6e765c6 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -49,8 +49,10 @@ const EXPECT_MSG: &str = "BUG: inner connection already taken!"; impl Debug for PoolConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - // TODO: Show the type name of the connection ? - f.debug_struct("PoolConnection").finish() + f.debug_struct("PoolConnection") + .field("database", &DB::NAME) + .field("id", &self.live.as_ref().map(|live| live.id)) + .finish() } } From e210a23b7286b75d24c7d23df70bbec23ae8c7cf Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 15:56:53 -0800 Subject: [PATCH 08/24] fix: compilation error, warnings --- Cargo.lock | 16 +++++++++++++++- sqlx-core/src/error.rs | 6 +++--- sqlx-core/src/pool/connect.rs | 23 ++++++++++------------- sqlx-core/src/pool/connection.rs | 2 -- sqlx-core/src/pool/inner.rs | 2 +- sqlx-core/src/pool/mod.rs | 10 ---------- sqlx-core/src/rt/mod.rs | 2 +- 7 files changed, 30 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ebd8daba84..3a90f55f76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1242,6 +1242,19 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "ease-off" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20e90ae5e739d99dc0406f9a4e2307a999625e2414d2ecc4fbb4ded8a3945f77" +dependencies = [ + "async-io", + "pin-project", + "rand", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "either" version = "1.15.0" @@ -3535,10 +3548,10 @@ dependencies = [ "chrono", "crc", "crossbeam-queue", + "ease-off", "either", "event-listener 5.4.0", "futures-core", - "futures-intrusive", "futures-io", "futures-util", "hashbrown 0.16.0", @@ -3551,6 +3564,7 @@ dependencies = [ "memchr", "native-tls", "percent-encoding", + "pin-project-lite", "rust_decimal", "rustls", "rustls-native-certs", diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index ff416ed2f7..00b1a64064 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -289,13 +289,13 @@ pub trait DatabaseError: 'static + Send + Sync + StdError { /// For example, the Postgres driver overrides this to return `true` for the following error codes: /// /// * `53300 too_many_connections`: returned when the maximum connections are exceeded - /// on the server. Assumed to be the result of a temporary overcommit - /// (e.g. an extra application replica being spun up to replace one that is going down). + /// on the server. Assumed to be the result of a temporary overcommit + /// (e.g. an extra application replica being spun up to replace one that is going down). /// * This error being consistently logged or returned is a likely indicator of a misconfiguration; /// the sum of [`PoolOptions::max_connections`] for all replicas should not exceed /// the maximum connections allowed by the server. /// * `57P03 cannot_connect_now`: returned when the database server is still starting up - /// and the tcop component is not ready to accept connections yet. + /// and the tcop component is not ready to accept connections yet. fn is_retryable_connect_error(&self) -> bool { false } diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 187dab9293..f1f7ce7d4b 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -1,21 +1,18 @@ use crate::connection::{ConnectOptions, Connection}; use crate::database::Database; -use crate::pool::connection::{Floating, Live}; +use crate::pool::connection::Floating; use crate::pool::inner::PoolInner; use crate::pool::PoolConnection; use crate::rt::JoinHandle; use crate::Error; use ease_off::EaseOff; -use event_listener::{Event, EventListener}; +use event_listener::Event; use std::fmt::{Display, Formatter}; use std::future::Future; -use std::pin::Pin; use std::ptr; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::{Arc, RwLock}; -use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; -use tracing::Instrument; +use std::sync::Arc; +use std::time::Instant; use std::io; @@ -74,7 +71,7 @@ use std::io; /// `set_connect_options` and `get_connect_options` were removed in 0.9.0 because they complicated /// the pool internals. They can be reimplemented by capturing a mutex, or similar, in the callback. /// -/// This example uses Postgres and [`tokio::sync::Mutex`] but may be adapted to any driver +/// This example uses Postgres and [`tokio::sync::RwLock`] but may be adapted to any driver /// or `async-std`, respectively. /// /// ```rust,no_run @@ -197,11 +194,11 @@ pub trait PoolConnector: Send + Sync + 'static { /// /// * [`io::ErrorKind::ConnectionRefused`] /// * Database errors for which - /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] - /// returns `true`. + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. /// * [`Error::PoolConnector`] with `retryable: true`. - /// This error kind is not returned internally and is designed to allow this method to return - /// arbitrary error types not otherwise supported. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. /// /// Manual implementations of this method may also use the signature: /// ```rust,ignore @@ -363,7 +360,7 @@ impl ConnectionCounter { // Check that `self` can increase size first before we check the parent. let acquired = self.acquire_permit_self(pool).await; - if let Some(parent) = &pool.options.parent_pool { + if let Some(parent) = pool.parent() { let (_, permit) = parent.0.counter.acquire_permit_self(&parent.0).await; // consume the parent permit diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index f4d6e765c6..76e4e24b03 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -4,8 +4,6 @@ use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::{Duration, Instant}; -use crate::sync::AsyncSemaphoreReleaser; - use crate::connection::Connection; use crate::database::Database; use crate::error::Error; diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index a700f9fa42..da9b0add35 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -91,7 +91,7 @@ impl PoolInner { } } - fn parent(&self) -> Option<&Pool> { + pub(super) fn parent(&self) -> Option<&Pool> { self.options.parent_pool.as_ref() } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 0caa1161c1..978f101da6 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -59,7 +59,6 @@ use std::future::Future; use std::pin::{pin, Pin}; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use std::time::{Duration, Instant}; use event_listener::EventListener; use futures_core::FusedFuture; @@ -628,15 +627,6 @@ impl FusedFuture for CloseEvent { } } -/// get the time between the deadline and now and use that as our timeout -/// -/// returns `Error::PoolTimedOut` if the deadline is in the past -fn deadline_as_timeout(deadline: Instant) -> Result { - deadline - .checked_duration_since(Instant::now()) - .ok_or(Error::PoolTimedOut) -} - #[test] #[allow(dead_code)] fn assert_pool_traits() { diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 1da096d93b..0044139f55 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -71,7 +71,7 @@ pub async fn timeout_at(deadline: Instant, f: F) -> Result Date: Fri, 8 Nov 2024 17:05:19 -0800 Subject: [PATCH 09/24] chore: delete defunct use of `futures-intrusive` --- sqlx-core/src/sync.rs | 203 +----------------------------------------- 1 file changed, 4 insertions(+), 199 deletions(-) diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index ed082f752c..971752f88f 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -1,206 +1,11 @@ -use cfg_if::cfg_if; - // For types with identical signatures that don't require runtime support, // we can just arbitrarily pick one to use based on what's enabled. // // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. -pub struct AsyncSemaphore { - // We use the semaphore from futures-intrusive as the one from async-lock - // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: - // * https://github.com/smol-rs/async-lock/issues/22 - // * https://github.com/smol-rs/async-lock/issues/23 - // - // We're on the look-out for a replacement, however, as futures-intrusive is not maintained - // and there are some soundness concerns (although it turns out any intrusive future is unsound - // in MIRI due to the necessitated mutable aliasing): - // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] - inner: futures_intrusive::sync::Semaphore, - - #[cfg(feature = "_rt-tokio")] - inner: tokio::sync::Semaphore, -} - -impl AsyncSemaphore { - #[track_caller] - pub fn new(fair: bool, permits: usize) -> Self { - if cfg!(not(any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol", - feature = "_rt-tokio" - ))) { - crate::rt::missing_rt((fair, permits)); - } - - AsyncSemaphore { - #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] - inner: futures_intrusive::sync::Semaphore::new(fair, permits), - #[cfg(feature = "_rt-tokio")] - inner: { - debug_assert!(fair, "Tokio only has fair permits"); - tokio::sync::Semaphore::new(permits) - }, - } - } - - pub fn permits(&self) -> usize { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - self.inner.permits() - } else if #[cfg(feature = "_rt-tokio")] { - self.inner.available_permits() - } else { - crate::rt::missing_rt(()) - } - } - } - - pub async fn acquire(&self, permits: u32) -> AsyncSemaphoreReleaser<'_> { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - AsyncSemaphoreReleaser { - inner: self.inner.acquire(permits as usize).await, - } - } else if #[cfg(feature = "_rt-tokio")] { - AsyncSemaphoreReleaser { - inner: self - .inner - // Weird quirk: `tokio::sync::Semaphore` mostly uses `usize` for permit counts, - // but `u32` for this and `try_acquire_many()`. - .acquire_many(permits) - .await - .expect("BUG: we do not expose the `.close()` method"), - } - } else { - crate::rt::missing_rt(permits) - } - } - } - - pub fn try_acquire(&self, permits: u32) -> Option> { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire(permits as usize)?, - }) - } else if #[cfg(feature = "_rt-tokio")] { - Some(AsyncSemaphoreReleaser { - inner: self.inner.try_acquire_many(permits).ok()?, - }) - } else { - crate::rt::missing_rt(permits) - } - } - } - - pub fn release(&self, permits: usize) { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - self.inner.release(permits); - } else if #[cfg(feature = "_rt-tokio")] { - self.inner.add_permits(permits); - } else { - crate::rt::missing_rt(permits); - } - } - } -} - -pub struct AsyncSemaphoreReleaser<'a> { - // We use the semaphore from futures-intrusive as the one from async-std - // is missing the ability to add arbitrary permits, and is not guaranteed to be fair: - // * https://github.com/smol-rs/async-lock/issues/22 - // * https://github.com/smol-rs/async-lock/issues/23 - // - // We're on the look-out for a replacement, however, as futures-intrusive is not maintained - // and there are some soundness concerns (although it turns out any intrusive future is unsound - // in MIRI due to the necessitated mutable aliasing): - // https://github.com/launchbadge/sqlx/issues/1668 - #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] - inner: futures_intrusive::sync::SemaphoreReleaser<'a>, - - #[cfg(feature = "_rt-tokio")] - inner: tokio::sync::SemaphorePermit<'a>, - - #[cfg(not(any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol", - feature = "_rt-tokio" - )))] - _phantom: std::marker::PhantomData<&'a ()>, -} +#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] +pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; -impl AsyncSemaphoreReleaser<'_> { - pub fn disarm(self) { - cfg_if! { - if #[cfg(all( - any( - feature = "_rt-async-global-executor", - feature = "_rt-async-std", - feature = "_rt-smol" - ), - not(feature = "_rt-tokio") - ))] { - let mut this = self; - this.inner.disarm(); - } else if #[cfg(feature = "_rt-tokio")] { - self.inner.forget(); - } else { - crate::rt::missing_rt(()); - } - } - } -} +#[cfg(feature = "_rt-tokio")] +pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; From eb478cec035bde2c46ddc2ddadd37dd3f3f623d8 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 21:17:33 -0800 Subject: [PATCH 10/24] fix: upgrade `ease-off` --- Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 35ba67013a..25284eec70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -189,8 +189,9 @@ uuid = "1.1.2" # Common utility crates cfg-if = "1.0.0" -dotenvy = { version = "0.15.0", default-features = false } thiserror = { version = "2.0.17", default-features = false, features = ["std"] } +dotenvy = { version = "0.15.7", default-features = false } +ease-off = "0.1.6" # Runtimes [workspace.dependencies.async-global-executor] From 928b25691e50e9aaa1b86aa7a40b0c7d39869dfe Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 8 Nov 2024 21:32:51 -0800 Subject: [PATCH 11/24] fix: tests --- sqlx-core/src/pool/connect.rs | 21 ++++---- tests/any/pool.rs | 90 +++++++++++------------------------ 2 files changed, 42 insertions(+), 69 deletions(-) diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index f1f7ce7d4b..7fdc8c4739 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -33,10 +33,11 @@ use std::io; /// use sqlx::PgConnection; /// use sqlx::postgres::PgPoolOptions; /// use sqlx::Connection; +/// use sqlx::pool::PoolConnectMetadata; /// -/// # async fn _example() -> sqlx::Result<()> { -/// // `PoolConnector` is implemented for closures but has restrictions on returning borrows -/// // due to current language limitations. +/// async fn _example() -> sqlx::Result<()> { +/// // `PoolConnector` is implemented for closures but this has restrictions on returning borrows +/// // due to current language limitations. Custom implementations are not subject to this. /// // /// // This example shows how to get around this using `Arc`. /// let database_url: Arc = "postgres://...".into(); @@ -44,7 +45,8 @@ use std::io; /// let pool = PgPoolOptions::new() /// .min_connections(5) /// .max_connections(30) -/// .connect_with_connector(move |meta| { +/// // Type annotation on the argument is required for the trait impl to reseolve. +/// .connect_with_connector(move |meta: PoolConnectMetadata| { /// let database_url = database_url.clone(); /// async move { /// println!( @@ -57,7 +59,9 @@ use std::io; /// let mut conn = PgConnection::connect(&database_url).await?; /// /// // Override the time zone of the connection. -/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'").await?; +/// sqlx::raw_sql("SET TIME ZONE 'Europe/Berlin'") +/// .execute(&mut conn) +/// .await?; /// /// Ok(conn) /// } @@ -76,13 +80,14 @@ use std::io; /// /// ```rust,no_run /// use std::sync::Arc; -/// use tokio::sync::{Mutex, RwLock}; +/// use tokio::sync::RwLock; /// use sqlx::PgConnection; /// use sqlx::postgres::PgConnectOptions; /// use sqlx::postgres::PgPoolOptions; /// use sqlx::ConnectOptions; +/// use sqlx::pool::PoolConnectMetadata; /// -/// # async fn _example() -> sqlx::Result<()> { +/// async fn _example() -> sqlx::Result<()> { /// // If you do not wish to hold the lock during the connection attempt, /// // you could use `Arc` instead. /// let connect_opts: Arc> = Arc::new(RwLock::new("postgres://...".parse()?)); @@ -90,7 +95,7 @@ use std::io; /// let connect_opts_ = connect_opts.clone(); /// /// let pool = PgPoolOptions::new() -/// .connect_with_connector(move |meta| { +/// .connect_with_connector(move |meta: PoolConnectMetadata| { /// let connect_opts_ = connect_opts.clone(); /// async move { /// println!( diff --git a/tests/any/pool.rs b/tests/any/pool.rs index a4849940b8..1cc0838053 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,45 +1,14 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::connection::ConnectOptions; +use sqlx_core::pool::PoolConnectMetadata; use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ - atomic::{AtomicI32, AtomicUsize, Ordering}, + atomic::{AtomicI32, Ordering}, Arc, Mutex, }; use std::time::Duration; -#[sqlx_macros::test] -async fn pool_should_invoke_after_connect() -> anyhow::Result<()> { - sqlx::any::install_default_drivers(); - - let counter = Arc::new(AtomicUsize::new(0)); - - let pool = AnyPoolOptions::new() - .after_connect({ - let counter = counter.clone(); - move |_conn, _meta| { - let counter = counter.clone(); - Box::pin(async move { - counter.fetch_add(1, Ordering::SeqCst); - - Ok(()) - }) - } - }) - .connect(&dotenvy::var("DATABASE_URL")?) - .await?; - - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - let _ = pool.acquire().await?; - - // since connections are released asynchronously, - // `.after_connect()` may be called more than once - assert!(counter.load(Ordering::SeqCst) >= 1); - - Ok(()) -} - // https://github.com/launchbadge/sqlx/issues/527 #[sqlx_macros::test] async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { @@ -84,38 +53,13 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx_test::setup_if_needed(); - let conn_options: AnyConnectOptions = std::env::var("DATABASE_URL")?.parse()?; + let conn_options: Arc = Arc::new(std::env::var("DATABASE_URL")?.parse()?); let current_id = AtomicI32::new(0); let pool = AnyPoolOptions::new() .max_connections(1) .acquire_timeout(Duration::from_secs(5)) - .after_connect(move |conn, meta| { - assert_eq!(meta.age, Duration::ZERO); - assert_eq!(meta.idle_for, Duration::ZERO); - - let id = current_id.fetch_add(1, Ordering::AcqRel); - - Box::pin(async move { - let statement = format!( - // language=SQL - r#" - CREATE TEMPORARY TABLE conn_stats( - id int primary key, - before_acquire_calls int default 0, - after_release_calls int default 0 - ); - INSERT INTO conn_stats(id) VALUES ({}); - "#, - // Until we have generalized bind parameters - id - ); - - conn.execute(AssertSqlSafe(statement)).await?; - Ok(()) - }) - }) .before_acquire(|conn, meta| { // `age` and `idle_for` should both be nonzero assert_ne!(meta.age, Duration::ZERO); @@ -166,7 +110,31 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { }) }) // Don't establish a connection yet. - .connect_lazy_with(conn_options); + .connect_lazy_with_connector(move |_meta: PoolConnectMetadata| { + let connect_opts = Arc::clone(&conn_options); + let id = current_id.fetch_add(1, Ordering::AcqRel); + + async move { + let mut conn = connect_opts.connect().await?; + + let statement = format!( + // language=SQL + r#" + CREATE TEMPORARY TABLE conn_stats( + id int primary key, + before_acquire_calls int default 0, + after_release_calls int default 0 + ); + INSERT INTO conn_stats(id) VALUES ({}); + "#, + // Until we have generalized bind parameters + id + ); + + conn.execute(&statement[..]).await?; + Ok(conn) + } + }); // Expected pattern of (id, before_acquire_calls, after_release_calls) let pattern = [ From 0d593a039eb7426864dd673cad7ba615cfd7c9e5 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 11 Nov 2024 12:24:25 -0800 Subject: [PATCH 12/24] fix(pool): don't stop emptying idle queue in `.close()` --- sqlx-core/src/pool/inner.rs | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index da9b0add35..ec86c615ff 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -18,7 +18,7 @@ use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; use futures_util::future::{self, OptionFuture}; -use futures_util::FutureExt; +use futures_util::{select, FutureExt}; use std::time::{Duration, Instant}; use tracing::Level; @@ -76,12 +76,18 @@ impl PoolInner { pub(super) fn close(self: &Arc) -> impl Future + '_ { self.mark_closed(); + // Keep clearing the idle queue as connections are released until the count reaches zero. async move { - while let Some(idle) = self.idle.try_acquire(self) { - idle.close().await; + let mut drained = pin!(self.counter.drain()); + + loop { + select! { + idle = self.idle.acquire(self) => { + idle.close().await; + }, + () = drained.as_mut() => break, + } } - - self.counter.drain().await; } } @@ -117,7 +123,7 @@ impl PoolInner { let acquire_started_at = Instant::now(); let mut close_event = pin!(self.close_event()); - let mut deadline = pin!(crate::rt::sleep(self.options.acquire_timeout)); + let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); let mut acquire_idle = pin!(self.idle.acquire(self).fuse()); let mut before_acquire = OptionFuture::from(None); let mut acquire_connect_permit = pin!(OptionFuture::from(Some( @@ -131,6 +137,9 @@ impl PoolInner { // * If we acquire a `ConnectPermit`, we begin the connection loop (with backoff) // as implemented by `DynConnector`. // * If we acquire an idle connection, we then start polling `check_idle_conn()`. + // + // This doesn't quite fit into `select!{}` because the set of futures that may be polled + // at a given time is dynamic, so it's actually simpler to hand-roll it. let acquired = future::poll_fn(|cx| { use std::task::Poll::*; From 18908e62066ce81dbce079d36e05b334cfde4bb2 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 11 Nov 2024 12:27:02 -0800 Subject: [PATCH 13/24] fix(pool): use the correct method in `try_min_connections` --- sqlx-core/src/pool/inner.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index ec86c615ff..48270b989d 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -240,13 +240,13 @@ impl PoolInner { /// Try to maintain `min_connections`, returning any errors (including `PoolTimedOut`). pub async fn try_min_connections(self: &Arc, deadline: Instant) -> Result<(), Error> { - crate::rt::timeout_at(deadline, async { + rt::timeout_at(deadline, async { while self.size() < self.options.min_connections { // Don't wait for a connect permit. // // If no extra permits are available then we shouldn't be trying to spin up // connections anyway. - let Some((id, permit)) = self.counter.acquire_permit(self).now_or_never() else { + let Some((id, permit)) = self.counter.try_acquire_permit(self) else { return Ok(()); }; From 4d73193c3914b6a6532003cbbcb6a2cceb87509e Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 11 Nov 2024 12:29:18 -0800 Subject: [PATCH 14/24] fix(pool): use `.fuse()` --- sqlx-core/src/pool/inner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 48270b989d..4a515342bb 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -78,7 +78,7 @@ impl PoolInner { // Keep clearing the idle queue as connections are released until the count reaches zero. async move { - let mut drained = pin!(self.counter.drain()); + let mut drained = pin!(self.counter.drain()).fuse(); loop { select! { From 782011bec4bb3e8ea10c3c423e3078670752b53e Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Mon, 11 Nov 2024 13:31:47 -0800 Subject: [PATCH 15/24] fix(pool): tweaks and fixes --- sqlx-core/Cargo.toml | 2 +- sqlx-core/src/pool/connect.rs | 16 +++++++++------- sqlx-core/src/pool/idle.rs | 5 ++++- sqlx-core/src/pool/inner.rs | 17 +++++++++++------ 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 8512508cba..0eadf293c5 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -107,7 +107,7 @@ ease-off = { workspace = true, features = ["futures"] } pin-project-lite = "0.2.14" [dev-dependencies] -tokio = { version = "1", features = ["rt"] } +tokio = { version = "1", features = ["rt", "sync"] } [dev-dependencies.sqlx] # FIXME: https://github.com/rust-lang/cargo/issues/15622 diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 7fdc8c4739..ee80591428 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -6,7 +6,7 @@ use crate::pool::PoolConnection; use crate::rt::JoinHandle; use crate::Error; use ease_off::EaseOff; -use event_listener::Event; +use event_listener::{listener, Event}; use std::fmt::{Display, Formatter}; use std::future::Future; use std::ptr; @@ -50,7 +50,7 @@ use std::io; /// let database_url = database_url.clone(); /// async move { /// println!( -/// "opening connection {}, attempt {}; elapsed time: {}", +/// "opening connection {}, attempt {}; elapsed time: {:?}", /// meta.pool_size, /// meta.num_attempts + 1, /// meta.start.elapsed() @@ -96,10 +96,10 @@ use std::io; /// /// let pool = PgPoolOptions::new() /// .connect_with_connector(move |meta: PoolConnectMetadata| { -/// let connect_opts_ = connect_opts.clone(); +/// let connect_opts = connect_opts_.clone(); /// async move { /// println!( -/// "opening connection {}, attempt {}; elapsed time: {}", +/// "opening connection {}, attempt {}; elapsed time: {:?}", /// meta.pool_size, /// meta.num_attempts + 1, /// meta.start.elapsed() @@ -318,7 +318,8 @@ impl ConnectionCounter { pub async fn drain(&self) { while self.count.load(Ordering::Acquire) > 0 { - self.connect_available.listen().await; + listener!(self.connect_available => permit_released); + permit_released.await; } } @@ -386,13 +387,14 @@ impl ConnectionCounter { return acquired; } - self.connect_available.listen().await; - if attempt == 2 { tracing::warn!( "unable to acquire a connect permit after sleeping; this may indicate a bug" ); } + + listener!(self.connect_available => connect_available); + connect_available.await; } panic!("BUG: was never able to acquire a connection despite waking many times") diff --git a/sqlx-core/src/pool/idle.rs b/sqlx-core/src/pool/idle.rs index 239313f7ed..8b07b8e7c4 100644 --- a/sqlx-core/src/pool/idle.rs +++ b/sqlx-core/src/pool/idle.rs @@ -8,6 +8,8 @@ use futures_util::FutureExt; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; +use event_listener::listener; + pub struct IdleQueue { queue: ArrayQueue>, // Keep a separate count because `ArrayQueue::len()` loops until the head and tail pointers @@ -36,7 +38,8 @@ impl IdleQueue { for attempt in 1usize.. { if should_wait { - self.release_event.listen().await; + listener!(self.release_event => release_event); + release_event.await; } if let Some(conn) = self.try_acquire(pool) { diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 4a515342bb..5eb1d203a7 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -18,7 +18,7 @@ use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; use futures_util::future::{self, OptionFuture}; -use futures_util::{select, FutureExt}; +use futures_util::{FutureExt}; use std::time::{Duration, Instant}; use tracing::Level; @@ -78,14 +78,19 @@ impl PoolInner { // Keep clearing the idle queue as connections are released until the count reaches zero. async move { - let mut drained = pin!(self.counter.drain()).fuse(); + let mut drained = pin!(self.counter.drain()); loop { - select! { - idle = self.idle.acquire(self) => { + let mut acquire_idle = pin!(self.idle.acquire(self)); + + // Not using `futures::select!{}` here because it requires a proc-macro dep, + // and frankly it's a little broken. + match future::select(drained.as_mut(), acquire_idle.as_mut()).await { + // *not* `either::Either`; they rolled their own + future::Either::Left(_) => break, + future::Either::Right((idle, _)) => { idle.close().await; - }, - () = drained.as_mut() => break, + } } } } From 29035dd13e36102f39584c83bebcd10ca04c3134 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Thu, 21 Aug 2025 23:41:23 -0700 Subject: [PATCH 16/24] fix: errors after rebasing --- sqlx-core/src/pool/connection.rs | 1 - sqlx-core/src/pool/inner.rs | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 76e4e24b03..2ab315fab7 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -12,7 +12,6 @@ use super::inner::{is_beyond_max_lifetime, PoolInner}; use crate::pool::connect::{ConnectPermit, ConnectionId}; use crate::pool::options::PoolConnectionMetadata; use crate::rt; -use std::future::Future; const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 5eb1d203a7..5a37f17ed1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -4,8 +4,7 @@ use crate::error::Error; use crate::pool::{CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; use std::cmp; -use std::future::{self, Future}; -use std::pin::pin; +use std::future::Future; use std::pin::pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; @@ -18,7 +17,7 @@ use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; use futures_util::future::{self, OptionFuture}; -use futures_util::{FutureExt}; +use futures_util::FutureExt; use std::time::{Duration, Instant}; use tracing::Level; From 4edab50827ae954de6b56f84dc9bb70337c64de4 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Fri, 29 Aug 2025 12:04:47 -0700 Subject: [PATCH 17/24] fix errors after rebase --- sqlx-mysql/src/testing/mod.rs | 3 +-- sqlx-postgres/src/testing/mod.rs | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index c27dda3ccd..f532dcc5a9 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -109,8 +109,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .after_release(|_conn, _| Box::pin(async move { Ok(false) })) .connect_lazy_with(master_opts.clone()); - let master_pool = MASTER_POOL - .try_insert(pool) + let master_pool = once_lock_try_insert_polyfill(&MASTER_POOL, pool) .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index a7f6a54944..70b00b6351 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -102,8 +102,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .after_release(|_conn, _| Box::pin(async move { Ok(false) })) .connect_lazy_with(master_opts.clone()); - let master_pool = MASTER_POOL - .try_insert(pool) + let master_pool = once_lock_try_insert_polyfill(&MASTER_POOL, pool) .unwrap_or_else(|(existing, _pool)| existing); let mut conn = master_pool.acquire().await?; From 7f6d040cdd7271facf73adb07ef60b0f13fe3617 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 19 Aug 2025 18:06:28 -0700 Subject: [PATCH 18/24] feat: create sharding structure for pool --- Cargo.lock | 1 + sqlx-core/Cargo.toml | 4 + sqlx-core/src/pool/mod.rs | 2 + sqlx-core/src/pool/shard.rs | 384 ++++++++++++++++++++++++++++++++++++ 4 files changed, 391 insertions(+) create mode 100644 sqlx-core/src/pool/shard.rs diff --git a/Cargo.lock b/Cargo.lock index 3a90f55f76..d4b7cdc2c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3563,6 +3563,7 @@ dependencies = [ "mac_address", "memchr", "native-tls", + "parking_lot", "percent-encoding", "pin-project-lite", "rust_decimal", diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 0eadf293c5..61c7387a7d 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -106,6 +106,10 @@ thiserror.workspace = true ease-off = { workspace = true, features = ["futures"] } pin-project-lite = "0.2.14" +[dependencies.parking_lot] +version = "0.12.4" +features = ["arc_lock"] + [dev-dependencies] tokio = { version = "1", features = ["rt", "sync"] } diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 978f101da6..84776d0e22 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -90,6 +90,8 @@ mod inner; mod idle; mod options; +mod shard; + /// An asynchronous pool of SQLx database connections. /// /// Create a pool with [Pool::connect] or [Pool::connect_with] and then call [Pool::acquire] diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs new file mode 100644 index 0000000000..242635e133 --- /dev/null +++ b/sqlx-core/src/pool/shard.rs @@ -0,0 +1,384 @@ +use event_listener::{Event, IntoNotification}; +use parking_lot::Mutex; +use std::future::Future; +use std::pin::pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use std::{array, iter}; + +type ShardId = usize; +type ConnectionIndex = usize; + +/// Delay before a task waiting in a call to `acquire()` enters the global wait queue. +/// +/// We want tasks to acquire from their local shards where possible, so they don't enter +/// the global queue immediately. +const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(5); + +pub struct Sharded { + shards: Box<[ArcShard]>, + global: Arc>, +} + +type ArcShard = Arc>>]>>; + +struct Global { + unlock_event: Event>, + disconnect_event: Event>, +} + +type ArcMutexGuard = parking_lot::ArcMutexGuard>; + +pub struct LockGuard { + // `Option` allows us to drop the guard before sending the notification. + // Otherwise, if the receiver wakes too quickly, it might fail to lock the mutex. + locked: Option>, + shard: ArcShard, + index: ConnectionIndex, +} + +// Align to cache lines. +// Simplified from https://docs.rs/crossbeam-utils/0.8.21/src/crossbeam_utils/cache_padded.rs.html#80 +// +// Instead of listing every possible architecture, we just assume 64-bit architectures have 128-byte +// cache lines, which is at least true for newer versions of x86-64 and AArch64. +// A larger alignment isn't harmful as long as we make use of the space. +#[cfg_attr(target_pointer_width = "64", repr(align(128)))] +#[cfg_attr(not(target_pointer_width = "64"), repr(align(64)))] +struct Shard { + shard_id: ShardId, + /// Bitset for all connection indexes that are currently in-use. + locked_set: AtomicUsize, + /// Bitset for all connection indexes that are currently connected. + connected_set: AtomicUsize, + unlock_event: Event>, + disconnect_event: Event>, + global: Arc>, + connections: Ts, +} + +#[derive(Debug)] +struct Params { + shards: usize, + shard_size: usize, + remainder: usize, +} + +const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { + 64 +} else { + usize::BITS as usize +}; + +impl Sharded { + pub fn new(connections: usize, shards: usize) -> Sharded { + let global = Arc::new(Global { + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + }); + + let shards = Params::calc(connections, shards) + .shard_sizes() + .enumerate() + .map(|(shard_id, size)| Shard::new(shard_id, size, global.clone())) + .collect::>(); + + Sharded { shards, global } + } + + pub async fn acquire(&self, connected: bool) -> LockGuard { + let mut acquire_local = + pin!(self.shards[thread_id() % self.shards.len()].acquire(connected)); + + let mut acquire_global = pin!(async { + crate::rt::sleep(GLOBAL_QUEUE_DELAY).await; + + let event_to_listen = if connected { + &self.global.unlock_event + } else { + &self.global.disconnect_event + }; + + event_listener::listener!(event_to_listen => listener); + listener.await + }); + + // Hand-rolled `select!{}` because there isn't a great cross-runtime solution. + // + // `futures_util::select!{}` is a proc-macro. + std::future::poll_fn(|cx| { + if let Poll::Ready(locked) = acquire_local.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + Poll::Pending + }) + .await + } +} + +impl Shard>>]> { + fn new(shard_id: ShardId, len: usize, global: Arc>) -> Arc { + macro_rules! make_array { + ($($n:literal),+) => { + match len { + $($n => Arc::new(Shard { + shard_id, + locked_set: AtomicUsize::new(0), + unlock_event: Event::with_tag(), + connected_set: AtomicUsize::new(0), + disconnect_event: Event::with_tag(), + global, + connections: array::from_fn::<_, $n, _>(|_| Arc::new(Mutex::new(None))) + }),)* + _ => unreachable!("BUG: length not supported: {len}"), + } + } + } + + make_array!( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, + 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64 + ) + } + + async fn acquire(self: &Arc, connected: bool) -> LockGuard { + // Attempt an unfair acquire first, before we modify the waitlist. + if let Some(locked) = self.try_acquire(connected) { + return locked; + } + + let event_to_listen = if connected { + &self.unlock_event + } else { + &self.disconnect_event + }; + + event_listener::listener!(event_to_listen => listener); + + // We need to check again after creating the event listener, + // because in the meantime, a concurrent task may have seen that there were no listeners + // and just unlocked its connection. + if let Some(locked) = self.try_acquire(connected) { + return locked; + } + + listener.await + } + + fn try_acquire(self: &Arc, connected: bool) -> Option> { + let locked_set = self.locked_set.load(Ordering::Acquire); + let connected_set = self.connected_set.load(Ordering::Relaxed); + + let connected_mask = if connected { + connected_set + } else { + !connected_set + }; + + // Choose the first index that is unlocked with bit `connected` + let index = (!locked_set & connected_mask).leading_zeros() as usize; + + self.try_lock(index) + } + + fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { + let locked = self.connections[index].try_lock_arc()?; + + // The locking of the connection itself must use an `Acquire` fence, + // so additional synchronization is unnecessary. + atomic_set(&self.locked_set, index, true, Ordering::Relaxed); + + Some(LockGuard { + locked: Some(locked), + shard: self.clone(), + index, + }) + } +} + +impl Params { + fn calc(connections: usize, mut shards: usize) -> Params { + let mut shard_size = connections / shards; + let mut remainder = connections % shards; + + if shard_size == 0 { + tracing::debug!(connections, shards, "more shards than connections; clamping shard size to 1, shard count to connections"); + shards = connections; + shard_size = 1; + remainder = 0; + } else if shard_size >= MAX_SHARD_SIZE { + let new_shards = connections.div_ceil(MAX_SHARD_SIZE); + + tracing::debug!(connections, shards, "clamping shard count to {new_shards}"); + + shards = new_shards; + shard_size = connections / shards; + remainder = connections % shards; + } + + Params { + shards, + shard_size, + remainder, + } + } + + fn shard_sizes(&self) -> impl Iterator { + iter::repeat_n(self.shard_size + 1, self.remainder).chain(iter::repeat_n( + self.shard_size, + self.shards - self.remainder, + )) + } +} + +fn thread_id() -> usize { + // FIXME: this can be replaced when this is stabilized: + // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + } + + CURRENT_THREAD_ID.with(|i| *i) +} + +impl Drop for LockGuard { + fn drop(&mut self) { + let Some(locked) = self.locked.take() else { + return; + }; + + let connected = locked.is_some(); + + // Updating the connected flag shouldn't require a fence. + atomic_set( + &self.shard.connected_set, + self.index, + connected, + Ordering::Relaxed, + ); + + // If another receiver is waiting for a connection, we can directly pass them the lock. + // + // This prevents drive-by tasks from acquiring connections before waiting tasks + // at high contention, while requiring little synchronization otherwise. + // + // We *could* just pass them the shard ID and/or index, but then we have to handle + // the situation when a receiver was passed a connection that was still marked as locked, + // but was cancelled before it could complete the acquisition. Otherwise, the connection + // would be marked as locked forever, effectively being leaked. + + let mut locked = Some(locked); + + // This is a code smell, but it's necessary because `event-listener` has no way to specify + // that a message should *only* be sent once. This means tags either need to be `Clone` + // or provided by a `FnMut()` closure. + // + // Note that there's no guarantee that this closure won't be called more than once by the + // implementation, but the code as of writing should not. + let mut self_as_tag = || { + let locked = locked + .take() + .expect("BUG: notification sent more than once"); + + LockGuard { + locked: Some(locked), + shard: self.shard.clone(), + index: self.index, + } + }; + + if connected { + // Check for global waiters first. + if self + .shard + .global + .unlock_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 { + return; + } + } else { + if self + .shard + .global + .disconnect_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + + if self + .shard + .disconnect_event + .notify(1.tag_with(&mut self_as_tag)) + > 0 + { + return; + } + } + + // Be sure to drop the lock guard if it's still held, + // *before* we semantically release the lock in the bitset. + // + // Otherwise, another task could check and see the connection is free, + // but then fail to lock the mutex for it. + drop(locked); + + atomic_set(&self.shard.locked_set, self.index, false, Ordering::Release); + } +} + +fn atomic_set(atomic: &AtomicUsize, index: usize, value: bool, ordering: Ordering) { + if value { + let bit = 1 >> index; + atomic.fetch_or(bit, ordering); + } else { + let bit = !(1 >> index); + atomic.fetch_and(bit, ordering); + } +} + +#[cfg(test)] +mod tests { + use super::{Params, MAX_SHARD_SIZE}; + + #[test] + fn test_params() { + for connections in 0..100 { + for shards in 1..32 { + let params = Params::calc(connections, shards); + + let mut sum = 0; + + for (i, size) in params.shard_sizes().enumerate() { + assert!(size <= MAX_SHARD_SIZE, "Params::calc({connections}, {shards}) exceeded MAX_SHARD_SIZE at shard #{i}, size {size}"); + + sum += size; + + assert!(sum <= connections, "Params::calc({connections}, {shards}) exceeded connections at shard #{i}, size {size}"); + } + + assert_eq!( + sum, connections, + "Params::calc({connections}, {shards}) does not add up ({params:?}" + ); + } + } + } +} From 7c5486270b10a11c3a4b233a9a9786f74a6c47c4 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 3 Sep 2025 02:50:35 -0700 Subject: [PATCH 19/24] WIP feat: integrate sharding into pool --- sqlx-core/Cargo.toml | 18 +++++++- sqlx-core/src/pool/connect.rs | 5 +++ sqlx-core/src/pool/inner.rs | 3 ++ sqlx-core/src/pool/options.rs | 81 +++++++++++++++++++++++++++++++++++ sqlx-core/src/pool/shard.rs | 60 +++++++++++++++++++++----- 5 files changed, 154 insertions(+), 13 deletions(-) diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 61c7387a7d..e9f7085934 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -106,8 +106,22 @@ thiserror.workspace = true ease-off = { workspace = true, features = ["futures"] } pin-project-lite = "0.2.14" -[dependencies.parking_lot] -version = "0.12.4" +# N.B. we don't actually utilize spinlocks, we just need a `Mutex` type with a few requirements: +# * Guards that are `Send` (so `parking_lot` and `std::sync` are non-starters) +# * Guards that can use `Arc` and so don't borrow (which is provided by `lock_api`) +# +# Where we actually use this (in `sqlx-core/src/pool/shard.rs`), we don't rely on the mutex itself for anything but +# safe shared mutability. The `Shard` structure has its own synchronization, and only uses `Mutex::try_lock()`. +# +# We *could* use either `tokio::sync::Mutex` or `async_lock::Mutex` for this, but those have all the code for the +# async support, which we don't need. +[dependencies.spin] +version = "0.10.0" +default-features = false +features = ["mutex", "lock_api", "spin_mutex"] + +[dependencies.lock_api] +version = "0.4.13" features = ["arc_lock"] [dev-dependencies] diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index ee80591428..63c8798739 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -220,6 +220,11 @@ pub trait PoolConnector: Send + Sync + 'static { ) -> impl Future> + Send + '_; } +/// # Note: Future Changes (FIXME) +/// This could theoretically be replaced with an impl over `AsyncFn` to allow lending closures, +/// except we have no way to put the `Send` bound on the returned future. +/// +/// We need Return Type Notation for that: https://github.com/rust-lang/rust/pull/138424 impl PoolConnector for F where DB: Database, diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 5a37f17ed1..e3aee6a3af 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -13,6 +13,7 @@ use std::task::ready; use crate::logger::private_level_filter_to_trace_level; use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector}; use crate::pool::idle::IdleQueue; +use crate::pool::shard::Sharded; use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; @@ -24,6 +25,7 @@ use tracing::Level; pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, + pub(super) sharded: Sharded, pub(super) idle: IdleQueue, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, @@ -40,6 +42,7 @@ impl PoolInner { let pool = Self { connector: DynConnector::new(connector), counter: ConnectionCounter::new(), + sharded: Sharded::new(options.max_connections, options.shards), idle: IdleQueue::new(options.fair, options.max_connections), is_closed: AtomicBool::new(false), on_closed: event_listener::Event::new(), diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 9775799fdf..0e8e05b4cb 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -7,6 +7,7 @@ use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; +use std::num::NonZero; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -68,6 +69,7 @@ pub struct PoolOptions { >, >, pub(crate) max_connections: usize, + pub(crate) shards: NonZero, pub(crate) acquire_time_level: LevelFilter, pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, @@ -91,6 +93,7 @@ impl Clone for PoolOptions { before_acquire: self.before_acquire.clone(), after_release: self.after_release.clone(), max_connections: self.max_connections, + shards: self.shards, acquire_time_level: self.acquire_time_level, acquire_slow_threshold: self.acquire_slow_threshold, acquire_slow_level: self.acquire_slow_level, @@ -143,6 +146,7 @@ impl PoolOptions { // A production application will want to set a higher limit than this. max_connections: 10, min_connections: 0, + shards: NonZero::::MIN, // Logging all acquires is opt-in acquire_time_level: LevelFilter::Off, // Default to warning, because an acquire timeout will be an error @@ -206,6 +210,58 @@ impl PoolOptions { self.min_connections } + /// Set the number of shards to split the internal structures into. + /// + /// The default value is dynamically determined based on the configured number of worker threads + /// in the current runtime (if that information is available), + /// or [`std::thread::available_parallelism()`], + /// or 1 otherwise. + /// + /// Each shard is assigned an equal share of [`max_connections`][Self::max_connections] + /// and its own queue of tasks waiting to acquire a connection. + /// + /// Then, when accessing the pool, each thread selects a "local" shard based on its + /// [thread ID][std::thread::Thread::id]1. + /// + /// If the number of shards equals the number of threads (which they do by default), + /// and worker threads are spawned sequentially (which they generally are), + /// each thread should access a different shard, which should significantly reduce + /// cache coherence overhead on multicore systems. + /// + /// If the number of shards does not evenly divide `max_connections`, + /// the implementation makes a best-effort to distribute them as evenly as possible + /// (if `remainder = max_connections % shards` and `remainder != 0`, + /// then `remainder` shards will get one additional connection each). + /// + /// The implementation then clamps the number of connections in a shard to the range `[1, 64]`. + /// + /// ### Details + /// When a task calls [`Pool::acquire()`] (or any other method that calls `acquire()`), + /// it will first attempt to acquire a connection from its thread-local shard, or lock an empty + /// slot to open a new connection (acquiring an idle connection and opening a new connection + /// happen concurrently to minimize acquire time). + /// + /// Failing that, it joins the wait list on the shard. Released connections are passed to + /// waiting tasks in a first-come, first-serve order per shard. + /// + /// If the task cannot acquire a connection after a short delay, + /// it tries to acquire a connection from another shard. + /// + /// If the task _still_ cannot acquire a connection after a longer delay, + /// it joins a global wait list. Tasks in the global wait list are the highest priority + /// for released connections, implementing a kind of eventual fairness. + /// + /// 1 because, as of writing, [`std::thread::ThreadId::as_u64`] is unstable, + /// the current implementation assigns each thread its own sequential ID in a `thread_local!()`. + pub fn shards(mut self, shards: NonZero) -> Self { + self.shards = shards; + self + } + + pub fn get_shards(&self) -> usize { + self.shards.get() + } + /// Enable logging of time taken to acquire a connection from the connection pool via /// [`Pool::acquire()`]. /// @@ -572,3 +628,28 @@ impl Debug for PoolOptions { .finish() } } + +fn default_shards() -> NonZero { + #[cfg(feature = "_rt-tokio")] + if let Ok(rt) = tokio::runtime::Handle::try_current() { + return rt + .metrics() + .num_workers() + .try_into() + .unwrap_or(NonZero::::MIN); + } + + #[cfg(feature = "_rt-async-std")] + if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT") + .ok() + .and_then(|s| s.parse()) + { + return val; + } + + if let Ok(val) = std::thread::available_parallelism() { + return val; + } + + NonZero::::MIN +} diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index 242635e133..a0bcee223b 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -1,6 +1,6 @@ use event_listener::{Event, IntoNotification}; -use parking_lot::Mutex; use std::future::Future; +use std::num::NonZero; use std::pin::pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -8,6 +8,8 @@ use std::task::Poll; use std::time::Duration; use std::{array, iter}; +use spin::lock_api::Mutex; + type ShardId = usize; type ConnectionIndex = usize; @@ -15,7 +17,11 @@ type ConnectionIndex = usize; /// /// We want tasks to acquire from their local shards where possible, so they don't enter /// the global queue immediately. -const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(5); +const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(10); + +/// Delay before attempting to acquire from a non-local shard, +/// as well as the backoff when iterating through shards. +const NON_LOCAL_ACQUIRE_DELAY: Duration = Duration::from_micros(100); pub struct Sharded { shards: Box<[ArcShard]>, @@ -29,11 +35,10 @@ struct Global { disconnect_event: Event>, } -type ArcMutexGuard = parking_lot::ArcMutexGuard>; +type ArcMutexGuard = lock_api::ArcMutexGuard, Option>; pub struct LockGuard { - // `Option` allows us to drop the guard before sending the notification. - // Otherwise, if the receiver wakes too quickly, it might fail to lock the mutex. + // `Option` allows us to take the guard in the drop handler. locked: Option>, shard: ArcShard, index: ConnectionIndex, @@ -73,13 +78,13 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { }; impl Sharded { - pub fn new(connections: usize, shards: usize) -> Sharded { + pub fn new(connections: usize, shards: NonZero) -> Sharded { let global = Arc::new(Global { unlock_event: Event::with_tag(), disconnect_event: Event::with_tag(), }); - let shards = Params::calc(connections, shards) + let shards = Params::calc(connections, shards.get()) .shard_sizes() .enumerate() .map(|(shard_id, size)| Shard::new(shard_id, size, global.clone())) @@ -89,8 +94,28 @@ impl Sharded { } pub async fn acquire(&self, connected: bool) -> LockGuard { - let mut acquire_local = - pin!(self.shards[thread_id() % self.shards.len()].acquire(connected)); + if self.shards.len() == 1 { + return self.shards[0].acquire(connected).await; + } + + let thread_id = current_thread_id(); + + let mut acquire_local = pin!(self.shards[thread_id % self.shards.len()].acquire(connected)); + + let mut acquire_nonlocal = pin!(async { + let mut next_shard = thread_id; + + loop { + crate::rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; + + // Choose shards pseudorandomly by multiplying with a (relatively) large prime. + next_shard = (next_shard.wrapping_mul(547)) % self.shards.len(); + + if let Some(locked) = self.shards[next_shard].try_acquire(connected) { + return locked; + } + } + }); let mut acquire_global = pin!(async { crate::rt::sleep(GLOBAL_QUEUE_DELAY).await; @@ -113,6 +138,10 @@ impl Sharded { return Poll::Ready(locked); } + if let Poll::Ready(locked) = acquire_nonlocal.as_mut().poll(cx) { + return Poll::Ready(locked); + } + if let Poll::Ready(locked) = acquire_global.as_mut().poll(cx) { return Poll::Ready(locked); } @@ -125,6 +154,9 @@ impl Sharded { impl Shard>>]> { fn new(shard_id: ShardId, len: usize, global: Arc>) -> Arc { + // There's no way to create DSTs like this, in `std::sync::Arc`, on stable. + // + // Instead, we coerce from an array. macro_rules! make_array { ($($n:literal),+) => { match len { @@ -206,6 +238,8 @@ impl Shard>>]> { impl Params { fn calc(connections: usize, mut shards: usize) -> Params { + assert_ne!(shards, 0); + let mut shard_size = connections / shards; let mut remainder = connections % shards; @@ -217,7 +251,11 @@ impl Params { } else if shard_size >= MAX_SHARD_SIZE { let new_shards = connections.div_ceil(MAX_SHARD_SIZE); - tracing::debug!(connections, shards, "clamping shard count to {new_shards}"); + tracing::debug!( + connections, + shards, + "shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}" + ); shards = new_shards; shard_size = connections / shards; @@ -239,7 +277,7 @@ impl Params { } } -fn thread_id() -> usize { +fn current_thread_id() -> usize { // FIXME: this can be replaced when this is stabilized: // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 static THREAD_ID: AtomicUsize = AtomicUsize::new(0); From 23643d7fe256f3e60d86c8a8176c828c16678af7 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 23 Sep 2025 08:00:00 -0700 Subject: [PATCH 20/24] chore: `Cargo.lock` after rebase --- Cargo.lock | 22 +- Cargo.toml | 1 + sqlx-core/Cargo.toml | 13 +- sqlx-core/src/error.rs | 9 +- sqlx-core/src/pool/connect.rs | 311 ++++++++++-- sqlx-core/src/pool/connection.rs | 395 ++++++--------- sqlx-core/src/pool/idle.rs | 4 +- sqlx-core/src/pool/inner.rs | 384 +++++++------- sqlx-core/src/pool/mod.rs | 4 +- sqlx-core/src/pool/options.rs | 26 +- sqlx-core/src/pool/shard.rs | 476 +++++++++++++++--- sqlx-core/src/rt/mod.rs | 23 +- sqlx-core/src/rt/rt_async_io/mod.rs | 4 +- .../rt/rt_async_io/{timeout.rs => time.rs} | 18 +- sqlx-core/src/sync.rs | 51 +- 15 files changed, 1126 insertions(+), 615 deletions(-) rename sqlx-core/src/rt/rt_async_io/{timeout.rs => time.rs} (54%) diff --git a/Cargo.lock b/Cargo.lock index d4b7cdc2c0..7df78920e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -272,9 +272,9 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.4.0" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" dependencies = [ "event-listener 5.4.0", "event-listener-strategy", @@ -1392,7 +1392,7 @@ checksum = "da0e4dd2a88388a1f4ccc7c9ce104604dab68d9f408dc34cd45823d5a9069095" dependencies = [ "futures-core", "futures-sink", - "spin", + "spin 0.9.8", ] [[package]] @@ -2083,7 +2083,7 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin", + "spin 0.9.8", ] [[package]] @@ -3465,6 +3465,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spin" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -3537,6 +3546,7 @@ dependencies = [ "async-fs", "async-global-executor 3.1.0", "async-io", + "async-lock", "async-std", "async-task", "base64 0.22.1", @@ -3559,13 +3569,14 @@ dependencies = [ "indexmap 2.10.0", "ipnet", "ipnetwork", + "lock_api", "log", "mac_address", "memchr", "native-tls", - "parking_lot", "percent-encoding", "pin-project-lite", + "rand", "rust_decimal", "rustls", "rustls-native-certs", @@ -3574,6 +3585,7 @@ dependencies = [ "sha2", "smallvec", "smol", + "spin 0.10.0", "sqlx", "thiserror 2.0.17", "time", diff --git a/Cargo.toml b/Cargo.toml index 25284eec70..c3cb86c8af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -192,6 +192,7 @@ cfg-if = "1.0.0" thiserror = { version = "2.0.17", default-features = false, features = ["std"] } dotenvy = { version = "0.15.7", default-features = false } ease-off = "0.1.6" +rand = "0.8.5" # Runtimes [workspace.dependencies.async-global-executor] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index e9f7085934..dc03c192de 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -20,11 +20,12 @@ any = [] json = ["serde", "serde_json"] # for conditional compilation -_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-task"] -_rt-async-io = ["async-io", "async-fs"] # see note at async-fs declaration -_rt-async-std = ["async-std", "_rt-async-io", "ease-off/async-io-2"] +_rt-async-global-executor = ["async-global-executor", "_rt-async-io", "_rt-async-lock", "_rt-async-task"] +_rt-async-io = ["async-io", "async-fs", "ease-off/async-io-2"] # see note at async-fs declaration +_rt-async-lock = ["async-lock"] +_rt-async-std = ["async-std", "_rt-async-io", "_rt-async-lock"] _rt-async-task = ["async-task"] -_rt-smol = ["smol", "_rt-async-io", "_rt-async-task"] +_rt-smol = ["smol", "_rt-async-io", "_rt-async-lock", "_rt-async-task"] _rt-tokio = ["tokio", "tokio-stream", "ease-off/tokio"] _tls-native-tls = ["native-tls"] _tls-rustls-aws-lc-rs = ["_tls-rustls", "rustls/aws-lc-rs", "webpki-roots"] @@ -72,6 +73,7 @@ uuid = { workspace = true, optional = true } # work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281 async-fs = { version = "2.1", optional = true } async-io = { version = "2.4.1", optional = true } +async-lock = { version = "3.4.1", optional = true } async-task = { version = "4.7.1", optional = true } base64 = { version = "0.22.0", default-features = false, features = ["std"] } @@ -101,9 +103,10 @@ indexmap = "2.0" event-listener = "5.2.0" hashbrown = "0.16.0" +rand.workspace = true thiserror.workspace = true -ease-off = { workspace = true, features = ["futures"] } +ease-off = { workspace = true, default-features = false } pin-project-lite = "0.2.14" # N.B. we don't actually utilize spinlocks, we just need a `Mutex` type with a few requirements: diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 00b1a64064..8dfcc92a99 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -1,12 +1,12 @@ //! Types for working with errors produced by SQLx. +use crate::database::Database; use std::any::type_name; use std::borrow::Cow; use std::error::Error as StdError; use std::fmt::Display; use std::io; - -use crate::database::Database; +use std::sync::Arc; use crate::type_info::TypeInfo; use crate::types::Type; @@ -104,7 +104,10 @@ pub enum Error { /// /// [`Pool::acquire`]: crate::pool::Pool::acquire #[error("pool timed out while waiting for an open connection")] - PoolTimedOut, + PoolTimedOut { + #[source] + last_connect_error: Option>, + }, /// [`Pool::close`] was called while we were waiting in [`Pool::acquire`]. /// diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 63c8798739..52920c6a67 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -1,20 +1,30 @@ use crate::connection::{ConnectOptions, Connection}; use crate::database::Database; -use crate::pool::connection::Floating; +use crate::pool::connection::ConnectionInner; use crate::pool::inner::PoolInner; -use crate::pool::PoolConnection; +use crate::pool::{Pool, PoolConnection}; use crate::rt::JoinHandle; -use crate::Error; +use crate::{rt, Error}; use ease_off::EaseOff; -use event_listener::{listener, Event}; +use event_listener::{listener, Event, EventListener}; use std::fmt::{Display, Formatter}; use std::future::Future; use std::ptr; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, Mutex, RwLock}; use std::time::Instant; +use crate::pool::shard::DisconnectedSlot; +#[cfg(doc)] +use crate::pool::PoolOptions; +use crate::sync::{AsyncMutex, AsyncMutexGuard}; +use ease_off::core::EaseOffCore; use std::io; +use std::ops::ControlFlow; +use std::pin::{pin, Pin}; +use std::task::{ready, Context, Poll}; + +const EASE_OFF: EaseOffCore = ease_off::Options::new().into_core(); /// Custom connect callback for [`Pool`][crate::pool::Pool]. /// @@ -197,7 +207,7 @@ pub trait PoolConnector: Send + Sync + 'static { /// If this method returns an error that is known to be retryable, it is called again /// in an exponential backoff loop. Retryable errors include, but are not limited to: /// - /// * [`io::ErrorKind::ConnectionRefused`] + /// * [`io::Error`] /// * Database errors for which /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] /// returns `true`. @@ -205,6 +215,8 @@ pub trait PoolConnector: Send + Sync + 'static { /// This error kind is not returned internally and is designed to allow this method to return /// arbitrary error types not otherwise supported. /// + /// This behavior may be customized by overriding [`Self::connect_with_control_flow()`]. + /// /// Manual implementations of this method may also use the signature: /// ```rust,ignore /// async fn connect( @@ -218,6 +230,54 @@ pub trait PoolConnector: Send + Sync + 'static { &self, meta: PoolConnectMetadata, ) -> impl Future> + Send + '_; + + /// Open a connection for the pool, or indicate what to do on an error. + /// + /// This method may return one of the following: + /// + /// * `ControlFlow::Break(Ok(_))` with a successfully established connection. + /// * `ControlFlow::Break(Err(_))` with an error to immediately return to the caller. + /// * `ControlFlow::Continue(_)` with a retryable error. + /// The pool will call this method again in an exponential backoff loop until it succeeds, + /// or the [connect timeout][PoolOptions::connect_timeout] + /// or [acquire timeout][PoolOptions::acquire_timeout] is reached. + /// + /// # Default Implementation + /// This method has a provided implementation by default which calls [`Self::connect()`] + /// and then returns `ControlFlow::Continue` if the error is any of the following: + /// + /// * [`io::Error`] + /// * Database errors for which + /// [`is_retryable_connect_error`][crate::error::DatabaseError::is_retryable_connect_error] + /// returns `true`. + /// * [`Error::PoolConnector`] with `retryable: true`. + /// This error kind is not returned internally and is designed to allow this method to return + /// arbitrary error types not otherwise supported. + /// + /// A custom backoff loop may be implemented by overriding this method and retrying internally, + /// only returning `ControlFlow::Break` if/when an error should be propagated out to the caller. + /// + /// If this method is overridden and does not call [`Self::connect()`], then the implementation + /// of the latter can be a stub. It is not called internally. + fn connect_with_control_flow( + &self, + meta: PoolConnectMetadata, + ) -> impl Future, Error>> + Send + '_ { + async { + match self.connect(meta).await { + Err(err @ Error::Io(_)) => ControlFlow::Continue(err), + Err(Error::Database(dbe)) if dbe.is_retryable_connect_error() => { + ControlFlow::Continue(Error::Database(dbe)) + } + Err( + err @ Error::PoolConnector { + retryable: true, .. + }, + ) => ControlFlow::Continue(err), + res => ControlFlow::Break(res), + } + } + } } /// # Note: Future Changes (FIXME) @@ -260,8 +320,12 @@ pub struct PoolConnectMetadata { /// /// May be used for reporting purposes, or to implement a custom backoff. pub start: Instant, + + /// The deadline (`start` plus the [connect timeout][PoolOptions::connect_timeout], if set). + pub deadline: Option, + /// The number of attempts that have occurred so far. - pub num_attempts: usize, + pub num_attempts: u32, /// The current size of the pool. pub pool_size: usize, /// The ID of the connection, unique for the pool. @@ -271,7 +335,12 @@ pub struct PoolConnectMetadata { pub struct DynConnector { // We want to spawn the connection attempt as a task anyway connect: Box< - dyn Fn(ConnectionId, ConnectPermit) -> JoinHandle>> + dyn Fn( + Pool, + ConnectionId, + DisconnectedSlot>, + Arc, + ) -> ConnectTask + Send + Sync + 'static, @@ -283,18 +352,90 @@ impl DynConnector { let connector = Arc::new(connector); Self { - connect: Box::new(move |id, permit| { - crate::rt::spawn(connect_with_backoff(id, permit, connector.clone())) + connect: Box::new(move |pool, id, guard, shared| { + ConnectTask::spawn(pool, id, guard, connector.clone(), shared) }), } } pub fn connect( &self, + pool: Pool, id: ConnectionId, - permit: ConnectPermit, - ) -> JoinHandle>> { - (self.connect)(id, permit) + slot: DisconnectedSlot>, + shared: Arc, + ) -> ConnectTask { + (self.connect)(pool, id, slot, shared) + } +} + +pub struct ConnectTask { + handle: JoinHandle>>, + shared: Arc, +} + +pub struct ConnectTaskShared { + cancel_event: Event, + // Using the normal `std::sync::Mutex` because the critical sections are very short; + // we only hold the lock long enough to insert or take the value. + last_error: Mutex>, +} + +impl Future for ConnectTask { + type Output = crate::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.handle).poll(cx) + } +} + +impl ConnectTask { + fn spawn( + pool: Pool, + id: ConnectionId, + guard: DisconnectedSlot>, + connector: Arc>, + shared: Arc, + ) -> Self { + let handle = crate::rt::spawn(connect_with_backoff( + pool, + id, + connector, + guard, + shared.clone(), + )); + + Self { handle, shared } + } + + pub fn cancel(&self) -> Option { + self.shared.cancel_event.notify(1); + + self.shared + .last_error + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + } +} + +impl ConnectTaskShared { + pub fn new_arc() -> Arc { + Arc::new(Self { + cancel_event: Event::new(), + last_error: Mutex::new(None), + }) + } + + pub fn take_error(&self) -> Option { + self.last_error + .lock() + .unwrap_or_else(|e| e.into_inner()) + .take() + } + + fn put_error(&self, error: Error) { + *self.last_error.lock().unwrap_or_else(|e| e.into_inner()) = Some(error); } } @@ -308,6 +449,14 @@ pub struct ConnectionCounter { #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct ConnectionId(usize); +impl ConnectionId { + pub(super) fn next() -> ConnectionId { + static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + + ConnectionId(NEXT_ID.fetch_add(1, Ordering::AcqRel)) + } +} + impl ConnectionCounter { pub fn new() -> Self { Self { @@ -456,41 +605,131 @@ impl Display for ConnectionId { err )] async fn connect_with_backoff( + pool: Pool, connection_id: ConnectionId, - permit: ConnectPermit, connector: Arc>, + slot: DisconnectedSlot>, + shared: Arc, ) -> crate::Result> { - if permit.pool().is_closed() { - return Err(Error::PoolClosed); - } + listener!(pool.0.on_closed => closed); + listener!(shared.cancel_event => cancelled); - let mut ease_off = EaseOff::start_timeout(permit.pool().options.connect_timeout); + let start = Instant::now(); + let deadline = pool + .0 + .options + .connect_timeout + .and_then(|timeout| start.checked_add(timeout)); - for attempt in 1usize.. { + for attempt in 1u32.. { let meta = PoolConnectMetadata { - start: ease_off.started_at(), + start, + deadline, num_attempts: attempt, - pool_size: permit.pool().size(), + pool_size: pool.size(), connection_id, }; - let conn = ease_off - .try_async(connector.connect(meta)) - .await - .or_retry_if(|e| can_retry_error(e.inner()))?; + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds=start.elapsed().as_secs_f64(), + "beginning connection attempt" + ); + + let res = connector.connect_with_control_flow(meta).await; + + let now = Instant::now(); + let elapsed = now.duration_since(start); + let elapsed_seconds = elapsed.as_secs_f64(); + + match res { + ControlFlow::Break(Ok(conn)) => { + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + "connection established", + ); - if let Some(conn) = conn { - return Ok(Floating::new_live(conn, connection_id, permit).reattach()); - } - } + return Ok(PoolConnection::new( + slot.put(ConnectionInner { + raw: conn, + id: connection_id, + created_at: now, + last_released_at: now, + }), + pool.0.clone(), + )); + } + ControlFlow::Break(Err(e)) => { + tracing::warn!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + error=?e, + "error connecting to database", + ); - Err(Error::PoolTimedOut) -} + return Err(e); + } + ControlFlow::Continue(e) => { + tracing::warn!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + error=?e, + "error connecting to database; retrying", + ); + + shared.put_error(e); + } + } -fn can_retry_error(e: &Error) -> bool { - match e { - Error::Io(e) if e.kind() == io::ErrorKind::ConnectionRefused => true, - Error::Database(e) => e.is_retryable_connect_error(), - _ => false, + let wait = EASE_OFF + .nth_retry_at(attempt, now, deadline, &mut rand::thread_rng()) + .map_err(|_| { + Error::PoolTimedOut { + // This should be populated by the caller + last_connect_error: None, + } + })?; + + if let Some(wait) = wait { + tracing::trace!( + target: "sqlx::pool::connect", + %connection_id, + attempt, + elapsed_seconds, + "waiting for {:?}", + wait.duration_since(now), + ); + + let mut sleep = pin!(rt::sleep_until(wait)); + + std::future::poll_fn(|cx| { + if let Poll::Ready(()) = Pin::new(&mut closed).poll(cx) { + return Poll::Ready(Err(Error::PoolClosed)); + } + + if let Poll::Ready(()) = Pin::new(&mut cancelled).poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: None, + })); + } + + ready!(sleep.as_mut().poll(cx)); + Poll::Ready(Ok(())) + }) + .await?; + } } + + Err(Error::PoolTimedOut { + last_connect_error: None, + }) } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 2ab315fab7..8d115818f4 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -9,8 +9,10 @@ use crate::database::Database; use crate::error::Error; use super::inner::{is_beyond_max_lifetime, PoolInner}; -use crate::pool::connect::{ConnectPermit, ConnectionId}; +use crate::pool::connect::{ConnectPermit, ConnectTaskShared, ConnectionId}; use crate::pool::options::PoolConnectionMetadata; +use crate::pool::shard::{ConnectedSlot, DisconnectedSlot}; +use crate::pool::Pool; use crate::rt; const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); @@ -20,26 +22,16 @@ const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); /// /// Will be returned to the pool on-drop. pub struct PoolConnection { - live: Option>, - close_on_drop: bool, + conn: Option>>, pub(crate) pool: Arc>, + close_on_drop: bool, } -pub(super) struct Live { +pub(super) struct ConnectionInner { pub(super) raw: DB::Connection, pub(super) id: ConnectionId, pub(super) created_at: Instant, -} - -pub(super) struct Idle { - pub(super) live: Live, - pub(super) idle_since: Instant, -} - -/// RAII wrapper for connections being handled by functions that may drop them -pub(super) struct Floating { - pub(super) inner: C, - pub(super) permit: ConnectPermit, + pub(super) last_released_at: Instant, } const EXPECT_MSG: &str = "BUG: inner connection already taken!"; @@ -48,7 +40,7 @@ impl Debug for PoolConnection { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("PoolConnection") .field("database", &DB::NAME) - .field("id", &self.live.as_ref().map(|live| live.id)) + .field("id", &self.conn.as_ref().map(|live| live.id)) .finish() } } @@ -57,13 +49,13 @@ impl Deref for PoolConnection { type Target = DB::Connection; fn deref(&self) -> &Self::Target { - &self.live.as_ref().expect(EXPECT_MSG).raw + &self.conn.as_ref().expect(EXPECT_MSG).raw } } impl DerefMut for PoolConnection { fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live.as_mut().expect(EXPECT_MSG).raw + &mut self.conn.as_mut().expect(EXPECT_MSG).raw } } @@ -80,6 +72,14 @@ impl AsMut for PoolConnection { } impl PoolConnection { + pub(super) fn new(live: ConnectedSlot>, pool: Arc>) -> Self { + Self { + conn: Some(live), + close_on_drop: false, + pool, + } + } + /// Close this connection, allowing the pool to open a replacement. /// /// Equivalent to calling [`.detach()`] then [`.close()`], but the connection permit is retained @@ -88,8 +88,8 @@ impl PoolConnection { /// [`.detach()`]: PoolConnection::detach /// [`.close()`]: Connection::close pub async fn close(mut self) -> Result<(), Error> { - let floating = self.take_live().float(self.pool.clone()); - floating.inner.raw.close().await + let (res, _slot) = close(self.take_conn()).await; + res } /// Close this connection on-drop, instead of returning it to the pool. @@ -115,7 +115,8 @@ impl PoolConnection { /// [`max_connections`]: crate::pool::PoolOptions::max_connections /// [`min_connections`]: crate::pool::PoolOptions::min_connections pub fn detach(mut self) -> DB::Connection { - self.take_live().float(self.pool.clone()).detach() + let (conn, _slot) = ConnectedSlot::take(self.take_conn()); + conn.raw } /// Detach this connection from the pool, treating it as permanently checked-out. @@ -124,15 +125,13 @@ impl PoolConnection { /// /// If you don't want to impact the pool's capacity, use [`.detach()`][Self::detach] instead. pub fn leak(mut self) -> DB::Connection { - self.take_live().raw - } - - fn take_live(&mut self) -> Live { - self.live.take().expect(EXPECT_MSG) + let (conn, slot) = ConnectedSlot::take(self.take_conn()); + DisconnectedSlot::leak(slot); + conn.raw } - pub(super) fn into_floating(mut self) -> Floating> { - self.take_live().float(self.pool.clone()) + fn take_conn(&mut self) -> ConnectedSlot> { + self.conn.take().expect(EXPECT_MSG) } /// Test the connection to make sure it is still live before returning it to the pool. @@ -140,48 +139,30 @@ impl PoolConnection { /// This effectively runs the drop handler eagerly instead of spawning a task to do it. #[doc(hidden)] pub fn return_to_pool(&mut self) -> impl Future + Send + 'static { - // float the connection in the pool before we move into the task - // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime - // https://github.com/launchbadge/sqlx/issues/1396 - // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). - let floating: Option>> = - self.live.take().map(|live| live.float(self.pool.clone())); - + let conn = self.conn.take(); let pool = self.pool.clone(); async move { - let returned_to_pool = if let Some(floating) = floating { - rt::timeout(RETURN_TO_POOL_TIMEOUT, floating.return_to_pool()) - .await - .unwrap_or(false) - } else { - false + let Some(conn) = conn else { + return; }; - if !returned_to_pool { - pool.min_connections_maintenance(None).await; - } + rt::timeout(RETURN_TO_POOL_TIMEOUT, return_to_pool(conn, &pool)) + .await + // Dropping of the `slot` will check if the connection must be re-established + // but only after trying to pass it to a task that needs it. + .ok(); } } fn take_and_close(&mut self) -> impl Future + Send + 'static { - // float the connection in the pool before we move into the task - // in case the returned `Future` isn't executed, like if it's spawned into a dying runtime - // https://github.com/launchbadge/sqlx/issues/1396 - // Type hints seem to be broken by `Option` combinators in IntelliJ Rust right now (6/22). - let floating = self.live.take().map(|live| live.float(self.pool.clone())); - - let pool = self.pool.clone(); + let conn = self.conn.take(); async move { - if let Some(floating) = floating { + if let Some(conn) = conn { // Don't hold the connection forever if it hangs while trying to close - crate::rt::timeout(CLOSE_ON_DROP_TIMEOUT, floating.close()) - .await - .ok(); + rt::timeout(CLOSE_ON_DROP_TIMEOUT, close(conn)).await.ok(); } - - pool.min_connections_maintenance(None).await; } } } @@ -214,205 +195,21 @@ impl Drop for PoolConnection { } // We still need to spawn a task to maintain `min_connections`. - if self.live.is_some() || self.pool.options.min_connections > 0 { + if self.conn.is_some() || self.pool.options.min_connections > 0 { crate::rt::spawn(self.return_to_pool()); } } } -impl Live { - pub fn float(self, pool: Arc>) -> Floating { - Floating { - inner: self, - // create a new guard from a previously leaked permit - permit: ConnectPermit::float_existing(pool), - } - } - - pub fn into_idle(self) -> Idle { - Idle { - live: self, - idle_since: Instant::now(), - } - } -} - -impl Deref for Idle { - type Target = Live; - - fn deref(&self) -> &Self::Target { - &self.live - } -} - -impl DerefMut for Idle { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.live - } -} - -impl Floating> { - pub fn new_live(conn: DB::Connection, id: ConnectionId, permit: ConnectPermit) -> Self { - Self { - inner: Live { - raw: conn, - id, - created_at: Instant::now(), - }, - permit, - } - } - - pub fn reattach(self) -> PoolConnection { - let Floating { inner, permit } = self; - - let pool = Arc::clone(permit.pool()); - - permit.consume(); - PoolConnection { - live: Some(inner), - close_on_drop: false, - pool, - } - } - - pub fn release(self) { - self.permit.pool().clone().release(self); - } - - /// Return the connection to the pool. - /// - /// Returns `true` if the connection was successfully returned, `false` if it was closed. - async fn return_to_pool(mut self) -> bool { - // Immediately close the connection. - if self.permit.pool().is_closed() { - self.close().await; - return false; - } - - // If the connection is beyond max lifetime, close the connection and - // immediately create a new connection - if is_beyond_max_lifetime(&self.inner, &self.permit.pool().options) { - self.close().await; - return false; - } - - if let Some(test) = &self.permit.pool().options.after_release { - let meta = self.metadata(); - match (test)(&mut self.inner.raw, meta).await { - Ok(true) => (), - Ok(false) => { - self.close().await; - return false; - } - Err(error) => { - tracing::warn!(%error, "error from `after_release`"); - // Connection is broken, don't try to gracefully close as - // something weird might happen. - self.close_hard().await; - return false; - } - } - } - - // test the connection on-release to ensure it is still viable, - // and flush anything time-sensitive like transaction rollbacks - // if an Executor future/stream is dropped during an `.await` call, the connection - // is likely to be left in an inconsistent state, in which case it should not be - // returned to the pool; also of course, if it was dropped due to an error - // this is simply a band-aid as SQLx-next connections should be able - // to recover from cancellations - if let Err(error) = self.raw.ping().await { - tracing::warn!( - %error, - "error occurred while testing the connection on-release", - ); - - // Connection is broken, don't try to gracefully close. - self.close_hard().await; - false - } else { - // if the connection is still viable, release it to the pool - self.release(); - true - } - } - - pub async fn close(self) { - // This isn't used anywhere that we care about the return value - let _ = self.inner.raw.close().await; - - // `guard` is dropped as intended - } - - pub async fn close_hard(self) { - let _ = self.inner.raw.close_hard().await; - } - - pub fn detach(self) -> DB::Connection { - self.inner.raw - } - - pub fn into_idle(self) -> Floating> { - Floating { - inner: self.inner.into_idle(), - permit: self.permit, - } - } - +impl ConnectionInner { pub fn metadata(&self) -> PoolConnectionMetadata { PoolConnectionMetadata { age: self.created_at.elapsed(), idle_for: Duration::ZERO, } } -} -impl Floating> { - pub fn from_idle(idle: Idle, pool: Arc>) -> Self { - Self { - inner: idle, - permit: ConnectPermit::float_existing(pool), - } - } - - pub async fn ping(&mut self) -> Result<(), Error> { - self.live.raw.ping().await - } - - pub fn into_live(self) -> Floating> { - Floating { - inner: self.inner.live, - permit: self.permit, - } - } - - pub async fn close(self) -> (ConnectionId, ConnectPermit) { - let connection_id = self.inner.live.id; - - tracing::debug!(%connection_id, "closing connection (gracefully)"); - - if let Err(error) = self.inner.live.raw.close().await { - tracing::debug!( - %connection_id, - %error, - "error occurred while closing the pool connection" - ); - } - (connection_id, self.permit) - } - - pub async fn close_hard(self) -> (ConnectionId, ConnectPermit) { - let connection_id = self.inner.live.id; - - tracing::debug!(%connection_id, "closing connection (hard)"); - - let _ = self.inner.live.raw.close_hard().await; - - (connection_id, self.permit) - } - - pub fn metadata(&self) -> PoolConnectionMetadata { + pub fn idle_metadata(&self) -> PoolConnectionMetadata { // Use a single `now` value for consistency. let now = Instant::now(); @@ -420,21 +217,113 @@ impl Floating> { // NOTE: the receiver is the later `Instant` and the arg is the earlier // https://github.com/launchbadge/sqlx/issues/1912 age: now.saturating_duration_since(self.created_at), - idle_for: now.saturating_duration_since(self.idle_since), + idle_for: now.saturating_duration_since(self.last_released_at), } } } -impl Deref for Floating { - type Target = C; +pub(crate) async fn close( + conn: ConnectedSlot>, +) -> (Result<(), Error>, DisconnectedSlot>) { + let connection_id = conn.id; - fn deref(&self) -> &Self::Target { - &self.inner - } + tracing::debug!(target: "sqlx::pool", %connection_id, "closing connection (gracefully)"); + + let (conn, slot) = ConnectedSlot::take(conn); + + let res = conn.raw.close().await.inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); + + (res, slot) +} +pub(crate) async fn close_hard( + conn: ConnectedSlot>, +) -> (Result<(), Error>, DisconnectedSlot>) { + let connection_id = conn.id; + + tracing::debug!( + target: "sqlx::pool", + %connection_id, + "closing connection (forcefully)" + ); + + let (conn, slot) = ConnectedSlot::take(conn); + + let res = conn.raw.close_hard().await.inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); + + (res, slot) } -impl DerefMut for Floating { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner +/// Return the connection to the pool. +/// +/// Returns `true` if the connection was successfully returned, `false` if it was closed. +async fn return_to_pool( + mut conn: ConnectedSlot>, + pool: &PoolInner, +) -> Result<(), DisconnectedSlot>> { + // Immediately close the connection. + if pool.is_closed() { + let (_res, slot) = close(conn).await; + return Err(slot); + } + + // If the connection is beyond max lifetime, close the connection and + // immediately create a new connection + if is_beyond_max_lifetime(&conn, &pool.options) { + let (_res, slot) = close(conn).await; + return Err(slot); + } + + if let Some(test) = &pool.options.after_release { + let meta = conn.metadata(); + match (test)(&mut conn.raw, meta).await { + Ok(true) => (), + Ok(false) => { + let (_res, slot) = close(conn).await; + return Err(slot); + } + Err(error) => { + tracing::warn!(%error, "error from `after_release`"); + // Connection is broken, don't try to gracefully close as + // something weird might happen. + let (_res, slot) = close_hard(conn).await; + return Err(slot); + } + } + } + + // test the connection on-release to ensure it is still viable, + // and flush anything time-sensitive like transaction rollbacks + // if an Executor future/stream is dropped during an `.await` call, the connection + // is likely to be left in an inconsistent state, in which case it should not be + // returned to the pool; also of course, if it was dropped due to an error + // this is simply a band-aid as SQLx-next connections should be able + // to recover from cancellations + if let Err(error) = conn.raw.ping().await { + tracing::warn!( + %error, + "error occurred while testing the connection on-release", + ); + + // Connection is broken, don't try to gracefully close. + let (_res, slot) = close_hard(conn).await; + Err(slot) + } else { + // if the connection is still viable, release it to the pool + drop(conn); + Ok(()) } } diff --git a/sqlx-core/src/pool/idle.rs b/sqlx-core/src/pool/idle.rs index 8b07b8e7c4..602ed3c5c8 100644 --- a/sqlx-core/src/pool/idle.rs +++ b/sqlx-core/src/pool/idle.rs @@ -1,6 +1,6 @@ use crate::connection::Connection; use crate::database::Database; -use crate::pool::connection::{Floating, Idle, Live}; +use crate::pool::connection::{Floating, Idle, ConnectionInner}; use crate::pool::inner::PoolInner; use crossbeam_queue::ArrayQueue; use event_listener::Event; @@ -71,7 +71,7 @@ impl IdleQueue { }) } - pub fn release(&self, conn: Floating>) { + pub fn release(&self, conn: Floating>) { let Floating { inner: conn, permit, diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index e3aee6a3af..af9229d47d 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -1,32 +1,35 @@ -use super::connection::{Floating, Idle, Live}; +use super::connection::ConnectionInner; use crate::database::Database; use crate::error::Error; -use crate::pool::{CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; +use crate::pool::{connection, CloseEvent, Pool, PoolConnection, PoolConnector, PoolOptions}; use std::cmp; use std::future::Future; -use std::pin::pin; +use std::pin::{pin, Pin}; +use std::rc::Weak; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use std::task::ready; +use std::task::{ready, Poll}; +use crate::connection::Connection; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::connect::{ConnectPermit, ConnectionCounter, ConnectionId, DynConnector}; -use crate::pool::idle::IdleQueue; -use crate::pool::shard::Sharded; +use crate::pool::connect::{ + ConnectPermit, ConnectTask, ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector, +}; +use crate::pool::shard::{ConnectedSlot, DisconnectedSlot, Sharded}; use crate::rt::JoinHandle; use crate::{private_tracing_dynamic_event, rt}; use either::Either; +use futures_core::FusedFuture; use futures_util::future::{self, OptionFuture}; -use futures_util::FutureExt; +use futures_util::{stream, FutureExt, TryStreamExt}; use std::time::{Duration, Instant}; use tracing::Level; pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, - pub(super) sharded: Sharded, - pub(super) idle: IdleQueue, + pub(super) sharded: Sharded>, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -39,19 +42,38 @@ impl PoolInner { options: PoolOptions, connector: impl PoolConnector, ) -> Arc { - let pool = Self { - connector: DynConnector::new(connector), - counter: ConnectionCounter::new(), - sharded: Sharded::new(options.max_connections, options.shards), - idle: IdleQueue::new(options.fair, options.max_connections), - is_closed: AtomicBool::new(false), - on_closed: event_listener::Event::new(), - acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), - acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), - options, - }; - - let pool = Arc::new(pool); + let pool = Arc::::new_cyclic(|pool_weak| { + let pool_weak = pool_weak.clone(); + + let reconnect = move |slot| { + let Some(pool) = pool_weak.upgrade() else { + return; + }; + + pool.connector.connect( + Pool(pool.clone()), + ConnectionId::next(), + slot, + ConnectTaskShared::new_arc(), + ); + }; + + Self { + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + sharded: Sharded::new( + options.max_connections, + options.shards, + options.min_connections, + reconnect, + ), + is_closed: AtomicBool::new(false), + on_closed: event_listener::Event::new(), + acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), + acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), + options, + } + }); spawn_maintenance_tasks(&pool); @@ -59,11 +81,11 @@ impl PoolInner { } pub(super) fn size(&self) -> usize { - self.counter.connections() + self.sharded.count_connected() } pub(super) fn num_idle(&self) -> usize { - self.idle.len() + self.sharded.count_unlocked(true) } pub(super) fn is_closed(&self) -> bool { @@ -79,23 +101,13 @@ impl PoolInner { self.mark_closed(); // Keep clearing the idle queue as connections are released until the count reaches zero. - async move { - let mut drained = pin!(self.counter.drain()); - - loop { - let mut acquire_idle = pin!(self.idle.acquire(self)); - - // Not using `futures::select!{}` here because it requires a proc-macro dep, - // and frankly it's a little broken. - match future::select(drained.as_mut(), acquire_idle.as_mut()).await { - // *not* `either::Either`; they rolled their own - future::Either::Left(_) => break, - future::Either::Right((idle, _)) => { - idle.close().await; - } - } - } - } + self.sharded.drain(|slot| async move { + let (conn, slot) = ConnectedSlot::take(slot); + + let _ = conn.raw.close().await; + + slot + }) } pub(crate) fn close_event(&self) -> CloseEvent { @@ -109,17 +121,12 @@ impl PoolInner { } #[inline] - pub(super) fn try_acquire(self: &Arc) -> Option>> { + pub(super) fn try_acquire(self: &Arc) -> Option>> { if self.is_closed() { return None; } - self.idle.try_acquire(self) - } - - pub(super) fn release(&self, floating: Floating>) { - // `options.after_release` and other checks are in `PoolConnection::return_to_pool()`. - self.idle.release(floating); + self.sharded.try_acquire_connected() } pub(super) async fn acquire(self: &Arc) -> Result, Error> { @@ -131,91 +138,70 @@ impl PoolInner { let mut close_event = pin!(self.close_event()); let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); - let mut acquire_idle = pin!(self.idle.acquire(self).fuse()); - let mut before_acquire = OptionFuture::from(None); - let mut acquire_connect_permit = pin!(OptionFuture::from(Some( - self.counter.acquire_permit(self).fuse() - ))); - let mut connect = OptionFuture::from(None); - - // The internal state machine of `acquire()`. - // - // * The initial state is racing to acquire either an idle connection or a new `ConnectPermit`. - // * If we acquire a `ConnectPermit`, we begin the connection loop (with backoff) - // as implemented by `DynConnector`. - // * If we acquire an idle connection, we then start polling `check_idle_conn()`. - // - // This doesn't quite fit into `select!{}` because the set of futures that may be polled - // at a given time is dynamic, so it's actually simpler to hand-roll it. - let acquired = future::poll_fn(|cx| { - use std::task::Poll::*; - - // First check if the pool is already closed, - // or register for a wakeup if it gets closed. - if let Ready(()) = close_event.poll_unpin(cx) { - return Ready(Err(Error::PoolClosed)); - } - // Then check if our deadline has elapsed, or schedule a wakeup for when that happens. - if let Ready(()) = deadline.poll_unpin(cx) { - return Ready(Err(Error::PoolTimedOut)); + let connect_shared = ConnectTaskShared::new_arc(); + + let mut acquire_connected = pin!(self.acquire_connected().fuse()); + + let mut acquire_disconnected = pin!(self.sharded.acquire_disconnected().fuse()); + + let mut connect = future::Fuse::terminated(); + + let acquired = std::future::poll_fn(|cx| loop { + if let Poll::Ready(()) = close_event.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolClosed)); } - // Attempt to acquire a connection from the idle queue. - if let Ready(idle) = acquire_idle.poll_unpin(cx) { - // If we acquired an idle connection, run any checks that need to be done. - // - // Includes `test_on_acquire` and the `before_acquire` callback, if set. - match finish_acquire(idle) { - // There are checks needed to be done, so they're spawned as a task - // to be cancellation-safe. - Either::Left(check_task) => { - before_acquire = Some(check_task).into(); - } - // The connection is ready to go. - Either::Right(conn) => { - return Ready(Ok(conn)); - } - } + if let Poll::Ready(()) = deadline.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: connect_shared.take_error().map(Box::new), + })); } - // Poll the task returned by `finish_acquire` - match ready!(before_acquire.poll_unpin(cx)) { - Some(Ok(conn)) => return Ready(Ok(conn)), - Some(Err((id, permit))) => { - // We don't strictly need to poll `connect` here; all we really want to do - // is to check if it is `None`. But since currently there's no getter for that, - // it doesn't really hurt to just poll it here. - match connect.poll_unpin(cx) { - Ready(None) => { - // If we're not already attempting to connect, - // take the permit returned from closing the connection and - // attempt to open a new one. - connect = Some(self.connector.connect(id, permit)).into(); - } - // `permit` is dropped in these branches, allowing another task to use it - Ready(Some(res)) => return Ready(res), - Pending => (), + if let Poll::Ready(res) = acquire_connected.as_mut().poll(cx) { + match res { + Ok(conn) => { + return Poll::Ready(Ok(conn)); } + Err(slot) => { + if connect.is_terminated() { + connect = self + .connector + .connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + connect_shared.clone(), + ) + .fuse(); + } - // Attempt to acquire another idle connection concurrently to opening a new one. - acquire_idle.set(self.idle.acquire(self).fuse()); - // Annoyingly, `OptionFuture` doesn't fuse to `None` on its own - before_acquire = None.into(); + // Try to acquire another connected connection. + acquire_connected.set(self.acquire_connected().fuse()); + continue; + } } - None => (), } - if let Ready(Some((id, permit))) = acquire_connect_permit.poll_unpin(cx) { - connect = Some(self.connector.connect(id, permit)).into(); + if let Poll::Ready(slot) = acquire_disconnected.as_mut().poll(cx) { + if connect.is_terminated() { + connect = self + .connector + .connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + connect_shared.clone(), + ) + .fuse(); + } } - if let Ready(Some(res)) = connect.poll_unpin(cx) { - // RFC: suppress errors here? - return Ready(res); + if let Poll::Ready(res) = Pin::new(&mut connect).poll(cx) { + return Poll::Ready(res); } - Pending + return Poll::Pending; }) .await?; @@ -245,59 +231,66 @@ impl PoolInner { Ok(acquired) } - /// Try to maintain `min_connections`, returning any errors (including `PoolTimedOut`). - pub async fn try_min_connections(self: &Arc, deadline: Instant) -> Result<(), Error> { - rt::timeout_at(deadline, async { - while self.size() < self.options.min_connections { - // Don't wait for a connect permit. - // - // If no extra permits are available then we shouldn't be trying to spin up - // connections anyway. - let Some((id, permit)) = self.counter.try_acquire_permit(self) else { - return Ok(()); - }; - - let conn = self.connector.connect(id, permit).await?; + async fn acquire_connected( + self: &Arc, + ) -> Result, DisconnectedSlot>> { + let connected = self.sharded.acquire_connected().await; - // We skip `after_release` since the connection was never provided to user code - // besides inside `PollConnector::connect()`, if they override it. - self.release(conn.into_floating()); - } + tracing::debug!( + target: "sqlx::pool", + connection_id=%connected.id, + "acquired idle connection" + ); - Ok(()) - }) - .await - .unwrap_or_else(|_| Err(Error::PoolTimedOut)) + match finish_acquire(self, connected) { + Either::Left(task) => task.await, + Either::Right(conn) => Ok(conn), + } } - /// Attempt to maintain `min_connections`, logging if unable. - pub async fn min_connections_maintenance(self: &Arc, deadline: Option) { - let deadline = deadline.unwrap_or_else(|| { - // Arbitrary default deadline if the caller doesn't care. - Instant::now() + Duration::from_secs(300) - }); - - match self.try_min_connections(deadline).await { - Ok(()) => (), - Err(Error::PoolClosed) => (), - Err(Error::PoolTimedOut) => { - tracing::debug!("unable to complete `min_connections` maintenance before deadline") + pub(crate) async fn try_min_connections(self: &Arc) -> Result<(), Error> { + stream::iter( + self.sharded + .iter_min_connections() + .map(Result::<_, Error>::Ok), + ) + .try_for_each_concurrent(None, |slot| async move { + let shared = ConnectTaskShared::new_arc(); + + let res = self + .connector + .connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + shared.clone(), + ) + .await; + + match res { + Ok(conn) => { + drop(conn); + Ok(()) + } + Err(Error::PoolTimedOut { .. }) => Err(Error::PoolTimedOut { + last_connect_error: shared.take_error().map(Box::new), + }), + Err(other) => Err(other), } - Err(error) => tracing::debug!(%error, "error while maintaining min_connections"), - } + }) + .await } } impl Drop for PoolInner { fn drop(&mut self) { self.mark_closed(); - self.idle.drain(self); } } /// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise. pub(super) fn is_beyond_max_lifetime( - live: &Live, + live: &ConnectionInner, options: &PoolOptions, ) -> bool { options @@ -306,60 +299,69 @@ pub(super) fn is_beyond_max_lifetime( } /// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. -fn is_beyond_idle_timeout(idle: &Idle, options: &PoolOptions) -> bool { +fn is_beyond_idle_timeout( + idle: &ConnectionInner, + options: &PoolOptions, +) -> bool { options .idle_timeout - .is_some_and(|timeout| idle.idle_since.elapsed() > timeout) + .is_some_and(|timeout| idle.last_released_at.elapsed() > timeout) } /// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. /// /// Otherwise, immediately returns the connection. fn finish_acquire( - mut conn: Floating>, + pool: &Arc>, + mut conn: ConnectedSlot>, ) -> Either< - JoinHandle, (ConnectionId, ConnectPermit)>>, + JoinHandle, DisconnectedSlot>>>, PoolConnection, > { - let pool = conn.permit.pool(); - if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + let pool = pool.clone(); + // Spawn a task so the call may complete even if `acquire()` is cancelled. return Either::Left(rt::spawn(async move { // Check that the connection is still live - if let Err(error) = conn.ping().await { + if let Err(error) = conn.raw.ping().await { // an error here means the other end has hung up or we lost connectivity // either way we're fine to just discard the connection // the error itself here isn't necessarily unexpected so WARN is too strong - tracing::info!(%error, "ping on idle connection returned error"); + tracing::info!(%error, connection_id=%conn.id, "ping on idle connection returned error"); + // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); + let (_res, slot) = connection::close_hard(conn).await; + return Err(slot); } - if let Some(test) = &conn.permit.pool().options.before_acquire { - let meta = conn.metadata(); - match test(&mut conn.inner.live.raw, meta).await { + if let Some(test) = &pool.options.before_acquire { + let meta = conn.idle_metadata(); + match test(&mut conn.raw, meta).await { Ok(false) => { // connection was rejected by user-defined hook, close nicely - return Err(conn.close().await); + let (_res, slot) = connection::close(conn).await; + return Err(slot); } Err(error) => { tracing::warn!(%error, "error from `before_acquire`"); + // connection is broken so don't try to close nicely - return Err(conn.close_hard().await); + let (_res, slot) = connection::close_hard(conn).await; + return Err(slot); } Ok(true) => {} } } - Ok(conn.into_live().reattach()) + Ok(PoolConnection::new(conn, pool)) })); } // No checks are configured, return immediately. - Either::Right(conn.into_live().reattach()) + Either::Right(PoolConnection::new(conn, pool.clone())) } fn spawn_maintenance_tasks(pool: &Arc>) { @@ -376,7 +378,13 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if pool.options.min_connections > 0 { rt::spawn(async move { if let Some(pool) = pool_weak.upgrade() { - pool.min_connections_maintenance(None).await; + if let Err(error) = pool.try_min_connections().await { + tracing::error!( + target: "sqlx::pool", + ?error, + "error maintaining min_connections" + ); + } } }); } @@ -401,31 +409,21 @@ fn spawn_maintenance_tasks(pool: &Arc>) { // Go over all idle connections, check for idleness and lifetime, // and if we have fewer than min_connections after reaping a connection, - // open a new one immediately. Note that other connections may be popped from - // the queue in the meantime - that's fine, there is no harm in checking more - for _ in 0..pool.num_idle() { - if let Some(conn) = pool.try_acquire() { - if is_beyond_idle_timeout(&conn, &pool.options) - || is_beyond_max_lifetime(&conn, &pool.options) - { - let _ = conn.close().await; - pool.min_connections_maintenance(Some(next_run)).await; - } else { - pool.release(conn.into_live()); - } + // open a new one immediately. + for conn in pool.sharded.iter_idle() { + if is_beyond_idle_timeout(&conn, &pool.options) + || is_beyond_max_lifetime(&conn, &pool.options) + { + // Dropping the slot will check if the connection needs to be + // re-made. + let _ = connection::close(conn).await; } } // Don't hold a reference to the pool while sleeping. drop(pool); - if let Some(duration) = next_run.checked_duration_since(Instant::now()) { - // `async-std` doesn't have a `sleep_until()` - rt::sleep(duration).await; - } else { - // `next_run` is in the past, just yield. - rt::yield_now().await; - } + rt::sleep_until(next_run).await; } }) .await; diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 84776d0e22..0b8d94521e 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -87,7 +87,7 @@ mod connect; mod connection; mod inner; -mod idle; +// mod idle; mod options; mod shard; @@ -369,7 +369,7 @@ impl Pool { /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| conn.into_live().reattach()) + self.0.try_acquire().map(|conn| PoolConnection::new(conn, self.0.clone())) } /// Retrieves a connection and immediately begins a new transaction. diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index 0e8e05b4cb..e346956137 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -1,10 +1,11 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use crate::pool::connect::DefaultConnector; +use crate::pool::connect::{ConnectTaskShared, ConnectionId, DefaultConnector}; use crate::pool::inner::PoolInner; use crate::pool::{Pool, PoolConnector}; use futures_core::future::BoxFuture; +use futures_util::{stream, TryStreamExt}; use log::LevelFilter; use std::fmt::{self, Debug, Formatter}; use std::num::NonZero; @@ -74,7 +75,7 @@ pub struct PoolOptions { pub(crate) acquire_slow_level: LevelFilter, pub(crate) acquire_slow_threshold: Duration, pub(crate) acquire_timeout: Duration, - pub(crate) connect_timeout: Duration, + pub(crate) connect_timeout: Option, pub(crate) min_connections: usize, pub(crate) max_lifetime: Option, pub(crate) idle_timeout: Option, @@ -155,7 +156,7 @@ impl PoolOptions { // to not flag typical time to add a new connection to a pool. acquire_slow_threshold: Duration::from_secs(2), acquire_timeout: Duration::from_secs(30), - connect_timeout: Duration::from_secs(2 * 60), + connect_timeout: None, idle_timeout: Some(Duration::from_secs(10 * 60)), max_lifetime: Some(Duration::from_secs(30 * 60)), fair: true, @@ -323,15 +324,15 @@ impl PoolOptions { /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. /// /// If shorter than `acquire_timeout`, this will cause the last connec - pub fn connect_timeout(mut self, timeout: Duration) -> Self { - self.connect_timeout = timeout; + pub fn connect_timeout(mut self, timeout: impl Into>) -> Self { + self.connect_timeout = timeout.into(); self } /// Get the maximum amount of time to spend attempting to open a connection. /// /// This timeout happens independently of [`acquire_timeout`][Self::acquire_timeout]. - pub fn get_connect_timeout(&self) -> Duration { + pub fn get_connect_timeout(&self) -> Option { self.connect_timeout } @@ -573,17 +574,6 @@ impl PoolOptions { let inner = PoolInner::new_arc(self, connector); - if inner.options.min_connections > 0 { - // If the idle reaper is spawned then this will race with the call from that task - // and may not report any connection errors. - inner.try_min_connections(deadline).await?; - } - - // If `min_connections` is nonzero then we'll likely just pull a connection - // from the idle queue here, but it should at least get tested first. - let conn = inner.acquire().await?; - inner.release(conn.into_floating()); - Ok(Pool(inner)) } @@ -642,7 +632,7 @@ fn default_shards() -> NonZero { #[cfg(feature = "_rt-async-std")] if let Some(val) = std::env::var("ASYNC_STD_THREAD_COUNT") .ok() - .and_then(|s| s.parse()) + .and_then(|s| s.parse().ok()) { return val; } diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index a0bcee223b..24750e0a32 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -1,15 +1,17 @@ -use event_listener::{Event, IntoNotification}; +use crate::rt; +use event_listener::{listener, Event, IntoNotification}; +use futures_util::{future, stream, StreamExt}; +use spin::lock_api::Mutex; use std::future::Future; use std::num::NonZero; +use std::ops::{Deref, DerefMut}; use std::pin::pin; use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use std::task::Poll; +use std::sync::{atomic, Arc}; +use std::task::{ready, Poll}; use std::time::Duration; use std::{array, iter}; -use spin::lock_api::Mutex; - type ShardId = usize; type ConnectionIndex = usize; @@ -17,7 +19,7 @@ type ConnectionIndex = usize; /// /// We want tasks to acquire from their local shards where possible, so they don't enter /// the global queue immediately. -const GLOBAL_QUEUE_DELAY: Duration = Duration::from_millis(10); +const GLOBAL_ACQUIRE_DELAY: Duration = Duration::from_millis(10); /// Delay before attempting to acquire from a non-local shard, /// as well as the backoff when iterating through shards. @@ -30,20 +32,27 @@ pub struct Sharded { type ArcShard = Arc>>]>>; -struct Global { - unlock_event: Event>, - disconnect_event: Event>, +struct Global) + Send + Sync + 'static> { + unlock_event: Event>, + disconnect_event: Event>, + min_connections: usize, + num_shards: usize, + do_reconnect: F, } type ArcMutexGuard = lock_api::ArcMutexGuard, Option>; -pub struct LockGuard { +struct SlotGuard { // `Option` allows us to take the guard in the drop handler. locked: Option>, shard: ArcShard, index: ConnectionIndex, } +pub struct ConnectedSlot(SlotGuard); + +pub struct DisconnectedSlot(SlotGuard); + // Align to cache lines. // Simplified from https://docs.rs/crossbeam-utils/0.8.21/src/crossbeam_utils/cache_padded.rs.html#80 // @@ -54,12 +63,15 @@ pub struct LockGuard { #[cfg_attr(not(target_pointer_width = "64"), repr(align(64)))] struct Shard { shard_id: ShardId, - /// Bitset for all connection indexes that are currently in-use. + /// Bitset for all connection indices that are currently in-use. locked_set: AtomicUsize, - /// Bitset for all connection indexes that are currently connected. + /// Bitset for all connection indices that are currently connected. connected_set: AtomicUsize, - unlock_event: Event>, - disconnect_event: Event>, + /// Bitset for all connection indices that have been explicitly leaked. + leaked_set: AtomicUsize, + unlock_event: Event>, + disconnect_event: Event>, + leak_event: Event, global: Arc>, connections: Ts, } @@ -78,13 +90,23 @@ const MAX_SHARD_SIZE: usize = if usize::BITS > 64 { }; impl Sharded { - pub fn new(connections: usize, shards: NonZero) -> Sharded { + pub fn new( + connections: usize, + shards: NonZero, + min_connections: usize, + do_reconnect: impl Fn(DisconnectedSlot) + Send + Sync + 'static, + ) -> Sharded { + let params = Params::calc(connections, shards.get()); + let global = Arc::new(Global { unlock_event: Event::with_tag(), disconnect_event: Event::with_tag(), + num_shards: params.shards, + min_connections, + do_reconnect, }); - let shards = Params::calc(connections, shards.get()) + let shards = params .shard_sizes() .enumerate() .map(|(shard_id, size)| Shard::new(shard_id, size, global.clone())) @@ -93,7 +115,60 @@ impl Sharded { Sharded { shards, global } } - pub async fn acquire(&self, connected: bool) -> LockGuard { + #[inline] + pub fn num_shards(&self) -> usize { + self.shards.len() + } + + #[allow(clippy::cast_possible_truncation)] // This is only informational + pub fn count_connected(&self) -> usize { + atomic::fence(Ordering::Acquire); + + self.shards + .iter() + .map(|shard| shard.connected_set.load(Ordering::Relaxed).count_ones() as usize) + .sum() + } + + #[allow(clippy::cast_possible_truncation)] // This is only informational + pub fn count_unlocked(&self, connected: bool) -> usize { + self.shards + .iter() + .map(|shard| shard.unlocked_mask(connected).count_ones() as usize) + .sum() + } + + pub async fn acquire_connected(&self) -> ConnectedSlot { + let guard = self.acquire(true).await; + + assert!( + guard.get().is_some(), + "BUG: expected slot {}/{} to be connected but it wasn't", + guard.shard.shard_id, + guard.index + ); + + ConnectedSlot(guard) + } + + pub fn try_acquire_connected(&self) -> Option> { + todo!() + } + + pub async fn acquire_disconnected(&self) -> DisconnectedSlot { + let guard = self.acquire(true).await; + + assert!( + guard.get().is_some(), + "BUG: expected slot {}/{} NOT to be connected but it WAS", + guard.shard.shard_id, + guard.index + ); + + DisconnectedSlot(guard) + } + + async fn acquire(&self, connected: bool) -> SlotGuard { if self.shards.len() == 1 { return self.shards[0].acquire(connected).await; } @@ -106,7 +181,7 @@ impl Sharded { let mut next_shard = thread_id; loop { - crate::rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; + rt::sleep(NON_LOCAL_ACQUIRE_DELAY).await; // Choose shards pseudorandomly by multiplying with a (relatively) large prime. next_shard = (next_shard.wrapping_mul(547)) % self.shards.len(); @@ -118,7 +193,7 @@ impl Sharded { }); let mut acquire_global = pin!(async { - crate::rt::sleep(GLOBAL_QUEUE_DELAY).await; + rt::sleep(GLOBAL_ACQUIRE_DELAY).await; let event_to_listen = if connected { &self.global.unlock_event @@ -150,6 +225,36 @@ impl Sharded { }) .await } + + pub fn iter_min_connections(&self) -> impl Iterator> + '_ { + self.shards + .iter() + .flat_map(|shard| shard.iter_min_connections()) + } + + pub fn iter_idle(&self) -> impl Iterator> + '_ { + self.shards.iter().flat_map(|shard| shard.iter_idle()) + } + + pub async fn drain(&self, close: F) + where + F: Fn(ConnectedSlot) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + T: Send + 'static, + { + let close = Arc::new(close); + + stream::iter(self.shards.iter()) + .for_each_concurrent(None, |shard| { + let shard = shard.clone(); + let close = close.clone(); + + rt::spawn(async move { + shard.drain(&*close).await; + }) + }) + .await; + } } impl Shard>>]> { @@ -163,9 +268,11 @@ impl Shard>>]> { $($n => Arc::new(Shard { shard_id, locked_set: AtomicUsize::new(0), - unlock_event: Event::with_tag(), connected_set: AtomicUsize::new(0), + leaked_set: AtomicUsize::new(0), + unlock_event: Event::with_tag(), disconnect_event: Event::with_tag(), + leak_event: Event::with_tag(), global, connections: array::from_fn::<_, $n, _>(|_| Arc::new(Mutex::new(None))) }),)* @@ -181,7 +288,27 @@ impl Shard>>]> { ) } - async fn acquire(self: &Arc, connected: bool) -> LockGuard { + #[inline] + fn unlocked_mask(&self, connected: bool) -> Mask { + let locked_set = self.locked_set.load(Ordering::Acquire); + let connected_set = self.connected_set.load(Ordering::Relaxed); + + let connected_mask = if connected { + connected_set + } else { + !connected_set + }; + + Mask(!locked_set & connected_mask) + } + + /// Choose the first index that is unlocked with bit `connected` + #[inline] + fn next_unlocked(&self, connected: bool) -> Option { + self.unlocked_mask(connected).next() + } + + async fn acquire(self: &Arc, connected: bool) -> SlotGuard { // Attempt an unfair acquire first, before we modify the waitlist. if let Some(locked) = self.try_acquire(connected) { return locked; @@ -205,91 +332,178 @@ impl Shard>>]> { listener.await } - fn try_acquire(self: &Arc, connected: bool) -> Option> { - let locked_set = self.locked_set.load(Ordering::Acquire); - let connected_set = self.connected_set.load(Ordering::Relaxed); - - let connected_mask = if connected { - connected_set - } else { - !connected_set - }; - - // Choose the first index that is unlocked with bit `connected` - let index = (!locked_set & connected_mask).leading_zeros() as usize; - - self.try_lock(index) + fn try_acquire(self: &Arc, connected: bool) -> Option> { + self.try_lock(self.next_unlocked(connected)?) } - fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { - let locked = self.connections[index].try_lock_arc()?; + fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { + let locked = self.connections.get(index)?.try_lock_arc()?; // The locking of the connection itself must use an `Acquire` fence, // so additional synchronization is unnecessary. atomic_set(&self.locked_set, index, true, Ordering::Relaxed); - Some(LockGuard { + Some(SlotGuard { locked: Some(locked), shard: self.clone(), index, }) } + + fn iter_min_connections(self: &Arc) -> impl Iterator> + '_ { + (0..self.connections.len()) + .filter_map(|index| { + let slot = self.try_lock(index)?; + + // Guard against some weird bug causing this to already be connected + slot.get().is_none().then_some(DisconnectedSlot(slot)) + }) + .take(self.global.shard_min_connections(self.shard_id)) + } + + fn iter_idle(self: &Arc) -> impl Iterator> + '_ { + self.unlocked_mask(true).filter_map(|index| { + let slot = self.try_lock(index)?; + + // Guard against some weird bug causing this to already be connected + slot.get().is_some().then_some(ConnectedSlot(slot)) + }) + } + + async fn drain(self: &Arc, close: F) + where + F: Fn(ConnectedSlot) -> Fut, + Fut: Future>, + { + let mut drain_connected = pin!(async { + loop { + let connected = self.acquire(true).await; + DisconnectedSlot::leak(close(ConnectedSlot(connected)).await); + } + }); + + let mut drain_disconnected = pin!(async { + loop { + let disconnected = DisconnectedSlot(self.acquire(false).await); + DisconnectedSlot::leak(disconnected); + } + }); + + let mut drain_leaked = pin!(async { + loop { + listener!(self.leak_event => leaked); + leaked.await; + } + }); + + let finished_mask = (1usize << self.connections.len()) - 1; + + std::future::poll_fn(|cx| { + // The connection set is drained once all slots are leaked. + if self.leaked_set.load(Ordering::Acquire) == finished_mask { + return Poll::Ready(()); + } + + // These futures shouldn't return `Ready` + let _ = drain_connected.as_mut().poll(cx); + let _ = drain_disconnected.as_mut().poll(cx); + let _ = drain_leaked.as_mut().poll(cx); + + Poll::Pending + }) + .await; + } } -impl Params { - fn calc(connections: usize, mut shards: usize) -> Params { - assert_ne!(shards, 0); +impl Deref for ConnectedSlot { + type Target = T; - let mut shard_size = connections / shards; - let mut remainder = connections % shards; + fn deref(&self) -> &Self::Target { + self.0 + .get() + .as_ref() + .expect("BUG: expected slot to be populated, but it wasn't") + } +} - if shard_size == 0 { - tracing::debug!(connections, shards, "more shards than connections; clamping shard size to 1, shard count to connections"); - shards = connections; - shard_size = 1; - remainder = 0; - } else if shard_size >= MAX_SHARD_SIZE { - let new_shards = connections.div_ceil(MAX_SHARD_SIZE); +impl DerefMut for ConnectedSlot { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + .get_mut() + .as_mut() + .expect("BUG: expected slot to be populated, but it wasn't") + } +} - tracing::debug!( - connections, - shards, - "shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}" - ); +impl ConnectedSlot { + pub fn take(mut this: Self) -> (T, DisconnectedSlot) { + let conn = this + .0 + .get_mut() + .take() + .expect("BUG: expected slot to be populated, but it wasn't"); - shards = new_shards; - shard_size = connections / shards; - remainder = connections % shards; - } + (conn, DisconnectedSlot(this.0)) + } +} - Params { - shards, - shard_size, - remainder, - } +impl DisconnectedSlot { + pub fn put(mut self, connection: T) -> ConnectedSlot { + *self.0.get_mut() = Some(connection); + ConnectedSlot(self.0) } - fn shard_sizes(&self) -> impl Iterator { - iter::repeat_n(self.shard_size + 1, self.remainder).chain(iter::repeat_n( - self.shard_size, - self.shards - self.remainder, - )) + pub fn leak(mut self: Self) { + self.0.locked = None; + + atomic_set( + &self.0.shard.connected_set, + self.0.index, + false, + Ordering::Relaxed, + ); + atomic_set( + &self.0.shard.leaked_set, + self.0.index, + true, + Ordering::Release, + ); + + self.0.shard.leak_event.notify(usize::MAX.tag(self.0.index)); + } + + pub fn should_reconnect(&self) -> bool { + self.0.should_reconnect() } } -fn current_thread_id() -> usize { - // FIXME: this can be replaced when this is stabilized: - // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 - static THREAD_ID: AtomicUsize = AtomicUsize::new(0); +impl SlotGuard { + fn get(&self) -> &Option { + self.locked + .as_deref() + .expect("BUG: `SlotGuard.locked` taken") + } - thread_local! { - static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + fn get_mut(&mut self) -> &mut Option { + self.locked + .as_deref_mut() + .expect("BUG: `SlotGuard.locked` taken") } - CURRENT_THREAD_ID.with(|i| *i) + fn should_reconnect(&self) -> bool { + let min_connections = self.shard.global.shard_min_connections(self.shard.shard_id); + + let num_connected = self + .shard + .connected_set + .load(Ordering::Acquire) + .count_ones() as usize; + + num_connected < min_connections + } } -impl Drop for LockGuard { +impl Drop for SlotGuard { fn drop(&mut self) { let Some(locked) = self.locked.take() else { return; @@ -305,6 +519,7 @@ impl Drop for LockGuard { Ordering::Relaxed, ); + // We don't actually unlock the connection unless there's no receivers to accept it. // If another receiver is waiting for a connection, we can directly pass them the lock. // // This prevents drive-by tasks from acquiring connections before waiting tasks @@ -314,7 +529,6 @@ impl Drop for LockGuard { // the situation when a receiver was passed a connection that was still marked as locked, // but was cancelled before it could complete the acquisition. Otherwise, the connection // would be marked as locked forever, effectively being leaked. - let mut locked = Some(locked); // This is a code smell, but it's necessary because `event-listener` has no way to specify @@ -328,7 +542,7 @@ impl Drop for LockGuard { .take() .expect("BUG: notification sent more than once"); - LockGuard { + SlotGuard { locked: Some(locked), shard: self.shard.clone(), index: self.index, @@ -369,6 +583,12 @@ impl Drop for LockGuard { { return; } + + // If this connection is required to satisfy `min_connections` + if self.should_reconnect() { + (self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag())); + return; + } } // Be sure to drop the lock guard if it's still held, @@ -382,16 +602,110 @@ impl Drop for LockGuard { } } +impl Global { + fn shard_min_connections(&self, shard_id: ShardId) -> usize { + let min_connections_per_shard = self.min_connections / self.num_shards; + + if (self.min_connections % self.num_shards) < shard_id { + min_connections_per_shard + 1 + } else { + min_connections_per_shard + } + } +} + +impl Params { + fn calc(connections: usize, mut shards: usize) -> Params { + assert_ne!(shards, 0); + + let mut shard_size = connections / shards; + let mut remainder = connections % shards; + + if shard_size == 0 { + tracing::debug!(connections, shards, "more shards than connections; clamping shard size to 1, shard count to connections"); + shards = connections; + shard_size = 1; + remainder = 0; + } else if shard_size >= MAX_SHARD_SIZE { + let new_shards = connections.div_ceil(MAX_SHARD_SIZE); + + tracing::debug!( + connections, + shards, + "shard size exceeds {MAX_SHARD_SIZE}, clamping shard count to {new_shards}" + ); + + shards = new_shards; + shard_size = connections / shards; + remainder = connections % shards; + } + + Params { + shards, + shard_size, + remainder, + } + } + + fn shard_sizes(&self) -> impl Iterator { + iter::repeat_n(self.shard_size + 1, self.remainder).chain(iter::repeat_n( + self.shard_size, + self.shards - self.remainder, + )) + } +} + fn atomic_set(atomic: &AtomicUsize, index: usize, value: bool, ordering: Ordering) { if value { - let bit = 1 >> index; + let bit = 1 << index; atomic.fetch_or(bit, ordering); } else { - let bit = !(1 >> index); + let bit = !(1 << index); atomic.fetch_and(bit, ordering); } } +fn current_thread_id() -> usize { + // FIXME: this can be replaced when this is stabilized: + // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + } + + CURRENT_THREAD_ID.with(|i| *i) +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct Mask(usize); + +impl Mask { + pub fn count_ones(&self) -> usize { + self.0.count_ones() as usize + } +} + +impl Iterator for Mask { + type Item = usize; + + fn next(&mut self) -> Option { + if self.0 == 0 { + return None; + } + + let index = self.0.trailing_zeros() as usize; + self.0 &= 1 << index; + + Some(index) + } + + fn size_hint(&self) -> (usize, Option) { + let count = self.0.count_ones() as usize; + (count, Some(count)) + } +} + #[cfg(test)] mod tests { use super::{Params, MAX_SHARD_SIZE}; diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 0044139f55..985d9bb607 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -56,18 +56,18 @@ pub async fn timeout_at(deadline: Instant, f: F) -> Result(fut: F) -> JoinHandle where @@ -186,7 +201,7 @@ pub fn test_block_on(f: F) -> F::Output { #[track_caller] pub const fn missing_rt(_unused: T) -> ! { if cfg!(feature = "_rt-tokio") { - panic!("this functionality requires a Tokio context") + panic!("this functionality requires an active Tokio runtime") } panic!("one of the `runtime` features of SQLx must be enabled") diff --git a/sqlx-core/src/rt/rt_async_io/mod.rs b/sqlx-core/src/rt/rt_async_io/mod.rs index 5e4d7074dc..70d01fbecb 100644 --- a/sqlx-core/src/rt/rt_async_io/mod.rs +++ b/sqlx-core/src/rt/rt_async_io/mod.rs @@ -1,4 +1,4 @@ mod socket; -mod timeout; -pub use timeout::*; +mod time; +pub use time::*; diff --git a/sqlx-core/src/rt/rt_async_io/timeout.rs b/sqlx-core/src/rt/rt_async_io/time.rs similarity index 54% rename from sqlx-core/src/rt/rt_async_io/timeout.rs rename to sqlx-core/src/rt/rt_async_io/time.rs index b4a779074b..039610b758 100644 --- a/sqlx-core/src/rt/rt_async_io/timeout.rs +++ b/sqlx-core/src/rt/rt_async_io/time.rs @@ -1,20 +1,24 @@ -use std::{future::Future, pin::pin, time::Duration}; +use std::{ + future::Future, + pin::pin, + time::{Duration, Instant}, +}; use futures_util::future::{select, Either}; use crate::rt::TimeoutError; pub async fn sleep(duration: Duration) { - timeout_future(duration).await; + async_io::Timer::after(duration).await; +} + +pub async fn sleep_until(deadline: Instant) { + async_io::Timer::at(deadline).await; } pub async fn timeout(duration: Duration, future: F) -> Result { - match select(pin!(future), timeout_future(duration)).await { + match select(pin!(future), pin!(sleep(duration))).await { Either::Left((result, _)) => Ok(result), Either::Right(_) => Err(TimeoutError), } } - -fn timeout_future(duration: Duration) -> impl Future { - async_io::Timer::after(duration) -} diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 971752f88f..2fd51445b3 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -4,8 +4,51 @@ // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. -#[cfg(all(feature = "_rt-async-std", not(feature = "_rt-tokio")))] -pub use async_std::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; - #[cfg(feature = "_rt-tokio")] -pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard}; +pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; + +#[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] +pub use async_lock::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; + +#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] +pub use noop::*; + +#[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] +mod noop { + use crate::rt::missing_rt; + use std::marker::PhantomData; + use std::ops::{Deref, DerefMut}; + + pub struct AsyncMutex { + // `Sync` if `T: Send` + _marker: PhantomData>, + } + + pub struct AsyncMutexGuard<'a, T> { + inner: &'a AsyncMutex, + } + + impl AsyncMutex { + pub fn new(val: T) -> Self { + missing_rt(val) + } + + pub fn lock(&self) -> AsyncMutexGuard { + missing_rt(self) + } + } + + impl Deref for AsyncMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + missing_rt(self) + } + } + + impl DerefMut for AsyncMutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + missing_rt(self) + } + } +} From dd9cb718de48edbe0f08cc5320689ff04d69173c Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 29 Oct 2025 06:54:15 -0700 Subject: [PATCH 21/24] fix: bugs in sharded pool --- Cargo.toml | 4 +-- sqlx-core/src/pool/inner.rs | 61 ++++++++++++++++++++--------------- sqlx-core/src/pool/mod.rs | 4 ++- sqlx-core/src/pool/options.rs | 4 +++ sqlx-core/src/pool/shard.rs | 51 ++++++++++++++++++++++------- sqlx-test/Cargo.toml | 1 + sqlx-test/src/lib.rs | 6 +++- tests/any/pool.rs | 33 +++++++++++++++++-- 8 files changed, 121 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c3cb86c8af..a92c67d8c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -224,7 +224,6 @@ sqlx-sqlite = { workspace = true, optional = true } anyhow = "1.0.52" time_ = { version = "0.3.2", package = "time" } futures-util = { version = "0.3.19", default-features = false, features = ["alloc"] } -env_logger = "0.11" async-std = { workspace = true, features = ["attributes"] } tokio = { version = "1.15.0", features = ["full"] } dotenvy = "0.15.0" @@ -241,7 +240,8 @@ tempfile = "3.10.1" criterion = { version = "0.7.0", features = ["async_tokio"] } libsqlite3-sys = { version = "0.30.1" } -tracing = { version = "0.1.44", features = ["attributes"] } +tracing = "0.1.41" +tracing-subscriber = "0.3.20" # If this is an unconditional dev-dependency then Cargo will *always* try to build `libsqlite3-sys`, # even when SQLite isn't the intended test target, and fail if the build environment is not set up for compiling C code. diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index af9229d47d..eb6e827e8f 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -26,6 +26,8 @@ use futures_util::{stream, FutureExt, TryStreamExt}; use std::time::{Duration, Instant}; use tracing::Level; +const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5); + pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, @@ -47,6 +49,8 @@ impl PoolInner { let reconnect = move |slot| { let Some(pool) = pool_weak.upgrade() else { + // Prevent an infinite loop on pool drop. + DisconnectedSlot::leak(slot); return; }; @@ -104,7 +108,7 @@ impl PoolInner { self.sharded.drain(|slot| async move { let (conn, slot) = ConnectedSlot::take(slot); - let _ = conn.raw.close().await; + let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await; slot }) @@ -248,37 +252,42 @@ impl PoolInner { } } - pub(crate) async fn try_min_connections(self: &Arc) -> Result<(), Error> { - stream::iter( - self.sharded - .iter_min_connections() - .map(Result::<_, Error>::Ok), - ) - .try_for_each_concurrent(None, |slot| async move { - let shared = ConnectTaskShared::new_arc(); - - let res = self - .connector - .connect( + pub(crate) async fn try_min_connections( + self: &Arc, + deadline: Option, + ) -> Result<(), Error> { + let shared = ConnectTaskShared::new_arc(); + + let connect_min_connections = + future::try_join_all(self.sharded.iter_min_connections().map(|slot| { + self.connector.connect( Pool(self.clone()), ConnectionId::next(), slot, shared.clone(), ) - .await; - - match res { - Ok(conn) => { - drop(conn); - Ok(()) + })); + + let mut conns = if let Some(deadline) = deadline { + match rt::timeout_at(deadline, connect_min_connections).await { + Ok(Ok(conns)) => conns, + Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => { + return Err(Error::PoolTimedOut { + last_connect_error: shared.take_error().map(Box::new), + }); } - Err(Error::PoolTimedOut { .. }) => Err(Error::PoolTimedOut { - last_connect_error: shared.take_error().map(Box::new), - }), - Err(other) => Err(other), + Ok(Err(e)) => return Err(e), } - }) - .await + } else { + connect_min_connections.await? + }; + + for mut conn in conns { + // Bypass `after_release` + drop(conn.return_to_pool()); + } + + Ok(()) } } @@ -378,7 +387,7 @@ fn spawn_maintenance_tasks(pool: &Arc>) { if pool.options.min_connections > 0 { rt::spawn(async move { if let Some(pool) = pool_weak.upgrade() { - if let Err(error) = pool.try_min_connections().await { + if let Err(error) = pool.try_min_connections(None).await { tracing::error!( target: "sqlx::pool", ?error, diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 0b8d94521e..7d2b18ed4c 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -369,7 +369,9 @@ impl Pool { /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| PoolConnection::new(conn, self.0.clone())) + self.0 + .try_acquire() + .map(|conn| PoolConnection::new(conn, self.0.clone())) } /// Retrieves a connection and immediately begins a new transaction. diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index e346956137..975583e6f7 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -574,6 +574,10 @@ impl PoolOptions { let inner = PoolInner::new_arc(self, connector); + if inner.options.min_connections > 0 { + inner.try_min_connections(Some(deadline)).await?; + } + Ok(Pool(inner)) } diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index 24750e0a32..2385c9982a 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -47,6 +47,7 @@ struct SlotGuard { locked: Option>, shard: ArcShard, index: ConnectionIndex, + dropped: bool, } pub struct ConnectedSlot(SlotGuard); @@ -134,7 +135,7 @@ impl Sharded { pub fn count_unlocked(&self, connected: bool) -> usize { self.shards .iter() - .map(|shard| shard.unlocked_mask(connected).count_ones() as usize) + .map(|shard| shard.unlocked_mask(connected).count_ones()) .sum() } @@ -156,10 +157,10 @@ impl Sharded { } pub async fn acquire_disconnected(&self) -> DisconnectedSlot { - let guard = self.acquire(true).await; + let guard = self.acquire(false).await; assert!( - guard.get().is_some(), + guard.get().is_none(), "BUG: expected slot {}/{} NOT to be connected but it WAS", guard.shard.shard_id, guard.index @@ -347,6 +348,7 @@ impl Shard>>]> { locked: Some(locked), shard: self.clone(), index, + dropped: false, }) } @@ -370,6 +372,13 @@ impl Shard>>]> { }) } + fn all_leaked(&self) -> bool { + let all_leaked_mask = (1usize << self.connections.len()) - 1; + let leaked_set = self.leaked_set.load(Ordering::Acquire); + + leaked_set == all_leaked_mask + } + async fn drain(self: &Arc, close: F) where F: Fn(ConnectedSlot) -> Fut, @@ -396,11 +405,9 @@ impl Shard>>]> { } }); - let finished_mask = (1usize << self.connections.len()) - 1; - std::future::poll_fn(|cx| { // The connection set is drained once all slots are leaked. - if self.leaked_set.load(Ordering::Acquire) == finished_mask { + if self.all_leaked() { return Poll::Ready(()); } @@ -409,7 +416,12 @@ impl Shard>>]> { let _ = drain_disconnected.as_mut().poll(cx); let _ = drain_leaked.as_mut().poll(cx); - Poll::Pending + // Check again after driving the `drain` futures forward. + if self.all_leaked() { + Poll::Ready(()) + } else { + Poll::Pending + } }) .await; } @@ -443,6 +455,13 @@ impl ConnectedSlot { .take() .expect("BUG: expected slot to be populated, but it wasn't"); + atomic_set( + &this.0.shard.connected_set, + this.0.index, + false, + Ordering::AcqRel, + ); + (conn, DisconnectedSlot(this.0)) } } @@ -450,6 +469,14 @@ impl ConnectedSlot { impl DisconnectedSlot { pub fn put(mut self, connection: T) -> ConnectedSlot { *self.0.get_mut() = Some(connection); + + atomic_set( + &self.0.shard.connected_set, + self.0.index, + true, + Ordering::AcqRel, + ); + ConnectedSlot(self.0) } @@ -546,10 +573,13 @@ impl Drop for SlotGuard { locked: Some(locked), shard: self.shard.clone(), index: self.index, + // To avoid infinite recursion or deadlock, don't send another notification + // if this guard was already dropped once: just unlock it. + dropped: true, } }; - if connected { + if !self.dropped && connected { // Check for global waiters first. if self .shard @@ -564,7 +594,7 @@ impl Drop for SlotGuard { if self.shard.unlock_event.notify(1.tag_with(&mut self_as_tag)) > 0 { return; } - } else { + } else if !self.dropped { if self .shard .global @@ -584,7 +614,6 @@ impl Drop for SlotGuard { return; } - // If this connection is required to satisfy `min_connections` if self.should_reconnect() { (self.shard.global.do_reconnect)(DisconnectedSlot(self_as_tag())); return; @@ -695,7 +724,7 @@ impl Iterator for Mask { } let index = self.0.trailing_zeros() as usize; - self.0 &= 1 << index; + self.0 &= !(1 << index); Some(index) } diff --git a/sqlx-test/Cargo.toml b/sqlx-test/Cargo.toml index 32a341adcb..4fdcb3723e 100644 --- a/sqlx-test/Cargo.toml +++ b/sqlx-test/Cargo.toml @@ -10,6 +10,7 @@ sqlx = { default-features = false, path = ".." } env_logger = "0.11" dotenvy = "0.15.0" anyhow = "1.0.26" +tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } [lints] workspace = true diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 3744724c12..01cdc2977b 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -1,10 +1,14 @@ use sqlx::pool::PoolOptions; use sqlx::{Connection, Database, Error, Pool}; use std::env; +use tracing_subscriber::EnvFilter; pub fn setup_if_needed() { let _ = dotenvy::dotenv(); - let _ = env_logger::builder().is_test(true).try_init(); + let _ = tracing_subscriber::fmt::Subscriber::builder() + .with_env_filter(EnvFilter::from_default_env()) + .with_test_writer() + .finish(); } // Make a new connection diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 1cc0838053..d5d47d161f 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,6 +1,6 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; -use sqlx_core::connection::ConnectOptions; +use sqlx_core::connection::{ConnectOptions, Connection}; use sqlx_core::pool::PoolConnectMetadata; use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ @@ -9,6 +9,29 @@ use std::sync::{ }; use std::time::Duration; +#[sqlx_macros::test] +async fn pool_basic_functions() -> anyhow::Result<()> { + sqlx::any::install_default_drivers(); + + let pool = AnyPoolOptions::new() + .max_connections(2) + .acquire_timeout(Duration::from_secs(3)) + .connect(&dotenvy::var("DATABASE_URL")?) + .await?; + + let mut conn = pool.acquire().await?; + + conn.ping().await?; + + drop(conn); + + let b: bool = sqlx::query_scalar("SELECT true").fetch_one(&pool).await?; + + assert!(b); + + Ok(()) +} + // https://github.com/launchbadge/sqlx/issues/527 #[sqlx_macros::test] async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { @@ -43,6 +66,7 @@ async fn pool_should_be_returned_failed_transactions() -> anyhow::Result<()> { #[sqlx_macros::test] async fn test_pool_callbacks() -> anyhow::Result<()> { sqlx::any::install_default_drivers(); + tracing_subscriber::fmt::init(); #[derive(sqlx::FromRow, Debug, PartialEq, Eq)] struct ConnStats { @@ -131,7 +155,9 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + sqlx::raw_sql(AssertSqlSafe(statement)) + .execute(&mut conn) + .await?; Ok(conn) } }); @@ -154,6 +180,8 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { ]; for (id, before_acquire_calls, after_release_calls) in pattern { + eprintln!("ID: {id}, before_acquire calls: {before_acquire_calls}, after_release calls: {after_release_calls}"); + let conn_stats: ConnStats = sqlx::query_as("SELECT * FROM conn_stats") .fetch_one(&pool) .await?; @@ -183,6 +211,7 @@ async fn test_connection_maintenance() -> anyhow::Result<()> { let last_meta = Arc::new(Mutex::new(None)); let last_meta_ = last_meta.clone(); let pool = AnyPoolOptions::new() + .acquire_timeout(Duration::from_secs(1)) .max_lifetime(Duration::from_millis(400)) .min_connections(3) .before_acquire(move |_conn, _meta| { From 44e40b28168f5ef04373ce887c1da01744f09794 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 29 Oct 2025 06:54:21 -0700 Subject: [PATCH 22/24] fixup! benchmark --- benches/any/pool.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/benches/any/pool.rs b/benches/any/pool.rs index a689058055..423b2ce02b 100644 --- a/benches/any/pool.rs +++ b/benches/any/pool.rs @@ -9,7 +9,7 @@ use tracing::Instrument; struct Input { threads: usize, tasks: usize, - pool_size: u32, + pool_size: usize, } impl Display for Input { @@ -24,6 +24,7 @@ impl Display for Input { fn bench_pool(c: &mut Criterion) { sqlx::any::install_default_drivers(); + tracing_subscriber::fmt::try_init().ok(); let database_url = dotenvy::var("DATABASE_URL").expect("DATABASE_URL must be set"); @@ -72,6 +73,14 @@ fn bench_pool(c: &mut Criterion) { } fn bench_pool_with(b: &mut Bencher, input: &Input, database_url: &str) { + let _span = tracing::info_span!( + "bench_pool_with", + threads = input.threads, + tasks = input.tasks, + pool_size = input.pool_size + ) + .entered(); + let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() .worker_threads(input.threads) @@ -88,10 +97,13 @@ fn bench_pool_with(b: &mut Bencher, input: &Input, database_url: &str) { .expect("error connecting to pool") }); - for _ in 1..=input.tasks { + for num in 1..=input.tasks { let pool = pool.clone(); - runtime.spawn(async move { while pool.acquire().await.is_ok() {} }); + runtime.spawn( + async move { while pool.acquire().await.is_ok() {} } + .instrument(tracing::info_span!("task", num)), + ); } // Spawn the benchmark loop into the runtime so we're not accidentally including the main thread From d905016923aeca4f371e97a3205c45482cf8a02f Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Wed, 29 Oct 2025 11:37:43 -0700 Subject: [PATCH 23/24] fix: debug timeouts in benchmark --- Cargo.toml | 3 ++ sqlx-core/src/pool/inner.rs | 12 +++---- sqlx-core/src/pool/shard.rs | 69 +++++++++++++++++++++++++++---------- 3 files changed, 60 insertions(+), 24 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a92c67d8c3..3c3db27d5d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -467,3 +467,6 @@ name = "any-pool" path = "benches/any/pool.rs" required-features = ["runtime-tokio", "any"] harness = false + +[profile.bench] +debug = true diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index eb6e827e8f..046834a4c2 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -156,12 +156,6 @@ impl PoolInner { return Poll::Ready(Err(Error::PoolClosed)); } - if let Poll::Ready(()) = deadline.as_mut().poll(cx) { - return Poll::Ready(Err(Error::PoolTimedOut { - last_connect_error: connect_shared.take_error().map(Box::new), - })); - } - if let Poll::Ready(res) = acquire_connected.as_mut().poll(cx) { match res { Ok(conn) => { @@ -205,6 +199,12 @@ impl PoolInner { return Poll::Ready(res); } + if let Poll::Ready(()) = deadline.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: connect_shared.take_error().map(Box::new), + })); + } + return Poll::Pending; }) .await?; diff --git a/sqlx-core/src/pool/shard.rs b/sqlx-core/src/pool/shard.rs index 2385c9982a..c1964c7c6a 100644 --- a/sqlx-core/src/pool/shard.rs +++ b/sqlx-core/src/pool/shard.rs @@ -303,12 +303,6 @@ impl Shard>>]> { Mask(!locked_set & connected_mask) } - /// Choose the first index that is unlocked with bit `connected` - #[inline] - fn next_unlocked(&self, connected: bool) -> Option { - self.unlocked_mask(connected).next() - } - async fn acquire(self: &Arc, connected: bool) -> SlotGuard { // Attempt an unfair acquire first, before we modify the waitlist. if let Some(locked) = self.try_acquire(connected) { @@ -323,18 +317,34 @@ impl Shard>>]> { event_listener::listener!(event_to_listen => listener); - // We need to check again after creating the event listener, - // because in the meantime, a concurrent task may have seen that there were no listeners - // and just unlocked its connection. - if let Some(locked) = self.try_acquire(connected) { - return locked; + let mut listener = pin!(listener); + + loop { + // We need to check again after creating the event listener, + // because in the meantime, a concurrent task may have seen that there were no listeners + // and just unlocked its connection. + match rt::timeout(NON_LOCAL_ACQUIRE_DELAY, listener.as_mut()).await { + Ok(slot) => return slot, + Err(_) => { + if let Some(slot) = self.try_acquire(connected) { + return slot; + } + } + } } - - listener.await } fn try_acquire(self: &Arc, connected: bool) -> Option> { - self.try_lock(self.next_unlocked(connected)?) + // If `locked_set` is constantly changing, don't loop forever. + for index in self.unlocked_mask(connected) { + if let Some(slot) = self.try_lock(index) { + return Some(slot); + } + + std::hint::spin_loop(); + } + + None } fn try_lock(self: &Arc, index: ConnectionIndex) -> Option> { @@ -353,7 +363,7 @@ impl Shard>>]> { } fn iter_min_connections(self: &Arc) -> impl Iterator> + '_ { - (0..self.connections.len()) + self.unlocked_mask(false) .filter_map(|index| { let slot = self.try_lock(index)?; @@ -493,7 +503,7 @@ impl DisconnectedSlot { &self.0.shard.leaked_set, self.0.index, true, - Ordering::Release, + Ordering::AcqRel, ); self.0.shard.leak_event.notify(usize::MAX.tag(self.0.index)); @@ -627,7 +637,7 @@ impl Drop for SlotGuard { // but then fail to lock the mutex for it. drop(locked); - atomic_set(&self.shard.locked_set, self.index, false, Ordering::Release); + atomic_set(&self.shard.locked_set, self.index, false, Ordering::AcqRel); } } @@ -737,7 +747,7 @@ impl Iterator for Mask { #[cfg(test)] mod tests { - use super::{Params, MAX_SHARD_SIZE}; + use super::{Mask, Params, MAX_SHARD_SIZE}; #[test] fn test_params() { @@ -762,4 +772,27 @@ mod tests { } } } + + #[test] + fn test_mask() { + let inputs: &[(usize, &[usize])] = &[ + (0b0, &[]), + (0b1, &[0]), + (0b11, &[0, 1]), + (0b111, &[0, 1, 2]), + (0b1000, &[3]), + (0b1001, &[0, 3]), + (0b1001001, &[0, 3, 6]), + ]; + + for (mask, expected_indices) in inputs { + let actual_indices = Mask(*mask).collect::>(); + + assert_eq!( + actual_indices[..], + expected_indices[..], + "invalid mask: {mask:b}" + ); + } + } } From 0dd92b4594e961fbd7b9536f78bc94871af97d7c Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Sun, 30 Nov 2025 16:43:48 -0800 Subject: [PATCH 24/24] WIP refactor: replace sharding with single connection set --- sqlx-core/src/ext/future.rs | 38 ++ sqlx-core/src/ext/mod.rs | 2 + sqlx-core/src/pool/connect.rs | 22 +- sqlx-core/src/pool/connection.rs | 113 ++++-- sqlx-core/src/pool/connection_set.rs | 543 +++++++++++++++++++++++++++ sqlx-core/src/pool/inner.rs | 490 ++++++++++++------------ sqlx-core/src/pool/mod.rs | 57 +-- sqlx-core/src/rt/mod.rs | 175 ++++++++- sqlx-core/src/rt/rt_async_io/time.rs | 23 +- sqlx-core/src/rt/rt_tokio/mod.rs | 1 + sqlx-core/src/sync.rs | 52 ++- sqlx-test/src/lib.rs | 6 +- tests/postgres/postgres.rs | 7 +- 13 files changed, 1183 insertions(+), 346 deletions(-) create mode 100644 sqlx-core/src/ext/future.rs create mode 100644 sqlx-core/src/pool/connection_set.rs diff --git a/sqlx-core/src/ext/future.rs b/sqlx-core/src/ext/future.rs new file mode 100644 index 0000000000..138f800118 --- /dev/null +++ b/sqlx-core/src/ext/future.rs @@ -0,0 +1,38 @@ +use pin_project_lite::pin_project; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project! { + #[project = RaceProject] + pub struct Race { + #[pin] + left: L, + #[pin] + right: R, + } +} + +impl Future for Race +where + L: Future, + R: Future, +{ + type Output = Result; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(left) = this.left.as_mut().poll(cx) { + return Poll::Ready(Ok(left)); + } + + this.right.as_mut().poll(cx).map(Err) + } +} + +#[inline(always)] +pub fn race(left: L, right: R) -> Race { + Race { left, right } +} diff --git a/sqlx-core/src/ext/mod.rs b/sqlx-core/src/ext/mod.rs index 98059f8ca0..167c68560a 100644 --- a/sqlx-core/src/ext/mod.rs +++ b/sqlx-core/src/ext/mod.rs @@ -2,3 +2,5 @@ pub mod ustr; #[macro_use] pub mod async_stream; + +pub mod future; diff --git a/sqlx-core/src/pool/connect.rs b/sqlx-core/src/pool/connect.rs index 52920c6a67..5adf82bf9a 100644 --- a/sqlx-core/src/pool/connect.rs +++ b/sqlx-core/src/pool/connect.rs @@ -14,7 +14,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::{Arc, Mutex, RwLock}; use std::time::Instant; -use crate::pool::shard::DisconnectedSlot; +use crate::pool::connection_set::DisconnectedSlot; #[cfg(doc)] use crate::pool::PoolOptions; use crate::sync::{AsyncMutex, AsyncMutexGuard}; @@ -646,7 +646,7 @@ async fn connect_with_backoff( match res { ControlFlow::Break(Ok(conn)) => { - tracing::trace!( + tracing::debug!( target: "sqlx::pool::connect", %connection_id, attempt, @@ -654,18 +654,16 @@ async fn connect_with_backoff( "connection established", ); - return Ok(PoolConnection::new( - slot.put(ConnectionInner { - raw: conn, - id: connection_id, - created_at: now, - last_released_at: now, - }), - pool.0.clone(), - )); + return Ok(PoolConnection::new(slot.put(ConnectionInner { + pool: Arc::downgrade(&pool.0), + raw: conn, + id: connection_id, + created_at: now, + last_released_at: now, + }))); } ControlFlow::Break(Err(e)) => { - tracing::warn!( + tracing::error!( target: "sqlx::pool::connect", %connection_id, attempt, diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 8d115818f4..1103374cca 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,33 +1,35 @@ use std::fmt::{self, Debug, Formatter}; use std::future::{self, Future}; +use std::io; use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; use crate::connection::Connection; use crate::database::Database; use crate::error::Error; -use super::inner::{is_beyond_max_lifetime, PoolInner}; +use super::inner::PoolInner; use crate::pool::connect::{ConnectPermit, ConnectTaskShared, ConnectionId}; +use crate::pool::connection_set::{ConnectedSlot, DisconnectedSlot}; use crate::pool::options::PoolConnectionMetadata; -use crate::pool::shard::{ConnectedSlot, DisconnectedSlot}; -use crate::pool::Pool; +use crate::pool::{Pool, PoolOptions}; use crate::rt; const RETURN_TO_POOL_TIMEOUT: Duration = Duration::from_secs(5); -const CLOSE_ON_DROP_TIMEOUT: Duration = Duration::from_secs(5); +const CLOSE_TIMEOUT: Duration = Duration::from_secs(5); /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. pub struct PoolConnection { conn: Option>>, - pub(crate) pool: Arc>, close_on_drop: bool, } pub(super) struct ConnectionInner { + // Note: must be `Weak` to prevent a reference cycle + pub(crate) pool: Weak>, pub(super) raw: DB::Connection, pub(super) id: ConnectionId, pub(super) created_at: Instant, @@ -72,11 +74,10 @@ impl AsMut for PoolConnection { } impl PoolConnection { - pub(super) fn new(live: ConnectedSlot>, pool: Arc>) -> Self { + pub(super) fn new(live: ConnectedSlot>) -> Self { Self { conn: Some(live), close_on_drop: false, - pool, } } @@ -140,13 +141,16 @@ impl PoolConnection { #[doc(hidden)] pub fn return_to_pool(&mut self) -> impl Future + Send + 'static { let conn = self.conn.take(); - let pool = self.pool.clone(); async move { let Some(conn) = conn else { return; }; + let Some(pool) = Weak::upgrade(&conn.pool) else { + return; + }; + rt::timeout(RETURN_TO_POOL_TIMEOUT, return_to_pool(conn, &pool)) .await // Dropping of the `slot` will check if the connection must be re-established @@ -161,7 +165,7 @@ impl PoolConnection { async move { if let Some(conn) = conn { // Don't hold the connection forever if it hangs while trying to close - rt::timeout(CLOSE_ON_DROP_TIMEOUT, close(conn)).await.ok(); + rt::timeout(CLOSE_TIMEOUT, close(conn)).await.ok(); } } } @@ -195,7 +199,7 @@ impl Drop for PoolConnection { } // We still need to spawn a task to maintain `min_connections`. - if self.conn.is_some() || self.pool.options.min_connections > 0 { + if self.conn.is_some() { crate::rt::spawn(self.return_to_pool()); } } @@ -220,6 +224,48 @@ impl ConnectionInner { idle_for: now.saturating_duration_since(self.last_released_at), } } + + pub fn is_beyond_max_lifetime(&self, options: &PoolOptions) -> bool { + if let Some(max_lifetime) = options.max_lifetime { + let age = self.created_at.elapsed(); + + if age > max_lifetime { + tracing::info!( + target: "sqlx::pool", + connection_id=%self.id, + ?age, + "connection is beyond `max_lifetime`, closing" + ); + + return true; + } + } + + false + } + + pub fn is_beyond_idle_timeout(&self, options: &PoolOptions) -> bool { + if let Some(idle_timeout) = options.idle_timeout { + let now = Instant::now(); + + let age = now.duration_since(self.created_at); + let idle_duration = now.duration_since(self.last_released_at); + + if idle_duration > idle_timeout { + tracing::info!( + target: "sqlx::pool", + connection_id=%self.id, + ?age, + ?idle_duration, + "connection is beyond `idle_timeout`, closing" + ); + + return true; + } + } + + false + } } pub(crate) async fn close( @@ -231,14 +277,19 @@ pub(crate) async fn close( let (conn, slot) = ConnectedSlot::take(conn); - let res = conn.raw.close().await.inspect_err(|error| { - tracing::debug!( - target: "sqlx::pool", - %connection_id, - %error, - "error occurred while closing the pool connection" - ); - }); + let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close()) + .await + .unwrap_or_else(|_| { + Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into()) + }) + .inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); (res, slot) } @@ -255,14 +306,19 @@ pub(crate) async fn close_hard( let (conn, slot) = ConnectedSlot::take(conn); - let res = conn.raw.close_hard().await.inspect_err(|error| { - tracing::debug!( - target: "sqlx::pool", - %connection_id, - %error, - "error occurred while closing the pool connection" - ); - }); + let res = rt::timeout(CLOSE_TIMEOUT, conn.raw.close_hard()) + .await + .unwrap_or_else(|_| { + Err(io::Error::new(io::ErrorKind::TimedOut, "timed out sending close packet").into()) + }) + .inspect_err(|error| { + tracing::debug!( + target: "sqlx::pool", + %connection_id, + %error, + "error occurred while closing the pool connection" + ); + }); (res, slot) } @@ -282,7 +338,7 @@ async fn return_to_pool( // If the connection is beyond max lifetime, close the connection and // immediately create a new connection - if is_beyond_max_lifetime(&conn, &pool.options) { + if conn.is_beyond_max_lifetime(&pool.options) { let (_res, slot) = close(conn).await; return Err(slot); } @@ -314,6 +370,7 @@ async fn return_to_pool( // to recover from cancellations if let Err(error) = conn.raw.ping().await { tracing::warn!( + target: "sqlx::pool", %error, "error occurred while testing the connection on-release", ); diff --git a/sqlx-core/src/pool/connection_set.rs b/sqlx-core/src/pool/connection_set.rs new file mode 100644 index 0000000000..8683f8a902 --- /dev/null +++ b/sqlx-core/src/pool/connection_set.rs @@ -0,0 +1,543 @@ +use crate::ext::future::race; +use crate::rt; +use crate::sync::{AsyncMutex, AsyncMutexGuardArc}; +use event_listener::{listener, Event, EventListener, IntoNotification}; +use futures_core::Stream; +use futures_util::stream::FuturesUnordered; +use futures_util::{FutureExt, StreamExt}; +use std::cmp; +use std::future::Future; +use std::ops::{Deref, DerefMut, RangeInclusive, RangeToInclusive}; +use std::pin::{pin, Pin}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; + +pub struct ConnectionSet { + global: Arc, + slots: Box<[Arc>]>, +} + +pub struct ConnectedSlot(SlotGuard); + +pub struct DisconnectedSlot(SlotGuard); + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +enum AcquirePreference { + Connected, + Disconnected, + Either, +} + +struct Global { + unlock_event: Event, + disconnect_event: Event, + num_connected: AtomicUsize, + min_connections: usize, + min_connections_event: Event<()>, +} + +struct SlotGuard { + slot: Arc>, + // `Option` allows us to take the guard in the drop handler. + locked: Option>>, +} + +struct Slot { + // By having each `Slot` hold its own reference to `Global`, we can avoid extra contended clones + // which would sap performance + global: Arc, + index: usize, + // I'd love to eliminate this redundant `Arc` but it's likely not possible without `unsafe` + connection: Arc>>, + unlock_event: Event, + disconnect_event: Event, + connected: AtomicBool, + locked: AtomicBool, + leaked: AtomicBool, +} + +impl ConnectionSet { + pub fn new(size: RangeInclusive) -> Self { + let global = Arc::new(Global { + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + num_connected: AtomicUsize::new(0), + min_connections: *size.start(), + min_connections_event: Event::with_tag(), + }); + + ConnectionSet { + // `vec![; size].into()` clones `` instead of repeating it, + // which is *no bueno* when wrapping something in `Arc` + slots: (0..*size.end()) + .map(|index| { + Arc::new(Slot { + global: global.clone(), + index, + connection: Arc::new(AsyncMutex::new(None)), + unlock_event: Event::with_tag(), + disconnect_event: Event::with_tag(), + connected: AtomicBool::new(false), + locked: AtomicBool::new(false), + leaked: AtomicBool::new(false), + }) + }) + .collect(), + global, + } + } + + #[inline(always)] + pub fn num_connected(&self) -> usize { + self.global.num_connected() + } + + pub fn count_idle(&self) -> usize { + self.slots.iter().filter(|slot| slot.is_locked()).count() + } + + pub async fn acquire_connected(&self) -> ConnectedSlot { + self.acquire_inner(AcquirePreference::Connected) + .await + .assert_connected() + } + + pub async fn acquire_disconnected(&self) -> DisconnectedSlot { + self.acquire_inner(AcquirePreference::Disconnected) + .await + .assert_disconnected() + } + + /// Attempt to acquire the connection associated with the current thread. + pub async fn acquire_any(&self) -> Result, DisconnectedSlot> { + self.acquire_inner(AcquirePreference::Either) + .await + .try_connected() + } + + async fn acquire_inner(&self, pref: AcquirePreference) -> SlotGuard { + /// Smallest time-step supported by [`tokio::time::sleep()`]. + /// + /// `async-io` doesn't document a minimum time-step, instead deferring to the platform. + const STEP_INTERVAL: Duration = Duration::from_millis(1); + + const SEARCH_LIMIT: usize = 5; + + let preferred_slot = current_thread_id() % self.slots.len(); + + tracing::trace!(preferred_slot, ?pref, "acquire_inner"); + + // Always try to lock the connection associated with our thread ID + let mut acquire_preferred = pin!(self.slots[preferred_slot].acquire(pref)); + + let mut step_interval = pin!(rt::interval_after(STEP_INTERVAL)); + + let mut intervals_elapsed = 0usize; + + let mut search_slots = FuturesUnordered::new(); + + let mut listen_global = pin!(self.global.listen(pref)); + + let mut search_slot = self.next_slot(preferred_slot); + + std::future::poll_fn(|cx| loop { + if let Poll::Ready(locked) = acquire_preferred.as_mut().poll(cx) { + return Poll::Ready(locked); + } + + // Don't push redundant futures for small sets. + let search_limit = cmp::min(SEARCH_LIMIT, self.slots.len()); + + if search_slots.len() < search_limit && step_interval.as_mut().poll_tick(cx).is_ready() + { + intervals_elapsed = intervals_elapsed.saturating_add(1); + + if search_slot != preferred_slot && self.slots[search_slot].matches_pref(pref) { + search_slots.push(self.slots[search_slot].lock()); + } + + search_slot = self.next_slot(search_slot); + } + + if let Poll::Ready(Some(locked)) = Pin::new(&mut search_slots).poll_next(cx) { + if locked.matches_pref(pref) { + return Poll::Ready(locked); + } + + continue; + } + + if intervals_elapsed > search_limit && search_slots.len() < search_limit { + if let Poll::Ready(slot) = listen_global.as_mut().poll(cx) { + if self.slots[slot].matches_pref(pref) { + search_slots.push(self.slots[slot].lock()); + } + + listen_global.as_mut().set(self.global.listen(pref)); + continue; + } + } + + return Poll::Pending; + }) + .await + } + + pub fn try_acquire_connected(&self) -> Option> { + Some( + self.try_acquire(AcquirePreference::Connected)? + .assert_connected(), + ) + } + + pub fn try_acquire_disconnected(&self) -> Option> { + Some( + self.try_acquire(AcquirePreference::Disconnected)? + .assert_disconnected(), + ) + } + + fn try_acquire(&self, pref: AcquirePreference) -> Option> { + let mut search_slot = current_thread_id() % self.slots.len(); + + for _ in 0..self.slots.len() { + if let Some(locked) = self.slots[search_slot].try_acquire(pref) { + return Some(locked); + } + + search_slot = self.next_slot(search_slot); + } + + None + } + + pub fn min_connections_listener(&self) -> EventListener { + self.global.min_connections_event.listen() + } + + pub fn iter_idle(&self) -> impl Iterator> + '_ { + self.slots.iter().filter_map(|slot| { + Some( + slot.try_acquire(AcquirePreference::Connected)? + .assert_connected(), + ) + }) + } + + pub async fn drain(&self, ref close: impl AsyncFn(ConnectedSlot) -> DisconnectedSlot) { + let mut closing = FuturesUnordered::new(); + + // We could try to be more efficient by only populating the `FuturesUnordered` for + // connected slots, but then we'd have to handle a disconnected slot becoming connected, + // which could happen concurrently. + // + // However, we don't *need* to be efficient when shutting down the pool. + for slot in &self.slots { + closing.push(async { + let locked = slot.lock().await; + + let slot = match locked.try_connected() { + Ok(connected) => close(connected).await, + Err(disconnected) => disconnected, + }; + + // The pool is shutting down; don't wake any tasks that might have been interested + slot.leak(); + }); + } + + while closing.next().await.is_some() {} + } + + #[inline(always)] + fn next_slot(&self, slot: usize) -> usize { + // By adding a number that is coprime to `slots.len()` before taking the modulo, + // we can visit each slot in a pseudo-random order, spreading the demand evenly. + // + // Interestingly, this pattern returns to the original slot after `slots.len()` iterations, + // because of congruence: https://en.wikipedia.org/wiki/Modular_arithmetic#Congruence + (slot + 547) % self.slots.len() + } +} + +impl AcquirePreference { + #[inline(always)] + fn wants_connected(&self, is_connected: bool) -> bool { + match (self, is_connected) { + (Self::Connected, true) => true, + (Self::Disconnected, false) => true, + (Self::Either, _) => true, + _ => false, + } + } +} + +impl Slot { + #[inline(always)] + fn matches_pref(&self, pref: AcquirePreference) -> bool { + !self.is_leaked() && pref.wants_connected(self.is_connected()) + } + + #[inline(always)] + fn is_connected(&self) -> bool { + self.connected.load(Ordering::Relaxed) + } + + #[inline(always)] + fn is_locked(&self) -> bool { + self.locked.load(Ordering::Relaxed) + } + + #[inline(always)] + fn is_leaked(&self) -> bool { + self.leaked.load(Ordering::Relaxed) + } + + #[inline(always)] + fn set_is_connected(&self, connected: bool) { + let was_connected = self.connected.swap(connected, Ordering::Acquire); + + match (connected, was_connected) { + (false, true) => { + // Ensure this is synchronized with `connected` + self.global.num_connected.fetch_add(1, Ordering::Release); + } + (true, false) => { + self.global.num_connected.fetch_sub(1, Ordering::Release); + } + _ => (), + } + } + + async fn acquire(self: &Arc, pref: AcquirePreference) -> SlotGuard { + loop { + if self.matches_pref(pref) { + tracing::trace!(slot_index=%self.index, "waiting for lock"); + + let locked = self.lock().await; + + if locked.matches_pref(pref) { + return locked; + } + } + + match pref { + AcquirePreference::Connected => { + listener!(self.unlock_event => listener); + listener.await; + } + AcquirePreference::Disconnected => { + listener!(self.disconnect_event => listener); + listener.await + } + AcquirePreference::Either => { + listener!(self.unlock_event => unlock_listener); + listener!(self.disconnect_event => disconnect_listener); + race(unlock_listener, disconnect_listener).await.ok(); + } + } + } + } + + fn try_acquire(self: &Arc, pref: AcquirePreference) -> Option> { + if self.matches_pref(pref) { + let locked = self.try_lock()?; + + if locked.matches_pref(pref) { + return Some(locked); + } + } + + None + } + + async fn lock(self: &Arc) -> SlotGuard { + let locked = crate::sync::lock_arc(&self.connection).await; + + self.locked.store(true, Ordering::Relaxed); + + SlotGuard { + slot: self.clone(), + locked: Some(locked), + } + } + + fn try_lock(self: &Arc) -> Option> { + let locked = crate::sync::try_lock_arc(&self.connection)?; + + self.locked.store(true, Ordering::Relaxed); + + Some(SlotGuard { + slot: self.clone(), + locked: Some(locked), + }) + } +} + +impl SlotGuard { + #[inline(always)] + fn get(&self) -> &Option { + self.locked.as_ref().expect(EXPECT_LOCKED) + } + + #[inline(always)] + fn get_mut(&mut self) -> &mut Option { + self.locked.as_mut().expect(EXPECT_LOCKED) + } + + #[inline(always)] + fn matches_pref(&self, pref: AcquirePreference) -> bool { + !self.slot.is_leaked() && pref.wants_connected(self.is_connected()) + } + + #[inline(always)] + fn is_connected(&self) -> bool { + self.get().is_some() + } + + fn try_connected(self) -> Result, DisconnectedSlot> { + if self.is_connected() { + Ok(ConnectedSlot(self)) + } else { + Err(DisconnectedSlot(self)) + } + } + + fn assert_connected(self) -> ConnectedSlot { + assert!(self.is_connected()); + ConnectedSlot(self) + } + + fn assert_disconnected(self) -> DisconnectedSlot { + assert!(!self.is_connected()); + + DisconnectedSlot(self) + } + + /// Updates `Slot::connected` without notifying the `ConnectionSet`. + /// + /// Returns `Some(connected)` or `None` if this guard was already dropped. + fn drop_without_notify(&mut self) -> Option { + self.locked.take().map(|locked| { + let connected = locked.is_some(); + self.slot.set_is_connected(connected); + self.slot.locked.store(false, Ordering::Release); + connected + }) + } +} + +const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation"; +const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`"; + +impl ConnectedSlot { + pub fn take(mut self) -> (C, DisconnectedSlot) { + let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED); + (conn, self.0.assert_disconnected()) + } +} + +impl Deref for ConnectedSlot { + type Target = C; + + #[inline(always)] + fn deref(&self) -> &Self::Target { + self.0.get().as_ref().expect(EXPECT_CONNECTED) + } +} + +impl DerefMut for ConnectedSlot { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.get_mut().as_mut().expect(EXPECT_CONNECTED) + } +} + +impl DisconnectedSlot { + pub fn put(mut self, conn: C) -> ConnectedSlot { + *self.0.get_mut() = Some(conn); + ConnectedSlot(self.0) + } + + pub fn leak(mut self) { + self.0.slot.leaked.store(true, Ordering::Release); + self.0.drop_without_notify(); + } +} + +impl Drop for SlotGuard { + fn drop(&mut self) { + let Some(connected) = self.drop_without_notify() else { + return; + }; + + let event = if connected { + &self.slot.global.unlock_event + } else { + &self.slot.global.disconnect_event + }; + + if event.notify(1.tag(self.slot.index).additional()) != 0 { + return; + } + + if connected { + self.slot.unlock_event.notify(1); + return; + } + + if self.slot.disconnect_event.notify(1) != 0 { + return; + } + + if self.slot.global.num_connected() < self.slot.global.min_connections { + self.slot.global.min_connections_event.notify(1); + } + } +} + +impl Global { + #[inline(always)] + fn num_connected(&self) -> usize { + self.num_connected.load(Ordering::Relaxed) + } + + async fn listen(&self, pref: AcquirePreference) -> usize { + match pref { + AcquirePreference::Either => race(self.listen_unlocked(), self.listen_disconnected()) + .await + .unwrap_or_else(|slot| slot), + AcquirePreference::Connected => self.listen_unlocked().await, + AcquirePreference::Disconnected => self.listen_disconnected().await, + } + } + + async fn listen_unlocked(&self) -> usize { + listener!(self.unlock_event => listener); + listener.await + } + + async fn listen_disconnected(&self) -> usize { + listener!(self.disconnect_event => listener); + listener.await + } +} + +fn current_thread_id() -> usize { + // FIXME: this can be replaced when this is stabilized: + // https://doc.rust-lang.org/stable/std/thread/struct.ThreadId.html#method.as_u64 + static THREAD_ID: AtomicUsize = AtomicUsize::new(0); + + thread_local! { + // `SeqCst` is possibly too strong since we don't need synchronization with + // any other variable. I'm not confident enough in my understanding of atomics to be certain, + // especially with regards to weakly ordered architectures. + // + // However, this is literally only done once on each thread, so it doesn't really matter. + static CURRENT_THREAD_ID: usize = THREAD_ID.fetch_add(1, Ordering::SeqCst); + } + + CURRENT_THREAD_ID.with(|i| *i) +} diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index 046834a4c2..1ae687f1d1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -5,33 +5,30 @@ use crate::pool::{connection, CloseEvent, Pool, PoolConnection, PoolConnector, P use std::cmp; use std::future::Future; +use std::ops::ControlFlow; use std::pin::{pin, Pin}; -use std::rc::Weak; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; -use std::task::{ready, Poll}; +use std::sync::{Arc, Weak}; +use std::task::{Context, Poll}; use crate::connection::Connection; +use crate::ext::future::race; use crate::logger::private_level_filter_to_trace_level; -use crate::pool::connect::{ - ConnectPermit, ConnectTask, ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector, -}; -use crate::pool::shard::{ConnectedSlot, DisconnectedSlot, Sharded}; -use crate::rt::JoinHandle; +use crate::pool::connect::{ConnectTaskShared, ConnectionCounter, ConnectionId, DynConnector}; +use crate::pool::connection_set::{ConnectedSlot, ConnectionSet, DisconnectedSlot}; use crate::{private_tracing_dynamic_event, rt}; -use either::Either; -use futures_core::FusedFuture; -use futures_util::future::{self, OptionFuture}; -use futures_util::{stream, FutureExt, TryStreamExt}; +use event_listener::listener; +use futures_util::future::{self}; use std::time::{Duration, Instant}; use tracing::Level; const GRACEFUL_CLOSE_TIMEOUT: Duration = Duration::from_secs(5); +const TEST_BEFORE_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(60); pub(crate) struct PoolInner { pub(super) connector: DynConnector, pub(super) counter: ConnectionCounter, - pub(super) sharded: Sharded>, + pub(super) connections: ConnectionSet>, is_closed: AtomicBool, pub(super) on_closed: event_listener::Event, pub(super) options: PoolOptions, @@ -44,39 +41,15 @@ impl PoolInner { options: PoolOptions, connector: impl PoolConnector, ) -> Arc { - let pool = Arc::::new_cyclic(|pool_weak| { - let pool_weak = pool_weak.clone(); - - let reconnect = move |slot| { - let Some(pool) = pool_weak.upgrade() else { - // Prevent an infinite loop on pool drop. - DisconnectedSlot::leak(slot); - return; - }; - - pool.connector.connect( - Pool(pool.clone()), - ConnectionId::next(), - slot, - ConnectTaskShared::new_arc(), - ); - }; - - Self { - connector: DynConnector::new(connector), - counter: ConnectionCounter::new(), - sharded: Sharded::new( - options.max_connections, - options.shards, - options.min_connections, - reconnect, - ), - is_closed: AtomicBool::new(false), - on_closed: event_listener::Event::new(), - acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), - acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), - options, - } + let pool = Arc::new(Self { + connector: DynConnector::new(connector), + counter: ConnectionCounter::new(), + connections: ConnectionSet::new(options.min_connections..=options.max_connections), + is_closed: AtomicBool::new(false), + on_closed: event_listener::Event::new(), + acquire_time_level: private_level_filter_to_trace_level(options.acquire_time_level), + acquire_slow_level: private_level_filter_to_trace_level(options.acquire_slow_level), + options, }); spawn_maintenance_tasks(&pool); @@ -85,11 +58,11 @@ impl PoolInner { } pub(super) fn size(&self) -> usize { - self.sharded.count_connected() + self.connections.num_connected() } pub(super) fn num_idle(&self) -> usize { - self.sharded.count_unlocked(true) + self.connections.count_idle() } pub(super) fn is_closed(&self) -> bool { @@ -105,11 +78,8 @@ impl PoolInner { self.mark_closed(); // Keep clearing the idle queue as connections are released until the count reaches zero. - self.sharded.drain(|slot| async move { - let (conn, slot) = ConnectedSlot::take(slot); - - let _ = rt::timeout(GRACEFUL_CLOSE_TIMEOUT, conn.raw.close()).await; - + self.connections.drain(async |slot| { + let (_res, slot) = connection::close(slot).await; slot }) } @@ -130,7 +100,7 @@ impl PoolInner { return None; } - self.sharded.try_acquire_connected() + self.connections.try_acquire_connected() } pub(super) async fn acquire(self: &Arc) -> Result, Error> { @@ -140,74 +110,43 @@ impl PoolInner { let acquire_started_at = Instant::now(); - let mut close_event = pin!(self.close_event()); - let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); - - let connect_shared = ConnectTaskShared::new_arc(); - - let mut acquire_connected = pin!(self.acquire_connected().fuse()); + // Lazily allocated `Arc` + let mut connect_shared = None; - let mut acquire_disconnected = pin!(self.sharded.acquire_disconnected().fuse()); + let res = { + // Pinned to the stack without allocating + listener!(self.on_closed => close_listener); + let mut deadline = pin!(rt::sleep(self.options.acquire_timeout)); + let mut acquire_inner = pin!(self.acquire_inner(&mut connect_shared)); - let mut connect = future::Fuse::terminated(); - - let acquired = std::future::poll_fn(|cx| loop { - if let Poll::Ready(()) = close_event.as_mut().poll(cx) { - return Poll::Ready(Err(Error::PoolClosed)); - } - - if let Poll::Ready(res) = acquire_connected.as_mut().poll(cx) { - match res { - Ok(conn) => { - return Poll::Ready(Ok(conn)); - } - Err(slot) => { - if connect.is_terminated() { - connect = self - .connector - .connect( - Pool(self.clone()), - ConnectionId::next(), - slot, - connect_shared.clone(), - ) - .fuse(); - } - - // Try to acquire another connected connection. - acquire_connected.set(self.acquire_connected().fuse()); - continue; - } + std::future::poll_fn(|cx| { + if self.is_closed() { + return Poll::Ready(Err(Error::PoolClosed)); } - } - if let Poll::Ready(slot) = acquire_disconnected.as_mut().poll(cx) { - if connect.is_terminated() { - connect = self - .connector - .connect( - Pool(self.clone()), - ConnectionId::next(), - slot, - connect_shared.clone(), - ) - .fuse(); - } - } + // The result doesn't matter so much as the wakeup + let _ = Pin::new(&mut close_listener).poll(cx); - if let Poll::Ready(res) = Pin::new(&mut connect).poll(cx) { - return Poll::Ready(res); - } + if let Poll::Ready(()) = deadline.as_mut().poll(cx) { + return Poll::Ready(Err(Error::PoolTimedOut { + last_connect_error: None, + })); + } - if let Poll::Ready(()) = deadline.as_mut().poll(cx) { - return Poll::Ready(Err(Error::PoolTimedOut { - last_connect_error: connect_shared.take_error().map(Box::new), - })); - } + acquire_inner.as_mut().poll(cx) + }) + .await + }; - return Poll::Pending; - }) - .await?; + let acquired = res.map_err(|e| match e { + Error::PoolTimedOut { + last_connect_error: None, + } => Error::PoolTimedOut { + last_connect_error: connect_shared + .and_then(|shared| Some(shared.take_error()?.into())), + }, + e => e, + })?; let acquired_after = acquire_started_at.elapsed(); @@ -235,20 +174,36 @@ impl PoolInner { Ok(acquired) } - async fn acquire_connected( + async fn acquire_inner( self: &Arc, - ) -> Result, DisconnectedSlot>> { - let connected = self.sharded.acquire_connected().await; + connect_shared: &mut Option>, + ) -> Result, Error> { + tracing::trace!("waiting for any connection"); + + let disconnected = match self.connections.acquire_any().await { + Ok(conn) => match finish_acquire(self, conn).await { + Ok(conn) => return Ok(conn), + Err(slot) => slot, + }, + Err(slot) => slot, + }; - tracing::debug!( - target: "sqlx::pool", - connection_id=%connected.id, - "acquired idle connection" + let mut connect_task = self.connector.connect( + Pool(self.clone()), + ConnectionId::next(), + disconnected, + connect_shared.insert(ConnectTaskShared::new_arc()).clone(), ); - match finish_acquire(self, connected) { - Either::Left(task) => task.await, - Either::Right(conn) => Ok(conn), + loop { + match race(&mut connect_task, self.connections.acquire_connected()).await { + Ok(Ok(conn)) => return Ok(conn), + Ok(Err(e)) => return Err(e), + Err(conn) => match finish_acquire(self, conn).await { + Ok(conn) => return Ok(conn), + Err(_) => continue, + }, + } } } @@ -258,17 +213,20 @@ impl PoolInner { ) -> Result<(), Error> { let shared = ConnectTaskShared::new_arc(); - let connect_min_connections = - future::try_join_all(self.sharded.iter_min_connections().map(|slot| { - self.connector.connect( - Pool(self.clone()), - ConnectionId::next(), - slot, - shared.clone(), - ) - })); - - let mut conns = if let Some(deadline) = deadline { + let connect_min_connections = future::try_join_all( + (self.connections.num_connected()..self.options.min_connections) + .filter_map(|_| self.connections.try_acquire_disconnected()) + .map(|slot| { + self.connector.connect( + Pool(self.clone()), + ConnectionId::next(), + slot, + shared.clone(), + ) + }), + ); + + let conns = if let Some(deadline) = deadline { match rt::timeout_at(deadline, connect_min_connections).await { Ok(Ok(conns)) => conns, Err(_) | Ok(Err(Error::PoolTimedOut { .. })) => { @@ -297,144 +255,192 @@ impl Drop for PoolInner { } } -/// Returns `true` if the connection has exceeded `options.max_lifetime` if set, `false` otherwise. -pub(super) fn is_beyond_max_lifetime( - live: &ConnectionInner, - options: &PoolOptions, -) -> bool { - options - .max_lifetime - .is_some_and(|max| live.created_at.elapsed() > max) -} - -/// Returns `true` if the connection has exceeded `options.idle_timeout` if set, `false` otherwise. -fn is_beyond_idle_timeout( - idle: &ConnectionInner, - options: &PoolOptions, -) -> bool { - options - .idle_timeout - .is_some_and(|timeout| idle.last_released_at.elapsed() > timeout) -} - /// Execute `test_before_acquire` and/or `before_acquire` in a background task, if applicable. /// /// Otherwise, immediately returns the connection. -fn finish_acquire( +async fn finish_acquire( pool: &Arc>, mut conn: ConnectedSlot>, -) -> Either< - JoinHandle, DisconnectedSlot>>>, - PoolConnection, -> { - if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { - let pool = pool.clone(); +) -> Result, DisconnectedSlot>> { + struct SpawnOnDrop(Option>>) + where + F::Output: Send + 'static; + + impl Future for SpawnOnDrop + where + F::Output: Send + 'static, + { + type Output = F::Output; + + #[inline(always)] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.0 + .as_mut() + .expect("BUG: inner future taken") + .as_mut() + .poll(cx) + } + } - // Spawn a task so the call may complete even if `acquire()` is cancelled. - return Either::Left(rt::spawn(async move { - // Check that the connection is still live + impl Drop for SpawnOnDrop + where + F::Output: Send + 'static, + { + fn drop(&mut self) { + rt::try_spawn(self.0.take().expect("BUG: inner future taken")); + } + } + + async fn finish_inner( + conn: &mut ConnectedSlot>, + pool: &PoolInner, + ) -> ControlFlow<()> { + // Check that the connection is still live + if pool.options.test_before_acquire { if let Err(error) = conn.raw.ping().await { // an error here means the other end has hung up or we lost connectivity // either way we're fine to just discard the connection // the error itself here isn't necessarily unexpected so WARN is too strong tracing::info!(%error, connection_id=%conn.id, "ping on idle connection returned error"); - - // connection is broken so don't try to close nicely - let (_res, slot) = connection::close_hard(conn).await; - return Err(slot); + return ControlFlow::Break(()); } + } + + if let Some(test) = &pool.options.before_acquire { + let meta = conn.idle_metadata(); + match test(&mut conn.raw, meta).await { + Ok(false) => { + // connection was rejected by user-defined hook, close nicely + tracing::debug!(connection_id=%conn.id, "connection rejected by `before_acquire`"); + return ControlFlow::Break(()); + } + + Err(error) => { + tracing::warn!(%error, "error from `before_acquire`"); + return ControlFlow::Break(()); + } - if let Some(test) = &pool.options.before_acquire { - let meta = conn.idle_metadata(); - match test(&mut conn.raw, meta).await { - Ok(false) => { - // connection was rejected by user-defined hook, close nicely - let (_res, slot) = connection::close(conn).await; - return Err(slot); - } + Ok(true) => (), + } + } - Err(error) => { - tracing::warn!(%error, "error from `before_acquire`"); + // Checks passed + ControlFlow::Continue(()) + } - // connection is broken so don't try to close nicely - let (_res, slot) = connection::close_hard(conn).await; - return Err(slot); - } + if pool.options.test_before_acquire || pool.options.before_acquire.is_some() { + let pool = pool.clone(); - Ok(true) => {} + // Spawn a task on-drop so the call may complete even if `acquire()` is cancelled. + conn = SpawnOnDrop(Some(Box::pin(async move { + match rt::timeout(TEST_BEFORE_ACQUIRE_TIMEOUT, finish_inner(&mut conn, &pool)).await { + Ok(ControlFlow::Continue(())) => { + Ok(conn) + } + Ok(ControlFlow::Break(())) => { + // Connection rejected by user-defined hook, attempt to close nicely + let (_res, slot) = connection::close(conn).await; + Err(slot) + } + Err(_) => { + tracing::info!(connection_id=%conn.id, "`before_acquire` checks timed out, closing connection"); + let (_res, slot) = connection::close_hard(conn).await; + Err(slot) } } - - Ok(PoolConnection::new(conn, pool)) - })); + }))).await?; } - // No checks are configured, return immediately. - Either::Right(PoolConnection::new(conn, pool.clone())) + tracing::debug!( + target: "sqlx::pool", + connection_id=%conn.id, + "acquired idle connection" + ); + + Ok(PoolConnection::new(conn)) } fn spawn_maintenance_tasks(pool: &Arc>) { - // NOTE: use `pool_weak` for the maintenance tasks - // so they don't keep `PoolInner` from being dropped. - let pool_weak = Arc::downgrade(pool); + if pool.options.min_connections > 0 { + // NOTE: use `pool_weak` for the maintenance tasks + // so they don't keep `PoolInner` from being dropped. + let pool_weak = Arc::downgrade(pool); + let mut close_event = pool.close_event(); + + rt::spawn(async move { + close_event + .do_until(check_min_connections(pool_weak)) + .await + .ok(); + }); + } - let period = match (pool.options.max_lifetime, pool.options.idle_timeout) { + let check_interval = match (pool.options.max_lifetime, pool.options.idle_timeout) { (Some(it), None) | (None, Some(it)) => it, - (Some(a), Some(b)) => cmp::min(a, b), - - (None, None) => { - if pool.options.min_connections > 0 { - rt::spawn(async move { - if let Some(pool) = pool_weak.upgrade() { - if let Err(error) = pool.try_min_connections(None).await { - tracing::error!( - target: "sqlx::pool", - ?error, - "error maintaining min_connections" - ); - } - } - }); - } - - return; - } + (None, None) => return, }; - // Immediately cancel this task if the pool is closed. + let pool_weak = Arc::downgrade(pool); let mut close_event = pool.close_event(); rt::spawn(async move { let _ = close_event - .do_until(async { - // If the last handle to the pool was dropped while we were sleeping - while let Some(pool) = pool_weak.upgrade() { - if pool.is_closed() { - return; - } - - let next_run = Instant::now() + period; - - // Go over all idle connections, check for idleness and lifetime, - // and if we have fewer than min_connections after reaping a connection, - // open a new one immediately. - for conn in pool.sharded.iter_idle() { - if is_beyond_idle_timeout(&conn, &pool.options) - || is_beyond_max_lifetime(&conn, &pool.options) - { - // Dropping the slot will check if the connection needs to be - // re-made. - let _ = connection::close(conn).await; - } - } - - // Don't hold a reference to the pool while sleeping. - drop(pool); - - rt::sleep_until(next_run).await; - } - }) + .do_until(check_idle_conns(pool_weak, check_interval)) .await; }); } + +async fn check_idle_conns(pool_weak: Weak>, check_interval: Duration) { + let mut interval = pin!(rt::interval_after(check_interval)); + + while let Some(pool) = pool_weak.upgrade() { + if pool.is_closed() { + return; + } + + // Go over all idle connections, check for idleness and lifetime, + // and if we have fewer than min_connections after reaping a connection, + // open a new one immediately. + for conn in pool.connections.iter_idle() { + if conn.is_beyond_idle_timeout(&pool.options) + || conn.is_beyond_max_lifetime(&pool.options) + { + // Dropping the slot will check if the connection needs to be re-made. + let _ = connection::close(conn).await; + } + } + + // Don't hold a reference to the pool while sleeping. + drop(pool); + + interval.as_mut().tick().await; + } +} + +async fn check_min_connections(pool_weak: Weak>) { + while let Some(pool) = pool_weak.upgrade() { + if pool.is_closed() { + return; + } + + match pool.try_min_connections(None).await { + Ok(()) => { + let listener = pool.connections.min_connections_listener(); + + // Important: don't hold a strong ref while sleeping + drop(pool); + + listener.await; + } + Err(e) => { + tracing::warn!( + target: "sqlx::pool::maintenance", + min_connections=pool.options.min_connections, + num_connected=pool.connections.num_connected(), + "unable to maintain `min_connections`: {e:?}", + ); + } + } + } +} diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 7d2b18ed4c..224ee8ffb6 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -56,20 +56,19 @@ use std::fmt; use std::future::Future; -use std::pin::{pin, Pin}; +use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use event_listener::EventListener; -use futures_core::FusedFuture; -use futures_util::FutureExt; - use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::ext::future::race; use crate::sql_str::SqlSafeStr; use crate::transaction::Transaction; - +use event_listener::EventListener; +use futures_core::FusedFuture; +use tracing::Instrument; pub use self::connect::{PoolConnectMetadata, PoolConnector}; pub use self::connection::PoolConnection; use self::inner::PoolInner; @@ -90,7 +89,9 @@ mod inner; // mod idle; mod options; -mod shard; +// mod shard; + +mod connection_set; /// An asynchronous pool of SQLx database connections. /// @@ -362,16 +363,21 @@ impl Pool { pub fn acquire(&self) -> impl Future, Error>> + 'static { let shared = self.0.clone(); async move { shared.acquire().await } + .instrument(tracing::error_span!(target: "sqlx::pool", "acquire")) } /// Attempts to retrieve a connection from the pool if there is one available. /// /// Returns `None` immediately if there are no idle connections available in the pool /// or there are tasks waiting for a connection which have yet to wake. + /// + /// # Note: Bypasses `before_acquire` + /// Since this function is not `async`, it cannot await the future returned by + /// [`before_acquire`][PoolOptions::before_acquire] without blocking. + /// + /// Instead, it simply returns the connection immediately. pub fn try_acquire(&self) -> Option> { - self.0 - .try_acquire() - .map(|conn| PoolConnection::new(conn, self.0.clone())) + self.0.try_acquire().map(|conn| PoolConnection::new(conn)) } /// Retrieves a connection and immediately begins a new transaction. @@ -577,42 +583,19 @@ impl CloseEvent { /// /// Cancels the future and returns `Err(PoolClosed)` if/when the pool is closed. /// If the pool was already closed, the future is never run. + #[inline(always)] pub async fn do_until(&mut self, fut: Fut) -> Result { - // Check that the pool wasn't closed already. - // - // We use `poll_immediate()` as it will use the correct waker instead of - // a no-op one like `.now_or_never()`, but it won't actually suspend execution here. - futures_util::future::poll_immediate(&mut *self) - .await - .map_or(Ok(()), |_| Err(Error::PoolClosed))?; - - let mut fut = pin!(fut); - - // I find that this is clearer in intent than `futures_util::future::select()` - // or `futures_util::select_biased!{}` (which isn't enabled anyway). - std::future::poll_fn(|cx| { - // Poll `fut` first as the wakeup event is more likely for it than `self`. - if let Poll::Ready(ret) = fut.as_mut().poll(cx) { - return Poll::Ready(Ok(ret)); - } - - // Can't really factor out mapping to `Err(Error::PoolClosed)` though it seems like - // we should because that results in a different `Ok` type each time. - // - // Ideally we'd map to something like `Result` but using `!` as a type - // is not allowed on stable Rust yet. - self.poll_unpin(cx).map(|_| Err(Error::PoolClosed)) - }) - .await + race(fut, self).await.map_err(|_| Error::PoolClosed) } } impl Future for CloseEvent { type Output = (); + #[inline(always)] fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if let Some(listener) = &mut self.listener { - ready!(listener.poll_unpin(cx)); + ready!(Pin::new(listener).poll(cx)); } // `EventListener` doesn't like being polled after it yields, and even if it did it diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 985d9bb607..862cf6cedd 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -1,10 +1,13 @@ use std::future::Future; use std::marker::PhantomData; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{ready, Context, Poll}; use std::time::{Duration, Instant}; use cfg_if::cfg_if; +use futures_core::Stream; +use futures_util::StreamExt; +use pin_project_lite::pin_project; #[cfg(feature = "_rt-async-io")] pub mod rt_async_io; @@ -59,19 +62,13 @@ pub async fn timeout_at(deadline: Instant, f: F) -> Result Interval { + #[cfg(feature = "_rt-tokio")] + if rt_tokio::available() { + return Interval::Tokio { + sleep: tokio::time::sleep(period), + period, + }; + } + + cfg_if! { + if #[cfg(feature = "_rt-async-io")] { + Interval::AsyncIo { timer: async_io::Timer::interval(period) } + } else { + missing_rt(period) + } + } +} + +impl Interval { + #[inline(always)] + pub fn tick(mut self: Pin<&mut Self>) -> impl Future + use<'_> { + std::future::poll_fn(move |cx| self.as_mut().poll_tick(cx)) + } + + #[inline(always)] + pub fn as_timeout(self: Pin<&mut Self>, fut: F) -> AsTimeout<'_, F> { + AsTimeout { + interval: self, + future: fut, + } + } + + #[inline(always)] + pub fn poll_tick(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + cfg_if! { + if #[cfg(any(feature = "_rt-tokio", feature = "_rt-async-io"))] { + match self.project() { + #[cfg(feature = "_rt-tokio")] + IntervalProjected::Tokio { mut sleep, period } => { + ready!(sleep.as_mut().poll(cx)); + let now = Instant::now(); + sleep.reset((now + *period).into()); + Poll::Ready(now) + } + #[cfg(feature = "_rt-async-io")] + IntervalProjected::AsyncIo { mut timer } => { + Poll::Ready(ready!(timer + .as_mut() + .poll_next(cx)) + .expect("BUG: `async_io::Timer::next()` should always yield")) + } + } + } else { + unreachable!() + } + } + } +} + +pin_project! { + pub struct AsTimeout<'i, F> { + interval: Pin<&'i mut Interval>, + #[pin] + future: F, + } +} + +impl Future for AsTimeout<'_, F> +where + F: Future, +{ + type Output = Option; + + #[inline(always)] + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + + if let Poll::Ready(out) = this.future.poll(cx) { + return Poll::Ready(Some(out)); + } + + this.interval.as_mut().poll_tick(cx).map(|_| None) + } +} + #[track_caller] pub fn spawn(fut: F) -> JoinHandle where @@ -128,6 +254,29 @@ where } } +pub fn try_spawn(fut: F) -> Option> +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + #[cfg(feature = "_rt-tokio")] + if let Ok(handle) = tokio::runtime::Handle::try_current() { + return Some(JoinHandle::Tokio(handle.spawn(fut))); + } + + cfg_if! { + if #[cfg(feature = "_rt-async-global-executor")] { + Some(JoinHandle::AsyncTask(Some(async_global_executor::spawn(fut)))) + } else if #[cfg(feature = "_rt-smol")] { + Some(JoinHandle::AsyncTask(Some(smol::spawn(fut)))) + } else if #[cfg(feature = "_rt-async-std")] { + Some(JoinHandle::AsyncStd(async_std::task::spawn(fut))) + } else { + None + } + } +} + #[track_caller] pub fn spawn_blocking(f: F) -> JoinHandle where diff --git a/sqlx-core/src/rt/rt_async_io/time.rs b/sqlx-core/src/rt/rt_async_io/time.rs index 039610b758..dbe1d8f725 100644 --- a/sqlx-core/src/rt/rt_async_io/time.rs +++ b/sqlx-core/src/rt/rt_async_io/time.rs @@ -1,13 +1,10 @@ +use crate::ext::future::race; +use crate::rt::TimeoutError; use std::{ future::Future, - pin::pin, time::{Duration, Instant}, }; -use futures_util::future::{select, Either}; - -use crate::rt::TimeoutError; - pub async fn sleep(duration: Duration) { async_io::Timer::after(duration).await; } @@ -17,8 +14,16 @@ pub async fn sleep_until(deadline: Instant) { } pub async fn timeout(duration: Duration, future: F) -> Result { - match select(pin!(future), pin!(sleep(duration))).await { - Either::Left((result, _)) => Ok(result), - Either::Right(_) => Err(TimeoutError), - } + race(future, sleep(duration)) + .await + .map_err(|_| TimeoutError) +} + +pub async fn timeout_at( + deadline: Instant, + future: F, +) -> Result { + race(future, sleep_until(deadline)) + .await + .map_err(|_| TimeoutError) } diff --git a/sqlx-core/src/rt/rt_tokio/mod.rs b/sqlx-core/src/rt/rt_tokio/mod.rs index ce699456db..364ce3bfee 100644 --- a/sqlx-core/src/rt/rt_tokio/mod.rs +++ b/sqlx-core/src/rt/rt_tokio/mod.rs @@ -1,5 +1,6 @@ mod socket; +#[inline(always)] pub fn available() -> bool { tokio::runtime::Handle::try_current().is_ok() } diff --git a/sqlx-core/src/sync.rs b/sqlx-core/src/sync.rs index 2fd51445b3..bce8d60c0d 100644 --- a/sqlx-core/src/sync.rs +++ b/sqlx-core/src/sync.rs @@ -4,11 +4,40 @@ // We'll generally lean towards Tokio's types as those are more featureful // (including `tokio-console` support) and more widely deployed. +use std::sync::Arc; #[cfg(feature = "_rt-tokio")] -pub use tokio::sync::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; +pub use tokio::sync::{ + Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, OwnedMutexGuard as AsyncMutexGuardArc, + RwLock as AsyncRwLock, +}; #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] -pub use async_lock::{Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, RwLock as AsyncRwLock}; +pub use async_lock::{ + Mutex as AsyncMutex, MutexGuard as AsyncMutexGuard, MutexGuardArc as AsyncMutexGuardArc, + RwLock as AsyncRwLock, +}; + +pub async fn lock_arc(mutex: &Arc>) -> AsyncMutexGuardArc { + #[cfg(feature = "_rt-tokio")] + return mutex.clone().lock_owned().await; + + #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] + return mutex.lock_arc().await; + + #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] + return crate::rt::missing_rt(mutex); +} + +pub fn try_lock_arc(mutex: &Arc>) -> Option> { + #[cfg(feature = "_rt-tokio")] + return mutex.clone().try_lock_owned().ok(); + + #[cfg(all(feature = "_rt-async-lock", not(feature = "_rt-tokio")))] + return mutex.try_lock_arc(); + + #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] + return crate::rt::missing_rt(mutex); +} #[cfg(not(any(feature = "_rt-async-lock", feature = "_rt-tokio")))] pub use noop::*; @@ -18,6 +47,7 @@ mod noop { use crate::rt::missing_rt; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; + use std::sync::Arc; pub struct AsyncMutex { // `Sync` if `T: Send` @@ -28,6 +58,10 @@ mod noop { inner: &'a AsyncMutex, } + pub struct AsyncMutexGuardArc { + inner: Arc>, + } + impl AsyncMutex { pub fn new(val: T) -> Self { missing_rt(val) @@ -51,4 +85,18 @@ mod noop { missing_rt(self) } } + + impl Deref for AsyncMutexGuardArc { + type Target = T; + + fn deref(&self) -> &Self::Target { + missing_rt(self) + } + } + + impl DerefMut for AsyncMutexGuardArc { + fn deref_mut(&mut self) -> &mut Self::Target { + missing_rt(self) + } + } } diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 01cdc2977b..6a8b9d1120 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -2,13 +2,15 @@ use sqlx::pool::PoolOptions; use sqlx::{Connection, Database, Error, Pool}; use std::env; use tracing_subscriber::EnvFilter; +use tracing_subscriber::fmt::format::FmtSpan; pub fn setup_if_needed() { let _ = dotenvy::dotenv(); let _ = tracing_subscriber::fmt::Subscriber::builder() .with_env_filter(EnvFilter::from_default_env()) - .with_test_writer() - .finish(); + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + // .with_test_writer() + .try_init(); } // Make a new connection diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index 06adf0ca7f..d6bbebf2c5 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -255,6 +255,10 @@ async fn it_works_with_cache_disabled() -> anyhow::Result<()> { #[sqlx_macros::test] async fn it_executes_with_pool() -> anyhow::Result<()> { + setup_if_needed(); + + tracing::info!("starting test"); + let pool = sqlx_test::pool::().await?; let rows = pool.fetch_all("SELECT 1; SElECT 2").await?; @@ -1146,7 +1150,7 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { assert!(listener.next_buffered().is_none()); // Activate connection. - sqlx::query!("SELECT 1 AS one") + sqlx::query("SELECT 1 AS one") .fetch_all(&mut listener) .await?; @@ -2086,6 +2090,7 @@ async fn test_issue_3052() { } #[sqlx_macros::test] +#[cfg(feature = "chrono")] async fn test_bind_iter() -> anyhow::Result<()> { use sqlx::postgres::PgBindIterExt; use sqlx::types::chrono::{DateTime, Utc};